Module: TensorStream::OpHelper

Included in:
TensorStream, Evaluator::RubyEvaluator, MathGradients, Tensor
Defined in:
lib/tensor_stream/helpers/op_helper.rb

Overview

module that contains helper functions useful for ops

Instance Method Summary collapse

Instance Method Details

#cons(value, options = {}) ⇒ Object



13
14
15
# File 'lib/tensor_stream/helpers/op_helper.rb', line 13

def cons(value, options = {})
  TensorStream.constant(value, options)
end

#dtype_eval(rank, value) ⇒ Object



36
37
38
39
40
41
42
# File 'lib/tensor_stream/helpers/op_helper.rb', line 36

def dtype_eval(rank, value)
  dtype = Tensor.detect_type(value[0])

  rank += 1 if dtype == :array

  [dtype, rank, value[0], value.size]
end

#fp_type?(type) ⇒ Boolean

Returns:

  • (Boolean)


58
59
60
# File 'lib/tensor_stream/helpers/op_helper.rb', line 58

def fp_type?(type)
  TensorStream::Ops::FLOATING_POINT_TYPES.include?(type)
end

#i_cons(value, options = {}) ⇒ Object



17
18
19
# File 'lib/tensor_stream/helpers/op_helper.rb', line 17

def i_cons(value, options = {})
  TensorStream.constant(value, options.merge(internal: true))
end

#i_op(code, t_a, t_b = nil, options = {}) ⇒ Object

same as op but with a marker that it was internal generated



9
10
11
# File 'lib/tensor_stream/helpers/op_helper.rb', line 9

def i_op(code, t_a, t_b = nil, options = {})
  Operation.new(code.to_sym, t_a, t_b, options.merge(internal: true))
end

#op(code, t_a, t_b = nil, options = {}) ⇒ Object



4
5
6
# File 'lib/tensor_stream/helpers/op_helper.rb', line 4

def op(code, t_a, t_b = nil, options = {})
  Operation.new(code.to_sym, t_a, t_b, options)
end

#shape_eval(input, output_type = :int32) ⇒ Object



21
22
23
24
25
26
27
28
29
30
31
32
33
34
# File 'lib/tensor_stream/helpers/op_helper.rb', line 21

def shape_eval(input, output_type = :int32)
  return [] unless input.is_a?(Array)
  arr = []
  arr_ptr = input

  Kernel.loop do
    arr << (TensorStream::Ops::FLOATING_POINT_TYPES.include?(output_type) ? arr_ptr.size.to_f : arr_ptr.size)
    arr_ptr = arr_ptr[0]

    break unless arr_ptr.is_a?(Array)
  end

  arr
end

#val_to_dtype(value) ⇒ Object



44
45
46
47
48
49
50
51
52
53
54
55
56
# File 'lib/tensor_stream/helpers/op_helper.rb', line 44

def val_to_dtype(value)
  if value.is_a?(String)
    :string
  elsif value.is_a?(Float)
    :float32
  elsif value.is_a?(Integer)
    :int32
  elsif value.is_a?(Array)
    :array
  else
    :float32
  end
end