Module: TensorStream::OpHelper

Included in:
TensorStream, Evaluator::RubyEvaluator, MathGradients, Tensor
Defined in:
lib/tensor_stream/helpers/op_helper.rb

Instance Method Summary collapse

Instance Method Details

#cons(value, options = {}) ⇒ Object



12
13
14
# File 'lib/tensor_stream/helpers/op_helper.rb', line 12

def cons(value, options = {})
  TensorStream.constant(value, options)
end

#dtype_eval(dtype, rank, value) ⇒ Object



35
36
37
38
39
40
# File 'lib/tensor_stream/helpers/op_helper.rb', line 35

def dtype_eval(dtype, rank, value)
  dtype = Tensor.detect_type(value[0])
  rank+=1 if dtype == :array

  [dtype, rank, value[0], value.size]
end

#i_cons(value, options = {}) ⇒ Object



16
17
18
# File 'lib/tensor_stream/helpers/op_helper.rb', line 16

def i_cons(value, options = {})
  TensorStream.constant(value, options.merge(internal: true))
end

#i_op(code, a, b = nil, options = {}) ⇒ Object

same as op but with a marker that it was internal generated



8
9
10
# File 'lib/tensor_stream/helpers/op_helper.rb', line 8

def i_op(code, a, b = nil, options = {})
  Operation.new(code.to_sym, a, b, options.merge(internal: true))
end

#op(code, a, b = nil, options = {}) ⇒ Object



3
4
5
# File 'lib/tensor_stream/helpers/op_helper.rb', line 3

def op(code, a, b = nil, options = {})
  Operation.new(code.to_sym, a, b, options)
end

#shape_eval(input) ⇒ Object



20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/tensor_stream/helpers/op_helper.rb', line 20

def shape_eval(input)
  return [] unless input.kind_of?(Array)
  arr = []
  arr_ptr = input

  Kernel.loop do
    arr << arr_ptr.size
    arr_ptr = arr_ptr[0]

    break unless arr_ptr.is_a?(Array)
  end

  arr
end

#val_to_dtype(value, rank = 0) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# File 'lib/tensor_stream/helpers/op_helper.rb', line 42

def val_to_dtype(value, rank = 0)
  dtype = if value.is_a?(String)
    :string
  elsif value.is_a?(Float)
    :float32
  elsif value.is_a?(Integer)
    :int32
  elsif value.is_a?(Array)
    rank += 1
    :array
  else
    :float32
  end
  dtype
end