Module: TensorStream::Utils

Included in:
TensorStream
Defined in:
lib/tensor_stream/utils.rb

Instance Method Summary collapse

Instance Method Details

#__v_scope_nameObject



107
108
109
# File 'lib/tensor_stream/utils.rb', line 107

def __v_scope_name
  Thread.current[:tensor_stream_variable_scope].map(&:name).compact.reject(&:empty?).join('/')
end

#assign(ref, value, name: nil) ⇒ Object



164
165
166
167
# File 'lib/tensor_stream/utils.rb', line 164

def assign(ref, value, name: nil)
  raise "#{ref.name} not a variable" unless ref.is_a?(Variable)
  ref.assign(value, name: name)
end

#check_allowed_types(input, types) ⇒ Object



205
206
207
208
209
210
# File 'lib/tensor_stream/utils.rb', line 205

def check_allowed_types(input, types)
  return input unless input.is_a?(Tensor)
  return input if input.data_type.nil?

  raise "#{input.source}: Parameter data type #{input.data_type} passed not in #{types.join(',')}" unless types.include?(input.data_type.to_sym)
end

#check_data_types(input_a, input_b) ⇒ Object



212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# File 'lib/tensor_stream/utils.rb', line 212

def check_data_types(input_a, input_b)
  if !input_a.is_a?(Tensor) && input_b.is_a?(Tensor)
    input_a = convert_to_tensor(input_a, dtype: input_b.data_type)
  elsif !input_b.is_a?(Tensor) && input_a.is_a?(Tensor)
    input_b = convert_to_tensor(input_b, dtype: input_a.data_type)
  else
    input_a = convert_to_tensor(input_a)
    input_b = convert_to_tensor(input_b)
  end

  if norm_dtype(input_a.data_type) != norm_dtype(input_b.data_type)
    raise TensorStream::ValueError, "Value Error: Tensor conversion requested dtype #{input_a.data_type} for tensor type #{input_b.data_type}"
  end

  [input_a, input_b]
end

#constant(value, dtype: nil, shape: nil, internal: false, name: 'Const') ⇒ Object



126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# File 'lib/tensor_stream/utils.rb', line 126

def constant(value, dtype: nil, shape: nil, internal: false, name: 'Const')
  shared_options = { const: true, value: value, name: name, internal: internal }
  if value.is_a?(Float)
    TensorStream::Tensor.new(dtype || :float32, 0, shape || [], shared_options)
  elsif value.is_a?(Integer)
    TensorStream::Tensor.new(dtype || :int32, 0, shape || [], shared_options)
  elsif value.is_a?(String)
    TensorStream::Tensor.new(dtype || :string, 0, shape || [], shared_options)
  elsif !!value == value
    TensorStream::Tensor.new(dtype || :boolean, 0, shape || [], shared_options)
  elsif value.is_a?(Array)
    dimension = shape || shape_eval(value)
    rank = dimension.size

    cur_dtype = dtype || Tensor.detect_type(value.flatten.last)
    value = Tensor.cast_dtype(value, cur_dtype) unless dtype.nil?

    shared_options[:value] = value
    TensorStream::Tensor.new(cur_dtype, rank, dimension, shared_options)
  end
end

#control_dependencies(control_inputs, &block) ⇒ Object



191
192
193
# File 'lib/tensor_stream/utils.rb', line 191

def control_dependencies(control_inputs, &block)
  TensorStream.get_default_graph.control_dependencies(control_inputs, &block)
end

#convert_to_tensor(value, dtype: nil, name: nil, preferred_dtype: nil) ⇒ Object



195
196
197
198
199
200
201
202
203
# File 'lib/tensor_stream/utils.rb', line 195

def convert_to_tensor(value, dtype: nil, name: nil, preferred_dtype: nil)
  return convert_to_tensor(value.call) if value.is_a?(Proc)

  if !value.is_a?(Tensor)
    i_cons(value, dtype: dtype || Tensor.detect_type(value), name: name)
  else
    value
  end
