Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



8
9
10
# File 'lib/torch/tensor.rb', line 8

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#%(other) ⇒ Object



126
127
128
# File 'lib/torch/tensor.rb', line 126

def %(other)
  remainder(other)
end

#*(other) ⇒ Object



118
119
120
# File 'lib/torch/tensor.rb', line 118

def *(other)
  mul(other)
end

#**(other) ⇒ Object



130
131
132
# File 'lib/torch/tensor.rb', line 130

def **(other)
  pow(other)
end

#+(other) ⇒ Object



110
111
112
# File 'lib/torch/tensor.rb', line 110

def +(other)
  add(other)
end

#-(other) ⇒ Object



114
115
116
# File 'lib/torch/tensor.rb', line 114

def -(other)
  sub(other)
end

#-@Object



134
135
136
# File 'lib/torch/tensor.rb', line 134

def -@
  neg
end

#/(other) ⇒ Object



122
123
124
# File 'lib/torch/tensor.rb', line 122

def /(other)
  div(other)
end

#<=>(other) ⇒ Object



138
139
140
# File 'lib/torch/tensor.rb', line 138

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp



143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# File 'lib/torch/tensor.rb', line 143

def [](*indexes)
  result = self
  dim = 0
  indexes.each do |index|
    if index.is_a?(Numeric)
      result = result._select_int(dim, index)
    elsif index.is_a?(Range)
      finish = index.end
      finish += 1 unless index.exclude_end?
      result = result._slice_tensor(dim, index.begin, finish, 1)
      dim += 1
    elsif index.nil?
      result = result.unsqueeze(dim)
      dim += 1
    elsif index == true
      result = result.unsqueeze(dim)
      # TODO handle false
    else
      raise Error, "Unsupported index type: #{index.class.name}"
    end
  end
  result
end

#[]=(index, value) ⇒ Object

TODO based on python_variable_indexing.cpp

Raises:

  • (ArgumentError)


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# File 'lib/torch/tensor.rb', line 169

def []=(index, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?

  value = Torch.tensor(value) unless value.is_a?(Tensor)

  if index.is_a?(Numeric)
    copy_to(_select_int(0, index), value)
  elsif index.is_a?(Range)
    finish = index.end
    finish += 1 unless index.exclude_end?
    copy_to(_slice_tensor(0, index.begin, finish, 1), value)
  else
    raise Error, "Unsupported index type: #{index.class.name}"
  end
end

#add!(value = 1, other) ⇒ Object

value and other are swapped for some methods



102
103
104
105
106
107
108
# File 'lib/torch/tensor.rb', line 102

def add!(value = 1, other)
  if other.is_a?(Numeric)
    _add__scalar(other, value)
  else
    _add__tensor(other, value)
  end
end

#backward(gradient = nil) ⇒ Object



65
66
67
# File 'lib/torch/tensor.rb', line 65

def backward(gradient = nil)
  _backward(gradient)
end

#dtypeObject

Raises:



12
13
14
15
16
# File 'lib/torch/tensor.rb', line 12

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#itemObject



53
54
55
56
57
58
# File 'lib/torch/tensor.rb', line 53

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  _flat_data.first
end

#layoutObject



18
19
20
# File 'lib/torch/tensor.rb', line 18

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



49
50
51
# File 'lib/torch/tensor.rb', line 49

def length
  size(0)
end

#newObject

unsure if this is correct



61
62
63
# File 'lib/torch/tensor.rb', line 61

def new
  Torch.empty(0, dtype: dtype)
end

#new_ones(*size, **options) ⇒ Object



76
77
78
# File 'lib/torch/tensor.rb', line 76

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



70
71
72
73
74
# File 'lib/torch/tensor.rb', line 70

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.cast(_flat_data).reshape(*shape)
end

#requires_grad!(requires_grad = true) ⇒ Object



80
81
82
# File 'lib/torch/tensor.rb', line 80

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#reshape(*size) ⇒ Object



90
91
92
93
94
# File 'lib/torch/tensor.rb', line 90

def reshape(*size)
  # Python doesn't check if size == 1, just ignores later arguments
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _reshape(size)
end

#shapeObject



44
45
46
# File 'lib/torch/tensor.rb', line 44

def shape
  dim.times.map { |i| size(i) }
end

#size(dim = nil) ⇒ Object



36
37
38
39
40
41
42
# File 'lib/torch/tensor.rb', line 36

def size(dim = nil)
  if dim
    _size_int(dim)
  else
    shape
  end
end

#to(device, non_blocking: false, copy: false) ⇒ Object

TODO support dtype



31
32
33
34
# File 'lib/torch/tensor.rb', line 31

def to(device, non_blocking: false, copy: false)
  device = Device.new(device) if device.is_a?(String)
  _to(device, _dtype, non_blocking, copy)
end

#to_aObject



26
27
28
# File 'lib/torch/tensor.rb', line 26

def to_a
  reshape_arr(_flat_data, shape)
end

#to_sObject



22
23
24
# File 'lib/torch/tensor.rb', line 22

def to_s
  inspect
end

#type(dtype) ⇒ Object

Raises:



84
85
86
87
88
# File 'lib/torch/tensor.rb', line 84

def type(dtype)
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum
  _type(enum)
end

#view(*size) ⇒ Object



96
97
98
99
# File 'lib/torch/tensor.rb', line 96

def view(*size)
  size = size.first if size.size == 1 && size.first.is_a?(Array)
  _view(size)
end