Class: TensorStream::Session
- Inherits:
-
Object
- Object
- TensorStream::Session
- Includes:
- StringHelper
- Defined in:
- lib/tensor_stream/session.rb
Overview
TensorStream class that defines a session
Instance Attribute Summary collapse
-
#closed ⇒ Object
readonly
Returns the value of attribute closed.
-
#last_session_context ⇒ Object
readonly
Returns the value of attribute last_session_context.
-
#randomizer ⇒ Object
Returns the value of attribute randomizer.
-
#session_cache ⇒ Object
readonly
Returns the value of attribute session_cache.
-
#target ⇒ Object
readonly
Returns the value of attribute target.
Class Method Summary collapse
Instance Method Summary collapse
- #clear_session_cache ⇒ Object
- #close ⇒ Object
- #closed? ⇒ Boolean
- #dump_internal_ops(tensor) ⇒ Object
- #dump_ops(tensor, selector) ⇒ Object
- #dump_user_ops(tensor) ⇒ Object
- #graph_ml(tensor, filename) ⇒ Object
-
#initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor, evaluator_options: {}) ⇒ Session
constructor
A new instance of Session.
- #list_devices ⇒ Object
- #run(*args) ⇒ Object
Methods included from StringHelper
Constructor Details
#initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor, evaluator_options: {}) ⇒ Session
Returns a new instance of Session.
9 10 11 12 13 14 15 16 |
# File 'lib/tensor_stream/session.rb', line 9 def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor, evaluator_options: {}) @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}") @thread_pool = thread_pool_class.new @closed = false @session_cache = {} @randomizer = {} = end |
Instance Attribute Details
#closed ⇒ Object (readonly)
Returns the value of attribute closed.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def closed @closed end |
#last_session_context ⇒ Object (readonly)
Returns the value of attribute last_session_context.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def last_session_context @last_session_context end |
#randomizer ⇒ Object
Returns the value of attribute randomizer.
7 8 9 |
# File 'lib/tensor_stream/session.rb', line 7 def randomizer @randomizer end |
#session_cache ⇒ Object (readonly)
Returns the value of attribute session_cache.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def session_cache @session_cache end |
#target ⇒ Object (readonly)
Returns the value of attribute target.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def target @target end |
Class Method Details
.default_session ⇒ Object
22 23 24 |
# File 'lib/tensor_stream/session.rb', line 22 def self.default_session @session ||= Session.new end |
Instance Method Details
#clear_session_cache ⇒ Object
18 19 20 |
# File 'lib/tensor_stream/session.rb', line 18 def clear_session_cache @session_cache = {} end |
#close ⇒ Object
59 60 61 |
# File 'lib/tensor_stream/session.rb', line 59 def close @closed = true end |
#closed? ⇒ Boolean
63 64 65 |
# File 'lib/tensor_stream/session.rb', line 63 def closed? @closed end |
#dump_internal_ops(tensor) ⇒ Object
67 68 69 |
# File 'lib/tensor_stream/session.rb', line 67 def dump_internal_ops(tensor) dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? }) end |
#dump_ops(tensor, selector) ⇒ Object
75 76 77 78 79 80 81 |
# File 'lib/tensor_stream/session.rb', line 75 def dump_ops(tensor, selector) graph = tensor.graph graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node| next unless @last_session_context[node.name] "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}" end.compact end |
#dump_user_ops(tensor) ⇒ Object
71 72 73 |
# File 'lib/tensor_stream/session.rb', line 71 def dump_user_ops(tensor) dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? }) end |
#graph_ml(tensor, filename) ⇒ Object
83 84 85 |
# File 'lib/tensor_stream/session.rb', line 83 def graph_ml(tensor, filename) TensorStream::Graphml.new(self).serialize(tensor, filename) end |
#list_devices ⇒ Object
55 56 57 |
# File 'lib/tensor_stream/session.rb', line 55 def list_devices [Device.new("cpu")] end |
#run(*args) ⇒ Object
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/tensor_stream/session.rb', line 26 def run(*args) = if args.last.is_a?(Hash) args.pop else {} end context = { _cache: @session_cache } # scan for placeholders and assign value if [:feed_dict] [:feed_dict].keys.each do |k| if k.is_a?(Placeholder) context[k.name.to_sym] = [:feed_dict][k] end end end [:thread_pool] = @thread_pool [:log_intermediates] = [:log_intermediates] evaluator = @evaluator_class.new(self, context.merge!(retain: [:retain]), ) execution_context = {} @last_session_context = context result = args.collect { |e| evaluator.run(e, execution_context) } result.size == 1 ? result.first : result end |