Class: Tensorflow::Variable
- Inherits:
-
Object
- Object
- Tensorflow::Variable
show all
- Includes:
- Operators
- Defined in:
- lib/tensorflow/variable.rb
Instance Attribute Summary collapse
Instance Method Summary
collapse
Methods included from Operators
#%, #*, #**, #+, #-, #-@, #/
Constructor Details
#initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true) ⇒ Variable
Returns a new instance of Variable.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
# File 'lib/tensorflow/variable.rb', line 7
def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
initial_value = case initial_value
when NilClass
@dtype = dtype
shape = []
initial_value
when Graph::Operation
@dtype = dtype || initial_value.dtype
shape = shape || initial_value.output_shapes.first
initial_value
when Tensor
@dtype = initial_value.dtype
shape = shape || initial_value.shape
initial_value
else
tensor = Tensor.from_value(initial_value, dtype: dtype)
@dtype = tensor.dtype
shape = tensor.shape
tensor
end
name = name&.to_s
shared_name = shared_name&.to_s
unique_name = ExecutionContext.current.unique_name(name || shared_name)
shared_name ||= unique_name
@name = unique_name
collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
if trainable
collections << Graph::GraphKeys::TRAINABLE_VARIABLES
end
ExecutionContext.current.add_to_collections(collections, self)
@handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
self.value = initial_value if initial_value
end
|
Instance Attribute Details
#dtype ⇒ Object
Returns the value of attribute dtype.
5
6
7
|
# File 'lib/tensorflow/variable.rb', line 5
def dtype
@dtype
end
|
#handle ⇒ Object
Returns the value of attribute handle.
5
6
7
|
# File 'lib/tensorflow/variable.rb', line 5
def handle
@handle
end
|
#name ⇒ Object
Returns the value of attribute name.
5
6
7
|
# File 'lib/tensorflow/variable.rb', line 5
def name
@name
end
|
Instance Method Details
#assign_add(value, dtype: nil) ⇒ Object
101
102
103
104
105
106
|
# File 'lib/tensorflow/variable.rb', line 101
def assign_add(value, dtype: nil)
@value_handle = nil
tensor = Tensor.from_value(value, dtype: dtype)
tensor = Tensorflow.cast(tensor, self.dtype)
RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
end
|
#assign_sub(value) ⇒ Object
#consumers ⇒ Object
These methods match the operation api to enable gradients and sessions
71
72
73
|
# File 'lib/tensorflow/variable.rb', line 71
def consumers
self.handle.consumers
end
|
#initializer ⇒ Object
62
63
64
|
# File 'lib/tensorflow/variable.rb', line 62
def initializer
@initializer
end
|
#inspect ⇒ Object
119
120
121
122
123
124
125
|
# File 'lib/tensorflow/variable.rb', line 119
def inspect
inspection = []
inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
inspection << ["shape: #{self.value_handle.shape}"]
inspection << ["dtype: #{self.value_handle.dtype}"]
"#<#{self.class} #{inspection.join(", ")}>"
end
|
#outputs ⇒ Object
This enables executing variables to get the values in a session
#rank ⇒ Object
93
94
95
|
# File 'lib/tensorflow/variable.rb', line 93
def rank
self.shape.size
end
|
#reshape(shape) ⇒ Object
97
98
99
|
# File 'lib/tensorflow/variable.rb', line 97
def reshape(shape)
RawOps.reshape(self, shape)
end
|
#shape ⇒ Object
84
85
86
|
# File 'lib/tensorflow/variable.rb', line 84
def shape
self.value_handle.shape
end
|
#to_ptr ⇒ Object
80
81
82
|
# File 'lib/tensorflow/variable.rb', line 80
def to_ptr
self.handle.to_ptr
end
|
#to_s ⇒ Object
115
116
117
|
# File 'lib/tensorflow/variable.rb', line 115
def to_s
inspect
end
|
#value ⇒ Object
49
50
51
52
53
54
55
56
|
# File 'lib/tensorflow/variable.rb', line 49
def value
case value_handle
when Eager::TensorHandle
value_handle.value
when Graph::Operation
value_handle
end
end
|
#value=(value) ⇒ Object
58
59
60
|
# File 'lib/tensorflow/variable.rb', line 58
def value=(value)
@initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
end
|
#value_handle ⇒ Object
45
46
47
|
# File 'lib/tensorflow/variable.rb', line 45
def value_handle
@value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
end
|