Class: TensorStream::Session

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/session.rb

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor) ⇒ Session

Returns a new instance of Session.



3
4
5
6
# File 'lib/tensor_stream/session.rb', line 3

def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor)
  @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
  @thread_pool = thread_pool_class.new
end

Class Method Details

.default_sessionObject



8
9
10
# File 'lib/tensor_stream/session.rb', line 8

def self.default_session
  @session ||= Session.new
end

Instance Method Details

#dump_internal_ops(tensor) ⇒ Object



39
40
41
# File 'lib/tensor_stream/session.rb', line 39

def dump_internal_ops(tensor)
  dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && n.internal? } )
end

#dump_ops(tensor, selector) ⇒ Object



47
48
49
50
51
52
53
# File 'lib/tensor_stream/session.rb', line 47

def dump_ops(tensor, selector)
  graph = tensor.graph
  graph.nodes.select { |k,v| selector.call(k, v) }.collect do |k, node|
    next unless @last_session_context[node.name]
    "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
  end.compact
end

#dump_user_ops(tensor) ⇒ Object



43
44
45
# File 'lib/tensor_stream/session.rb', line 43

def dump_user_ops(tensor)
  dump_ops(tensor, ->(k, n) { n.is_a?(Tensor) && !n.internal? } )
end

#last_session_contextObject



12
13
14
# File 'lib/tensor_stream/session.rb', line 12

def last_session_context
  @last_session_context
end

#run(*args) ⇒ Object



16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/tensor_stream/session.rb', line 16

def run(*args)
  options = if args.last.is_a?(Hash)
    args.pop
  else
    {}
  end
  context = {}

  # scan for placeholders and assign value
  options[:feed_dict].keys.each do |k|
    if k.is_a?(Placeholder)
      context[k.name.to_sym] = options[:feed_dict][k]
    end
  end if options[:feed_dict]
  
  evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), thread_pool: @thread_pool)

  execution_context = {}
  result = args.collect { |e| evaluator.run(e, execution_context) }
  @last_session_context = context
  result.size == 1 ? result.first : result
end