Class: TensorStream::Graph

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/graph.rb

Overview

A class that defines a TensorStream graph

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeGraph

Returns a new instance of Graph.



6
7
8
9
10
11
12
13
14
15
# File 'lib/tensor_stream/graph.rb', line 6

def initialize
  @eager_execution = false
  @nodes = {}
  @node_keys = []
  @collections = {
    :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
    :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
  }
  @constants = {}
end

Instance Attribute Details

#collectionsObject

Returns the value of attribute collections.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def collections
  @collections
end

#constantsObject

Returns the value of attribute constants.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def constants
  @constants
end

#eager_executionObject

Returns the value of attribute eager_execution.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def eager_execution
  @eager_execution
end

#node_keysObject

Returns the value of attribute node_keys.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def node_keys
  @node_keys
end

#nodesObject

Returns the value of attribute nodes.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def nodes
  @nodes
end

#random_seedObject

Returns the value of attribute random_seed.



4
5
6
# File 'lib/tensor_stream/graph.rb', line 4

def random_seed
  @random_seed
end

Class Method Details

.create_defaultObject



67
68
69
# File 'lib/tensor_stream/graph.rb', line 67

def self.create_default
  Thread.current[:tensor_stream_current_graph] = TensorStream::Graph.new
end

.get_default_graphObject



63
64
65
# File 'lib/tensor_stream/graph.rb', line 63

def self.get_default_graph
  Thread.current[:tensor_stream_current_graph] || create_default
end

.parse_from_string(buffer) ⇒ Object



216
217
218
219
# File 'lib/tensor_stream/graph.rb', line 216

def self.parse_from_string(buffer)
  builder = TensorStream::GraphBuilder.new(Graph.new)
  builder.build(buffer)
end

Instance Method Details

#add_node(node) ⇒ Object



80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# File 'lib/tensor_stream/graph.rb', line 80

def add_node(node)
  raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder)

  node.name = if @nodes[node.name]
                uniqunify(node.name)
              else
                node.name
              end

  node.device = get_device_scope
  @node_keys << node.name
  @nodes[node.name] = node
  @constants[node.name] = node if node.is_const
  # puts "adding node"
  node.send(:propagate_outputs)
  node.send(:propagate_consumer, node)
  # puts "#{node.name}"
  node.value = node.eval if @eager_execution
end

#add_node!(name, node) ⇒ Object



113
114
115
116
# File 'lib/tensor_stream/graph.rb', line 113

def add_node!(name, node)
  @nodes[name] = node
  node
end

#add_to_collection(collection_name, val) ⇒ Object



75
76
77
78
# File 'lib/tensor_stream/graph.rb', line 75

def add_to_collection(collection_name, val)
  @collections[collection_name.to_sym] ||= []
  @collections[collection_name.to_sym] << val
end

#add_variable(node, options = {}) ⇒ Object



118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# File 'lib/tensor_stream/graph.rb', line 118

def add_variable(node, options = {})
  scope = _variable_scope

  raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse
  return @nodes[node.name] if @nodes[node.name]
  raise "shape is not declared for #{node.name}" if node.shape.nil?

  if !options[:collections].nil? && !options[:collections].empty?
    options[:collections] = [options[:collections]] unless options[:collections].is_a?(Array)
    options[:collections].each { |coll| add_to_collection(coll, node) }
  end

  add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
  add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable?
  add_node(node)
end

#as_default {|_self| ... } ⇒ Object

Yields:

  • (_self)

Yield Parameters:



32
33
34
35
36
# File 'lib/tensor_stream/graph.rb', line 32

def as_default
  Thread.current[:tensor_stream_current_graph] = self
  yield(self) if block_given?
  self
end

#as_graph_defObject



212
213
214
# File 'lib/tensor_stream/graph.rb', line 212

def as_graph_def
  TensorStream::Pbtext.new.get_string(self)
end

#control_dependencies(control_inputs = []) ⇒ Object



135
136
137
138
139
140
141
142
143
144
# File 'lib/tensor_stream/graph.rb', line 135

def control_dependencies(control_inputs = [])
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs)
  begin
    yield
  ensure
    Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop
  end
end

#device(device_name) ⇒ Object

Returns a context manager that specifies the default device to use.



52
53
54
55
56
57
58
59
60
61
# File 'lib/tensor_stream/graph.rb', line 52

