Class: TensorStream::Session
- Inherits:
-
Object
- Object
- TensorStream::Session
- Includes:
- StringHelper
- Defined in:
- lib/tensor_stream/session.rb
Overview
TensorStream class that defines a session
Instance Attribute Summary collapse
-
#closed ⇒ Object
readonly
Returns the value of attribute closed.
-
#last_session_context ⇒ Object
readonly
Returns the value of attribute last_session_context.
-
#randomizer ⇒ Object
Returns the value of attribute randomizer.
-
#session_cache ⇒ Object
readonly
Returns the value of attribute session_cache.
-
#target ⇒ Object
readonly
Returns the value of attribute target.
Class Method Summary collapse
Instance Method Summary collapse
- #clear_session_cache ⇒ Object
- #close ⇒ Object
- #closed? ⇒ Boolean
- #delegate_to_evaluator(tensor_arr, session_context, context) ⇒ Object
- #dump_internal_ops(tensor) ⇒ Object
- #dump_ops(tensor, selector) ⇒ Object
- #dump_user_ops(tensor) ⇒ Object
- #get_evaluator_classes(evaluators) ⇒ Object
- #graph_ml(tensor, filename) ⇒ Object
-
#initialize(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false, evaluator_options: {}) ⇒ Session
constructor
A new instance of Session.
- #list_devices ⇒ Object
- #run(*args) ⇒ Object
Methods included from StringHelper
Constructor Details
#initialize(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false, evaluator_options: {}) ⇒ Session
Returns a new instance of Session.
9 10 11 12 13 14 15 16 17 18 |
# File 'lib/tensor_stream/session.rb', line 9 def initialize(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false, evaluator_options: {}) @thread_pool = thread_pool_class.new @closed = false @session_cache = {} @randomizer = {} @log_device_placement = log_device_placement @evaluator_options = get_evaluator_classes(evaluator) @evaluators = {} end |
Instance Attribute Details
#closed ⇒ Object (readonly)
Returns the value of attribute closed.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def closed @closed end |
#last_session_context ⇒ Object (readonly)
Returns the value of attribute last_session_context.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def last_session_context @last_session_context end |
#randomizer ⇒ Object
Returns the value of attribute randomizer.
7 8 9 |
# File 'lib/tensor_stream/session.rb', line 7 def randomizer @randomizer end |
#session_cache ⇒ Object (readonly)
Returns the value of attribute session_cache.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def session_cache @session_cache end |
#target ⇒ Object (readonly)
Returns the value of attribute target.
6 7 8 |
# File 'lib/tensor_stream/session.rb', line 6 def target @target end |
Class Method Details
Instance Method Details
#clear_session_cache ⇒ Object
34 35 36 |
# File 'lib/tensor_stream/session.rb', line 34 def clear_session_cache @session_cache = {} end |
#close ⇒ Object
87 88 89 |
# File 'lib/tensor_stream/session.rb', line 87 def close @closed = true end |
#closed? ⇒ Boolean
91 92 93 |
# File 'lib/tensor_stream/session.rb', line 91 def closed? @closed end |
#delegate_to_evaluator(tensor_arr, session_context, context) ⇒ Object
115 116 117 118 119 120 121 |
# File 'lib/tensor_stream/session.rb', line 115 def delegate_to_evaluator(tensor_arr, session_context, context) arr = tensor_arr.is_a?(Array) ? tensor_arr : [tensor_arr] result = arr.collect do |tensor| session_context[:_cache][:placement][tensor.name][1].run_with_buffer(tensor, session_context, context) end result.size == 1 ? result.first : result end |
#dump_internal_ops(tensor) ⇒ Object
95 96 97 |
# File 'lib/tensor_stream/session.rb', line 95 def dump_internal_ops(tensor) dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? }) end |
#dump_ops(tensor, selector) ⇒ Object
103 104 105 106 107 108 109 |
# File 'lib/tensor_stream/session.rb', line 103 def dump_ops(tensor, selector) graph = tensor.graph graph.nodes.select { |k, v| selector.call(k, v) }.collect { |k, node| next unless @last_session_context[node.name] "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}" }.compact end |
#dump_user_ops(tensor) ⇒ Object
99 100 101 |
# File 'lib/tensor_stream/session.rb', line 99 def dump_user_ops(tensor) dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? }) end |
#get_evaluator_classes(evaluators) ⇒ Object
20 21 22 23 24 25 26 27 28 29 30 31 32 |
# File 'lib/tensor_stream/session.rb', line 20 def get_evaluator_classes(evaluators) @evaluator_classes = if evaluators.is_a?(Array) if evaluators.empty? TensorStream::Evaluator.default_evaluators else evaluators.collect { |name| Object.const_get("TensorStream::Evaluator::#{camelize(name.to_s)}") } end elsif evaluators.nil? TensorStream::Evaluator.default_evaluators else [Object.const_get("TensorStream::Evaluator::#{camelize(evaluators.to_s)}")] end end |
#graph_ml(tensor, filename) ⇒ Object
111 112 113 |
# File 'lib/tensor_stream/session.rb', line 111 def graph_ml(tensor, filename) TensorStream::Graphml.new(self).serialize(tensor, filename) end |
#list_devices ⇒ Object
79 80 81 82 83 84 85 |
# File 'lib/tensor_stream/session.rb', line 79 def list_devices TensorStream::Evaluator.evaluators.collect do |k, v| v[:class].query_supported_devices.collect do |device| device end end.flatten end |
#run(*args) ⇒ Object
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# File 'lib/tensor_stream/session.rb', line 42 def run(*args) = if args.last.is_a?(Hash) args.pop else {} end context = { _cache: @session_cache } # scan for placeholders and assign value if [:feed_dict] [:feed_dict].keys.each do |k| if k.is_a?(Placeholder) context[k.name.to_sym] = [:feed_dict][k] end end end @evaluator_options[:thread_pool] = @thread_pool @evaluator_options[:log_intermediates] = [:log_intermediates] args.each { |t| prepare_evaluators(t, context) } @last_session_context = context if @log_device_placement context[:_cache][:placement].each do |k, v| puts "#{k} : #{v[0].name}" end end result = args.collect do |e| value = delegate_to_evaluator(e, context, {}) value.respond_to?(:to_ruby) ? value.to_ruby : value end result.size == 1 ? result.first : result end |