Class: TensorStream::Operation

Inherits:
Tensor
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/operation.rb

Overview

TensorStream class that defines an operation

Direct Known Subclasses

ControlFlow, DynamicStitch

Instance Attribute Summary collapse

Attributes inherited from Tensor

#given_name, #graph, #internal, #native_buffer, #source, #value

Instance Method Summary collapse

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Methods inherited from Tensor

#auto_math, #breakpoint!, cast_dtype, #collect, detect_type, #dtype, #eval, #first, #internal?, #print!, reset_counters, #to_a, #to_f, #to_i

Methods included from TensorMixins

#!=, #%, #*, #**, #+, #-, #-@, #/, #<, #<=, #==, #>, #>=, #[], #and, #cast, #ceil, #dot, #floor, #log, #matmul, #reduce, #reshape, #round, #var, #zero?

Constructor Details

#initialize(graph, inputs: [], options: {}) ⇒ Operation

Returns a new instance of Operation.



10
11
12
13
14
15
16
17
# File 'lib/tensor_stream/operation.rb', line 10

def initialize(graph, inputs: [], options: {})
  @consumers = Set.new
  @outputs = []
  @op = self
  @graph = graph
  @inputs = inputs
  @options = options
end

Instance Attribute Details

#breakpointObject

Returns the value of attribute breakpoint.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def breakpoint
  @breakpoint
end

#consumersObject

Returns the value of attribute consumers.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def consumers
  @consumers
end

#data_typeObject (readonly)

Returns the value of attribute data_type.



8
9
10
# File 'lib/tensor_stream/operation.rb', line 8

def data_type
  @data_type
end

#deviceObject

Returns the value of attribute device.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def device
  @device
end

#inputsObject

Returns the value of attribute inputs.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def inputs
  @inputs
end

#is_constObject (readonly)

Returns the value of attribute is_const.



8
9
10
# File 'lib/tensor_stream/operation.rb', line 8

def is_const
  @is_const
end

#nameObject

Returns the value of attribute name.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def name
  @name
end

#operationObject

Returns the value of attribute operation.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def operation
  @operation
end

#optionsObject (readonly)

Returns the value of attribute options.



8
9
10
# File 'lib/tensor_stream/operation.rb', line 8

def options
  @options
end

#outputsObject (readonly)

Returns the value of attribute outputs.



8
9
10
# File 'lib/tensor_stream/operation.rb', line 8

def outputs
  @outputs
end

#rankObject

Returns the value of attribute rank.



7
8
9
# File 'lib/tensor_stream/operation.rb', line 7

def rank
  @rank
end

#shapeObject (readonly)

Returns the value of attribute shape.



8
9
10
# File 'lib/tensor_stream/operation.rb', line 8

def shape
  @shape
end

Instance Method Details

#const_valueObject



37
38
39
# File 'lib/tensor_stream/operation.rb', line 37

def const_value
  @options ? @options[:value] : nil
end

#container_bufferObject



41
42
43
# File 'lib/tensor_stream/operation.rb', line 41

def container_buffer
  @options[:container] ? @options[:container].buffer : nil
end

#infer_constObject



57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# File 'lib/tensor_stream/operation.rb', line 57

def infer_const
  return false if breakpoint

  case operation
  when :random_standard_normal, :random_uniform, :truncated_normal, :glorot_uniform, :print, :check_numerics
    false
  when :const
    true
  when :placeholder
    false
  when :variable_v2, :assign, :assign_add, :assign_sub
    false
  else
    non_const = @inputs.compact.find { |input| !input.is_const }
    non_const ? false : true
  end
end

#inspectObject



19
20
21
# File 'lib/tensor_stream/operation.rb', line 19

def inspect
  "Op(#{operation} name: #{name} shape: #{@shape || "?"} data_type: #{data_type})"
end

#opObject



244
245
246
# File 'lib/tensor_stream/operation.rb', line 244

def op
  self
end

#runObject



240
241
242
# File 'lib/tensor_stream/operation.rb', line 240

def run
  eval
end

#set_data_type(passed_data_type) ⇒ Object



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# File 'lib/tensor_stream/operation.rb', line 79

def set_data_type(passed_data_type)
  case operation
  when :where
    @inputs[1].data_type
  when :case
    if @inputs[2]
      @inputs[2].data_type
    else
      @inputs[1].data_type
    end
  when :case_grad
    @inputs[2].data_type
  when :placeholder, :variable_v2, :const
    options[:data_type]
  when :fill
    @inputs[1].data_type
  when :logical_and
    :boolean
  when :shape, :rank, :shape_n
    options[:out_type] || :int32
  when :zeros, :ones
    options[:dtype] || :float32
  when :random_standard_normal, :random_uniform, :glorot_uniform, :truncated_normal
    passed_data_type || :float32
  when :concat
    @inputs[1].data_type
  when :conv2d_backprop_input
    @inputs[1].data_type
  when :index
    if @inputs[0].is_a?(ControlFlow)
      if @inputs[1].is_const
        @inputs[0].inputs[@inputs[1].const_value].data_type
      else
        :unknown
      end
    else
      @inputs[0].data_type
    end
  else
    OpMaker.infer_data_type(self, self, passed_data_type)
  end
end

#set_input(index, value) ⇒ Object



45
46
47
48
49
50
51
# File 'lib/tensor_stream/operation.rb', line 45

