Class: TensorStream::Freezer

Inherits:
Object
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/utils/freezer.rb

Instance Method Summary collapse

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Instance Method Details

#convert(session, checkpoint_folder, output_file) ⇒ Object

Utility class to convert variables to constants for production deployment



8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/tensor_stream/utils/freezer.rb', line 8

def convert(session, checkpoint_folder, output_file)
  model_file = File.join(checkpoint_folder, "model.yaml")
  TensorStream.graph.as_default do |current_graph|
    YamlLoader.new.load_from_string(File.read(model_file))
    saver = TensorStream::Train::Saver.new
    saver.restore(session, checkpoint_folder)

    # collect all assign ops and remove them from the graph
    remove_nodes = Set.new(current_graph.nodes.values.select { |op| op.is_a?(TensorStream::Operation) && op.operation == :assign }.map { |op| op.consumers.to_a }.flatten.uniq)

    output_buffer = TensorStream::Yaml.new.get_string(current_graph) { |graph, node_key|
      node = graph.get_tensor_by_name(node_key)
      case node.operation
      when :variable_v2
        value = Evaluator.read_variable(node.graph, node.options[:var_name])
       if value.nil?
         raise "#{node.options[:var_name]} has no value"
       end

        options = {
          value: value,
          data_type: node.data_type,
          shape: shape_eval(value),
        }
        const_op = TensorStream::Operation.new(current_graph, inputs: [], options: options)
        const_op.name = node.name
        const_op.operation = :const
        const_op.data_type = node.data_type
        const_op.shape = TensorShape.new(shape_eval(value))

        const_op
      when :assign
        nil
      else
        remove_nodes.include?(node.name) ? nil : node
      end
    }
    File.write(output_file, output_buffer)
  end
end