Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Enumerable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Constant Summary

Constants included from Inspector

Inspector::PRINT_OPTS

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*args) ⇒ Object



23
24
25
# File 'lib/torch/tensor.rb', line 23

def self.new(*args)
  FloatTensor.new(*args)
end

Instance Method Details

#<=>(other) ⇒ Object

TODO better compare?



154
155
156
# File 'lib/torch/tensor.rb', line 154

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html



160
161
162
# File 'lib/torch/tensor.rb', line 160

def [](*indexes)
  _index(indexes)
end

#[]=(*indexes, value) ⇒ Object

based on python_variable_indexing.cpp and pytorch.org/cppdocs/notes/tensor_indexing.html

Raises:

  • (ArgumentError)


166
167
168
169
170
# File 'lib/torch/tensor.rb', line 166

def []=(*indexes, value)
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
  _index_put_custom(indexes, value)
end

#cpuObject



78
79
80
# File 'lib/torch/tensor.rb', line 78

def cpu
  to("cpu")
end

#cudaObject



82
83
84
# File 'lib/torch/tensor.rb', line 82

def cuda
  to("cuda")
end

#dtypeObject

Raises:



27
28
29
30
31
# File 'lib/torch/tensor.rb', line 27

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#eachObject



41
42
43
44
45
46
47
# File 'lib/torch/tensor.rb', line 41

def each
  return enum_for(:each) unless block_given?

  size(0).times do |i|
    yield self[i]
  end
end

#itemObject



107
108
109
110
111
112
# File 'lib/torch/tensor.rb', line 107

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  to_a.first
end

#layoutObject



33
34
35
# File 'lib/torch/tensor.rb', line 33

def layout
  _layout.downcase.to_sym
end

#lengthObject

mirror Python len()



103
104
105
# File 'lib/torch/tensor.rb', line 103

def length
  size(0)
end

#newObject

unsure if this is correct



123
124
125
# File 'lib/torch/tensor.rb', line 123

def new
  Torch.empty(0, dtype: dtype)
end

#new_ones(*size, **options) ⇒ Object



134
135
136
# File 'lib/torch/tensor.rb', line 134

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



128
129
130
131
132
# File 'lib/torch/tensor.rb', line 128

def numo
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.from_string(_data_str).reshape(*shape)
end

#random!(*args) ⇒ Object

parser can’t handle overlap, so need to handle manually



173
174
175
176
# File 'lib/torch/tensor.rb', line 173

def random!(*args)
  return _random!(0, *args) if args.size == 1
  _random!(*args)
end

#requires_grad!(requires_grad = true) ⇒ Object



138
139
140
# File 'lib/torch/tensor.rb', line 138

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#size(dim = nil) ⇒ Object



86
87
88
89
90
91
92
# File 'lib/torch/tensor.rb', line 86

def size(dim = nil)
  if dim
    _size(dim)
  else
    shape
  end
end

#stft(*args) ⇒ Object

center option



179
180
181
# File 'lib/torch/tensor.rb', line 179

def stft(*args)
  Torch.stft(*args)
end

#stride(dim = nil) ⇒ Object



94
95
96
97
98
99
100
# File 'lib/torch/tensor.rb', line 94

def stride(dim = nil)
  if dim
    _stride(dim)
  else
    _strides
  end
end

#to(device = nil, dtype: nil, non_blocking: false, copy: false) ⇒ Object

Raises:



62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/torch/tensor.rb', line 62

def to(device = nil, dtype: nil, non_blocking: false, copy: false)
  if device.is_a?(Symbol) && !dtype
    dtype = device
    device = nil
  end

  device ||= self.device
  device = Device.new(device) if device.is_a?(String)

  dtype ||= self.dtype
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum

  _to(device, enum, non_blocking, copy)
end

#to_aObject

TODO make more performant



50
51
52
53
54
55
56
57
58
59
60
# File 'lib/torch/tensor.rb', line 50

def to_a
  arr = _flat_data
  if shape.empty?
    arr
  else
    shape[1..-1].reverse.each do |dim|
      arr = arr.each_slice(dim)
    end
    arr.to_a
  end
end

#to_fObject



118
119
120
# File 'lib/torch/tensor.rb', line 118

def to_f
  item.to_f
end

#to_iObject



114
115
116
# File 'lib/torch/tensor.rb', line 114

def to_i
  item.to_i
end

#to_sObject



37
38
39
# File 'lib/torch/tensor.rb', line 37

def to_s
  inspect
end

#type(dtype) ⇒ Object



142
143
144
145
146
147
148
149
150
151
# File 'lib/torch/tensor.rb', line 142

def type(dtype)
  if dtype.is_a?(Class)
    raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
    dtype.new(self)
  else
    enum = DTYPE_TO_ENUM[dtype]
    raise Error, "Invalid type: #{dtype}" unless enum
    _type(enum)
  end
end