Module: TensorStream::OpHelper

Overview

module that contains helper functions useful for ops

Instance Method Summary collapse

Instance Method Details

#_op(code, *args) ⇒ Object



4
5
6
7
8
9
10
11
# File 'lib/tensor_stream/helpers/op_helper.rb', line 4

def _op(code, *args)
  op = Operation.new(code.to_sym, *args)
  if !TensorStream.get_default_graph.get_dependency_scope.nil?
    i_op(:identity, op, TensorStream.get_default_graph.get_dependency_scope, name: [op.name, 'tuple', 'control_dependency'].join('/'))
  else
    op
  end
end

#cons(value, options = {}) ⇒ Object



25
26
27
# File 'lib/tensor_stream/helpers/op_helper.rb', line 25

def cons(value, options = {})
  TensorStream.constant(value, options)
end

#format_source(trace) ⇒ Object



56
57
58
59
60
# File 'lib/tensor_stream/helpers/op_helper.rb', line 56

def format_source(trace)
  grad_source = trace.select { |c| c.to_s.include?(File.join('lib', 'tensor_stream', 'math_gradients')) }.first
  source = trace.reject { |c| c.to_s.include?(File.join('lib', 'tensor_stream')) }.first
  [grad_source, source].compact.join("\n")
end

#fp_type?(type) ⇒ Boolean

Returns:



48
49
50
# File 'lib/tensor_stream/helpers/op_helper.rb', line 48

def fp_type?(type)
  TensorStream::Ops::FLOATING_POINT_TYPES.include?(type)
end

#i_cons(value, options = {}) ⇒ Object



29
30
31
# File 'lib/tensor_stream/helpers/op_helper.rb', line 29

def i_cons(value, options = {})
  TensorStream.constant(value, options.merge(internal: true))
end

#i_op(code, *args) ⇒ Object

same as op but with a marker that it was internal generated



14
15
16
17
18
19
20
21
22
23
# File 'lib/tensor_stream/helpers/op_helper.rb', line 14

def i_op(code, *args)
  options = if args.last.is_a?(Hash)
              args.pop
            else
              {}
            end

  args << options.merge(internal: true)
  Operation.new(code.to_sym, *args)
end

#int_type?(type) ⇒ Boolean

Returns:



52
53
54
# File 'lib/tensor_stream/helpers/op_helper.rb', line 52

def int_type?(type)
  TensorStream::Ops::INTEGER_TYPES.include?(type)
end

#reduced_shape(input_shape, axes) ⇒ Object



77
78
79
80
81
82
83
84
85
86
# File 'lib/tensor_stream/helpers/op_helper.rb', line 77

def reduced_shape(input_shape, axes)
  input_shape = TensorStream.convert_to_tensor(input_shape)
  axes = TensorStream.convert_to_tensor(axes)
  input_rank = i_op(:size, input_shape)
  axes = TensorStream.range(0, input_rank) if axes.nil?
  axes = (axes + input_rank) % input_rank
  axes_shape = i_op(:shape, axes)
  TensorStream.dynamic_stitch([TensorStream.range(0, input_rank), axes],
                              [input_shape, i_op(:fill, axes_shape, 1)])
end

#shape_eval(input, output_type = :int32) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
# File 'lib/tensor_stream/helpers/op_helper.rb', line 33

def shape_eval(input, output_type = :int32)
  return [] unless input.is_a?(Array)
  arr = []
  arr_ptr = input

  Kernel.loop do
    arr << (TensorStream::Ops::FLOATING_POINT_TYPES.include?(output_type) ? arr_ptr.size.to_f : arr_ptr.size)
    arr_ptr = arr_ptr[0]

    break unless arr_ptr.is_a?(Array)
  end

  arr
end

#shape_full_specified(tensor) ⇒ Object



69
70
71
72
73
74
75
# File 'lib/tensor_stream/helpers/op_helper.rb', line 69

def shape_full_specified(tensor)
  return false if tensor.shape.nil?
  return false if tensor.shape.shape.nil?

  tensor.shape.shape.each { |s| return false if s.nil? }
  true
end

#shapes_fully_specified_and_equal(x, y) ⇒ Object



62
63
64
65
66
67
# File 'lib/tensor_stream/helpers/op_helper.rb', line 62

def shapes_fully_specified_and_equal(x, y)
  return false if !shape_full_specified(x) || !shape_full_specified(y)
  return false if x.shape.shape != y.shape.shape

  true
end