Class: Torch::Tensor

Inherits:
Object
  • Object
show all
Includes:
Comparable, Inspector
Defined in:
lib/torch/tensor.rb

Direct Known Subclasses

NN::Parameter

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Inspector

#inspect

Class Method Details

.new(*size) ⇒ Object



8
9
10
11
12
13
14
# File 'lib/torch/tensor.rb', line 8

def self.new(*size)
  if size.first.is_a?(Tensor)
    size.first
  else
    Torch.rand(*size)
  end
end

Instance Method Details

#%(other) ⇒ Object



114
115
116
# File 'lib/torch/tensor.rb', line 114

def %(other)
  remainder(other)
end

#*(other) ⇒ Object



106
107
108
# File 'lib/torch/tensor.rb', line 106

def *(other)
  mul(other)
end

#**(other) ⇒ Object



118
119
120
# File 'lib/torch/tensor.rb', line 118

def **(other)
  pow(other)
end

#+(other) ⇒ Object



98
99
100
# File 'lib/torch/tensor.rb', line 98

def +(other)
  add(other)
end

#-(other) ⇒ Object



102
103
104
# File 'lib/torch/tensor.rb', line 102

def -(other)
  sub(other)
end

#-@Object



122
123
124
# File 'lib/torch/tensor.rb', line 122

def -@
  neg
end

#/(other) ⇒ Object



110
111
112
# File 'lib/torch/tensor.rb', line 110

def /(other)
  div(other)
end

#<=>(other) ⇒ Object



126
127
128
# File 'lib/torch/tensor.rb', line 126

def <=>(other)
  item <=> other
end

#[](*indexes) ⇒ Object



130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# File 'lib/torch/tensor.rb', line 130

def [](*indexes)
  result = self
  dim = 0
  indexes.each_with_index do |index|
    if index.is_a?(Numeric)
      result = result._select(dim, index)
    elsif index.is_a?(Range)
      finish = index.end
      finish += 1 unless index.exclude_end?
      result = result._slice(dim, index.begin, finish, 1)
      dim += 1
    else
      raise Error, "Unsupported index type"
    end
  end
  result
end

#backward(gradient = nil) ⇒ Object



57
58
59
60
61
62
63
# File 'lib/torch/tensor.rb', line 57

def backward(gradient = nil)
  if gradient
    _backward_gradient(gradient)
  else
    _backward
  end
end

#dtypeObject

Raises:



16
17
18
19
20
# File 'lib/torch/tensor.rb', line 16

def dtype
  dtype = ENUM_TO_DTYPE[_dtype]
  raise Error, "Unknown type: #{_dtype}" unless dtype
  dtype
end

#itemObject



50
51
52
53
54
55
# File 'lib/torch/tensor.rb', line 50

def item
  if numel != 1
    raise Error, "only one element tensors can be converted to Ruby scalars"
  end
  _data.first
end

#layoutObject



22
23
24
# File 'lib/torch/tensor.rb', line 22

def layout
  _layout.downcase.to_sym
end

#new_ones(*size, **options) ⇒ Object



73
74
75
# File 'lib/torch/tensor.rb', line 73

def new_ones(*size, **options)
  Torch.ones_like(Torch.empty(*size), **options)
end

#numoObject

TODO read directly from memory

Raises:



66
67
68
69
70
71
# File 'lib/torch/tensor.rb', line 66

def numo
  raise Error, "Numo not found" unless defined?(Numo::NArray)
  cls = Torch._dtype_to_numo[dtype]
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
  cls.cast(_data).reshape(*shape)
end

#requires_grad!(requires_grad = true) ⇒ Object



77
78
79
# File 'lib/torch/tensor.rb', line 77

def requires_grad!(requires_grad = true)
  _requires_grad!(requires_grad)
end

#shapeObject



42
43
44
# File 'lib/torch/tensor.rb', line 42

def shape
  dim.times.map { |i| size(i) }
end

#size(dim = nil) ⇒ Object



34
35
36
37
38
39
40
# File 'lib/torch/tensor.rb', line 34

def size(dim = nil)
  if dim
    _size(dim)
  else
    shape
  end
end

#to_aObject



30
31
32
# File 'lib/torch/tensor.rb', line 30

def to_a
  reshape_arr(_data, shape)
end

#to_sObject



26
27
28
# File 'lib/torch/tensor.rb', line 26

def to_s
  inspect
end

#type(dtype) ⇒ Object

Raises:



81
82
83
84
85
# File 'lib/torch/tensor.rb', line 81

def type(dtype)
  enum = DTYPE_TO_ENUM[dtype]
  raise Error, "Unknown type: #{dtype}" unless enum
  _type(enum)
end

#view(*size) ⇒ Object



46
47
48
# File 'lib/torch/tensor.rb', line 46

def view(*size)
  _view(size)
end