Class: TensorStream::TensorShape

Inherits:
Object
  • Object
show all
Defined in:
lib/tensor_stream/tensor_shape.rb

Overview

class that defines a shape for TensorFlow compatibility

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(shape, rank = nil) ⇒ TensorShape

Returns a new instance of TensorShape.



6
7
8
9
# File 'lib/tensor_stream/tensor_shape.rb', line 6

def initialize(shape, rank = nil)
  @shape = shape
  @rank = rank.nil? && shape ? shape.size : rank
end

Instance Attribute Details

#rankObject

Returns the value of attribute rank.



4
5
6
# File 'lib/tensor_stream/tensor_shape.rb', line 4

def rank
  @rank
end

#shapeObject

Returns the value of attribute shape.



4
5
6
# File 'lib/tensor_stream/tensor_shape.rb', line 4

def shape
  @shape
end

Class Method Details

.fix_inferred_elements(shape, total_size) ⇒ Object



121
122
123
124
125
126
127
128
# File 'lib/tensor_stream/tensor_shape.rb', line 121

def self.fix_inferred_elements(shape, total_size)
  return shape if shape.empty?
  return nil if shape[0].is_a?(Tensor)

  current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
  inferred_size = total_size.nil? ? nil : total_size / current_size
  shape.map { |s| s == -1 ? inferred_size : s }
end

.infer_shape(shape_a, shape_b) ⇒ Object



76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# File 'lib/tensor_stream/tensor_shape.rb', line 76

def self.infer_shape(shape_a, shape_b)
  return nil if shape_a.nil? || shape_b.nil?
  return shape_a if shape_b.empty?
  return shape_b if shape_a.empty?
  return shape_a if shape_a == shape_b
  return shape_b if shape_b.size > shape_a.size
  return shape_a if shape_a.size > shape_b.size

  reversed_a = shape_a.reverse
  reversed_b = shape_b.reverse

  reversed_a.each_with_index.collect { |s, index|
    next s if index >= reversed_b.size
    next nil if s.nil? || reversed_b[index].nil?
    next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor)
    next reversed_b[index] if reversed_b[index] > s

    s
  }.reverse
end

.reshape(arr, new_shape) ⇒ Object



97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# File 'lib/tensor_stream/tensor_shape.rb', line 97

def self.reshape(arr, new_shape)
  arr = arr.is_a?(Array) ? arr.flatten : [arr]
  new_shape = new_shape.is_a?(TensorShape) ? new_shape.shape : new_shape
  new_shape = TensorShape.fix_inferred_elements(new_shape, arr.size)
  return arr[0] if arr.size == 1 && new_shape.empty?

  new_shape = new_shape.dup

  s = new_shape.shift

  if new_shape.size.zero?
    raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s

    return arr
  end

  dim = (arr.size / s)
  return arr if dim.zero?

  arr.each_slice(dim).collect do |slice|
    reshape(slice, new_shape.dup)
  end
end

Instance Method Details

#[](index) ⇒ Object



20
21
22
23
# File 'lib/tensor_stream/tensor_shape.rb', line 20

def [](index)
  new_shape = @shape[index]
  TensorShape.new(@shape[index])
end

#as_dimension(value) ⇒ Object



62
63
64
# File 'lib/tensor_stream/tensor_shape.rb', line 62

def as_dimension(value)
  value.is_a?(TensorShape) ? value.shape : value
end

#assert_compatible_with(other) ⇒ Object

Raises an exception if ‘other` is not compatible with this shape.



72
73
74
# File 'lib/tensor_stream/tensor_shape.rb', line 72

def assert_compatible_with(other)
  raise TensorStream::ValueError, "Dimensions #{self} and #{other} are not compatible" unless compatible_with?(other)
end

#compatible_with?(other) ⇒ Boolean

Returns:

  • (Boolean)


56
57
58
59
60
# File 'lib/tensor_stream/tensor_shape.rb', line 56

def compatible_with?(other)
  other = as_dimension(other)

  shape.nil? || other.nil? || shape == other
end

#fully_defined?Boolean

Returns:

  • (Boolean)


42
43
44
# File 'lib/tensor_stream/tensor_shape.rb', line 42

def fully_defined?
  known?
end

#known?Boolean

Returns:

  • (Boolean)


33
34
35
36
37
38
39
40
# File 'lib/tensor_stream/tensor_shape.rb', line 33

def known?
  return false if shape.nil?

  a_shape = shape.is_a?(Array) ? shape : [shape]
  a_shape.each { |s| return false if s.nil? || s < 0 }

  true
end

#merge_with(other) ⇒ Object



46
47
48
49
50
51
52
53
54
# File 'lib/tensor_stream/tensor_shape.rb', line 46

def merge_with(other)
  assert_compatible_with(other)

  if @shape.nil?
    TensorShape.new(other)
  else
    TensorShape.new(@shape)
  end
end

#ndimsObject



25
26
27
# File 'lib/tensor_stream/tensor_shape.rb', line 25

def ndims
  shape ? shape.size : nil
end

#scalar?Boolean

Returns:

  • (Boolean)


29
30
31
# File 'lib/tensor_stream/tensor_shape.rb', line 29

def scalar?
  known? && shape.size.zero?
end

#to_sObject



11
12
13
14
15
16
17
18
# File 'lib/tensor_stream/tensor_shape.rb', line 11

def to_s
  return "?" if @shape.nil?

  dimensions = @shape.collect { |r|
    "Dimension(#{r})"
  }.join(",")
  "TensorShape([#{dimensions}])"
end

#valueObject



66
67
68
# File 'lib/tensor_stream/tensor_shape.rb', line 66

def value
  shape
end