Class: TensorStream::Tensor
- Inherits:
-
Object
- Object
- TensorStream::Tensor
show all
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/tensor.rb
Overview
Base class that defines a tensor like interface
Instance Attribute Summary collapse
Class Method Summary
collapse
Instance Method Summary
collapse
Methods included from OpHelper
#cons, #dtype_eval, #fp_type?, #i_cons, #i_op, #op, #shape_eval, #val_to_dtype
Constructor Details
#initialize(data_type, rank, shape, options = {}) ⇒ Tensor
Returns a new instance of Tensor.
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
# File 'lib/tensor_stream/tensor.rb', line 10
def initialize(data_type, rank, shape, options = {})
@data_type = data_type
@rank = rank
@breakpoint = false
@shape = TensorShape.new(shape, rank)
@value = nil
@source = format_source(caller_locations)
@is_const = options[:const] || false
@internal = options[:internal]
@graph = options[:graph] || TensorStream.get_default_graph
@name = options[:name] || build_name
@given_name = @name
if options[:value]
if options[:value].is_a?(Array)
options[:value] = reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)
@value = options[:value].collect do |v|
v.is_a?(Tensor) ? Tensor.cast_dtype(v, data_type) : v
end
elsif !shape.empty?
@value = reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
else
@value = Tensor.cast_dtype(options[:value], @data_type)
end
end
@graph.add_node(self)
end
|
Instance Attribute Details
#breakpoint ⇒ Object
Returns the value of attribute breakpoint.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def breakpoint
@breakpoint
end
|
#data_type ⇒ Object
Returns the value of attribute data_type.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def data_type
@data_type
end
|
#given_name ⇒ Object
Returns the value of attribute given_name.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def given_name
@given_name
end
|
#graph ⇒ Object
Returns the value of attribute graph.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def graph
@graph
end
|
#internal ⇒ Object
Returns the value of attribute internal.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def internal
@internal
end
|
#is_const ⇒ Object
Returns the value of attribute is_const.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def is_const
@is_const
end
|
#name ⇒ Object
Returns the value of attribute name.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def name
@name
end
|
#native_buffer ⇒ Object
Returns the value of attribute native_buffer.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def native_buffer
@native_buffer
end
|
#rank ⇒ Object
Returns the value of attribute rank.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def rank
@rank
end
|
#shape ⇒ Object
Returns the value of attribute shape.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def shape
@shape
end
|
#source ⇒ Object
Returns the value of attribute source.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def source
@source
end
|
#value ⇒ Object
Returns the value of attribute value.
8
9
10
|
# File 'lib/tensor_stream/tensor.rb', line 8
def value
@value
end
|
Class Method Details
.cast_dtype(val, dtype) ⇒ Object
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
|
# File 'lib/tensor_stream/tensor.rb', line 173
def self.cast_dtype(val, dtype)
return val if val.is_a?(Tensor)
if val.is_a?(Array)
return val.collect do |v|
cast_dtype(v, dtype)
end
end
case dtype.to_sym
when :float32, :float
if !!val == val
val ? 1.0 : 0.0
else
val.to_f
end
when :string
val.to_s
when :int32, :int16
if !!val == val
val ? 1 : 0
else
val.to_i
end
when :boolean
!!val
when :unknown
val
else
raise "unknown data_type #{dtype} passed"
end
end
|
.detect_type(value) ⇒ Object
159
160
161
162
163
164
165
166
167
168
169
170
171
|
# File 'lib/tensor_stream/tensor.rb', line 159
def self.detect_type(value)
if value.is_a?(String)
:string
elsif value.is_a?(Float)
:float32
elsif value.is_a?(Integer)
:int32
elsif value.is_a?(Array)
:array
else
:float32
end
end
|
.reset_counters ⇒ Object
49
50
51
52
53
|
# File 'lib/tensor_stream/tensor.rb', line 49
def self.reset_counters
@const_counter = 0
@var_counter = 0
@placeholder_counter = 0
end
|
Instance Method Details
#!=(other) ⇒ Object
91
92
93
|
# File 'lib/tensor_stream/tensor.rb', line 91
def !=(other)
op(:not_equal, self, other)
end
|
#*(other) ⇒ Object
63
64
65
|
# File 'lib/tensor_stream/tensor.rb', line 63
def *(other)
TensorStream::Operation.new(:mul, self, auto_wrap(other))
end
|
#**(other) ⇒ Object
67
68
69
|
# File 'lib/tensor_stream/tensor.rb', line 67
def **(other)
TensorStream::Operation.new(:pow, self, auto_wrap(other))
end
|
#+(other) ⇒ Object
55
56
57
|
# File 'lib/tensor_stream/tensor.rb', line 55
def +(other)
TensorStream::Operation.new(:add, self, auto_wrap(other))
end
|
#-(other) ⇒ Object
75
76
77
|
# File 'lib/tensor_stream/tensor.rb', line 75
def -(other)
TensorStream::Operation.new(:sub, self, auto_wrap(other))
end
|
#/(other) ⇒ Object
71
72
73
|
# File 'lib/tensor_stream/tensor.rb', line 71
def /(other)
TensorStream::Operation.new(:div, self, auto_wrap(other))
end
|
#<(other) ⇒ Object
87
88
89
|
# File 'lib/tensor_stream/tensor.rb', line 87
def <(other)
op(:less, self, other)
end
|
#<=(other) ⇒ Object
103
104
105
|
# File 'lib/tensor_stream/tensor.rb', line 103
def <=(other)
op(:less_equal, self, other)
end
|
#==(other) ⇒ Object
83
84
85
|
# File 'lib/tensor_stream/tensor.rb', line 83
def ==(other)
op(:equal, self, other)
end
|
#>(other) ⇒ Object
95
96
97
|
# File 'lib/tensor_stream/tensor.rb', line 95
def >(other)
op(:greater, self, other)
end
|
#>=(other) ⇒ Object
99
100
101
|
# File 'lib/tensor_stream/tensor.rb', line 99
def >=(other)
op(:greater_equal, self, other)
end
|
#[](index) ⇒ Object
59
60
61
|
# File 'lib/tensor_stream/tensor.rb', line 59
def [](index)
TensorStream::Operation.new(:index, self, index)
end
|
#auto_math(tensor, name_only = false, max_depth = 99) ⇒ Object
155
156
157
|
# File 'lib/tensor_stream/tensor.rb', line 155
def auto_math(tensor, name_only = false, max_depth = 99)
tensor.is_a?(Tensor) ? tensor.to_math(name_only, max_depth) : tensor
end
|
#breakpoint!(&block) ⇒ Object
206
207
208
209
|
# File 'lib/tensor_stream/tensor.rb', line 206
def breakpoint!(&block)
@breakpoint = block
self
end
|
#collect(&block) ⇒ Object
107
108
109
|
# File 'lib/tensor_stream/tensor.rb', line 107
def collect(&block)
@value.collect(&block)
end
|
#dtype ⇒ Object
45
46
47
|
# File 'lib/tensor_stream/tensor.rb', line 45
def dtype
@data_type
end
|
#eval(options = {}) ⇒ Object
115
116
117
|
# File 'lib/tensor_stream/tensor.rb', line 115
def eval(options = {})
Session.default_session.run(self, options)
end
|
#first ⇒ Object
141
142
143
|
# File 'lib/tensor_stream/tensor.rb', line 141
def first
op(:index, self, 0)
end
|
#internal? ⇒ Boolean
41
42
43
|
# File 'lib/tensor_stream/tensor.rb', line 41
def internal?
!!@internal
end
|
#to_a ⇒ Object
133
134
135
|
# File 'lib/tensor_stream/tensor.rb', line 133
def to_a
@value
end
|
#to_f ⇒ Object
137
138
139
|
# File 'lib/tensor_stream/tensor.rb', line 137
def to_f
@value
end
|
#to_h ⇒ Object
119
120
121
122
123
124
125
126
127
|
# File 'lib/tensor_stream/tensor.rb', line 119
def to_h
{
name: @name,
value: hashify_tensor(@value),
dtype: @data_type,
shape: @shape,
const: !!is_const,
}
end
|
#to_i ⇒ Object
129
130
131
|
# File 'lib/tensor_stream/tensor.rb', line 129
def to_i
@value
end
|
#to_math(name_only = false, max_depth = 99) ⇒ Object
145
146
147
148
149
150
151
152
153
|
# File 'lib/tensor_stream/tensor.rb', line 145
def to_math(name_only = false, max_depth = 99)
return @name if max_depth.zero? || name_only || @value.nil?
if @value.is_a?(Array)
@value.collect { |v| v.is_a?(Tensor) ? v.to_math(name_only, max_depth - 1) : v }
else
is_const ? @value : @name
end
end
|
#to_s ⇒ Object
111
112
113
|
# File 'lib/tensor_stream/tensor.rb', line 111
def to_s
@name
end
|