Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



8
9
10
# File 'lib/torch/tensor.rb', line 8

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#%(other) ⇒ Object



134
135
136
# File 'lib/torch/tensor.rb', line 134

def %(other)
  remainder(other)
end

#*(other) ⇒ Object



126
127
128
# File 'lib/torch/tensor.rb', line 126

def *(other)
  mul(other)
end

#**(other) ⇒ Object



138
139
140
# File 'lib/torch/tensor.rb', line 138

def **(other)
  pow(other)
end

#+(other) ⇒ Object



118
119
120
# File 'lib/torch/tensor.rb', line 118

def +(other)
  add(other)
end

#-(other) ⇒ Object



122
123
124
# File 'lib/torch/tensor.rb', line 122

def -(other)
  sub(other)
end

#-@Object



142
143
144
# File 'lib/torch/tensor.rb', line 142

def -@
  neg
end

#/(other) ⇒ Object



130
131
132
# File 'lib/torch/tensor.rb', line 130

def /(other)
  div(other)
end

#<=>(other) ⇒ Object



146
147
148
# File 'lib/torch/tensor.rb', line 146

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp



151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# File 'lib/torch/tensor.rb', line 151

def [](*indexes)
  result = self
  dim = 0
  indexes.each do |index|
    if index.is_a?(Numeric)
      result = result._select_int(dim, index)
    elsif index.is_a?(Range)
      finish = index.end
      finish += 1 unless index.exclude_end?
      result = result._slice_tensor(dim, index.begin, finish, 1)
      dim += 1
    elsif index.nil?
      result = result.unsqueeze(dim)
      dim += 1
    elsif index == true
      result = result.unsqueeze(dim)
      # TODO handle false
    else
      raise Error, "Unsupported index type: #{index.class.name}"
    end
  end
  result
end

#[]=(index, value) ⇒ Object

TODO based on python_variable_indexing.cpp

Raises:

  • (ArgumentError)


177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# File 'lib/torch/tensor.rb', line 177

def []=(index, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?

  value = Torch.tensor(value) unless value.is_a?(Tensor)

  if index.is_a?(Numeric)
    copy_to(_select_int(0, index), value)
  elsif index.is_a?(Range)
    finish = index.end
    finish += 1 unless index.exclude_end?
    copy_to(_slice_tensor(0, index.begin, finish, 1), value)
  else
    raise Error, "Unsupported index type: #{index.class.name}"
  end
end

#add!(value = 1, other) ⇒ Object

value and other are swapped for some methods



110
111
112
113
114
115
116
# File 'lib/torch/tensor.rb', line 110

def add!(value = 1, other)
  if other.is_a?(Numeric)
    _add__scalar(other, value)
  else
    _add__tensor(other, value)
  end
end

#backward(gradient = nil) ⇒ Object



73
74
75
# File 'lib/torch/tensor.rb', line 73

def backward(gradient = nil)
  _backward(gradient)
end

#cpuObject



36
37
38
# File 'lib/torch/tensor.rb', line 36

def cpu
  to("cpu")
end

#cudaObject



40
41
42
# File 'lib/torch/tensor.rb', line 40

def cuda
  to("cuda")
end

#dtypeObject

Raises:



12
13
14
15
16
# File 'lib/torch/tensor.rb', line 12

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#itemObject



61
62
63
64
65
66
# File 'lib/torch/tensor.rb', line 61

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  _flat_data.first
end

#layoutObject



18
19
20
# File 'lib/torch/tensor.rb', line 18

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



57
58
59
# File 'lib/torch/tensor.rb', line 57

def length
  size(0)
end

#newObject

unsure if this is correct



69
70
71
# File 'lib/torch/tensor.rb', line 69

def new
  Torch.empty(0, dtype: dtype)
end

#new_ones(*size, **options) ⇒ Object



84
85
86
# File 'lib/torch/tensor.rb', line 84

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



78
79
80
81
82
# File 'lib/torch/tensor.rb', line 78

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.cast(_flat_data).reshape(*shape)
end

#random!(from = 0, to) ⇒ Object



193
194
195
# File 'lib/torch/tensor.rb', line 193

def random!(from = 0, to)
  _random__from_to(from, to)
end

#requires_grad!(requires_grad = true) ⇒ Object



88
89
90
# File 'lib/torch/tensor.rb', line 88

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#reshape(*size) ⇒ Object



98
99
100
101
102
# File 'lib/torch/tensor.rb', line 98

def reshape(*size)
  # Python doesn't check if size == 1, just ignores later arguments
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _reshape(size)
end

#shapeObject



52
53
54
# File 'lib/torch/tensor.rb', line 52

def shape
  dim.times.map { |i| size(i) }
end

#size(dim = nil) ⇒ Object



44
45
46
47
48
49
50
# File 'lib/torch/tensor.rb', line 44

def size(dim = nil)
  if dim
    _size_int(dim)
  else
    shape
  end
end

#to(device, non_blocking: false, copy: false) ⇒ Object

TODO support dtype



31
32
33
34
# File 'lib/torch/tensor.rb', line 31

def to(device, non_blocking: false, copy: false)
  device = Device.new(device) if device.is_a?(String)
  _to(device, _dtype, non_blocking, copy)
end

#to_aObject



26
27
28
# File 'lib/torch/tensor.rb', line 26

def to_a
  reshape_arr(_flat_data, shape)
end

#to_sObject



22
23
24
# File 'lib/torch/tensor.rb', line 22

def to_s
  inspect
end

#type(dtype) ⇒ Object

Raises:



92
93
94
95
96
# File 'lib/torch/tensor.rb', line 92

def type(dtype)
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum
  _type(enum)
end

#view(*size) ⇒ Object



104
105
106
107
# File 'lib/torch/tensor.rb', line 104

def view(*size)
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _view(size)
end