Class: TensorStream::TensorShape

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/tensor_shape.rb

Overview

class that defines a shape for TensorFlow compatibility

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(shape, rank = nil) ⇒ TensorShape

Returns a new instance of TensorShape.



6
7
8
9
# File 'lib/tensor_stream/tensor_shape.rb', line 6

def initialize(shape, rank = nil)
  @shape = shape
  @rank = rank.nil? && shape ? shape.size : rank
end

Instance Attribute Details

#rankObject

Returns the value of attribute rank.



4
5
6
# File 'lib/tensor_stream/tensor_shape.rb', line 4

def rank
  @rank
end

#shapeObject

Returns the value of attribute shape.



4
5
6
# File 'lib/tensor_stream/tensor_shape.rb', line 4

def shape
  @shape
end

Class Method Details

.fix_inferred_elements(shape, total_size) ⇒ Object



70
71
72
73
74
75
76
# File 'lib/tensor_stream/tensor_shape.rb', line 70

def self.fix_inferred_elements(shape, total_size)
  return shape if shape.empty?

  current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
  inferred_size = total_size / current_size
  shape.map { |s| s == -1 ? inferred_size : s }
end

.infer_shape(shape_a, shape_b) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/tensor_stream/tensor_shape.rb', line 33

def self.infer_shape(shape_a, shape_b)
  return shape_a if shape_b.nil?
  return shape_b if shape_a.nil?
  return shape_a if shape_a == shape_b
  return shape_b if shape_b.size > shape_a.size
  return shape_a if shape_a.size > shape_b.size

  reversed_a = shape_a.reverse
  reversed_b = shape_b.reverse

  reversed_a.each_with_index.collect do |s, index|
    next s if index >= reversed_b.size
    next nil if s.nil? || reversed_b[index].nil?
    next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor)
    next reversed_b[index] if reversed_b[index] > s
    s
  end.reverse
end

.reshape(arr, new_shape) ⇒ Object



52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# File 'lib/tensor_stream/tensor_shape.rb', line 52

def self.reshape(arr, new_shape)
  return arr if new_shape.empty?
  new_shape = new_shape.dup
  s = new_shape.shift

  if new_shape.size.zero?
    raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s
    return arr
  end

  dim = (arr.size / s)
  return arr if dim.zero?

  arr.each_slice(dim).collect do |slice|
    reshape(slice, new_shape.dup)
  end
end

Instance Method Details

#[](index) ⇒ Object



18
19
20
# File 'lib/tensor_stream/tensor_shape.rb', line 18

def [](index)
  @shape[index]
end

#known?Boolean

Returns:

  • (Boolean)


26
27
28
29
30
31
# File 'lib/tensor_stream/tensor_shape.rb', line 26

def known?
  return false if shape.nil?
  shape.each { |s| return false if s.nil? }

  true
end

#ndimsObject



22
23
24
# File 'lib/tensor_stream/tensor_shape.rb', line 22

def ndims
  shape.size
end

#to_sObject



11
12
13
14
15
16
# File 'lib/tensor_stream/tensor_shape.rb', line 11

def to_s
  dimensions = @shape.collect do |r|
    "Dimension(#{r})"
  end.join(',')
  "TensorShape([#{dimensions}])"
end