Class: TensorStream::Train::Saver
- Inherits:
-
Object
- Object
- TensorStream::Train::Saver
- Defined in:
- lib/tensor_stream/train/saver.rb
Overview
High level class used for loading and saving variables
Instance Method Summary collapse
- #restore(_session, inputfile) ⇒ Object
- #save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
Instance Method Details
#restore(_session, inputfile) ⇒ Object
37 38 39 40 41 42 43 44 |
# File 'lib/tensor_stream/train/saver.rb', line 37 def restore(_session, inputfile) input_dump = JSON.parse(File.read(inputfile)) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) vars.each do |variable| variable.value = input_dump['variables'][variable.name] end end |
#save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# File 'lib/tensor_stream/train/saver.rb', line 7 def save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} graph = {} gs = eval_global_step(session, global_step) output_dump = { variables: variables, graph: graph, global_step: gs } vars.each do |variable| variables[variable.name] = variable.read_value end basename = File.basename(outputfile) path = File.dirname(outputfile) new_filename = File.join(path, [basename, gs].compact.join('-')) File.write(new_filename, output_dump.to_json) path end |