Module: TensorStream::ArrayOpsHelper
- Included in:
- Evaluator::RubyEvaluator
- Defined in:
- lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb
Overview
varoius utility functions for array processing
Instance Method Summary collapse
- #broadcast(input_a, input_b) ⇒ Object
-
#broadcast_dimensions(input, dims = []) ⇒ Object
explicit broadcasting helper.
- #shape_diff(shape_a, shape_b) ⇒ Object
-
#vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false) ⇒ Object
handle 2 tensor math operations.
Instance Method Details
#broadcast(input_a, input_b) ⇒ Object
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb', line 4 def broadcast(input_a, input_b) sa = shape_eval(input_a) sb = shape_eval(input_b) return [input_a, input_b] if sa == sb # descalar if sa.empty? input_a = [input_a] sa = [1] end if sb.empty? input_b = [input_b] sb = [1] end target_shape = shape_diff(sa, sb) if target_shape input_b = broadcast_dimensions(input_b, target_shape) else target_shape = shape_diff(sb, sa) raise "Incompatible shapes for op #{shape_eval(input_a)} vs #{shape_eval(input_a)}" if target_shape.nil? input_a = broadcast_dimensions(input_a, target_shape) end [input_a, input_b] end |
#broadcast_dimensions(input, dims = []) ⇒ Object
explicit broadcasting helper
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# File 'lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb', line 36 def broadcast_dimensions(input, dims = []) return input if dims.empty? d = dims.shift if input.is_a?(Array) && (get_rank(input) - 1) == dims.size row_to_dup = input.collect do |item| broadcast_dimensions(item, dims.dup) end row_to_dup + Array.new(d) { row_to_dup }.flatten(1) elsif input.is_a?(Array) Array.new(d) { broadcast_dimensions(input, dims.dup) } else Array.new(d + 1) { input } end end |
#shape_diff(shape_a, shape_b) ⇒ Object
87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb', line 87 def shape_diff(shape_a, shape_b) return nil if shape_b.size > shape_a.size reversed_a = shape_a.reverse reversed_b = shape_b.reverse reversed_a.each_with_index.collect do |s, index| next s if index >= reversed_b.size return nil if reversed_b[index] > s s - reversed_b[index] end.reverse end |
#vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false) ⇒ Object
handle 2 tensor math operations
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# File 'lib/tensor_stream/evaluator/operation_helpers/array_ops_helper.rb', line 55 def vector_op(vector, vector2, op = ->(a, b) { a + b }, switch = false) if get_rank(vector) < get_rank(vector2) # upgrade rank of A duplicated = Array.new(vector2.size) do vector end return vector_op(duplicated, vector2, op, switch) end return op.call(vector, vector2) unless vector.is_a?(Array) vector.each_with_index.collect do |item, index| next vector_op(item, vector2, op, switch) if item.is_a?(Array) && get_rank(vector) > get_rank(vector2) z = if vector2.is_a?(Array) if index < vector2.size vector2[index] else raise 'incompatible tensor shapes used during op' if vector2.size != 1 vector2[0] end else vector2 end if item.is_a?(Array) vector_op(item, z, op, switch) else switch ? op.call(z, item) : op.call(item, z) end end end |