Class: TensorStream::Tensor
- Inherits:
-
Object
- Object
- TensorStream::Tensor
show all
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/tensor.rb
Overview
Base class that defines a tensor like interface
Instance Attribute Summary collapse
Class Method Summary
collapse
Instance Method Summary
collapse
Methods included from OpHelper
#_op, #cons, #dtype_eval, #format_source, #fp_type?, #i_cons, #i_op, #shape_eval, #val_to_dtype
Constructor Details
#initialize(data_type, rank, shape, options = {}) ⇒ Tensor
Returns a new instance of Tensor.
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
# File 'lib/tensor_stream/tensor.rb', line 12
def initialize(data_type, rank, shape, options = {})
@data_type = data_type
@rank = rank
@breakpoint = false
@shape = TensorShape.new(shape, rank)
@value = nil
@source = format_source(caller_locations)
@is_const = options[:const] || false
@internal = options[:internal]
@graph = options[:graph] || TensorStream.get_default_graph
@name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join('/')
@given_name = @name
if options[:value]
if options[:value].is_a?(Array)
options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
@value = options[:value].collect do |v|
v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
end
elsif !shape.empty?
@value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
else
@value = Tensor.cast_dtype(options[:value], @data_type)
end
end
@graph.add_node(self)
end
|
Instance Attribute Details
#breakpoint ⇒ Object
Returns the value of attribute breakpoint.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def breakpoint
@breakpoint
end
|
#consumers ⇒ Object
Returns the value of attribute consumers.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def consumers
@consumers
end
|
#data_type ⇒ Object
Returns the value of attribute data_type.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def data_type
@data_type
end
|
#given_name ⇒ Object
Returns the value of attribute given_name.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def given_name
@given_name
end
|
#graph ⇒ Object
Returns the value of attribute graph.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def graph
@graph
end
|
#internal ⇒ Object
Returns the value of attribute internal.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def internal
@internal
end
|
#is_const ⇒ Object
Returns the value of attribute is_const.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def is_const
@is_const
end
|
#name ⇒ Object
Returns the value of attribute name.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def name
@name
end
|
#native_buffer ⇒ Object
Returns the value of attribute native_buffer.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def native_buffer
@native_buffer
end
|
#rank ⇒ Object
Returns the value of attribute rank.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def rank
@rank
end
|
#shape ⇒ Object
Returns the value of attribute shape.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def shape
@shape
end
|
#source ⇒ Object
Returns the value of attribute source.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def source
@source
end
|
#value ⇒ Object
Returns the value of attribute value.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def value
@value
end
|
Class Method Details
.cast_dtype(val, dtype) ⇒ Object
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
|
# File 'lib/tensor_stream/tensor.rb', line 192
def self.cast_dtype(val, dtype)
return val if dtype.nil?
return val if val.is_a?(Tensor)
if val.is_a?(Array)
return val.collect do |v|
cast_dtype(v, dtype)
end
end
case dtype.to_sym
when :float32, :float
if !!val == val
val ? 1.0 : 0.0
else
val.to_f
end
when :string
val.to_s
when :int32, :int16
if !!val == val
val ? 1 : 0
else
val.to_i
end
when :boolean
!!val
when :unknown
val
else
raise "unknown data_type #{dtype} passed"
end
end
|
.detect_type(value) ⇒ Object
178
179
180
181
182
183
184
185
186
187
188
189
190
|
# File 'lib/tensor_stream/tensor.rb', line 178
def self.detect_type(value)
if value.is_a?(String)
:string
elsif value.is_a?(Float)
:float32
elsif value.is_a?(Integer)
:int32
elsif value.is_a?(Array)
:array
else
:float32
end
end
|
.reset_counters ⇒ Object
51
52
53
54
55
|
# File 'lib/tensor_stream/tensor.rb', line 51
def self.reset_counters
@const_counter = 0
@var_counter = 0
@placeholder_counter = 0
end
|
Instance Method Details
#!=(other) ⇒ Object
93
94
95
|
# File 'lib/tensor_stream/tensor.rb', line 93
def !=(other)
_op(:not_equal, self, other)
end
|
#*(other) ⇒ Object
65
66
67
|
# File 'lib/tensor_stream/tensor.rb', line 65
def *(other)
TensorStream::Operation.new(:mul, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end
|
#**(other) ⇒ Object
69
70
71
|
# File 'lib/tensor_stream/tensor.rb', line 69
def **(other)
TensorStream::Operation.new(:pow, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end
|
#+(other) ⇒ Object
57
58
59
|
# File 'lib/tensor_stream/tensor.rb', line 57
def +(other)
TensorStream::Operation.new(:add, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end
|
#-(other) ⇒ Object
77
78
79
|
# File 'lib/tensor_stream/tensor.rb', line 77
def -(other)
TensorStream::Operation.new(:sub, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end
|
#-@ ⇒ Object
81
82
83
|
# File 'lib/tensor_stream/tensor.rb', line 81
def -@
TensorStream::Operation.new(:negate, self, nil)
end
|
#/(other) ⇒ Object
73
74
75
|
# File 'lib/tensor_stream/tensor.rb', line 73
def /(other)
TensorStream::Operation.new(:div, self, TensorStream.convert_to_tensor(other, dtype: data_type))
end
|
#<(other) ⇒ Object
89
90
91
|
# File 'lib/tensor_stream/tensor.rb', line 89
def <(other)
_op(:less, self, other)
end
|
#<=(other) ⇒ Object
105
106
107
|
# File 'lib/tensor_stream/tensor.rb', line 105
def <=(other)
_op(:less_equal, self, other)
end
|
#==(other) ⇒ Object
85
86
87
|
# File 'lib/tensor_stream/tensor.rb', line 85
def ==(other)
_op(:equal, self, other)
end
|
#>(other) ⇒ Object
97
98
99
|
# File 'lib/tensor_stream/tensor.rb', line 97
def >(other)
_op(:greater, self, other)
end
|
#>=(other) ⇒ Object
101
102
103
|
# File 'lib/tensor_stream/tensor.rb', line 101
def >=(other)
_op(:greater_equal, self, other)
end
|
#[](index) ⇒ Object
61
62
63
|
# File 'lib/tensor_stream/tensor.rb', line 61
def [](index)
TensorStream::Operation.new(:index, self, index)
end
|
#and(other) ⇒ Object
109
110
111
|
# File 'lib/tensor_stream/tensor.rb', line 109
def and(other)
_op(:logical_and, self, other)
end
|
#auto_math(tensor, name_only = false, max_depth = 99, _cur_depth = 0) ⇒ Object
174
175
176
|
# File 'lib/tensor_stream/tensor.rb', line 174
def auto_math(tensor, name_only = false, max_depth = 99, _cur_depth = 0)
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth, _cur_depth) : tensor
end
|
#breakpoint!(&block) ⇒ Object
226
227
228
|
# File 'lib/tensor_stream/tensor.rb', line 226
def breakpoint!(&block)
self
end
|
#collect(&block) ⇒ Object
122
123
124
|
# File 'lib/tensor_stream/tensor.rb', line 122
def collect(&block)
@value.collect(&block)
end
|
#dot(other) ⇒ Object
118
119
120
|
# File 'lib/tensor_stream/tensor.rb', line 118
def dot(other)
_op(:matmul, self, other)
end
|
#dtype ⇒ Object
47
48
49
|
# File 'lib/tensor_stream/tensor.rb', line 47
def dtype
@data_type
end
|
#eval(options = {}) ⇒ Object
134
135
136
|
# File 'lib/tensor_stream/tensor.rb', line 134
def eval(options = {})
Session.default_session.run(self, options)
end
|
#first ⇒ Object
160
161
162
|
# File 'lib/tensor_stream/tensor.rb', line 160
def first
_op(:index, self, 0)
end
|
#internal? ⇒ Boolean
43
44
45
|
# File 'lib/tensor_stream/tensor.rb', line 43
def internal?
!!@internal
end
|
#matmul(other) ⇒ Object
113
114
115
|
# File 'lib/tensor_stream/tensor.rb', line 113
def matmul(other)
_op(:matmul, self, other)
end
|
#op ⇒ Object
130
131
132
|
# File 'lib/tensor_stream/tensor.rb', line 130
def op
is_const ? _op(:const, self, nil, name: self.name) : _op(:variable, self, nil, name: self.name)
end
|
#print!(message) ⇒ Object
230
231
232
|
# File 'lib/tensor_stream/tensor.rb', line 230
def print!(message)
_op(:print, self, self, message: message)
end
|
#to_a ⇒ Object
152
153
154
|
# File 'lib/tensor_stream/tensor.rb', line 152
def to_a
@value
end
|
#to_f ⇒ Object
156
157
158
|
# File 'lib/tensor_stream/tensor.rb', line 156
def to_f
@value
end
|
#to_h ⇒ Object
138
139
140
141
142
143
144
145
146
|
# File 'lib/tensor_stream/tensor.rb', line 138
def to_h
{
name: @name,
value: hashify_tensor(@value),
dtype: @data_type,
shape: @shape,
const: !!is_const,
}
end
|
#to_i ⇒ Object
148
149
150
|
# File 'lib/tensor_stream/tensor.rb', line 148
def to_i
@value
end
|
#to_math(name_only = false, max_depth = 99, _unused = 0) ⇒ Object
164
165
166
167
168
169
170
171
172
|
# File 'lib/tensor_stream/tensor.rb', line 164
def to_math(name_only = false, max_depth = 99, _unused = 0)
return @name if max_depth.zero? || name_only || @value.nil?
if @value.is_a?(Array)
@value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
else
is_const ? @value : @name
end
end
|
#to_s ⇒ Object
126
127
128
|
# File 'lib/tensor_stream/tensor.rb', line 126
def to_s
@name
end
|