Class: TensorStream::Protobuf

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/graph_deserializers/protobuf.rb

Overview

A .pb graph deserializer

Instance Method Summary collapse

Constructor Details

#initializeProtobuf

Returns a new instance of Protobuf.



6
7
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 6

def initialize
end

Instance Method Details

#evaluate_tensor_node(node) ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 31

def evaluate_tensor_node(node)
  if !node["shape"].empty? && node["tensor_content"]
    content = node["tensor_content"]
    unpacked = eval(%("#{content}"))

    if node["dtype"] == "DT_FLOAT"
      TensorShape.reshape(unpacked.unpack("f*"), node["shape"])
    elsif node["dtype"] == "DT_INT32"
      TensorShape.reshape(unpacked.unpack("l*"), node["shape"])
    elsif node["dtype"] == "DT_STRING"
      node["string_val"]
    else
      raise "unknown dtype #{node["dtype"]}"
    end
  else

    val = if node["dtype"] == "DT_FLOAT"
      node["float_val"] ? node["float_val"].to_f : []
    elsif node["dtype"] == "DT_INT32"
      node["int_val"] ? node["int_val"].to_i : []
    elsif node["dtype"] == "DT_STRING"
      node["string_val"]
    else
      raise "unknown dtype #{node["dtype"]}"
    end

    if node["shape"] == [1]
      [val]
    else
      val
    end
  end
end

#load(pbfile) ⇒ Object

parsers a protobuf file and spits out a ruby hash



16
17
18
19
20
21
22
23
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 16

def load(pbfile)
  f = File.new(pbfile, "r")
  lines = []
  while !f.eof? && (str = f.readline.strip)
    lines << str
  end
  evaluate_lines(lines)
end

#load_from_string(buffer) ⇒ Object



9
10
11
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 9

def load_from_string(buffer)
  evaluate_lines(buffer.split("\n").map(&:strip))
end

#map_type_to_ts(attr_value) ⇒ Object



65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 65

def map_type_to_ts(attr_value)
  case attr_value
  when "DT_FLOAT"
    :float32
  when "DT_INT32"
    :int32
  when "DT_INT64"
    :int64
  when "DT_STRING"
    :string
  when "DT_BOOL"
    :boolean
  else
    raise "unknown type #{attr_value}"
  end
end

#options_evaluator(node) ⇒ Object



82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 82

def options_evaluator(node)
  return {} if node["attributes"].nil?

  node["attributes"].map { |attribute|
    attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] }

    if attr_type == "tensor"
      attr_value = evaluate_tensor_node(attr_value)
    elsif attr_type == "type"
      attr_value = map_type_to_ts(attr_value)
    elsif attr_type == "b"
      attr_value = attr_value == "true"
    end

    [attribute["key"], attr_value]
  }.to_h
end

#parse_value(value_node) ⇒ Object



25
26
27
28
29
# File 'lib/tensor_stream/graph_deserializers/protobuf.rb', line 25

def parse_value(value_node)
  return unless value_node["tensor"]

  evaluate_tensor_node(value_node["tensor"])
end