Class: TensorStream::Tensor

Inherits:
Object
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/tensor.rb

Overview

Base class that defines a tensor like interface

Direct Known Subclasses

Operation, Placeholder, Variable

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Constructor Details

#initialize(data_type, rank, shape, options = {}) ⇒ Tensor

Returns a new instance of Tensor.



12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/tensor_stream/tensor.rb', line 12

def initialize(data_type, rank, shape, options = {})
  setup_initial_state(options)
  @data_type = data_type
  @rank = rank
  @breakpoint = false
  @shape = TensorShape.new(shape, rank)
  @value = nil

  @is_const = options[:const] || false
  @internal = options[:internal]
  @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join('/')
  @given_name = @name

  if options[:value]
    if options[:value].is_a?(Array)
      # check if single dimenstion array is passed
      options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)

      @value = options[:value].collect do |v|
        v.is_a?(Tensor) ? Tensor.cast_dtype(v, @data_type) : v
      end
    elsif !shape.empty?
      @value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
    else
      @value = Tensor.cast_dtype(options[:value], @data_type)
    end
  end

  @graph.add_node(self)
end

Instance Attribute Details

#breakpointObject

Returns the value of attribute breakpoint.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def breakpoint
  @breakpoint
end

#consumersObject

Returns the value of attribute consumers.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def consumers
  @consumers
end

#data_typeObject

Returns the value of attribute data_type.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def data_type
  @data_type
end

#deviceObject

Returns the value of attribute device.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def device
  @device
end

#given_nameObject

Returns the value of attribute given_name.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def given_name
  @given_name
end

#graphObject

Returns the value of attribute graph.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def graph
  @graph
end

#internalObject

Returns the value of attribute internal.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def internal
  @internal
end

#is_constObject

Returns the value of attribute is_const.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def is_const
  @is_const
end

#nameObject

Returns the value of attribute name.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def name
  @name
end

#native_bufferObject

Returns the value of attribute native_buffer.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def native_buffer
  @native_buffer
end

#outputsObject

Returns the value of attribute outputs.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def outputs
  @outputs
end

#rankObject

Returns the value of attribute rank.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def rank
  @rank
end

#shapeObject

Returns the value of attribute shape.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def shape
  @shape
end

#sourceObject

Returns the value of attribute source.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def source
  @source
end

#valueObject

Returns the value of attribute value.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def value
  @value
end

Class Method Details

.cast_dtype(val, dtype) ⇒ Object



242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# File 'lib/tensor_stream/tensor.rb', line 242

def self.cast_dtype(val, dtype)
  return val if dtype.nil?
  return val if val.is_a?(Tensor)

  if val.is_a?(Array)
    return val.collect do |v|
      cast_dtype(v, dtype)
    end
  end

  dtype = dtype[:dtype] if dtype.is_a?(Hash)

  case dtype.to_sym
  when :float64, :float32, :float
    if !!val == val
      val ? 1.0 : 0.0
    else
      val.to_f
    end
  when :string
    val.to_s
  when :int32, :int16, :int
    if !!val == val
      val ? 1 : 0
    else
      val.to_i
    end
  when :boolean
    !!val
  when :unknown
    val
  else
    raise "unknown data_type #{dtype} passed"
  end
end

.detect_type(value) ⇒ Object



224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# File 'lib/tensor_stream/tensor.rb', line 224

def self.detect_type(value)
  if !!value == value
    :boolean
  elsif value.is_a?(String)
    :string
  elsif value.is_a?(Float)
    :float32
  elsif value.is_a?(Integer)
    :int32
  elsif value.is_a?(Array)
    detect_type(value[0])
  elsif value.is_a?(Tensor)
    value.data_type
  else
    :float32
  end
end

.reset_countersObject



51
52
53
54
55
# File 'lib/tensor_stream/tensor.rb', line 51

def self.reset_counters
  @const_counter = 0
  @var_counter = 0
  @placeholder_counter = 0
end

Instance Method Details

#!=(other) ⇒ Object



117
118
119
120
# File 'lib/tensor_stream/tensor.rb', line 117

def !=(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:not_equal, self, other)
end

#%(other) ⇒ Object



90
91
92
93
# File 'lib/tensor_stream/tensor.rb', line 90

