Class: TensorStream::Tensor

Inherits:
Object
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/tensor.rb

Overview

Base class that defines a tensor like interface

Direct Known Subclasses

Operation, Placeholder, Variable

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods included from OpHelper

#cons, #dtype_eval, #fp_type?, #i_cons, #i_op, #op, #shape_eval, #val_to_dtype

Constructor Details

#initialize(data_type, rank, shape, options = {}) ⇒ Tensor

Returns a new instance of Tensor.



10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/tensor_stream/tensor.rb', line 10

def initialize(data_type, rank, shape, options = {})
  @data_type = data_type
  @rank = rank
  @breakpoint = false
  @shape = TensorShape.new(shape, rank)
  @value = nil
  @source = format_source(caller_locations)
  @is_const = options[:const] || false
  @internal = options[:internal]
  @graph = options[:graph] || TensorStream.get_default_graph
  @name = options[:name] || build_name
  @given_name = @name

  if options[:value]
    if options[:value].is_a?(Array)
      # check if single dimenstion array is passed
      options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)

      @value = options[:value].collect do |v|
        v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
      end
    elsif !shape.empty?
      @value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
    else
      @value = Tensor.cast_dtype(options[:value], @data_type)
    end
  end

  @graph.add_node(self)
end

Instance Attribute Details

#breakpointObject

Returns the value of attribute breakpoint.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def breakpoint
  @breakpoint
end

#data_typeObject

Returns the value of attribute data_type.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def data_type
  @data_type
end

#given_nameObject

Returns the value of attribute given_name.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def given_name
  @given_name
end

#graphObject

Returns the value of attribute graph.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def graph
  @graph
end

#internalObject

Returns the value of attribute internal.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def internal
  @internal
end

#is_constObject

Returns the value of attribute is_const.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def is_const
  @is_const
end

#nameObject

Returns the value of attribute name.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def name
  @name
end

#native_bufferObject

Returns the value of attribute native_buffer.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def native_buffer
  @native_buffer
end

#rankObject

Returns the value of attribute rank.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def rank
  @rank
end

#shapeObject

Returns the value of attribute shape.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def shape
  @shape
end

#sourceObject

Returns the value of attribute source.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def source
  @source
end

#valueObject

Returns the value of attribute value.



8
9
10
# File 'lib/tensor_stream/tensor.rb', line 8

def value
  @value
end

Class Method Details

.cast_dtype(val, dtype) ⇒ Object



173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# File 'lib/tensor_stream/tensor.rb', line 173

def self.cast_dtype(val, dtype)
  return val if val.is_a?(Tensor)

  if val.is_a?(Array)
    return val.collect do |v|
      cast_dtype(v, dtype)
    end
  end

  case dtype.to_sym
  when :float32, :float
    if !!val == val
      val ? 1.0 : 0.0
    else
      val.to_f
    end
  when :string
    val.to_s
  when :int32, :int16
    if !!val == val
      val ? 1 : 0
    else
      val.to_i
    end
  when :boolean
    !!val
  when :unknown
    val
  else
    raise "unknown data_type #{dtype} passed"
  end
end

.detect_type(value) ⇒ Object



159
160
161
162
163
164
165
166
167
168
169
170
171
# File 'lib/tensor_stream/tensor.rb', line 159

def self.detect_type(value)
  if value.is_a?(String)
    :string
  elsif value.is_a?(Float)
    :float32
  elsif value.is_a?(Integer)
    :int32
  elsif value.is_a?(Array)
    :array
  else
    :float32
  end
end

.reset_countersObject



49
50
51
52
53
# File 'lib/tensor_stream/tensor.rb', line 49

def self.reset_counters
  @const_counter = 0
  @var_counter = 0
  @placeholder_counter = 0
end

Instance Method Details

#!=(other) ⇒ Object



91
92
93
# File 'lib/tensor_stream/tensor.rb', line 91

def !=(other)
  op(:not_equal, self, other)
end

#*(other) ⇒ Object



63
64
65
# File 'lib/tensor_stream/tensor.rb', line 63

def *(other)
  TensorStream::Operation.new(:mul, self, auto_wrap(other))
end

#**(other) ⇒ Object



67
68
69
# File 'lib/tensor_stream/tensor.rb', line 67

