Class: TensorStream::Train::Saver
- Inherits:
-
Object
- Object
- TensorStream::Train::Saver
- Includes:
- OpHelper
- Defined in:
- lib/tensor_stream/train/saver.rb
Overview
High level class used for loading and saving variables
Instance Method Summary collapse
- #restore(_session, inputfile) ⇒ Object
- #save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
Methods included from OpHelper
#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal
Instance Method Details
#restore(_session, inputfile) ⇒ Object
39 40 41 42 43 44 45 46 |
# File 'lib/tensor_stream/train/saver.rb', line 39 def restore(_session, inputfile) input_dump = JSON.parse(File.read(inputfile)) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) vars.each do |variable| variable.value = input_dump['variables'][variable.name] end end |
#save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) ⇒ Object
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# File 'lib/tensor_stream/train/saver.rb', line 9 def save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} graph = {} gs = eval_global_step(session, global_step) output_dump = { variables: variables, graph: graph, global_step: gs } vars.each do |variable| variables[variable.name] = variable.read_value end basename = File.basename(outputfile) path = File.dirname(outputfile) new_filename = File.join(path, [basename, gs].compact.join('-')) File.write(new_filename, output_dump.to_json) path end |