Class: Tensorflow::Graph::FunctionDef

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/graph/function_def.rb

Defined Under Namespace

Classes: Signature

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(ruby_method, input_signatures = []) ⇒ FunctionDef

Returns a new instance of FunctionDef.



8
9
10
11
12
# File 'lib/tensorflow/graph/function_def.rb', line 8

def initialize(ruby_method, input_signatures = [])
  @ruby_method = ruby_method
  self.process_signatures(ruby_method, input_signatures)
  self.wrap_ruby_method
end

Instance Attribute Details

#ruby_methodObject (readonly)

Returns the value of attribute ruby_method.



4
5
6
# File 'lib/tensorflow/graph/function_def.rb', line 4

def ruby_method
  @ruby_method
end

#signaturesObject (readonly)

Returns the value of attribute signatures.



4
5
6
# File 'lib/tensorflow/graph/function_def.rb', line 4

def signatures
  @signatures
end

Instance Method Details

#aliased_nameObject



24
25
26
# File 'lib/tensorflow/graph/function_def.rb', line 24

def aliased_name
  "#{self.ruby_method.original_name}_original"
end

#build_function(object) ⇒ Object



46
47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/tensorflow/graph/function_def.rb', line 46

def build_function(object)
  Graph::new.as_default do |graph|
    placeholders = self.ruby_method.parameters.map.with_index do |param, index|
      signature = self.signatures[index]
      Tensorflow.placeholder(signature.dtype, name: param.last, shape: signature.shape)
    end

    # Call the original ruby_method to build the graph
    bound_method = self.ruby_method.bind(object)
    result = bound_method.call(*placeholders)

    graph.to_function(self.ruby_method.original_name.to_s, nil, placeholders, Array(result))
  end
end

#process_signatures(ruby_method, input_signatures) ⇒ Object



14
15
16
17
18
19
20
21
22
# File 'lib/tensorflow/graph/function_def.rb', line 14

def process_signatures(ruby_method, input_signatures)
  if input_signatures.length != ruby_method.parameters.length
    raise(Error::InvalidArgumentError, "Must specify input signature for each method parameter")
  end

  @signatures = input_signatures.map do |dtype, shape|
    Signature.new(dtype, shape)
  end
end

#wrap_ruby_methodObject



28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/tensorflow/graph/function_def.rb', line 28

def wrap_ruby_method
  new_name = self.aliased_name
  original_name = self.ruby_method.original_name
  self.ruby_method.owner.instance_eval do
    alias_method(new_name, original_name)
  end

  this = self
  original_name = ruby_method.original_name
  self.ruby_method.owner.instance_eval do
    define_method(original_name) do |*args|
      function = this.build_function(self)
      ExecutionContext.current.add_function(function)
      function
    end
  end
end