Class: TensorStream::Evaluator::RubyEvaluator

Inherits:
Object
  • Object
show all
Includes:
OpHelper
Defined in:
lib/tensor_stream/evaluator/ruby_evaluator.rb

Overview

PURE ruby evaluator used for testing and development

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods included from OpHelper

#cons, #dtype_eval, #fp_type?, #i_cons, #i_op, #op, #shape_eval, #val_to_dtype

Constructor Details

#initialize(session, context, thread_pool: nil) ⇒ RubyEvaluator

Returns a new instance of RubyEvaluator.



29
30
31
32
33
34
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 29

def initialize(session, context, thread_pool: nil)
  @session = session
  @context = context
  @retain = context[:retain] || []
  @thread_pool = thread_pool || Concurrent::ImmediateExecutor.new
end

Instance Attribute Details

#retainObject

Returns the value of attribute retain.



25
26
27
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 25

def retain
  @retain
end

Instance Method Details

#complete_eval(tensor, context) ⇒ Object



55
56
57
58
59
60
61
62
63
64
65
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 55

def complete_eval(tensor, context)
  Kernel.loop do
    old_tensor = tensor
    tensor = run(tensor, context)

    tensor = tensor.map { |t| complete_eval(t, context) } if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)

    return tensor if old_tensor.equal?(tensor)
    return tensor unless tensor.is_a?(Tensor)
  end
end

#run(tensor, execution_context) ⇒ Object



36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 36

def run(tensor, execution_context)
  return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)

  return tensor if retain.include?(tensor) # if var is in retain don't eval to value

  child_context = execution_context.dup
  res = if tensor.is_a?(Operation)
          eval_operation(tensor, child_context)
        elsif tensor.is_a?(Variable)
          eval_variable(tensor, child_context)
        elsif tensor.is_a?(Placeholder)
          resolve_placeholder(tensor, child_context)
        else
          eval_tensor(tensor, child_context)
        end
  execution_context.deep_merge!(returns: child_context[:returns])
  res
end