def set_input(index, value)
  @inputs[index] = value
  @shape = TensorShape.new(TensorStream::InferShape.infer_shape(self))
  @rank = @shape.rank
  @is_const = infer_const
  @data_type = set_data_type(@options[:data_type])
end

#set_nameObject



75
76
77
# File 'lib/tensor_stream/operation.rb', line 75

def set_name
  @operation.to_s
end

#set_option(key, value) ⇒ Object



53
54
55
# File 'lib/tensor_stream/operation.rb', line 53

def set_option(key, value)
  @options.merge!(key.to_sym => value)
end

#to_hObject



27
28
29
30
31
32
33
34
35
# File 'lib/tensor_stream/operation.rb', line 27

def to_h
  {
    op: operation.to_s,
    name: name.to_s,
    data_type: @data_type,
    inputs: @inputs.map(&:name),
    attrs: serialize_options,
  }
end

#to_math(name_only = false, max_depth = 99, cur_depth = 0) ⇒ Object



122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# File 'lib/tensor_stream/operation.rb', line 122

def to_math(name_only = false, max_depth = 99, cur_depth = 0)
  return @name if max_depth.zero?

  sub_input = auto_math(inputs[0], name_only, max_depth - 1, cur_depth + 1)
  sub_input2 = auto_math(inputs[1], name_only, max_depth - 1, cur_depth + 1) if inputs[1]

  out = case operation
        when :argmax
          "argmax(#{sub_input},#{options[:axis]})"
        when :negate
          "-#{sub_input}"
        when :index
          "#{sub_input}[#{sub_input2}]"
        when :slice
          "#{sub_input}[#{sub_input2}]"
        when :assign_sub
          "(#{inputs[0] ? inputs[0].name : "self"} -= #{auto_math(inputs[1], name_only, 1)})"
        when :assign_add
          "(#{inputs[0] ? inputs[0].name : "self"} += #{auto_math(inputs[1], name_only, 1)})"
        when :assign
          "(#{inputs[0] ? inputs[0].name : "self"} = #{auto_math(inputs[1], name_only, 1)})"
        when :sin, :cos, :tanh
          "#{operation}(#{sub_input})"
        when :add
          "(#{sub_input} + #{sub_input2})"
        when :sub
          "(#{sub_input} - #{sub_input2})"
        when :pow
          "(#{sub_input}^#{sub_input2})"
        when :div
          "(#{sub_input} / #{sub_input2})"
        when :mul
          if auto_math(inputs[0]) == 1
            sub_input2
          elsif auto_math(inputs[1]) == 1
            sub_input
          else
            "(#{sub_input} * #{sub_input2})"
          end
        when :sum
          "sum(|#{sub_input}|,  axis=#{sub_input2})"
        when :mean
          "mean(|#{sub_input}|, axis=#{sub_input2})"
        when :prod
          "prod(|#{sub_input}|,  axis=#{sub_input2})"
        when :gradients
          "gradient(#{sub_input})"
        when :stop_gradient
          sub_input
        when :mat_mul
          "#{sub_input}.matmul(#{sub_input2})"
        when :eye
          "eye(#{sub_input})"
        when :transpose
          "transpose(#{sub_input})"
        when :shape
          "#{sub_input}.shape"
        when :exp
          "e^#{sub_input})"
        when :ones
          "ones(#{sub_input})"
        when :ones_like
          "ones_like(#{sub_input})"
        when :flow_group
          "flow_group(#{inputs.collect { |i| auto_math(i, name_only, max_depth - 1, cur_depth) }.join(",")})"
        when :zeros
          "zeros(#{sub_input})"
        when :reshape
          "reshape(#{sub_input},#{sub_input2})"
        when :rank
          "#{sub_input}.rank"
        when :less
          "#{sub_input} < #{sub_input2}"
        when :less_equal
          "#{sub_input} <= #{sub_input2}"
        when :greater
          "#{sub_input} > #{sub_input2}"
        when :greater_equal
          "#{sub_input} >= #{sub_input2}"
        when :square
          "#{sub_input}\u00B2"
        when :log
          "log(#{sub_input})"
        when :identity
          "identity(#{sub_input})"
        when :print
          "print(#{sub_input})"
        when :pad
          "pad(#{sub_input},#{auto_math(options[:paddings])})"
        when :equal
          "#{sub_input} == #{sub_input2}"
        when :not_equal
          "#{sub_input} != #{sub_input2}"
        when :logical_and
          "#{sub_input} && #{sub_input2}"
        when :sqrt
          "sqrt(#{sub_input})"
        when :log1p
          "log1p(#{sub_input})"
        when :zeros_like
          "zeros_like(#{sub_input})"
        when :where
          "where(#{auto_math(options[:pred], name_only, max_depth - 1, cur_depth)}, #{sub_input}, #{sub_input2})"
        when :max
          "max(#{sub_input},#{sub_input2})"
        when :cast
          "cast(#{sub_input}, #{data_type})"
        when :broadcast_transform
          "broadcast_transform(#{sub_input},#{sub_input2})"
        when :broadcast_gradient_args
          "broadcast_transform(#{sub_input},#{sub_input2})"
        else
          "#{operation}(#{sub_input})" if sub_input
          "#{operation}(#{sub_input}, #{sub_input2})" if sub_input && sub_input2
  end
  ["\n", Array.new(cur_depth + 1) { " " }, out].flatten.join
end

#to_sObject



23
24
25
# File 'lib/tensor_stream/operation.rb', line 23

def to_s
  @name
end