Class: TensorFlow::Tensor
- Inherits:
-
Object
- Object
- TensorFlow::Tensor
- Defined in:
- lib/tensorflow/tensor.rb
Class Method Summary collapse
Instance Method Summary collapse
- #%(other) ⇒ Object
- #*(other) ⇒ Object
- #+(other) ⇒ Object
- #-(other) ⇒ Object
- #-@ ⇒ Object
- #/(other) ⇒ Object
- #dtype ⇒ Object
-
#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor
constructor
A new instance of Tensor.
- #inspect ⇒ Object
- #numo ⇒ Object
- #shape ⇒ Object
- #to_a ⇒ Object
- #to_i ⇒ Object
- #to_ptr ⇒ Object
- #to_s ⇒ Object
- #value ⇒ Object
Constructor Details
#initialize(value = nil, dtype: nil, shape: nil, pointer: nil) ⇒ Tensor
Returns a new instance of Tensor.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# File 'lib/tensorflow/tensor.rb', line 3 def initialize(value = nil, dtype: nil, shape: nil, pointer: nil) @status = FFI.TF_NewStatus if pointer @pointer = pointer else data = value data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray) shape ||= calculate_shape(value) if shape.size > 0 dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size) dims_ptr.write_array_of_int64(shape) else dims_ptr = nil end if data.is_a?(Numo::NArray) dtype ||= Utils.infer_type(data) # TODO use Numo read pointer? data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size) data_ptr.write_bytes(data.to_string) else data = data.flatten dtype ||= Utils.infer_type(data) case dtype when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 data_ptr = ::FFI::MemoryPointer.new(dtype, data.size) data_ptr.send("write_array_of_#{dtype}", data) when :bfloat16 # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2) data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join) when :complex64 data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2) data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] }) when :complex128 data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2) data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] }) when :string data_ptr = string_ptr(data) when :bool data_ptr = ::FFI::MemoryPointer.new(:int8, data.size) data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 }) else raise "Unknown type: #{dtype}" end end type = FFI::DataType[dtype] callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg| # FFI handles deallocation end # keep data pointer alive for duration of object @data_ptr = data_ptr @dims_ptr = dims_ptr @callback = callback tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil) @pointer = FFI.TFE_NewTensorHandle(tensor, @status) check_status @status end ObjectSpace.define_finalizer(self, self.class.finalize(@pointer, @status, tensor)) end |
Class Method Details
.finalize(pointer, status, tensor) ⇒ Object
165 166 167 168 169 170 171 172 |
# File 'lib/tensorflow/tensor.rb', line 165 def self.finalize(pointer, status, tensor) # must use proc instead of stabby lambda proc do FFI.TFE_DeleteTensorHandle(pointer) FFI.TFE_DeleteStatus(status) FFI.TFE_DeleteTensor(tensor) if tensor end end |
Instance Method Details
#%(other) ⇒ Object
87 88 89 |
# File 'lib/tensorflow/tensor.rb', line 87 def %(other) Math.floormod(self, other) end |
#*(other) ⇒ Object
79 80 81 |
# File 'lib/tensorflow/tensor.rb', line 79 def *(other) Math.multiply(self, other) end |
#+(other) ⇒ Object
71 72 73 |
# File 'lib/tensorflow/tensor.rb', line 71 def +(other) Math.add(self, other) end |
#-(other) ⇒ Object
75 76 77 |
# File 'lib/tensorflow/tensor.rb', line 75 def -(other) Math.subtract(self, other) end |
#-@ ⇒ Object
91 92 93 |
# File 'lib/tensorflow/tensor.rb', line 91 def -@ Math.negative(self) end |
#/(other) ⇒ Object
83 84 85 |
# File 'lib/tensorflow/tensor.rb', line 83 def /(other) Math.divide(self, other) end |
#dtype ⇒ Object
123 124 125 |
# File 'lib/tensorflow/tensor.rb', line 123 def dtype @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)] end |
#inspect ⇒ Object
160 161 162 163 |
# File 'lib/tensorflow/tensor.rb', line 160 def inspect inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"} "#<#{self.class} #{inspection.join(", ")}>" end |
#numo ⇒ Object
154 155 156 157 158 |
# File 'lib/tensorflow/tensor.rb', line 154 def numo klass = Utils::NUMO_TYPE_MAP[dtype] raise "Unknown type: #{dtype}" unless klass klass.cast(value) end |
#shape ⇒ Object
127 128 129 130 131 132 133 134 135 136 |
# File 'lib/tensorflow/tensor.rb', line 127 def shape @shape ||= begin shape = [] num_dims.times do |i| shape << FFI.TFE_TensorHandleDim(@pointer, i, @status) check_status @status end shape end end |
#to_a ⇒ Object
146 147 148 |
# File 'lib/tensorflow/tensor.rb', line 146 def to_a value end |
#to_i ⇒ Object
142 143 144 |
# File 'lib/tensorflow/tensor.rb', line 142 def to_i value.to_i end |
#to_ptr ⇒ Object
150 151 152 |
# File 'lib/tensorflow/tensor.rb', line 150 def to_ptr @pointer end |
#to_s ⇒ Object
138 139 140 |
# File 'lib/tensorflow/tensor.rb', line 138 def to_s inspect end |
#value ⇒ Object
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# File 'lib/tensorflow/tensor.rb', line 95 def value value = case dtype when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 data_pointer.send("read_array_of_#{dtype}", element_count) when :bfloat16 byte_str = data_pointer.read_bytes(element_count * 2) element_count.times.map { |i| "#{byte_str[(2 * i)..(2 * i + 1)]}\x00\x00".unpack1("g") } when :complex64 data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) } when :complex128 data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) } when :string tf_string_size = 24 element_count.times.map do |i| FFI.TF_StringGetDataPointer(data_pointer + i * tf_string_size) end when :bool data_pointer.read_array_of_int8(element_count).map { |v| v == 1 } when :resource, :variant return data_pointer else raise "Unknown type: #{dtype}" end reshape(value, shape) end |