Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Enumerable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Constant Summary

Constants included from Inspector

Inspector::PRINT_OPTS

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



28
29
30
# File 'lib/torch/tensor.rb', line 28

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#<=>(other) ⇒ Object

TODO better compare?



156
157
158
# File 'lib/torch/tensor.rb', line 156

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html



162
163
164
# File 'lib/torch/tensor.rb', line 162

def [](*indexes)
  _index(indexes)
end

#[]=(*indexes, value) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html

Raises:

  • (ArgumentError)


168
169
170
171
172
# File 'lib/torch/tensor.rb', line 168

def []=(*indexes, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
  _index_put_custom(indexes, value)
end

#coerce(other) ⇒ Object



198
199
200
201
202
203
204
# File 'lib/torch/tensor.rb', line 198

def coerce(other)
  if other.is_a?(Numeric)
    [Torch.tensor(other), self]
  else
    raise TypeError, "#{self.class} can't be coerced into #{other.class}"
  end
end

#cpuObject



83
84
85
# File 'lib/torch/tensor.rb', line 83

def cpu
  to("cpu")
end

#cudaObject



87
88
89
# File 'lib/torch/tensor.rb', line 87

def cuda
  to("cuda")
end

#dtypeObject

Raises:



32
33
34
35
36
# File 'lib/torch/tensor.rb', line 32

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#dupObject



180
181
182
183
184
# File 'lib/torch/tensor.rb', line 180

def dup
  Torch.no_grad do
    clone
  end
end

#eachObject



46
47
48
49
50
51
52
# File 'lib/torch/tensor.rb', line 46

def each
  return enum_for(:each) unless block_given?

  size(0).times do |i|
    yield self[i]
  end
end

#imagObject

not a method in native_functions.yaml attribute in Python rather than method



188
189
190
# File 'lib/torch/tensor.rb', line 188

def imag
  Torch.imag(self)
end

#itemObject



113
114
115
116
117
118
# File 'lib/torch/tensor.rb', line 113

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  to_a.first
end

#layoutObject



38
39
40
# File 'lib/torch/tensor.rb', line 38

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



108
109
110
# File 'lib/torch/tensor.rb', line 108

def length
  size(0)
end

#newObject

unsure if this is correct



129
130
131
# File 'lib/torch/tensor.rb', line 129

def new
  Torch.empty(0, dtype: dtype)
end

#numoObject

TODO read directly from memory

Raises:



134
135
136
137
138
# File 'lib/torch/tensor.rb', line 134

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.from_string(_data_str).reshape(*shape)
end

#random!(*args) ⇒ Object

parser can’t handle overlap, so need to handle manually



175
176
177
178
# File 'lib/torch/tensor.rb', line 175

def random!(*args)
  return _random!(0, *args) if args.size == 1
  _random!(*args)
end

#realObject

not a method in native_functions.yaml attribute in Python rather than method



194
195
196
# File 'lib/torch/tensor.rb', line 194

def real
  Torch.real(self)
end

#requires_grad=(requires_grad) ⇒ Object



140
141
142
# File 'lib/torch/tensor.rb', line 140

def requires_grad=(requires_grad)
  _requires_grad!(requires_grad)
end

#size(dim = nil) ⇒ Object



91
92
93
94
95
96
97
# File 'lib/torch/tensor.rb', line 91

def size(dim = nil)
  if dim
    _size(dim)
  else
    shape
  end
end

#stride(dim = nil) ⇒ Object



99
100
101
102
103
104
105
# File 'lib/torch/tensor.rb', line 99

def stride(dim = nil)
  if dim
    _stride(dim)
  else
    _strides
  end
end

#to(device = nil, dtype: nil, non_blocking: false, copy: false) ⇒ Object

Raises:



67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# File 'lib/torch/tensor.rb', line 67

def to(device = nil, dtype: nil, non_blocking: false, copy: false)
  if device.is_a?(Symbol) && !dtype
    dtype = device
    device = nil
  end

  device ||= self.device
  device = Device.new(device) if device.is_a?(String)

  dtype ||= self.dtype
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum

  _to(device, enum, non_blocking, copy)
end

#to_aObject

TODO make more performant



55
56
57
58
59
60
61
62
63
64
65
# File 'lib/torch/tensor.rb', line 55

def to_a
  arr = _flat_data
  if shape.empty?
    arr
  else
    shape[1..-1].reverse.each do |dim|
      arr = arr.each_slice(dim)
    end
    arr.to_a
  end
end

#to_fObject



124
125
126
# File 'lib/torch/tensor.rb', line 124

def to_f
  item.to_f
end

#to_iObject



120
121
122
# File 'lib/torch/tensor.rb', line 120

def to_i
  item.to_i
end

#to_sObject



42
43
44
# File 'lib/torch/tensor.rb', line 42

def to_s
  inspect
end

#type(dtype) ⇒ Object



144
145
146
147
148
149
150
151
152
153
# File 'lib/torch/tensor.rb', line 144

def type(dtype)
  if dtype.is_a?(Class)
    raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
    dtype.new(self)
  else
    enum = DTYPE_TO_ENUM[dtype]
    raise Error, "Invalid type: #{dtype}" unless enum
    _type(enum)
  end
end