end

#device(device_uri, &block) ⇒ Object



87
88
89
# File 'lib/tensor_stream/utils.rb', line 87

def device(device_uri, &block)
  get_default_graph.device(device_uri, &block)
end

#disable_eager_executionObject



23
24
25
# File 'lib/tensor_stream/utils.rb', line 23

def disable_eager_execution
  TensorStream::Graph.get_default_graph.disable_eager_execution
end

#dynamic_stitch(indices, data, name: nil) ⇒ Object



152
153
154
# File 'lib/tensor_stream/utils.rb', line 152

def dynamic_stitch(indices, data, name: nil)
  TensorStream::DynamicStitch.new(:dynamic_stitch, [indices, data], name: name)
end

#enable_eager_executionObject



19
20
21
# File 'lib/tensor_stream/utils.rb', line 19

def enable_eager_execution
  TensorStream::Graph.get_default_graph.enable_eager_execution
end

#executing_eagerly?Boolean

Returns:

  • (Boolean)


27
28
29
# File 'lib/tensor_stream/utils.rb', line 27

def executing_eagerly?
  TensorStream::Graph.get_default_graph.executing_eagerly?
end

#float32Object



3
4
5
# File 'lib/tensor_stream/utils.rb', line 3

def float32
  Types.float32
end

#get_collection(name, options = {}) ⇒ Object



160
161
162
# File 'lib/tensor_stream/utils.rb', line 160

def get_collection(name, options = {})
  Graph.get_default_graph.get_collection(name, options)
end

#get_default_graphObject



11
12
13
# File 'lib/tensor_stream/utils.rb', line 11

def get_default_graph
  TensorStream::Graph.get_default_graph
end

#get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil) ⇒ Object



156
157
158
# File 'lib/tensor_stream/utils.rb', line 156

def get_variable(name, dtype: nil, shape: nil, initializer: nil, trainable: true, collections: nil)
  get_variable_scope.get_variable(name, dtype: dtype, shape: shape, initializer: initializer, trainable: trainable, collections: collections)
end

#get_variable_scopeObject



102
103
104
105
# File 'lib/tensor_stream/utils.rb', line 102

def get_variable_scope
  return VariableScope.new unless Thread.current[:tensor_stream_variable_scope]
  Thread.current[:tensor_stream_variable_scope].last || VariableScope.new
end

#global_variables_initializerObject



175
176
177
# File 'lib/tensor_stream/utils.rb', line 175

def global_variables_initializer
  TensorStream::Variable.global_variables_initializer
end

#graphObject



7
8
9
# File 'lib/tensor_stream/utils.rb', line 7

def graph
  TensorStream::Graph.new
end

#group(inputs, name: nil) ⇒ Object



148
149
150
# File 'lib/tensor_stream/utils.rb', line 148

def group(inputs, name: nil)
  TensorStream::ControlFlow.new(:group, inputs, nil, name: name)
end

#layersObject



122
123
124
# File 'lib/tensor_stream/utils.rb', line 122

def layers
  TensorStream::Layers
end

#list_local_devicesObject

List available evaluators + devices in the current local environment Returns:

  • An array containing the names of those devices



35
36
37
38
39
40
41
42
# File 'lib/tensor_stream/utils.rb', line 35

def list_local_devices
  local_name = 'job:localhost'
  TensorStream::Evaluator.evaluators.collect do |k, v|
    v[:class].query_supported_devices.collect do |device_str|
      [local_name, "ts:#{k}:#{device_str.name}"].join('/')
    end
  end.flatten
end

#name_scope(name, default: nil, values: nil) ⇒ Object



91
92
93
94
95
96
97
98
99
100
# File 'lib/tensor_stream/utils.rb', line 91

def name_scope(name, default: nil, values: nil)
  if values
    graph_count = values.select { |v| v.is_a?(Tensor) }.map(&:graph).map(&:object_id).uniq.size
    raise "values are not on the same graph" if graph_count > 1
  end

  get_default_graph.name_scope(name || default) do |scope|
    yield scope if block_given?
  end
