Class: TensorFlow::Tensor

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/tensor.rb

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor

Returns a new instance of Tensor.



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# File 'lib/tensorflow/tensor.rb', line 3

def initialize(value = nil, dtype: nil, shape: nil, pointer: nil)
  @status = FFI.TF_NewStatus

  if pointer
    @pointer = pointer
  else
    data = value
    data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray)
    shape ||= calculate_shape(value)

    if shape.size > 0
      dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
      dims_ptr.write_array_of_int64(shape)
    else
      dims_ptr = nil
    end

    if data.is_a?(Numo::NArray)
      dtype ||= Utils.infer_type(data)
      # TODO use Numo read pointer?
      data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size)
      data_ptr.write_bytes(data.to_string)
    else
      data = data.flatten
      dtype ||= Utils.infer_type(data)
      case dtype
      when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
        data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
        data_ptr.send("write_array_of_#{dtype}", data)
      when :bfloat16
        # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
        data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
        data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
      when :complex64
        data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
        data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
      when :complex128
        data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
        data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
      when :string
        data_ptr = string_ptr(data)
      when :bool
        data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
        data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
      else
        raise "Unknown type: #{dtype}"
      end
    end

    type = FFI::DataType[dtype]

    callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
      # FFI handles deallocation
    end

    # keep data pointer alive for duration of object
    @data_ptr = data_ptr
    @dims_ptr = dims_ptr
    @callback = callback

    tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
    @pointer = FFI.TFE_NewTensorHandle(tensor, @status)
    check_status @status
  end

  ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor))
end

Class Method Details

.finalize(pointer, status, tensor) ⇒ Object



175
176
177
178
179
180
181
182
# File 'lib/tensorflow/tensor.rb', line 175

def self.finalize(pointer, status, tensor)
  # must use proc instead of stabby lambda
  proc do
    FFI.TFE_DeleteTensorHandle(pointer)
    FFI.TFE_DeleteStatus(status)
    FFI.TFE_DeleteTensor(tensor) if tensor
  end
end

Instance Method Details

#%(other) ⇒ Object



87
88
89
# File 'lib/tensorflow/tensor.rb', line 87

def %(other)
  Math.floormod(self, other)
end

#*(other) ⇒ Object



79
80
81
# File 'lib/tensorflow/tensor.rb', line 79

def *(other)
  Math.multiply(self, other)
end

#+(other) ⇒ Object



71
72
73
# File 'lib/tensorflow/tensor.rb', line 71

def +(other)
  Math.add(self, other)
end

#-(other) ⇒ Object



75
76
77
# File 'lib/tensorflow/tensor.rb', line 75

def -(other)
  Math.subtract(self, other)
end

#-@Object



91
92
93
# File 'lib/tensorflow/tensor.rb', line 91

def -@
  Math.negative(self)
end

#/(other) ⇒ Object



83
84
85
# File 'lib/tensorflow/tensor.rb', line 83

def /(other)
  Math.divide(self, other)
end

#dtypeObject



133
134
135
# File 'lib/tensorflow/tensor.rb', line 133

def dtype
  @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
end

#inspectObject



170
171
172
173
# File 'lib/tensorflow/tensor.rb', line 170

def inspect
  inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
  "#<#{self.class} #{inspection.join(", ")}>"
end

#numoObject



164
165
166
167
168
# File 'lib/tensorflow/tensor.rb', line 164

def numo
  klass = Utils::NUMO_TYPE_MAP[dtype]
  raise "Unknown type: #{dtype}" unless klass
  klass.cast(value)
end

#shapeObject



137
138
139
140
141
142
143
144
145
146
# File 'lib/tensorflow/tensor.rb', line 137

def shape
  @shape ||= begin
    shape = []
    num_dims.times do |i|
      shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
      check_status @status
    end
    shape
  end
end

#to_aObject



156
157
158
# File 'lib/tensorflow/tensor.rb', line 156

def to_a
  value
end

#to_iObject



152
153
154
# File 'lib/tensorflow/tensor.rb', line 152

def to_i
  value.to_i
end

#to_ptrObject



160
161
162
# File 'lib/tensorflow/tensor.rb', line 160

def to_ptr
  @pointer
end

#to_sObject



148
149
150
# File 'lib/tensorflow/tensor.rb', line 148

def to_s
  inspect
end

#valueObject



95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# File 'lib/tensorflow/tensor.rb', line 95

def value
  value =
    case dtype
    when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
      data_pointer.send("read_array_of_#{dtype}", element_count)
    when :bfloat16
      byte_str = data_pointer.read_bytes(element_count * 2)
      element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") }
    when :complex64
      data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) }
    when :complex128
      data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) }
    when :string
      # string tensor format
      # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
      start_offset_size = element_count * 8
      offsets = data_pointer.read_array_of_uint64(element_count)
      byte_size = FFI.TF_TensorByteSize(tensor_pointer)
      element_count.times.map do |i|
        str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i]
        str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len)
        dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100)
        dst_len = ::FFI::MemoryPointer.new(:size_t)
        FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status)
        check_status @status
        dst.read_pointer.read_bytes(dst_len.read_int32)
      end
    when :bool
      data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
    when :resource, :variant
      return data_pointer
    else
      raise "Unknown type: #{dtype}"
    end

  reshape(value, shape)
end