Class: TensorStream::Train::Saver

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/train/saver.rb

Overview

High level class used for loading and saving variables

Instance Method Summary collapse

Instance Method Details

#restore(_session, inputfile) ⇒ Object



37
38
39
40
41
42
43
44
# File 'lib/tensor_stream/train/saver.rb', line 37

def restore(_session, inputfile)
  input_dump = JSON.parse(File.read(inputfile))

  vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
  vars.each do |variable|
    variable.value = input_dump['variables'][variable.name]
  end
end

#save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object



7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# File 'lib/tensor_stream/train/saver.rb', line 7

def save(session, outputfile, global_step: nil,
         latest_filename: nil,
         meta_graph_suffix: 'meta',
         write_meta_graph: true,
         write_state: true,
         strip_default_attrs: false)
  vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

  variables = {}
  graph = {}
  gs =  eval_global_step(session, global_step)
  output_dump = {
    variables: variables,
    graph: graph,
    global_step: gs
  }

  vars.each do |variable|
    variables[variable.name] = variable.read_value
  end

  basename = File.basename(outputfile)
  path = File.dirname(outputfile)

  new_filename = File.join(path, [basename, gs].compact.join('-'))
  File.write(new_filename, output_dump.to_json)

  path
end