Module: TensorStream::Train::Utils

Included in:
TensorStream::Trainer
Defined in:
lib/tensor_stream/train/utils.rb

Overview

convenience methods used for training

Instance Method Summary collapse

Instance Method Details

#create_global_step(graph = nil) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
# File 'lib/tensor_stream/train/utils.rb', line 5

def create_global_step(graph = nil)
  target_graph = graph || TensorStream.get_default_graph
  raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil?

  TensorStream.variable_scope.get_variable(
    TensorStream::GraphKeys::GLOBAL_STEP, shape: [],
                                          dtype: :int64,
                                          initializer: TensorStream.zeros_initializer,
                                          trainable: false,
                                          collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES,
                                                        TensorStream::GraphKeys::GLOBAL_STEP])
end

#get_global_step(graph = nil) ⇒ Object



18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# File 'lib/tensor_stream/train/utils.rb', line 18

def get_global_step(graph = nil)
  target_graph = graph || TensorStream.get_default_graph
  global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP)
  global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty?
                          begin
                            target_graph.get_tensor_by_name('global_step:0')
                          rescue TensorStream::KeyError
                            nil
                          end
                       elsif global_step_tensors.size == 1
                         global_step_tensors[0]
                       else
                         TensorStream.logger.error("Multiple tensors in global_step collection.")
                         nil
                       end
  global_step_tensor
end