Class: TensorFlow::Tensor
- Inherits:
-
Object
- Object
- TensorFlow::Tensor
- Defined in:
- lib/tensorflow/tensor.rb
Class Method Summary collapse
Instance Method Summary collapse
- #%(other) ⇒ Object
- #*(other) ⇒ Object
- #+(other) ⇒ Object
- #-(other) ⇒ Object
- #-@ ⇒ Object
- #/(other) ⇒ Object
- #dtype ⇒ Object
-
#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor
constructor
A new instance of Tensor.
- #inspect ⇒ Object
- #numo ⇒ Object
- #shape ⇒ Object
- #to_a ⇒ Object
- #to_i ⇒ Object
- #to_ptr ⇒ Object
- #to_s ⇒ Object
- #value ⇒ Object
Constructor Details
#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor
Returns a new instance of Tensor.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# File 'lib/tensorflow/tensor.rb', line 3 def initialize(value = nil, dtype: nil, shape: nil, pointer: nil) @status = FFI.TF_NewStatus if pointer @pointer = pointer else data = value data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray) shape ||= calculate_shape(value) if shape.size > 0 dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size) dims_ptr.write_array_of_int64(shape) else dims_ptr = nil end if data.is_a?(Numo::NArray) dtype ||= Utils.infer_type(data) # TODO use Numo read pointer? data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size) data_ptr.write_bytes(data.to_string) else data = data.flatten dtype ||= Utils.infer_type(data) case dtype when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 data_ptr = ::FFI::MemoryPointer.new(dtype, data.size) data_ptr.send("write_array_of_#{dtype}", data) when :bfloat16 # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2) data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join) when :complex64 data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2) data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] }) when :complex128 data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2) data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] }) when :string data_ptr = string_ptr(data) when :bool data_ptr = ::FFI::MemoryPointer.new(:int8, data.size) data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 }) else raise "Unknown type: #{dtype}" end end type = FFI::DataType[dtype] callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg| # FFI handles deallocation end # keep data pointer alive for duration of object @data_ptr = data_ptr @dims_ptr = dims_ptr @callback = callback tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil) @pointer = FFI.TFE_NewTensorHandle(tensor, @status) check_status @status end ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor)) end |
Class Method Details
.finalize(pointer, status, tensor) ⇒ Object
175 176 177 178 179 180 181 182 |
# File 'lib/tensorflow/tensor.rb', line 175 def self.finalize(pointer, status, tensor) # must use proc instead of stabby lambda proc do FFI.TFE_DeleteTensorHandle(pointer) FFI.TFE_DeleteStatus(status) FFI.TFE_DeleteTensor(tensor) if tensor end end |
Instance Method Details
#%(other) ⇒ Object
87 88 89 |
# File 'lib/tensorflow/tensor.rb', line 87 def %(other) Math.floormod(self, other) end |
#*(other) ⇒ Object
79 80 81 |
# File 'lib/tensorflow/tensor.rb', line 79 def *(other) Math.multiply(self, other) end |
#+(other) ⇒ Object
71 72 73 |
# File 'lib/tensorflow/tensor.rb', line 71 def +(other) Math.add(self, other) end |
#-(other) ⇒ Object
75 76 77 |
# File 'lib/tensorflow/tensor.rb', line 75 def -(other) Math.subtract(self, other) end |
#/(other) ⇒ Object
83 84 85 |
# File 'lib/tensorflow/tensor.rb', line 83 def /(other) Math.divide(self, other) end |
#dtype ⇒ Object
133 134 135 |
# File 'lib/tensorflow/tensor.rb', line 133 def dtype @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)] end |
#inspect ⇒ Object
170 171 172 173 |
# File 'lib/tensorflow/tensor.rb', line 170 def inspect inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"} "#<#{self.class} #{inspection.join(", ")}>" end |
#numo ⇒ Object
164 165 166 167 168 |
# File 'lib/tensorflow/tensor.rb', line 164 def numo klass = Utils::NUMO_TYPE_MAP[dtype] raise "Unknown type: #{dtype}" unless klass klass.cast(value) end |
#shape ⇒ Object
137 138 139 140 141 142 143 144 145 146 |
# File 'lib/tensorflow/tensor.rb', line 137 def shape @shape ||= begin shape = [] num_dims.times do |i| shape << FFI.TFE_TensorHandleDim(@pointer, i, @status) check_status @status end shape end end |
#to_a ⇒ Object
156 157 158 |
# File 'lib/tensorflow/tensor.rb', line 156 def to_a value end |
#to_i ⇒ Object
152 153 154 |
# File 'lib/tensorflow/tensor.rb', line 152 def to_i value.to_i end |
#to_ptr ⇒ Object
160 161 162 |
# File 'lib/tensorflow/tensor.rb', line 160 def to_ptr @pointer end |
#to_s ⇒ Object
148 149 150 |
# File 'lib/tensorflow/tensor.rb', line 148 def to_s inspect end |
#value ⇒ Object
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# File 'lib/tensorflow/tensor.rb', line 95 def value value = case dtype when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 data_pointer.send("read_array_of_#{dtype}", element_count) when :bfloat16 byte_str = data_pointer.read_bytes(element_count * 2) element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") } when :complex64 data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) } when :complex128 data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) } when :string # string tensor format # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57 start_offset_size = element_count * 8 offsets = data_pointer.read_array_of_uint64(element_count) byte_size = FFI.TF_TensorByteSize(tensor_pointer) element_count.times.map do |i| str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i] str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len) dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100) dst_len = ::FFI::MemoryPointer.new(:size_t) FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status) check_status @status dst.read_pointer.read_bytes(dst_len.read_int32) end when :bool data_pointer.read_array_of_int8(element_count).map { |v| v == 1 } when :resource, :variant return data_pointer else raise "Unknown type: #{dtype}" end reshape(value, shape) end |