Class: TensorStream::Evaluator::RubyEvaluator

Inherits:
Object
  • Object
show all
Includes:
ArrayOpsHelper, MathHelper, OpHelper
Defined in:
lib/tensor_stream/evaluator/ruby_evaluator.rb

Overview

PURE ruby evaluator used for testing and development

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods included from MathHelper

#sigmoid

Methods included from ArrayOpsHelper

#broadcast, #broadcast_dimensions, #get_rank, #process_function_op, #reduced_shape, #shape_diff, #slice_tensor, #tile_arr, #truncate, #vector_op

Methods included from OpHelper

#_op, #cons, #dtype_eval, #format_source, #fp_type?, #i_cons, #i_op, #shape_eval, #val_to_dtype

Constructor Details

#initialize(session, context, thread_pool: nil, log_intermediates: false) ⇒ RubyEvaluator

Returns a new instance of RubyEvaluator.



33
34
35
36
37
38
39
40
41
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 33

def initialize(session, context, thread_pool: nil, log_intermediates: false)
  @session = session
  @context = context
  @log_intermediates = log_intermediates
  @retain = context[:retain] || []
  @thread_pool = thread_pool || Concurrent::ImmediateExecutor.new

  @context[:compute_history] = [] if log_intermediates
end

Instance Attribute Details

#retainObject

Returns the value of attribute retain.



27
28
29
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 27

def retain
  @retain
end

Instance Method Details

#complete_eval(tensor, context) ⇒ Object



64
65
66
67
68
69
70
71
72
73
74
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 64

def complete_eval(tensor, context)
  Kernel.loop do
    old_tensor = tensor
    tensor = run(tensor, context)

    tensor = tensor.map { |t| complete_eval(t, context) } if tensor.is_a?(Array) && !tensor.empty? && tensor[0].is_a?(Tensor)

    return tensor if old_tensor.equal?(tensor)
    return tensor unless tensor.is_a?(Tensor)
  end
end

#run(tensor, execution_context) ⇒ Object



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# File 'lib/tensor_stream/evaluator/ruby_evaluator.rb', line 43

def run(tensor, execution_context)
  return tensor.map { |t| run(t, execution_context) } if tensor.is_a?(Array)

  return tensor if retain.include?(tensor) # if var is in retain don't eval to value

  tensor = tensor.call if tensor.is_a?(Proc)

  child_context = execution_context.dup
  res = if tensor.is_a?(Operation)
          eval_operation(tensor, child_context)
        elsif tensor.is_a?(Variable)
          eval_variable(tensor, child_context)
        elsif tensor.is_a?(Placeholder)
          resolve_placeholder(tensor, child_context)
        else
          eval_tensor(tensor, child_context)
        end
  execution_context.deep_merge!(returns: child_context[:returns])
  res
end