Class: TensorFlow::Tensor
- Inherits:
-
Object
- Object
- TensorFlow::Tensor
- Defined in:
- lib/tensorflow/tensor.rb
Class Method Summary collapse
Instance Method Summary collapse
- #%(other) ⇒ Object
- #*(other) ⇒ Object
- #+(other) ⇒ Object
- #-(other) ⇒ Object
- #/(other) ⇒ Object
- #data_pointer ⇒ Object
- #dtype ⇒ Object
- #element_count ⇒ Object
-
#initialize(value = nil, pointer: nil, dtype: nil, shape: nil) ⇒ Tensor
constructor
A new instance of Tensor.
- #inspect ⇒ Object
- #num_dims ⇒ Object
- #shape ⇒ Object
- #to_a ⇒ Object
- #to_i ⇒ Object
- #to_ptr ⇒ Object
- #to_s ⇒ Object
- #value ⇒ Object
Constructor Details
#initialize(value = nil, pointer: nil, dtype: nil, shape: nil) ⇒ Tensor
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# File 'lib/tensorflow/tensor.rb', line 3 def initialize(value = nil, pointer: nil, dtype: nil, shape: nil) @status = FFI.TF_NewStatus if pointer @pointer = pointer else data = Array(value) shape ||= calculate_shape(value) if shape.size > 0 dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size) dims_ptr.write_array_of_int64(shape) else dims_ptr = nil end data = data.flatten dtype ||= Utils.infer_type(value) type = FFI::DataType[dtype] case dtype when :string data_ptr = string_ptr(data) when :float data_ptr = ::FFI::MemoryPointer.new(:float, data.size) data_ptr.write_array_of_float(data) when :int32 data_ptr = ::FFI::MemoryPointer.new(:int32, data.size) data_ptr.write_array_of_int32(data) else raise "Unknown type: #{dtype}" end callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg| # FFI handles deallocation end tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil) @pointer = FFI.TFE_NewTensorHandle(tensor, @status) check_status @status end # TODO fix segfault # ObjectSpace.define_finalizer(self, self.class.finalize(@pointer)) end |
Class Method Details
.finalize(pointer) ⇒ Object
147 148 149 150 |
# File 'lib/tensorflow/tensor.rb', line 147 def self.finalize(pointer) # must use proc instead of stabby lambda proc { FFI.TFE_DeleteTensorHandle(pointer) } end |
Instance Method Details
#%(other) ⇒ Object
65 66 67 |
# File 'lib/tensorflow/tensor.rb', line 65 def %(other) TensorFlow.floormod(self, other) end |
#*(other) ⇒ Object
57 58 59 |
# File 'lib/tensorflow/tensor.rb', line 57 def *(other) TensorFlow.multiply(self, other) end |
#+(other) ⇒ Object
49 50 51 |
# File 'lib/tensorflow/tensor.rb', line 49 def +(other) TensorFlow.add(self, other) end |
#-(other) ⇒ Object
53 54 55 |
# File 'lib/tensorflow/tensor.rb', line 53 def -(other) TensorFlow.subtract(self, other) end |
#/(other) ⇒ Object
61 62 63 |
# File 'lib/tensorflow/tensor.rb', line 61 def /(other) TensorFlow.divide(self, other) end |
#data_pointer ⇒ Object
96 97 98 99 100 |
# File 'lib/tensorflow/tensor.rb', line 96 def data_pointer tensor = FFI.TFE_TensorHandleResolve(@pointer, @status) check_status @status FFI.TF_TensorData(tensor) end |
#dtype ⇒ Object
75 76 77 |
# File 'lib/tensorflow/tensor.rb', line 75 def dtype @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)] end |
#element_count ⇒ Object
79 80 81 82 83 |
# File 'lib/tensorflow/tensor.rb', line 79 def element_count ret = FFI.TFE_TensorHandleNumElements(@pointer, @status) check_status @status ret end |
#inspect ⇒ Object
142 143 144 145 |
# File 'lib/tensorflow/tensor.rb', line 142 def inspect inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"} "#<#{self.class} #{inspection.join(", ")}>" end |
#num_dims ⇒ Object
69 70 71 72 73 |
# File 'lib/tensorflow/tensor.rb', line 69 def num_dims ret = FFI.TFE_TensorHandleNumDims(@pointer, @status) check_status @status ret end |
#shape ⇒ Object
85 86 87 88 89 90 91 92 93 94 |
# File 'lib/tensorflow/tensor.rb', line 85 def shape @shape ||= begin shape = [] num_dims.times do |i| shape << FFI.TFE_TensorHandleDim(@pointer, i, @status) check_status @status end shape end end |
#to_a ⇒ Object
138 139 140 |
# File 'lib/tensorflow/tensor.rb', line 138 def to_a value end |
#to_i ⇒ Object
134 135 136 |
# File 'lib/tensorflow/tensor.rb', line 134 def to_i value.to_i end |
#to_ptr ⇒ Object
102 103 104 |
# File 'lib/tensorflow/tensor.rb', line 102 def to_ptr @pointer end |
#to_s ⇒ Object
130 131 132 |
# File 'lib/tensorflow/tensor.rb', line 130 def to_s inspect end |
#value ⇒ Object
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# File 'lib/tensorflow/tensor.rb', line 106 def value value = case dtype when :float data_pointer.read_array_of_float(element_count) when :int32 data_pointer.read_array_of_int32(element_count) when :string # string tensor format # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57 start_offset_size = element_count * 8 offsets = data_pointer.read_array_of_uint64(element_count) element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string } when :bool data_pointer.read_array_of_int8(element_count).map { |v| v == 1 } when :resource return data_pointer else raise "Unknown type: #{dtype}" end reshape(value, shape) end |