Class: TensorStream::Session

Inherits:
Object
  • Object
show all
Includes:
StringHelper
Defined in:
lib/tensor_stream/session.rb

Overview

TensorStream class that defines a session

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods included from StringHelper

#camelize

Constructor Details

#initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor, evaluator_options: {}) ⇒ Session

Returns a new instance of Session.



9
10
11
12
13
14
15
16
# File 'lib/tensor_stream/session.rb', line 9

def initialize(evaluator = :ruby_evaluator, thread_pool_class: Concurrent::ImmediateExecutor, evaluator_options: {})
  @evaluator_class = Object.const_get("TensorStream::Evaluator::#{camelize(evaluator.to_s)}")
  @thread_pool = thread_pool_class.new
  @closed = false
  @session_cache = {}
  @randomizer = {}
  @evaluator_options = evaluator_options
end

Instance Attribute Details

#closedObject (readonly)

Returns the value of attribute closed.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def closed
  @closed
end

#last_session_contextObject (readonly)

Returns the value of attribute last_session_context.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def last_session_context
  @last_session_context
end

#randomizerObject

Returns the value of attribute randomizer.



7
8
9
# File 'lib/tensor_stream/session.rb', line 7

def randomizer
  @randomizer
end

#session_cacheObject (readonly)

Returns the value of attribute session_cache.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def session_cache
  @session_cache
end

#targetObject (readonly)

Returns the value of attribute target.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def target
  @target
end

Class Method Details

.default_sessionObject



22
23
24
# File 'lib/tensor_stream/session.rb', line 22

def self.default_session
  @session ||= Session.new
end

Instance Method Details

#clear_session_cacheObject



18
19
20
# File 'lib/tensor_stream/session.rb', line 18

def clear_session_cache
  @session_cache = {}
end

#closeObject



59
60
61
# File 'lib/tensor_stream/session.rb', line 59

def close
  @closed = true
end

#closed?Boolean

Returns:

  • (Boolean)


63
64
65
# File 'lib/tensor_stream/session.rb', line 63

def closed?
  @closed
end

#dump_internal_ops(tensor) ⇒ Object



67
68
69
# File 'lib/tensor_stream/session.rb', line 67

def dump_internal_ops(tensor)
  dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
end

#dump_ops(tensor, selector) ⇒ Object



75
76
77
78
79
80
81
# File 'lib/tensor_stream/session.rb', line 75

def dump_ops(tensor, selector)
  graph = tensor.graph
  graph.nodes.select { |k, v| selector.call(k, v) }.collect do |k, node|
    next unless @last_session_context[node.name]
    "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
  end.compact
end

#dump_user_ops(tensor) ⇒ Object



71
72
73
# File 'lib/tensor_stream/session.rb', line 71

def dump_user_ops(tensor)
  dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
end

#graph_ml(tensor, filename) ⇒ Object



83
84
85
# File 'lib/tensor_stream/session.rb', line 83

def graph_ml(tensor, filename)
  TensorStream::Graphml.new(self).serialize(tensor, filename)
end

#list_devicesObject



55
56
57
# File 'lib/tensor_stream/session.rb', line 55

def list_devices
  [Device.new("cpu")]
end

#run(*args) ⇒ Object



26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/tensor_stream/session.rb', line 26

def run(*args)
  options = if args.last.is_a?(Hash)
              args.pop
            else
              {}
            end
  context = {
    _cache: @session_cache
  }

  # scan for placeholders and assign value
  if options[:feed_dict]
    options[:feed_dict].keys.each do |k|
      if k.is_a?(Placeholder)
        context[k.name.to_sym] = options[:feed_dict][k]
      end
    end
  end

  @evaluator_options[:thread_pool] = @thread_pool
  @evaluator_options[:log_intermediates] = options[:log_intermediates]
  evaluator = @evaluator_class.new(self, context.merge!(retain: options[:retain]), @evaluator_options)

  execution_context = {}
  @last_session_context = context
  result = args.collect { |e| evaluator.run(e, execution_context) }
  result.size == 1 ? result.first : result
end