Class: Tensorflow::Variable

Inherits:
Object
  • Object
show all
Includes:
Operators
Defined in:
lib/tensorflow/variable.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods included from Operators

#%, #*, #**, #+, #-, #-@, #/

Constructor Details

#initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true) ⇒ Variable

Returns a new instance of Variable.



7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/tensorflow/variable.rb', line 7

def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
  initial_value = case initial_value
                  when NilClass
                    @dtype = dtype
                    shape = []
                    initial_value
                  when Graph::Operation
                    @dtype = dtype || initial_value.dtype
                    shape = shape || initial_value.output_shapes.first
                    initial_value
                  when Tensor
                    @dtype = initial_value.dtype
                    shape = shape || initial_value.shape
                    initial_value
                  else
                    tensor = Tensor.from_value(initial_value, dtype: dtype)
                    @dtype = tensor.dtype
                    shape = tensor.shape
                    tensor
                  end

  name = name&.to_s
  shared_name = shared_name&.to_s
  unique_name = ExecutionContext.current.unique_name(name || shared_name)
  shared_name ||= unique_name
  @name = unique_name

  collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
  if trainable
    collections << Graph::GraphKeys::TRAINABLE_VARIABLES
  end

  ExecutionContext.current.add_to_collections(collections, self)

  @handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
  self.value = initial_value if initial_value
end

Instance Attribute Details

#dtypeObject (readonly)

Returns the value of attribute dtype.



5
6
7
# File 'lib/tensorflow/variable.rb', line 5

def dtype
  @dtype
end

#handleObject (readonly)

Returns the value of attribute handle.



5
6
7
# File 'lib/tensorflow/variable.rb', line 5

def handle
  @handle
end

#nameObject (readonly)

Returns the value of attribute name.



5
6
7
# File 'lib/tensorflow/variable.rb', line 5

def name
  @name
end

Instance Method Details

#assign_add(value, dtype: nil) ⇒ Object



101
102
103
104
105
106
# File 'lib/tensorflow/variable.rb', line 101

def assign_add(value, dtype: nil)
  @value_handle = nil
  tensor = Tensor.from_value(value, dtype: dtype)
  tensor = Tensorflow.cast(tensor, self.dtype)
  RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
end

#assign_sub(value) ⇒ Object



108
109
110
111
112
113
# File 'lib/tensorflow/variable.rb', line 108

def assign_sub(value)
  @value_handle = nil
  tensor = Tensor.from_value(value, dtype: dtype)
  tensor = Tensorflow.cast(tensor, self.dtype)
  RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype)
end

#consumersObject

These methods match the operation api to enable gradients and sessions



71
72
73
# File 'lib/tensorflow/variable.rb', line 71

def consumers
  self.handle.consumers
end

#initialized?Boolean

Returns:

  • (Boolean)


66
67
68
# File 'lib/tensorflow/variable.rb', line 66

def initialized?
  RawOps.var_is_initialized_op(self.handle)
end

#initializerObject



62
63
64
# File 'lib/tensorflow/variable.rb', line 62

def initializer
  @initializer
end

#inspectObject



119
120
121
122
123
124
125
# File 'lib/tensorflow/variable.rb', line 119

def inspect
  inspection = []
  inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
  inspection << ["shape: #{self.value_handle.shape}"]
  inspection << ["dtype: #{self.value_handle.dtype}"]
  "#<#{self.class} #{inspection.join(", ")}>"
end

#outputsObject

This enables executing variables to get the values in a session



76
77
78
# File 'lib/tensorflow/variable.rb', line 76

def outputs
  [Graph::OperationOutput.from_index(self.value_handle, 0)]
end

#rankObject



93
94
95
# File 'lib/tensorflow/variable.rb', line 93

def rank
  self.shape.size
end

#reshape(shape) ⇒ Object



97
98
99
# File 'lib/tensorflow/variable.rb', line 97

def reshape(shape)
  RawOps.reshape(self, shape)
end

#shapeObject



84
85
86
# File 'lib/tensorflow/variable.rb', line 84

def shape
  self.value_handle.shape
end

#tensorObject



88
89
90
91
# File 'lib/tensorflow/variable.rb', line 88

def tensor
  raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
  self.value_handle.tensor
end

#to_ptrObject



80
81
82
# File 'lib/tensorflow/variable.rb', line 80

def to_ptr
  self.handle.to_ptr
end

#to_sObject



115
116
117
# File 'lib/tensorflow/variable.rb', line 115

def to_s
  inspect
end

#valueObject



49
50
51
52
53
54
55
56
# File 'lib/tensorflow/variable.rb', line 49

def value
  case value_handle
    when Eager::TensorHandle
      value_handle.value
    when Graph::Operation
      value_handle
  end
end

#value=(value) ⇒ Object



58
59
60
# File 'lib/tensorflow/variable.rb', line 58

def value=(value)
  @initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
end

#value_handleObject



45
46
47
# File 'lib/tensorflow/variable.rb', line 45

def value_handle
  @value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
end