Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



10
11
12
# File 'lib/torch/tensor.rb', line 10

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#%(other) ⇒ Object



144
145
146
# File 'lib/torch/tensor.rb', line 144

def %(other)
  remainder(other)
end

#*(other) ⇒ Object



136
137
138
# File 'lib/torch/tensor.rb', line 136

def *(other)
  mul(other)
end

#**(other) ⇒ Object



148
149
150
# File 'lib/torch/tensor.rb', line 148

def **(other)
  pow(other)
end

#+(other) ⇒ Object



128
129
130
# File 'lib/torch/tensor.rb', line 128

def +(other)
  add(other)
end

#-(other) ⇒ Object



132
133
134
# File 'lib/torch/tensor.rb', line 132

def -(other)
  sub(other)
end

#-@Object



152
153
154
# File 'lib/torch/tensor.rb', line 152

def -@
  neg
end

#/(other) ⇒ Object



140
141
142
# File 'lib/torch/tensor.rb', line 140

def /(other)
  div(other)
end

#<=>(other) ⇒ Object

TODO better compare?



157
158
159
# File 'lib/torch/tensor.rb', line 157

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp



162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# File 'lib/torch/tensor.rb', line 162

def [](*indexes)
  result = self
  dim = 0
  indexes.each do |index|
    if index.is_a?(Numeric)
      result = result._select_int(dim, index)
    elsif index.is_a?(Range)
      finish = index.end
      finish += 1 unless index.exclude_end?
      result = result._slice_tensor(dim, index.begin, finish, 1)
      dim += 1
    elsif index.nil?
      result = result.unsqueeze(dim)
      dim += 1
    elsif index == true
      result = result.unsqueeze(dim)
      # TODO handle false
    else
      raise Error, "Unsupported index type: #{index.class.name}"
    end
  end
  result
end

#[]=(index, value) ⇒ Object

TODO based on python_variable_indexing.cpp

Raises:

  • (ArgumentError)


188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# File 'lib/torch/tensor.rb', line 188

def []=(index, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?

  value = Torch.tensor(value) unless value.is_a?(Tensor)

  if index.is_a?(Numeric)
    copy_to(_select_int(0, index), value)
  elsif index.is_a?(Range)
    finish = index.end
    finish += 1 unless index.exclude_end?
    copy_to(_slice_tensor(0, index.begin, finish, 1), value)
  else
    raise Error, "Unsupported index type: #{index.class.name}"
  end
end

#add!(value = 1, other) ⇒ Object

value and other are swapped for some methods



207
208
209
210
211
212
213
# File 'lib/torch/tensor.rb', line 207

def add!(value = 1, other)
  if other.is_a?(Numeric)
    _add__scalar(other, value)
  else
    _add__tensor(other, value)
  end
end

#backward(gradient = nil) ⇒ Object



92
93
94
# File 'lib/torch/tensor.rb', line 92

def backward(gradient = nil)
  _backward(gradient)
end

#cpuObject



47
48
49
# File 'lib/torch/tensor.rb', line 47

def cpu
  to("cpu")
end

#cudaObject



51
52
53
# File 'lib/torch/tensor.rb', line 51

def cuda
  to("cuda")
end

#dtypeObject

Raises:



14
15
16
17
18
# File 'lib/torch/tensor.rb', line 14

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#itemObject



72
73
74
75
76
77
# File 'lib/torch/tensor.rb', line 72

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  to_a.first
end

#layoutObject



20
21
22
# File 'lib/torch/tensor.rb', line 20

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



68
69
70
# File 'lib/torch/tensor.rb', line 68

def length
  size(0)
end

#newObject

unsure if this is correct



88
89
90
# File 'lib/torch/tensor.rb', line 88

def new
  Torch.empty(0, dtype: dtype)
end

#new_ones(*size, **options) ⇒ Object



103
104
105
# File 'lib/torch/tensor.rb', line 103

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



97
98
99
100
101
# File 'lib/torch/tensor.rb', line 97

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.from_string(_data_str).reshape(*shape)
end

#random!(*args) ⇒ Object

native functions overlap, so need to handle manually



216
217
218
219
220
221
222
223
224
225
# File 'lib/torch/tensor.rb', line 216

def random!(*args)
  case args.size
  when 1
    _random__to(*args)
  when 2
    _random__from_to(*args)
  else
    _random_(*args)
  end
end

#requires_grad!(requires_grad = true) ⇒ Object



107
108
109
# File 'lib/torch/tensor.rb', line 107

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#reshape(*size) ⇒ Object



117
118
119
120
121
# File 'lib/torch/tensor.rb', line 117

def reshape(*size)
  # Python doesn't check if size == 1, just ignores later arguments
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _reshape(size)
end

#shapeObject



63
64
65
# File 'lib/torch/tensor.rb', line 63

def shape
  dim.times.map { |i| size(i) }
end

#size(dim = nil) ⇒ Object



55
56
57
58
59
60
61
# File 'lib/torch/tensor.rb', line 55

def size(dim = nil)
  if dim
    _size_int(dim)
  else
    shape
  end
end

#to(device, non_blocking: false, copy: false) ⇒ Object

TODO support dtype



42
43
44
45
# File 'lib/torch/tensor.rb', line 42

def to(device, non_blocking: false, copy: false)
  device = Device.new(device) if device.is_a?(String)
  _to(device, _dtype, non_blocking, copy)
end

#to_aObject

TODO make more performant



29
30
31
32
33
34
35
36
37
38
39
# File 'lib/torch/tensor.rb', line 29

def to_a
  arr = _flat_data
  if shape.empty?
    arr
  else
    shape[1..-1].reverse.each do |dim|
      arr = arr.each_slice(dim)
    end
    arr.to_a
  end
end

#to_fObject



83
84
85
# File 'lib/torch/tensor.rb', line 83

def to_f
  item.to_f
end

#to_iObject



79
80
81
# File 'lib/torch/tensor.rb', line 79

def to_i
  item.to_i
end

#to_sObject



24
25
26
# File 'lib/torch/tensor.rb', line 24

def to_s
  inspect
end

#type(dtype) ⇒ Object

Raises:



111
112
113
114
115
# File 'lib/torch/tensor.rb', line 111

def type(dtype)
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum
  _type(enum)
end

#view(*size) ⇒ Object



123
124
125
126
# File 'lib/torch/tensor.rb', line 123

def view(*size)
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _view(size)
end