Class: Tensorflow::Graph::OperationDescription

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/graph/operation_description.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(graph, op_type, inputs, attrs) ⇒ OperationDescription

Returns a new instance of OperationDescription.



6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# File 'lib/tensorflow/graph/operation_description.rb', line 6

def initialize(graph, op_type, inputs, attrs)
  @graph = graph
  @op_def = case op_type
              when Function
                op_type.function_def.signature
              else
                self.graph.op_def(op_type)
            end
  raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def

  raw_name = attrs.delete(:name)&.to_s || self.op_def.name
  @name = self.graph.scoped_name(raw_name)
  @pointer = FFI.TF_NewOperation(graph, self.op_def.name, @name)

  inputs = Array(inputs)
  setup_inputs(inputs, attrs)
  setup_control_inputs(graph.control_inputs)
  setup_attrs(**attrs)
end

Instance Attribute Details

#graphObject (readonly)

Returns the value of attribute graph.



4
5
6
# File 'lib/tensorflow/graph/operation_description.rb', line 4

def graph
  @graph
end

#nameObject (readonly)

Returns the value of attribute name.



4
5
6
# File 'lib/tensorflow/graph/operation_description.rb', line 4

def name
  @name
end

#op_defObject (readonly)

Returns the value of attribute op_def.



4
5
6
# File 'lib/tensorflow/graph/operation_description.rb', line 4

def op_def
  @op_def
end

Instance Method Details

#add_input(operation) ⇒ Object



174
175
176
177
178
179
180
181
182
183
184
185
# File 'lib/tensorflow/graph/operation_description.rb', line 174

def add_input(operation)
  # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
  # to fit into one input
  if operation.is_a?(OperationOutput)
    FFI.TF_AddInput(self, operation)
  elsif operation.num_outputs > 1
    packed = Tensorflow.pack(operation, n: operation.num_outputs)
    FFI.TF_AddInput(self, packed.outputs.first)
  else
    FFI.TF_AddInput(self, operation.outputs.first)
  end
end

#add_input_list(operations) ⇒ Object



187
188
189
190
191
192
# File 'lib/tensorflow/graph/operation_description.rb', line 187

def add_input_list(operations)
  # Operation can represent multiple operations *or* one operation with multiple outputs (like SPLIT)
  outputs = Array(operations).map(&:outputs).flatten
  outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
  FFI.TF_AddInputList(self, outputs_ptr, outputs.length)
end

#capture(operation) ⇒ Object



112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# File 'lib/tensorflow/graph/operation_description.rb', line 112

def capture(operation)
  if self.op_def.is_stateful
    raise(Error::InvalidArgumentError, "Cannot capture a stateful node (name: #{operation.name}, type: #{operation.op_type})")
  elsif operation.op_type == "Placeholder"
    raise(Error::InvalidArgumentError, "Cannot capture a placeholder by value (name: #{operation.name}, type: #{operation.op_type})")
  end

  attrs = operation.attributes.reduce(Hash.new) do |hash, attr|
    hash[attr.name.to_sym] = attr.value
    hash
  end
  attrs[:name] = operation.name

  captured_inputs = self.capture_inputs(operation, attrs)
  self.graph.create_operation(operation.op_type, captured_inputs, **attrs)
end

#capture_inputs(operation, attrs) ⇒ Object



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# File 'lib/tensorflow/graph/operation_description.rb', line 79

def capture_inputs(operation, attrs)
  # First capture the inputs
  inputs = operation.inputs.map do |input|
    self.capture(input.operation)
  end

  # We now have to group the inputs together. For example, a TensorSlice dataset has 1 input argument
  # which a list. But the number of inputs returned by the operation is actually the number of items in
  # the list, so its usually more than one. We need to group them into one array to be able to call
  # the operation to create a captured copy.
  i = 0
  operation.op_def.input_arg.reduce(Array.new) do |result, input_arg|
    if !input_arg.number_attr.empty?
      input_len = attrs[input_arg.number_attr.to_sym]
      is_sequence = true
    elsif !input_arg.type_list_attr.empty?
      input_len = attrs[input_arg.type_list_attr.to_sym].length
      is_sequence = true
    else
      input_len = 1
      is_sequence = false
    end

    if is_sequence
      result << inputs[i..i+input_len]
    else
      result << inputs[i]
    end
    i += input_len
    result
  end
end

#check_input(arg_def, input, dtype) ⇒ Object



129
130
131
132
133
134
135
136
137
138
139
140
141
# File 'lib/tensorflow/graph/operation_description.rb', line 129

def check_input(arg_def, input, dtype)
  case input
    when Operation
      self.graph.equal?(input.graph) ? input : capture(input)
    when OperationOutput
      input
    when Variable
      arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
    else
      input_name = "#{self.name}/#{arg_def.name}"
      Tensorflow.constant(input, name: input_name, dtype: dtype)
  end
end

#device=(value) ⇒ Object



56
57
58
# File 'lib/tensorflow/graph/operation_description.rb', line 56

def device=(value)
  FFI.TF_SetDevice(self, value)
end

#figure_dtype(attrs, inputs) ⇒ Object



26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/tensorflow/graph/operation_description.rb', line 26

