Class: TensorStream::Pbtext

Inherits:
Serializer show all
Includes:
OpHelper, StringHelper
Defined in:
lib/tensor_stream/graph_serializers/pbtext.rb

Overview

Parses pbtext files and loads it as a graph

Instance Method Summary collapse

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Methods included from StringHelper

#camelize, #constantize, #symbolize_keys, #underscore

Methods inherited from Serializer

#serialize

Instance Method Details

#get_string(tensor_or_graph, session = nil, graph_keys = nil) ⇒ Object


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# File 'lib/tensor_stream/graph_serializers/pbtext.rb', line 7

def get_string(tensor_or_graph, session = nil, graph_keys = nil)
  graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph
  @lines = []

  node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) }

  node_keys.each do |k|
    node = if block_given?
      yield graph, k
    else
      graph.get_tensor_by_name(k)
    end

    @lines << "node {"
    @lines << "  name: #{node.name.to_json}"
    if node.is_a?(TensorStream::Operation)
      @lines << "  op: #{camelize(node.operation.to_s).to_json}"
      node.inputs.each do |input|
        next unless input

        @lines << "  input: #{input.name.to_json}"
      end
      # type
      pb_attr("T", "type: #{sym_to_protobuf_type(node.data_type)}")

      case node.operation.to_s
      when "const"
        pb_attr("value", tensor_value(node))
      when "variable_v2"
        pb_attr("shape", shape_buf(node, "shape"))
      end
      process_options(node)
    end
    @lines << "}"
  end
  @lines << "versions {"
  @lines << "  producer: 26"
  @lines << "}"
  @lines.flatten.join("\n")
end