Class: TensorStream::Graph
- Inherits:
-
Object
- Object
- TensorStream::Graph
- Defined in:
- lib/tensor_stream/graph.rb
Overview
A class that defines a TensorStream graph
Instance Attribute Summary collapse
-
#collections ⇒ Object
Returns the value of attribute collections.
-
#constants ⇒ Object
Returns the value of attribute constants.
-
#eager_execution ⇒ Object
Returns the value of attribute eager_execution.
-
#node_keys ⇒ Object
Returns the value of attribute node_keys.
-
#nodes ⇒ Object
Returns the value of attribute nodes.
-
#random_seed ⇒ Object
Returns the value of attribute random_seed.
Class Method Summary collapse
Instance Method Summary collapse
- #add_node(node) ⇒ Object
- #add_node!(name, node) ⇒ Object
- #add_to_collection(collection_name, val) ⇒ Object
- #add_variable(node, options = {}) ⇒ Object
- #as_default {|_self| ... } ⇒ Object
- #as_graph_def ⇒ Object
- #control_dependencies(control_inputs = []) ⇒ Object
-
#device(device_name) ⇒ Object
Returns a context manager that specifies the default device to use.
- #disable_eager_execution ⇒ Object
- #enable_eager_execution ⇒ Object
- #executing_eagerly? ⇒ Boolean
- #get_collection(name, _options = {}) ⇒ Object
- #get_const_counter ⇒ Object
- #get_dependency_scope ⇒ Object
- #get_device_scope ⇒ Object
- #get_name_scope ⇒ Object
- #get_node(name) ⇒ Object
- #get_operation_counter ⇒ Object
- #get_placeholder_counter ⇒ Object
- #get_tensor_by_name(name) ⇒ Object
- #get_var_counter ⇒ Object
- #graph_def_versions ⇒ Object
-
#initialize ⇒ Graph
constructor
A new instance of Graph.
- #name_scope(name = nil) ⇒ Object
- #node_added?(name) ⇒ Boolean
- #reset ⇒ Object
Constructor Details
#initialize ⇒ Graph
Returns a new instance of Graph.
6 7 8 9 10 11 12 13 14 15 |
# File 'lib/tensor_stream/graph.rb', line 6 def initialize @eager_execution = false @nodes = {} @node_keys = [] @collections = { :"#{GraphKeys::GLOBAL_VARIABLES}" => [], :"#{GraphKeys::TRAINABLE_VARIABLES}" => [] } @constants = {} end |
Instance Attribute Details
#collections ⇒ Object
Returns the value of attribute collections.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def collections @collections end |
#constants ⇒ Object
Returns the value of attribute constants.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def constants @constants end |
#eager_execution ⇒ Object
Returns the value of attribute eager_execution.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def eager_execution @eager_execution end |
#node_keys ⇒ Object
Returns the value of attribute node_keys.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def node_keys @node_keys end |
#nodes ⇒ Object
Returns the value of attribute nodes.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def nodes @nodes end |
#random_seed ⇒ Object
Returns the value of attribute random_seed.
4 5 6 |
# File 'lib/tensor_stream/graph.rb', line 4 def random_seed @random_seed end |
Class Method Details
.create_default ⇒ Object
67 68 69 |
# File 'lib/tensor_stream/graph.rb', line 67 def self.create_default Thread.current[:tensor_stream_current_graph] = TensorStream::Graph.new end |
.get_default_graph ⇒ Object
63 64 65 |
# File 'lib/tensor_stream/graph.rb', line 63 def self.get_default_graph Thread.current[:tensor_stream_current_graph] || create_default end |
.parse_from_string(buffer) ⇒ Object
216 217 218 219 |
# File 'lib/tensor_stream/graph.rb', line 216 def self.parse_from_string(buffer) builder = TensorStream::GraphBuilder.new(Graph.new) builder.build(buffer) end |
Instance Method Details
#add_node(node) ⇒ Object
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/tensor_stream/graph.rb', line 80 def add_node(node) raise 'Placeholder cannot be used when eager_execution is enabled' if @eager_execution && node.is_a?(Placeholder) node.name = if @nodes[node.name] uniqunify(node.name) else node.name end node.device = get_device_scope @node_keys << node.name @nodes[node.name] = node @constants[node.name] = node if node.is_const # puts "adding node" node.send(:propagate_outputs) node.send(:propagate_consumer, node) # puts "#{node.name}" node.value = node.eval if @eager_execution end |
#add_node!(name, node) ⇒ Object
113 114 115 116 |
# File 'lib/tensor_stream/graph.rb', line 113 def add_node!(name, node) @nodes[name] = node node end |
#add_to_collection(collection_name, val) ⇒ Object
75 76 77 78 |
# File 'lib/tensor_stream/graph.rb', line 75 def add_to_collection(collection_name, val) @collections[collection_name.to_sym] ||= [] @collections[collection_name.to_sym] << val end |
#add_variable(node, options = {}) ⇒ Object
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# File 'lib/tensor_stream/graph.rb', line 118 def add_variable(node, = {}) scope = _variable_scope raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse return @nodes[node.name] if @nodes[node.name] raise "shape is not declared for #{node.name}" if node.shape.nil? if ![:collections].nil? && ![:collections].empty? [:collections] = [[:collections]] unless [:collections].is_a?(Array) [:collections].each { |coll| add_to_collection(coll, node) } end add_to_collection(GraphKeys::GLOBAL_VARIABLES, node) add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable? add_node(node) end |
#as_default {|_self| ... } ⇒ Object
32 33 34 35 36 |
# File 'lib/tensor_stream/graph.rb', line 32 def as_default Thread.current[:tensor_stream_current_graph] = self yield(self) if block_given? self end |
#as_graph_def ⇒ Object
212 213 214 |
# File 'lib/tensor_stream/graph.rb', line 212 def as_graph_def TensorStream::Pbtext.new.get_string(self) end |
#control_dependencies(control_inputs = []) ⇒ Object
135 136 137 138 139 140 141 142 143 144 |
# File 'lib/tensor_stream/graph.rb', line 135 def control_dependencies(control_inputs = []) Thread.current["ts_graph_#{object_id}"] ||= {} Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= [] Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Operation.new(:no_op, *control_inputs) begin yield ensure Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop end end |
#device(device_name) ⇒ Object
Returns a context manager that specifies the default device to use.
52 53 54 55 56 57 58 59 60 61 |
# File 'lib/tensor_stream/graph.rb', line 52 def device(device_name) Thread.current["ts_graph_#{object_id}"] ||= {} Thread.current["ts_graph_#{object_id}"][:default_device] ||= [] Thread.current["ts_graph_#{object_id}"][:default_device] << device_name begin yield ensure Thread.current["ts_graph_#{object_id}"][:default_device].pop end end |
#disable_eager_execution ⇒ Object
150 151 152 |
# File 'lib/tensor_stream/graph.rb', line 150 def disable_eager_execution @eager_execution = false end |
#enable_eager_execution ⇒ Object
146 147 148 |
# File 'lib/tensor_stream/graph.rb', line 146 def enable_eager_execution @eager_execution = true end |
#executing_eagerly? ⇒ Boolean
154 155 156 |
# File 'lib/tensor_stream/graph.rb', line 154 def executing_eagerly? @eager_execution end |
#get_collection(name, _options = {}) ⇒ Object
71 72 73 |
# File 'lib/tensor_stream/graph.rb', line 71 def get_collection(name, = {}) @collections[name.to_sym] end |
#get_const_counter ⇒ Object
184 185 186 187 188 189 190 191 |
# File 'lib/tensor_stream/graph.rb', line 184 def get_const_counter @const_counter ||= 0 name = @const_counter.zero? ? '' : "_#{@const_counter}" @const_counter += 1 name end |
#get_dependency_scope ⇒ Object
200 201 202 203 204 |
# File 'lib/tensor_stream/graph.rb', line 200 def get_dependency_scope graph_thread_storage = Thread.current["ts_graph_#{object_id}"] return nil if graph_thread_storage.nil? || graph_thread_storage[:control_dependencies].nil? graph_thread_storage[:control_dependencies].last end |
#get_device_scope ⇒ Object
206 207 208 209 210 |
# File 'lib/tensor_stream/graph.rb', line 206 def get_device_scope graph_thread_storage = Thread.current["ts_graph_#{object_id}"] return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil? graph_thread_storage[:default_device].last end |
#get_name_scope ⇒ Object
193 194 195 196 197 198 |
# File 'lib/tensor_stream/graph.rb', line 193 def get_name_scope graph_thread_storage = Thread.current["ts_graph_#{object_id}"] return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil? graph_thread_storage[:current_scope].join('/') end |
#get_node(name) ⇒ Object
104 105 106 |
# File 'lib/tensor_stream/graph.rb', line 104 def get_node(name) @nodes[name] end |
#get_operation_counter ⇒ Object
158 159 160 161 162 163 164 165 166 |
# File 'lib/tensor_stream/graph.rb', line 158 def get_operation_counter @op_counter ||= 0 name = @op_counter.zero? ? '' : "_#{@op_counter}" @op_counter += 1 name end |
#get_placeholder_counter ⇒ Object
168 169 170 171 172 173 174 |
# File 'lib/tensor_stream/graph.rb', line 168 def get_placeholder_counter @placeholder_counter ||= 0 @placeholder_counter += 1 return '' if @placeholder_counter == 1 "_#{@placeholder_counter}" end |
#get_tensor_by_name(name) ⇒ Object
108 109 110 111 |
# File 'lib/tensor_stream/graph.rb', line 108 def get_tensor_by_name(name) raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name) get_node(name) end |
#get_var_counter ⇒ Object
176 177 178 179 180 181 182 |
# File 'lib/tensor_stream/graph.rb', line 176 def get_var_counter @var_counter ||= 0 @var_counter += 1 return '' if @var_counter == 1 "_#{@var_counter}" end |
#graph_def_versions ⇒ Object
221 222 223 |
# File 'lib/tensor_stream/graph.rb', line 221 def graph_def_versions "producer: 26" end |
#name_scope(name = nil) ⇒ Object
38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/tensor_stream/graph.rb', line 38 def name_scope(name = nil) Thread.current["ts_graph_#{object_id}"] ||= {} Thread.current["ts_graph_#{object_id}"][:current_scope] ||= [] Thread.current["ts_graph_#{object_id}"][:current_scope] << name begin yield get_name_scope if block_given? ensure Thread.current["ts_graph_#{object_id}"][:current_scope].pop end end |
#node_added?(name) ⇒ Boolean
100 101 102 |
# File 'lib/tensor_stream/graph.rb', line 100 def node_added?(name) @nodes.key?(name) end |
#reset ⇒ Object
17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
# File 'lib/tensor_stream/graph.rb', line 17 def reset @placeholder_counter = 0 @const_counter = 0 @var_counter = 0 @op_counter = 0 @random_seed = nil @nodes = {} @node_keys = [] @collections = { :"#{GraphKeys::GLOBAL_VARIABLES}" => [], :"#{GraphKeys::TRAINABLE_VARIABLES}" => [] } @constants = {} end |