end

#norm_dtype(dtype) ⇒ Object



229
230
231
232
233
234
235
236
237
238
239
# File 'lib/tensor_stream/utils.rb', line 229

def norm_dtype(dtype)
  dtype = dtype.to_sym
  case dtype
  when :int
    :int32
  when :float
    :float32
  else
    dtype
  end
end

#placeholder(dtype, shape: nil, name: nil) ⇒ Object

Inserts a placeholder for a tensor that will be always fed.



171
172
173
# File 'lib/tensor_stream/utils.rb', line 171

def placeholder(dtype, shape: nil, name: nil)
  TensorStream::Placeholder.new(dtype, nil, shape, name: name)
end

#program(&block) ⇒ Object



118
119
120
# File 'lib/tensor_stream/utils.rb', line 118

def program(&block)
  block.call(self)
end

#reset_default_graphObject



15
16
17
# File 'lib/tensor_stream/utils.rb', line 15

def reset_default_graph
  TensorStream::Graph.get_default_graph.reset
end

#session(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false) {|session| ... } ⇒ Object

Yields:



111
112
113
114
115
116
# File 'lib/tensor_stream/utils.rb', line 111

def session(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false)
  session = TensorStream::Session.new(evaluator, thread_pool_class: thread_pool_class, log_device_placement: log_device_placement)
  yield session if block_given?

  session
end

#set_random_seed(seed) ⇒ Object



187
188
189
# File 'lib/tensor_stream/utils.rb', line 187

def set_random_seed(seed)
  TensorStream.get_default_graph.random_seed = seed
end

#trainObject



179
180
181
# File 'lib/tensor_stream/utils.rb', line 179

def train
  TensorStream::Trainer
end

#trainable_variablesObject



183
184
185
# File 'lib/tensor_stream/utils.rb', line 183

def trainable_variables
  TensorStream.get_default_graph.get_collection(TensorStream::GraphKeys::TRAINABLE_VARIABLES)
end

#variable(value, name: nil, initializer: nil, graph: nil, dtype: nil, trainable: true) ⇒ Object

Creates a variable A variable maintains state across sessions



47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# File 'lib/tensor_stream/utils.rb', line 47

def variable(value, name: nil, initializer: nil, graph: nil, dtype: nil, trainable: true)
  op = Operation.new(:assign, nil, value)
  common_options = {
    initializer: initializer || op,
    name: name,
    graph: graph,
    dtype: dtype,
    trainable: trainable
  }
  tensor = if value.is_a?(String)
    TensorStream::Variable.new(dtype || :string, 0, [], get_variable_scope, common_options)
  elsif value.is_a?(Integer)
    TensorStream::Variable.new(dtype || :int32, 0, [], get_variable_scope, common_options)
  elsif value.is_a?(Float)
    TensorStream::Variable.new(dtype || :float32, 0, [], get_variable_scope, common_options)
  else
    TensorStream::Variable.new(dtype || :float32, 0, nil, get_variable_scope, common_options)
  end
  op.inputs[0] = tensor
  tensor
end

#variable_scope(scope = nil, reuse: nil, initializer: nil) ⇒ Object



69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# File 'lib/tensor_stream/utils.rb', line 69

def variable_scope(scope = nil, reuse: nil, initializer: nil)
  Thread.current[:tensor_stream_variable_scope] ||= []
  variable_scope = VariableScope.new(name: scope, reuse: reuse, initializer: initializer)
  Thread.current[:tensor_stream_variable_scope] << variable_scope
  scope_name = __v_scope_name
  if block_given?
    begin
      TensorStream.get_default_graph.name_scope(scope) do
        yield(scope_name)
      end
    ensure
      Thread.current[:tensor_stream_variable_scope].pop
    end
  else
    variable_scope
  end
end