Class: TensorStream::TensorShape
- Inherits:
-
Object
- Object
- TensorStream::TensorShape
- Defined in:
- lib/tensor_stream/tensor_shape.rb
Overview
class that defines a shape for TensorFlow compatibility
Instance Attribute Summary collapse
-
#rank ⇒ Object
Returns the value of attribute rank.
-
#shape ⇒ Object
Returns the value of attribute shape.
Class Method Summary collapse
- .fix_inferred_elements(shape, total_size) ⇒ Object
- .infer_shape(shape_a, shape_b) ⇒ Object
- .reshape(arr, new_shape) ⇒ Object
Instance Method Summary collapse
- #[](index) ⇒ Object
- #fully_defined? ⇒ Boolean
-
#initialize(shape, rank = nil) ⇒ TensorShape
constructor
A new instance of TensorShape.
- #known? ⇒ Boolean
- #ndims ⇒ Object
- #scalar? ⇒ Boolean
- #to_s ⇒ Object
Constructor Details
#initialize(shape, rank = nil) ⇒ TensorShape
6 7 8 9 |
# File 'lib/tensor_stream/tensor_shape.rb', line 6 def initialize(shape, rank = nil) @shape = shape @rank = rank.nil? && shape ? shape.size : rank end |
Instance Attribute Details
#rank ⇒ Object
Returns the value of attribute rank.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def rank @rank end |
#shape ⇒ Object
Returns the value of attribute shape.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def shape @shape end |
Class Method Details
.fix_inferred_elements(shape, total_size) ⇒ Object
90 91 92 93 94 95 96 97 |
# File 'lib/tensor_stream/tensor_shape.rb', line 90 def self.fix_inferred_elements(shape, total_size) return shape if shape.empty? return nil if shape[0].is_a?(Tensor) current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product } inferred_size = total_size.nil? ? nil : total_size / current_size shape.map { |s| s == -1 ? inferred_size : s } end |
.infer_shape(shape_a, shape_b) ⇒ Object
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# File 'lib/tensor_stream/tensor_shape.rb', line 45 def self.infer_shape(shape_a, shape_b) return nil if shape_a.nil? || shape_b.nil? return shape_a if shape_b.empty? return shape_b if shape_a.empty? return shape_a if shape_a == shape_b return shape_b if shape_b.size > shape_a.size return shape_a if shape_a.size > shape_b.size reversed_a = shape_a.reverse reversed_b = shape_b.reverse reversed_a.each_with_index.collect { |s, index| next s if index >= reversed_b.size next nil if s.nil? || reversed_b[index].nil? next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor) next reversed_b[index] if reversed_b[index] > s s }.reverse end |
.reshape(arr, new_shape) ⇒ Object
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# File 'lib/tensor_stream/tensor_shape.rb', line 66 def self.reshape(arr, new_shape) arr = arr.is_a?(Array) ? arr.flatten : [arr] new_shape = new_shape.is_a?(TensorShape) ? new_shape.shape : new_shape new_shape = TensorShape.fix_inferred_elements(new_shape, arr.size) return arr[0] if arr.size == 1 && new_shape.empty? new_shape = new_shape.dup s = new_shape.shift if new_shape.size.zero? raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s return arr end dim = (arr.size / s) return arr if dim.zero? arr.each_slice(dim).collect do |slice| reshape(slice, new_shape.dup) end end |
Instance Method Details
#[](index) ⇒ Object
20 21 22 |
# File 'lib/tensor_stream/tensor_shape.rb', line 20 def [](index) @shape[index] end |
#fully_defined? ⇒ Boolean
41 42 43 |
# File 'lib/tensor_stream/tensor_shape.rb', line 41 def fully_defined? known? end |
#known? ⇒ Boolean
32 33 34 35 36 37 38 39 |
# File 'lib/tensor_stream/tensor_shape.rb', line 32 def known? return false if shape.nil? a_shape = shape.is_a?(Array) ? shape : [shape] a_shape.each { |s| return false if s.nil? || s < 0 } true end |
#ndims ⇒ Object
24 25 26 |
# File 'lib/tensor_stream/tensor_shape.rb', line 24 def ndims shape ? shape.size : nil end |
#scalar? ⇒ Boolean
28 29 30 |
# File 'lib/tensor_stream/tensor_shape.rb', line 28 def scalar? known? && shape.size.zero? end |
#to_s ⇒ Object
11 12 13 14 15 16 17 18 |
# File 'lib/tensor_stream/tensor_shape.rb', line 11 def to_s return "?" if @shape.nil? dimensions = @shape.collect { |r| "Dimension(#{r})" }.join(",") "TensorShape([#{dimensions}])" end |