def device(device_name)
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
  Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
  begin
    yield
  ensure
    Thread.current["ts_graph_#{object_id}"][:default_device].pop
  end
end

#disable_eager_executionObject



150
151
152
# File 'lib/tensor_stream/graph.rb', line 150

def disable_eager_execution
  @eager_execution = false
end

#enable_eager_executionObject



146
147
148
# File 'lib/tensor_stream/graph.rb', line 146

def enable_eager_execution
  @eager_execution = true
end

#executing_eagerly?Boolean

Returns:

  • (Boolean)


154
155
156
# File 'lib/tensor_stream/graph.rb', line 154

def executing_eagerly?
  @eager_execution
end

#get_collection(name, _options = {}) ⇒ Object



71
72
73
# File 'lib/tensor_stream/graph.rb', line 71

def get_collection(name, _options = {})
  @collections[name.to_sym]
end

#get_const_counterObject



184
185
186
187
188
189
190
191
# File 'lib/tensor_stream/graph.rb', line 184

def get_const_counter
  @const_counter ||= 0

  name = @const_counter.zero? ? '' : "_#{@const_counter}"

  @const_counter += 1
  name
end

#get_dependency_scopeObject



200
201
202
203
204
# File 'lib/tensor_stream/graph.rb', line 200

def get_dependency_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return nil if graph_thread_storage.nil? || graph_thread_storage[:control_dependencies].nil?
  graph_thread_storage[:control_dependencies].last
end

#get_device_scopeObject



206
207
208
209
210
# File 'lib/tensor_stream/graph.rb', line 206

def get_device_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
  graph_thread_storage[:default_device].last
end

#get_name_scopeObject



193
194
195
196
197
198
# File 'lib/tensor_stream/graph.rb', line 193

def get_name_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?

  graph_thread_storage[:current_scope].join('/')
end

#get_node(name) ⇒ Object



104
105
106
# File 'lib/tensor_stream/graph.rb', line 104

def get_node(name)
  @nodes[name]
end

#get_operation_counterObject



158
159
160
161
162
163
164
165
166
# File 'lib/tensor_stream/graph.rb', line 158

def get_operation_counter
  @op_counter ||= 0

  name = @op_counter.zero? ? '' : "_#{@op_counter}"

  @op_counter += 1

  name
end

#get_placeholder_counterObject



168
169
170
171
172
173
174
# File 'lib/tensor_stream/graph.rb', line 168

def get_placeholder_counter
  @placeholder_counter ||= 0
  @placeholder_counter += 1

  return '' if @placeholder_counter == 1
  "_#{@placeholder_counter}"
end

#get_tensor_by_name(name) ⇒ Object



108
109
110
111
# File 'lib/tensor_stream/graph.rb', line 108

def get_tensor_by_name(name)
  raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)
  get_node(name)
end

#get_var_counterObject



176
177
178
179
180
181
182
# File 'lib/tensor_stream/graph.rb', line 176

def get_var_counter
  @var_counter ||= 0
  @var_counter += 1

  return '' if @var_counter == 1
  "_#{@var_counter}"
end

#graph_def_versionsObject



221
222
223
# File 'lib/tensor_stream/graph.rb', line 221

def graph_def_versions
  "producer: 26"
end

#name_scope(name = nil) ⇒ Object



38
39
40
41
42
43
44
45
46
47
48
# File 'lib/tensor_stream/graph.rb', line 38

def name_scope(name = nil)
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:current_scope] ||= []
  Thread.current["ts_graph_#{object_id}"][:current_scope] << name

  begin
    yield get_name_scope if block_given?
  ensure
    Thread.current["ts_graph_#{object_id}"][:current_scope].pop
  end
end

#node_added?(name) ⇒ Boolean

Returns:

  • (Boolean)


100
101
102
# File 'lib/tensor_stream/graph.rb', line 100

def node_added?(name)
  @nodes.key?(name)
end

#resetObject



17
18
19
20
21
22
23
24
25
26
27
28
29
30
# File 'lib/tensor_stream/graph.rb', line 17

def reset
  @placeholder_counter = 0
  @const_counter = 0
  @var_counter = 0
  @op_counter = 0
  @random_seed = nil
  @nodes = {}
  @node_keys = []
  @collections = {
    :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
    :"#{GraphKeys::TRAINABLE_VARIABLES}" => []
  }
  @constants = {}
end