Class: TensorStream::TensorShape
- Inherits:
-
Object
- Object
- TensorStream::TensorShape
- Defined in:
- lib/tensor_stream/tensor_shape.rb
Overview
class that defines a shape for TensorFlow compatibility
Instance Attribute Summary collapse
-
#rank ⇒ Object
Returns the value of attribute rank.
-
#shape ⇒ Object
Returns the value of attribute shape.
Class Method Summary collapse
- .fix_inferred_elements(shape, total_size) ⇒ Object
- .infer_shape(shape_a, shape_b) ⇒ Object
- .reshape(arr, new_shape) ⇒ Object
Instance Method Summary collapse
- #[](index) ⇒ Object
-
#initialize(shape, rank = nil) ⇒ TensorShape
constructor
A new instance of TensorShape.
- #known? ⇒ Boolean
- #ndims ⇒ Object
- #to_s ⇒ Object
Constructor Details
#initialize(shape, rank = nil) ⇒ TensorShape
Returns a new instance of TensorShape.
6 7 8 9 |
# File 'lib/tensor_stream/tensor_shape.rb', line 6 def initialize(shape, rank = nil) @shape = shape @rank = rank.nil? && shape ? shape.size : rank end |
Instance Attribute Details
#rank ⇒ Object
Returns the value of attribute rank.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def rank @rank end |
#shape ⇒ Object
Returns the value of attribute shape.
4 5 6 |
# File 'lib/tensor_stream/tensor_shape.rb', line 4 def shape @shape end |
Class Method Details
.fix_inferred_elements(shape, total_size) ⇒ Object
70 71 72 73 74 75 76 |
# File 'lib/tensor_stream/tensor_shape.rb', line 70 def self.fix_inferred_elements(shape, total_size) return shape if shape.empty? current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product } inferred_size = total_size / current_size shape.map { |s| s == -1 ? inferred_size : s } end |
.infer_shape(shape_a, shape_b) ⇒ Object
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# File 'lib/tensor_stream/tensor_shape.rb', line 33 def self.infer_shape(shape_a, shape_b) return shape_a if shape_b.nil? return shape_b if shape_a.nil? return shape_a if shape_a == shape_b return shape_b if shape_b.size > shape_a.size return shape_a if shape_a.size > shape_b.size reversed_a = shape_a.reverse reversed_b = shape_b.reverse reversed_a.each_with_index.collect do |s, index| next s if index >= reversed_b.size next nil if s.nil? || reversed_b[index].nil? next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor) next reversed_b[index] if reversed_b[index] > s s end.reverse end |
.reshape(arr, new_shape) ⇒ Object
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# File 'lib/tensor_stream/tensor_shape.rb', line 52 def self.reshape(arr, new_shape) return arr if new_shape.empty? new_shape = new_shape.dup s = new_shape.shift if new_shape.size.zero? raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s return arr end dim = (arr.size / s) return arr if dim.zero? arr.each_slice(dim).collect do |slice| reshape(slice, new_shape.dup) end end |
Instance Method Details
#[](index) ⇒ Object
18 19 20 |
# File 'lib/tensor_stream/tensor_shape.rb', line 18 def [](index) @shape[index] end |
#known? ⇒ Boolean
26 27 28 29 30 31 |
# File 'lib/tensor_stream/tensor_shape.rb', line 26 def known? return false if shape.nil? shape.each { |s| return false if s.nil? } true end |
#ndims ⇒ Object
22 23 24 |
# File 'lib/tensor_stream/tensor_shape.rb', line 22 def ndims shape.size end |
#to_s ⇒ Object
11 12 13 14 15 16 |
# File 'lib/tensor_stream/tensor_shape.rb', line 11 def to_s dimensions = @shape.collect do |r| "Dimension(#{r})" end.join(',') "TensorShape([#{dimensions}])" end |