Class: TensorFlow::Tensor

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/tensor.rb

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor

Returns a new instance of Tensor.



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# File 'lib/tensorflow/tensor.rb', line 3

def initialize(value = nil, dtype: nil, shape: nil, pointer: nil)
  @status = FFI.TF_NewStatus

  if pointer
    @pointer = pointer
  else
    data = value
    data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray)
    shape ||= calculate_shape(value)

    if shape.size > 0
      dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
      dims_ptr.write_array_of_int64(shape)
    else
      dims_ptr = nil
    end

    if data.is_a?(Numo::NArray)
      dtype ||= Utils.infer_type(data)
      # TODO use Numo read pointer?
      data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size)
      data_ptr.write_bytes(data.to_string)
    else
      data = data.flatten
      dtype ||= Utils.infer_type(data)
      case dtype
      when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
        data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
        data_ptr.send("write_array_of_#{dtype}", data)
      when :bfloat16
        # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
        data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
        data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
      when :complex64
        data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
        data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
      when :complex128
        data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
        data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
      when :string
        data_ptr = string_ptr(data)
      when :bool
        data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
        data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
      else
        raise "Unknown type: #{dtype}"
      end
    end

    type = FFI::DataType[dtype]

    callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
      # FFI handles deallocation
    end

    # keep data pointer alive for duration of object
    @data_ptr = data_ptr
    @dims_ptr = dims_ptr
    @callback = callback

    tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
    @pointer = FFI.TFE_NewTensorHandle(tensor, @status)
    check_status @status
  end

  ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor))
end

Class Method Details

.finalize(pointer, status, tensor) ⇒ Object



165
166
167
168
169
170
171
172
# File 'lib/tensorflow/tensor.rb', line 165

def self.finalize(pointer, status, tensor)
  # must use proc instead of stabby lambda
  proc do
    FFI.TFE_DeleteTensorHandle(pointer)
    FFI.TFE_DeleteStatus(status)
    FFI.TFE_DeleteTensor(tensor) if tensor
  end
end

Instance Method Details

#%(other) ⇒ Object



87
88
89
# File 'lib/tensorflow/tensor.rb', line 87

def %(other)
  Math.floormod(self, other)
end

#*(other) ⇒ Object



79
80
81
# File 'lib/tensorflow/tensor.rb', line 79

def *(other)
  Math.multiply(self, other)
end

#+(other) ⇒ Object



71
72
73
# File 'lib/tensorflow/tensor.rb', line 71

def +(other)
  Math.add(self, other)
end

#-(other) ⇒ Object



75
76
77
# File 'lib/tensorflow/tensor.rb', line 75

def -(other)
  Math.subtract(self, other)
end

#-@Object



91
92
93
# File 'lib/tensorflow/tensor.rb', line 91

def -@
  Math.negative(self)
end

#/(other) ⇒ Object



83
84
85
# File 'lib/tensorflow/tensor.rb', line 83

def /(other)
  Math.divide(self, other)
end

#dtypeObject



123
124
125
# File 'lib/tensorflow/tensor.rb', line 123

def dtype
  @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
end

#inspectObject



160
161
162
163
# File 'lib/tensorflow/tensor.rb', line 160

def inspect
  inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
  "#<#{self.class} #{inspection.join(", ")}>"
end

#numoObject



154
155
156
157
158
# File 'lib/tensorflow/tensor.rb', line 154

def numo
  klass = Utils::NUMO_TYPE_MAP[dtype]
  raise "Unknown type: #{dtype}" unless klass
  klass.cast(value)
end

#shapeObject



127
128
129
130
131
132
133
134
135
136
# File 'lib/tensorflow/tensor.rb', line 127

def shape
  @shape ||= begin
    shape = []
    num_dims.times do |i|
      shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
      check_status @status
    end
    shape
  end
end

#to_aObject



146
147
148
# File 'lib/tensorflow/tensor.rb', line 146

def to_a
  value
end

#to_iObject



142
143
144
# File 'lib/tensorflow/tensor.rb', line 142

def to_i
  value.to_i
end

#to_ptrObject



150
151
152
# File 'lib/tensorflow/tensor.rb', line 150

def to_ptr
  @pointer
end

#to_sObject



138
139
140
# File 'lib/tensorflow/tensor.rb', line 138

def to_s
  inspect
end

#valueObject



95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# File 'lib/tensorflow/tensor.rb', line 95

def value
  value =
    case dtype
    when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
      data_pointer.send("read_array_of_#{dtype}", element_count)
    when :bfloat16
      byte_str = data_pointer.read_bytes(element_count * 2)
      element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
    when :complex64
      data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) }
    when :complex128
      data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) }
    when :string
      tf_string_size = 24
      element_count.times.map do |i|
        FFI.TF_StringGetDataPointer(data_pointer + i * tf_string_size)
      end
    when :bool
      data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
    when :resource, :variant
      return data_pointer
    else
      raise "Unknown type: #{dtype}"
    end

  reshape(value, shape)
end