Class: TensorFlow::Tensor

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/tensor.rb

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(value = nil, pointer: nil, dtype: nil, shape: nil) ⇒ Tensor



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/tensorflow/tensor.rb', line 3

def initialize(value = nil, pointer: nil, dtype: nil, shape: nil)
  @status = FFI.TF_NewStatus

  if pointer
    @pointer = pointer
  else
    data = Array(value)
    shape ||= calculate_shape(value)

    if shape.size > 0
      dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
      dims_ptr.write_array_of_int64(shape)
    else
      dims_ptr = nil
    end

    data = data.flatten

    dtype ||= Utils.infer_type(value)
    type = FFI::DataType[dtype]
    case dtype
    when :string
      data_ptr = string_ptr(data)
    when :float
      data_ptr = ::FFI::MemoryPointer.new(:float, data.size)
      data_ptr.write_array_of_float(data)
    when :int32
      data_ptr = ::FFI::MemoryPointer.new(:int32, data.size)
      data_ptr.write_array_of_int32(data)
    else
      raise "Unknown type: #{dtype}"
    end

    callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
      # FFI handles deallocation
    end

    tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
    @pointer = FFI.TFE_NewTensorHandle(tensor, @status)
    check_status @status
  end

  # TODO fix segfault
  # ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
end

Class Method Details

.finalize(pointer) ⇒ Object



147
148
149
150
# File 'lib/tensorflow/tensor.rb', line 147

def self.finalize(pointer)
  # must use proc instead of stabby lambda
  proc { FFI.TFE_DeleteTensorHandle(pointer) }
end

Instance Method Details

#%(other) ⇒ Object



65
66
67
# File 'lib/tensorflow/tensor.rb', line 65

def %(other)
  TensorFlow.floormod(self, other)
end

#*(other) ⇒ Object



57
58
59
# File 'lib/tensorflow/tensor.rb', line 57

def *(other)
  TensorFlow.multiply(self, other)
end

#+(other) ⇒ Object



49
50
51
# File 'lib/tensorflow/tensor.rb', line 49

def +(other)
  TensorFlow.add(self, other)
end

#-(other) ⇒ Object



53
54
55
# File 'lib/tensorflow/tensor.rb', line 53

def -(other)
  TensorFlow.subtract(self, other)
end

#/(other) ⇒ Object



61
62
63
# File 'lib/tensorflow/tensor.rb', line 61

def /(other)
  TensorFlow.divide(self, other)
end

#data_pointerObject



96
97
98
99
100
# File 'lib/tensorflow/tensor.rb', line 96

def data_pointer
  tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
  check_status @status
  FFI.TF_TensorData(tensor)
end

#dtypeObject



75
76
77
# File 'lib/tensorflow/tensor.rb', line 75

def dtype
  @dtype ||= FFI::DataType[FFI.TFE_TensorHandleDataType(@pointer)]
end

#element_countObject



79
80
81
82
83
# File 'lib/tensorflow/tensor.rb', line 79

def element_count
  ret = FFI.TFE_TensorHandleNumElements(@pointer, @status)
  check_status @status
  ret
end

#inspectObject



142
143
144
145
# File 'lib/tensorflow/tensor.rb', line 142

def inspect
  inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
  "#<#{self.class} #{inspection.join(", ")}>"
end

#num_dimsObject



69
70
71
72
73
# File 'lib/tensorflow/tensor.rb', line 69

def num_dims
  ret = FFI.TFE_TensorHandleNumDims(@pointer, @status)
  check_status @status
  ret
end

#shapeObject



85
86
87
88
89
90
91
92
93
94
# File 'lib/tensorflow/tensor.rb', line 85

def shape
  @shape ||= begin
    shape = []
    num_dims.times do |i|
      shape << FFI.TFE_TensorHandleDim(@pointer, i, @status)
      check_status @status
    end
    shape
  end
end

#to_aObject



138
139
140
# File 'lib/tensorflow/tensor.rb', line 138

def to_a
  value
end

#to_iObject



134
135
136
# File 'lib/tensorflow/tensor.rb', line 134

def to_i
  value.to_i
end

#to_ptrObject



102
103
104
# File 'lib/tensorflow/tensor.rb', line 102

def to_ptr
  @pointer
end

#to_sObject



130
131
132
# File 'lib/tensorflow/tensor.rb', line 130

def to_s
  inspect
end

#valueObject



106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# File 'lib/tensorflow/tensor.rb', line 106

def value
  value =
    case dtype
    when :float
      data_pointer.read_array_of_float(element_count)
    when :int32
      data_pointer.read_array_of_int32(element_count)
    when :string
      # string tensor format
      # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
      start_offset_size = element_count * 8
      offsets = data_pointer.read_array_of_uint64(element_count)
      element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string }
    when :bool
      data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
    when :resource
      return data_pointer
    else
      raise "Unknown type: #{dtype}"
    end

  reshape(value, shape)
end