Class: TensorStream::Session

Inherits:
Object
  • Object
show all
Includes:
StringHelper
Defined in:
lib/tensor_stream/session.rb

Overview

TensorStream class that defines a session

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods included from StringHelper

#camelize

Constructor Details

#initialize(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false, evaluator_options: {}) ⇒ Session

Returns a new instance of Session.



9
10
11
12
13
14
15
16
17
18
# File 'lib/tensor_stream/session.rb', line 9

def initialize(evaluator = nil, thread_pool_class: Concurrent::ImmediateExecutor, log_device_placement: false, evaluator_options: {})
  @thread_pool = thread_pool_class.new
  @closed = false
  @session_cache = {}
  @randomizer = {}
  @log_device_placement = log_device_placement
  @evaluator_options = evaluator_options
  get_evaluator_classes(evaluator)
  @evaluators = {}
end

Instance Attribute Details

#closedObject (readonly)

Returns the value of attribute closed.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def closed
  @closed
end

#last_session_contextObject (readonly)

Returns the value of attribute last_session_context.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def last_session_context
  @last_session_context
end

#randomizerObject

Returns the value of attribute randomizer.



7
8
9
# File 'lib/tensor_stream/session.rb', line 7

def randomizer
  @randomizer
end

#session_cacheObject (readonly)

Returns the value of attribute session_cache.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def session_cache
  @session_cache
end

#targetObject (readonly)

Returns the value of attribute target.



6
7
8
# File 'lib/tensor_stream/session.rb', line 6

def target
  @target
end

Class Method Details

.default_sessionObject



38
39
40
# File 'lib/tensor_stream/session.rb', line 38

def self.default_session
  @session ||= Session.new
end

Instance Method Details

#clear_session_cacheObject



34
35
36
# File 'lib/tensor_stream/session.rb', line 34

def clear_session_cache
  @session_cache = {}
end

#closeObject



87
88
89
# File 'lib/tensor_stream/session.rb', line 87

def close
  @closed = true
end

#closed?Boolean

Returns:

  • (Boolean)


91
92
93
# File 'lib/tensor_stream/session.rb', line 91

def closed?
  @closed
end

#delegate_to_evaluator(tensor_arr, session_context, context) ⇒ Object



115
116
117
118
119
120
121
# File 'lib/tensor_stream/session.rb', line 115

def delegate_to_evaluator(tensor_arr, session_context, context)
  arr = tensor_arr.is_a?(Array) ? tensor_arr : [tensor_arr]
  result = arr.collect do |tensor|
    session_context[:_cache][:placement][tensor.name][1].run_with_buffer(tensor, session_context, context)
  end
  result.size == 1 ? result.first : result
end

#dump_internal_ops(tensor) ⇒ Object



95
96
97
# File 'lib/tensor_stream/session.rb', line 95

def dump_internal_ops(tensor)
  dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && n.internal? })
end

#dump_ops(tensor, selector) ⇒ Object



103
104
105
106
107
108
109
# File 'lib/tensor_stream/session.rb', line 103

def dump_ops(tensor, selector)
  graph = tensor.graph
  graph.nodes.select { |k, v| selector.call(k, v) }.collect { |k, node|
    next unless @last_session_context[node.name]
    "#{k} #{node.to_math(true, 1)} = #{@last_session_context[node.name]}"
  }.compact
end

#dump_user_ops(tensor) ⇒ Object



99
100
101
# File 'lib/tensor_stream/session.rb', line 99

def dump_user_ops(tensor)
  dump_ops(tensor, ->(_k, n) { n.is_a?(Tensor) && !n.internal? })
end

#get_evaluator_classes(evaluators) ⇒ Object



20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/tensor_stream/session.rb', line 20

def get_evaluator_classes(evaluators)
  @evaluator_classes = if evaluators.is_a?(Array)
                         if evaluators.empty?
                           TensorStream::Evaluator.default_evaluators
                         else
                           evaluators.collect { |name|  Object.const_get("TensorStream::Evaluator::#{camelize(name.to_s)}") }
                         end
                       elsif evaluators.nil?
                         TensorStream::Evaluator.default_evaluators
                       else
                         [Object.const_get("TensorStream::Evaluator::#{camelize(evaluators.to_s)}")]
                       end
end

#graph_ml(tensor, filename) ⇒ Object



111
112
113
# File 'lib/tensor_stream/session.rb', line 111

def graph_ml(tensor, filename)
  TensorStream::Graphml.new(self).serialize(tensor, filename)
end

#list_devicesObject



79
80
81
82
83
84
85
# File 'lib/tensor_stream/session.rb', line 79

def list_devices
  TensorStream::Evaluator.evaluators.collect do |k, v|
    v[:class].query_supported_devices.collect do |device|
      device
    end
  end.flatten
end

#run(*args) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/tensor_stream/session.rb', line 42

def run(*args)
  options = if args.last.is_a?(Hash)
              args.pop
            else
              {}
            end
  context = {
    _cache: @session_cache
  }

  # scan for placeholders and assign value
  if options[:feed_dict]
    options[:feed_dict].keys.each do |k|
      if k.is_a?(Placeholder)
        context[k.name.to_sym] = options[:feed_dict][k]
      end
    end
  end

  @evaluator_options[:thread_pool] = @thread_pool
  @evaluator_options[:log_intermediates] = options[:log_intermediates]

  args.each { |t| prepare_evaluators(t, context) }
  @last_session_context = context

  if @log_device_placement
    context[:_cache][:placement].each do |k, v|
      puts "#{k} : #{v[0].name}"
    end
  end
  result = args.collect do |e|
    value = delegate_to_evaluator(e, context, {})
    value.respond_to?(:to_ruby) ? value.to_ruby : value
  end
  result.size == 1 ? result.first : result
end