Module: TensorStream::CheckOps

Included in:
Evaluator::RubyEvaluator
Defined in:
lib/tensor_stream/evaluator/ruby/check_ops.rb

Class Method Summary collapse

Class Method Details

.included(klass) ⇒ Object



3
4
5
6
7
8
9
10
11
12
13
14
15
# File 'lib/tensor_stream/evaluator/ruby/check_ops.rb', line 3

def self.included(klass)
  klass.class_eval do
    register_op :assert_equal do |context, tensor, inputs|
      result = call_vector_op(tensor, :equal, inputs[0], inputs[1], context) { |t, u| t == u }

      result = result.is_a?(Array) ? result.flatten.uniq : [result]
      prefix = tensor.options[:message] || ""
      raise TensorStream::InvalidArgumentError, "#{prefix} #{tensor.inputs[0].name} != #{tensor.inputs[1].name}" if result != [true]

      nil
    end
  end
end