def %(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:mod, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end

#*(other) ⇒ Object



66
67
68
69
# File 'lib/tensor_stream/tensor.rb', line 66

def *(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end

#**(other) ⇒ Object



71
72
73
74
# File 'lib/tensor_stream/tensor.rb', line 71

def **(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end

#+(other) ⇒ Object



57
58
59
60
# File 'lib/tensor_stream/tensor.rb', line 57

def +(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:add, self, other)
end

#-(other) ⇒ Object



81
82
83
84
# File 'lib/tensor_stream/tensor.rb', line 81

def -(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end

#-@Object



86
87
88
# File 'lib/tensor_stream/tensor.rb', line 86

def -@
  _op(:negate, self, nil)
end

#/(other) ⇒ Object



76
77
78
79
# File 'lib/tensor_stream/tensor.rb', line 76

def /(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end

#<(other) ⇒ Object



112
113
114
115
# File 'lib/tensor_stream/tensor.rb', line 112

def <(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:less, self, other)
end

#<=(other) ⇒ Object



132
133
134
135
# File 'lib/tensor_stream/tensor.rb', line 132

def <=(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:less_equal, self, other)
end

#==(other) ⇒ Object



107
108
109
110
# File 'lib/tensor_stream/tensor.rb', line 107

def ==(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:equal, self, other)
end

#>(other) ⇒ Object



122
123
124
125
# File 'lib/tensor_stream/tensor.rb', line 122

def >(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:greater, self, other)
end

#>=(other) ⇒ Object



127
128
129
130
# File 'lib/tensor_stream/tensor.rb', line 127

def >=(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:greater_equal, self, other)
end

#[](index) ⇒ Object



62
63
64
# File 'lib/tensor_stream/tensor.rb', line 62

def [](index)
  _op(:index, self, index)
end

#and(other) ⇒ Object



137
138
139
140
# File 'lib/tensor_stream/tensor.rb', line 137

def and(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:logical_and, self, other)
end

#auto_math(tensor, name_only = false, max_depth = 99, cur_depth = 0) ⇒ Object



220
221
222
# File 'lib/tensor_stream/tensor.rb', line 220

def auto_math(tensor, name_only = false, max_depth = 99, cur_depth = 0)
  tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth, cur_depth) : tensor
end

#breakpoint!(&_block) ⇒ Object



278
279
280
# File 'lib/tensor_stream/tensor.rb', line 278

def breakpoint!(&_block)
  self
end

#ceilObject



99
100
101
# File 'lib/tensor_stream/tensor.rb', line 99

def ceil
  TensorStream.ceil(self)
end

#collect(&block) ⇒ Object



168
169
170
# File 'lib/tensor_stream/tensor.rb', line 168

def collect(&block)
  @value.collect(&block)
end

#dot(other) ⇒ Object



147
148
149
150
# File 'lib/tensor_stream/tensor.rb', line 147

def dot(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:mat_mul, self, other)
end

#dtypeObject



47
48
49
# File 'lib/tensor_stream/tensor.rb', line 47

def dtype
  @data_type
end

#eval(options = {}) ⇒ Object



180
181
182
# File 'lib/tensor_stream/tensor.rb', line 180

def eval(options = {})
  Session.default_session.run(self, options)
end

#firstObject



206
207
208
# File 'lib/tensor_stream/tensor.rb', line 206

def first
  _op(:index, self, 0)
end

#floorObject



95
96
97
# File 'lib/tensor_stream/tensor.rb', line 95

def floor
  TensorStream.floor(self)
end

#internal?Boolean

Returns:

  • (Boolean)


43
44
45
# File 'lib/tensor_stream/tensor.rb', line 43

def internal?
  !!@internal
end

#matmul(other) ⇒ Object



142
143
144
145
# File 'lib/tensor_stream/tensor.rb', line 142

def matmul(other)
  _a, other = TensorStream.check_data_types(self, other)
  _op(:mat_mul, self, other)
end

#opObject



176
177
178
# File 'lib/tensor_stream/tensor.rb', line 176

def op
  is_const ? _op(:const, self, nil, name: name) : _op(:variable, self, nil, name: name)
end

#print!(message) ⇒ Object



282
283
284
# File 'lib/tensor_stream/tensor.rb', line 282

def print!(message)
  _op(:print, self, self, message: message)
end

#reduce(op_type) ⇒ Object

Apply a reduction to tensor



154
155
156
157
158
159
160
161
162
163
164
165
166
# File 'lib/tensor_stream/tensor.rb', line 154

def reduce(op_type)
  reduce_op = case op_type.to_sym
              when :+
                :sum
              when :*
                :prod
              else
                raise "unsupported reduce op type #{op_type}"
              end
  raise "blocks are not supported for tensors" if block_given?

  _op(reduce_op, self, nil)
end

#to_aObject



198
199
200
# File 'lib/tensor_stream/tensor.rb', line 198

def to_a
  @value
end

#to_fObject



202
203
204
# File 'lib/tensor_stream/tensor.rb', line 202

def to_f
  @value
end

#to_hObject



184
185
186
187
188
189
190
191
192
# File 'lib/tensor_stream/tensor.rb', line 184

def to_h
  {
    name: @name,
    value: hashify_tensor(@value),
    dtype: @data_type,
    shape: @shape,
    const: !!is_const,
  }
end

#to_iObject



194
195
196
# File 'lib/tensor_stream/tensor.rb', line 194

def to_i
  @value
end

#to_math(name_only = false, max_depth = 99, _unused = 0) ⇒ Object



210
211
212
213
214
215
216
217
218
# File 'lib/tensor_stream/tensor.rb', line 210

def to_math(name_only = false, max_depth = 99, _unused = 0)
  return @name if max_depth.zero? || name_only || @value.nil?

  if @value.is_a?(Array)
    @value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
  else
    is_const ? @value : @name
  end
end

#to_sObject



172
173
174
# File 'lib/tensor_stream/tensor.rb', line 172

def to_s
  @name
end

#zero?Boolean

Returns:

  • (Boolean)


103
104
105
# File 'lib/tensor_stream/tensor.rb', line 103

def zero?
  _op(:equal, self, TensorStream.constant(0, dtype: data_type, name: 'equal/is_zero?'))
end