Module: TensorStream::Train::Utils

Included in:
TensorStream::Trainer
Defined in:
lib/tensor_stream/train/utils.rb

Overview

convenience methods used for training

Instance Method Summary collapse

Instance Method Details

#create_global_step(graph = nil) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
# File 'lib/tensor_stream/train/utils.rb', line 5

def create_global_step(graph = nil)
  target_graph = graph || TensorStream.get_default_graph
  raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil?

  TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [],
                                                                                 dtype: :int64,
                                                                                 initializer: TensorStream.zeros_initializer,
                                                                                 trainable: false,
                                                                                 collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES,
                                                                                               TensorStream::GraphKeys::GLOBAL_STEP,])
end

#get_global_step(graph = nil) ⇒ Object



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/tensor_stream/train/utils.rb', line 17

def get_global_step(graph = nil)
  target_graph = graph || TensorStream.get_default_graph
  global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP)
  global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty?
    begin
      target_graph.get_tensor_by_name("global_step:0")
    rescue TensorStream::KeyError
      nil
    end
  elsif global_step_tensors.size == 1
    global_step_tensors[0]
  else
    TensorStream.logger.error("Multiple tensors in global_step collection.")
    nil
  end
  global_step_tensor
end