def figure_dtype(attrs, inputs)
  attr_def = self.op_def.attr.detect do |attr_def|
    attr_def.type == 'type'
  end

  result = attr_def ? attrs[attr_def.name.to_sym] : nil
  unless result
    inputs.each do |input|
      case input
        when Operation
          return input.output_types.first
        when Variable
          return input.dtype
      end
    end
  end
  result
end

#saveObject



49
50
51
52
53
54
# File 'lib/tensorflow/graph/operation_description.rb', line 49

def save
  Status.check do |status|
    ptr = FFI.TF_FinishOperation(self, status)
    Operation.new(self.graph, ptr)
  end
end

#setup_attr(name, value) ⇒ Object



200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# File 'lib/tensorflow/graph/operation_description.rb', line 200

def setup_attr(name, value)
  attr_def = self.op_def.attr.detect do |attr_def|
    name.to_s == attr_def.name
  end
  unless attr_def
    raise(Error::UnknownError, "Unknown attribute: #{name}")
  end

  case attr_def.type
    when 'bool'
      FFI.TF_SetAttrBool(self, attr_def.name, value ? 1 : 0)
    when 'int'
      FFI.TF_SetAttrInt(self, attr_def.name, value)
    when 'float'
      FFI.TF_SetAttrFloat(self, attr_def.name, value)
    when 'func'
      function_name = value.is_a?(Function) ? value.name : value
      FFI.TF_SetAttrFuncName(self, attr_def.name, function_name, function_name.length)
    when 'shape'
      pointer = ::FFI::MemoryPointer.new(:int64, value.length)
      pointer.write_array_of_int64(value)
      FFI.TF_SetAttrShape(self, attr_def.name, pointer, value.length)
    when 'list(shape)'
      dims_pointer = ::FFI::MemoryPointer.new(:pointer, value.length)
      num_dims_pointer = ::FFI::MemoryPointer.new(:int32, value.length)
      value.each_with_index do |shape, i|
        dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
        dim_pointer.write_array_of_int64(shape)
        dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
        num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
      end
      FFI.TF_SetAttrShapeList(self, attr_def.name, dims_pointer, num_dims_pointer, value.length)
    when 'string'
      FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
    when 'list(string)'
      a = 1
      #FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
    when 'tensor'
      Status.check do |status|
        FFI.TF_SetAttrTensor(self, attr_def.name, value, status)
      end
    when 'type'
      FFI.TF_SetAttrType(self, attr_def.name, value)
    when 'list(type)'
      value_ptr = ::FFI::MemoryPointer.new(FFI::DataType.native_type.size, value.count)
      value.each_with_index do |a_value, i|
        value_ptr.put_int32(i * FFI::DataType.native_type.size, FFI::DataType[a_value])
      end
      FFI.TF_SetAttrTypeList(self, attr_def.name, value_ptr, value.count)
    else
      raise(Error::UnimplementedError, "Unsupported attribute. #{self.op_def.name} - #{attr_def.name}")
  end
end

#setup_attrs(**attrs) ⇒ Object



194
195
196
197
198
# File 'lib/tensorflow/graph/operation_description.rb', line 194

def setup_attrs(**attrs)
  attrs.each do |attr_name, attr_value|
    self.setup_attr(attr_name, attr_value)
  end
end

#setup_control_input(control_input) ⇒ Object



66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/tensorflow/graph/operation_description.rb', line 66

def setup_control_input(control_input)
  control_input = case control_input
                    when Operation
                      control_input
                    when Variable
                      control_input.handle
                    else
                      raise(Error::InvalidArgumentError, "Invalid control input")
                    end

  FFI.TF_AddControlInput(self, control_input)
end

#setup_control_inputs(control_inputs) ⇒ Object



60
61
62
63
64
# File 'lib/tensorflow/graph/operation_description.rb', line 60

def setup_control_inputs(control_inputs)
  control_inputs.each do |control_input|
    setup_control_input(control_input)
  end
end

#setup_input(index, value, attrs) ⇒ Object



149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# File 'lib/tensorflow/graph/operation_description.rb', line 149

def setup_input(index, value, attrs)
  arg_def = self.op_def.input_arg[index]
  dtype = attrs[arg_def.type_attr.to_sym]

  # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
  checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?)  && value.is_a?(Array)
                    value.map do |sub_value|
                      self.check_input(arg_def, sub_value, dtype)
                    end
                  else
                    self.check_input(arg_def, value, dtype)
                  end

  if !arg_def.type_list_attr.empty?
    # This input is a heterogeneous list
    self.add_input_list(checked_value)
  elsif !arg_def.number_attr.empty?
    # This input is a homogeneous list
    self.add_input_list(checked_value)
  else
    # This input is a single item
    self.add_input(checked_value)
  end
end

#setup_inputs(inputs, attrs) ⇒ Object



143
144
145
146
147
# File 'lib/tensorflow/graph/operation_description.rb', line 143

def setup_inputs(inputs, attrs)
  inputs.each_with_index do |input, index|
    self.setup_input(index, input, attrs)
  end
end

#to_ptrObject



45
46
47
# File 'lib/tensorflow/graph/operation_description.rb', line 45

def to_ptr
  @pointer
end