def **(other)
  TensorStream::Operation.new(:pow, self, auto_wrap(other))
end

#+(other) ⇒ Object



55
56
57
# File 'lib/tensor_stream/tensor.rb', line 55

def +(other)
  TensorStream::Operation.new(:add, self, auto_wrap(other))
end

#-(other) ⇒ Object



75
76
77
# File 'lib/tensor_stream/tensor.rb', line 75

def -(other)
  TensorStream::Operation.new(:sub, self, auto_wrap(other))
end

#-@Object



79
80
81
# File 'lib/tensor_stream/tensor.rb', line 79

def -@
  TensorStream::Operation.new(:negate, self, nil)
end

#/(other) ⇒ Object



71
72
73
# File 'lib/tensor_stream/tensor.rb', line 71

def /(other)
  TensorStream::Operation.new(:div, self, auto_wrap(other))
end

#<(other) ⇒ Object



87
88
89
# File 'lib/tensor_stream/tensor.rb', line 87

def <(other)
  op(:less, self, other)
end

#<=(other) ⇒ Object



103
104
105
# File 'lib/tensor_stream/tensor.rb', line 103

def <=(other)
  op(:less_equal, self, other)
end

#==(other) ⇒ Object



83
84
85
# File 'lib/tensor_stream/tensor.rb', line 83

def ==(other)
  op(:equal, self, other)
end

#>(other) ⇒ Object



95
96
97
# File 'lib/tensor_stream/tensor.rb', line 95

def >(other)
  op(:greater, self, other)
end

#>=(other) ⇒ Object



99
100
101
# File 'lib/tensor_stream/tensor.rb', line 99

def >=(other)
  op(:greater_equal, self, other)
end

#[](index) ⇒ Object



59
60
61
# File 'lib/tensor_stream/tensor.rb', line 59

def [](index)
  TensorStream::Operation.new(:index, self, index)
end

#auto_math(tensor, name_only = false, max_depth = 99) ⇒ Object



155
156
157
# File 'lib/tensor_stream/tensor.rb', line 155

def auto_math(tensor, name_only = false, max_depth = 99)
  tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
end

#breakpoint!(&block) ⇒ Object



206
207
208
209
# File 'lib/tensor_stream/tensor.rb', line 206

def breakpoint!(&block)
  @breakpoint = block
  self
end

#collect(&block) ⇒ Object



107
108
109
# File 'lib/tensor_stream/tensor.rb', line 107

def collect(&block)
  @value.collect(&block)
end

#dtypeObject



45
46
47
# File 'lib/tensor_stream/tensor.rb', line 45

def dtype
  @data_type
end

#eval(options = {}) ⇒ Object



115
116
117
# File 'lib/tensor_stream/tensor.rb', line 115

def eval(options = {})
  Session.default_session.run(self, options)
end

#firstObject



141
142
143
# File 'lib/tensor_stream/tensor.rb', line 141

def first
  op(:index, self, 0)
end

#internal?Boolean

Returns:

  • (Boolean)


41
42
43
# File 'lib/tensor_stream/tensor.rb', line 41

def internal?
  !!@internal
end

#to_aObject



133
134
135
# File 'lib/tensor_stream/tensor.rb', line 133

def to_a
  @value
end

#to_fObject



137
138
139
# File 'lib/tensor_stream/tensor.rb', line 137

def to_f
  @value
end

#to_hObject



119
120
121
122
123
124
125
126
127
# File 'lib/tensor_stream/tensor.rb', line 119

def to_h
  {
    name: @name,
    value: hashify_tensor(@value),
    dtype: @data_type,
    shape: @shape,
    const: !!is_const,
  }
end

#to_iObject



129
130
131
# File 'lib/tensor_stream/tensor.rb', line 129

def to_i
  @value
end

#to_math(name_only = false, max_depth = 99) ⇒ Object



145
146
147
148
149
150
151
152
153
# File 'lib/tensor_stream/tensor.rb', line 145

def to_math(name_only = false, max_depth = 99)
  return @name if max_depth.zero? || name_only || @value.nil?

  if @value.is_a?(Array)
    @value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
  else
    is_const ? @value : @name
  end
end

#to_sObject



111
112
113
# File 'lib/tensor_stream/tensor.rb', line 111

def to_s
  @name
end