Class: Torch::Tensor
- Inherits:
-
Object
- Object
- Torch::Tensor
- Includes:
- Comparable, Inspector
- Defined in:
- lib/torch/tensor.rb
Direct Known Subclasses
Class Method Summary collapse
Instance Method Summary collapse
- #%(other) ⇒ Object
- #*(other) ⇒ Object
- #**(other) ⇒ Object
- #+(other) ⇒ Object
- #-(other) ⇒ Object
- #-@ ⇒ Object
- #/(other) ⇒ Object
- #<=>(other) ⇒ Object
- #[](*indexes) ⇒ Object
- #backward(gradient = nil) ⇒ Object
- #dtype ⇒ Object
- #item ⇒ Object
- #layout ⇒ Object
- #new_ones(*size, **options) ⇒ Object
-
#numo ⇒ Object
TODO read directly from memory.
- #requires_grad!(requires_grad = true) ⇒ Object
- #shape ⇒ Object
- #size(dim = nil) ⇒ Object
- #to_a ⇒ Object
- #to_s ⇒ Object
- #type(dtype) ⇒ Object
- #view(*size) ⇒ Object
Methods included from Inspector
Class Method Details
Instance Method Details
#%(other) ⇒ Object
114 115 116 |
# File 'lib/torch/tensor.rb', line 114 def %(other) remainder(other) end |
#*(other) ⇒ Object
106 107 108 |
# File 'lib/torch/tensor.rb', line 106 def *(other) mul(other) end |
#**(other) ⇒ Object
118 119 120 |
# File 'lib/torch/tensor.rb', line 118 def **(other) pow(other) end |
#+(other) ⇒ Object
98 99 100 |
# File 'lib/torch/tensor.rb', line 98 def +(other) add(other) end |
#-(other) ⇒ Object
102 103 104 |
# File 'lib/torch/tensor.rb', line 102 def -(other) sub(other) end |
#-@ ⇒ Object
122 123 124 |
# File 'lib/torch/tensor.rb', line 122 def -@ neg end |
#/(other) ⇒ Object
110 111 112 |
# File 'lib/torch/tensor.rb', line 110 def /(other) div(other) end |
#<=>(other) ⇒ Object
126 127 128 |
# File 'lib/torch/tensor.rb', line 126 def <=>(other) item <=> other end |
#[](*indexes) ⇒ Object
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# File 'lib/torch/tensor.rb', line 130 def [](*indexes) result = self dim = 0 indexes.each_with_index do |index| if index.is_a?(Numeric) result = result._select(dim, index) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? result = result._slice(dim, index.begin, finish, 1) dim += 1 else raise Error, "Unsupported index type" end end result end |
#backward(gradient = nil) ⇒ Object
57 58 59 60 61 62 63 |
# File 'lib/torch/tensor.rb', line 57 def backward(gradient = nil) if gradient _backward_gradient(gradient) else _backward end end |
#dtype ⇒ Object
16 17 18 19 20 |
# File 'lib/torch/tensor.rb', line 16 def dtype dtype = ENUM_TO_DTYPE[_dtype] raise Error, "Unknown type: #{_dtype}" unless dtype dtype end |
#item ⇒ Object
50 51 52 53 54 55 |
# File 'lib/torch/tensor.rb', line 50 def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" end _data.first end |
#layout ⇒ Object
22 23 24 |
# File 'lib/torch/tensor.rb', line 22 def layout _layout.downcase.to_sym end |
#new_ones(*size, **options) ⇒ Object
73 74 75 |
# File 'lib/torch/tensor.rb', line 73 def new_ones(*size, **) Torch.ones_like(Torch.empty(*size), **) end |
#numo ⇒ Object
TODO read directly from memory
66 67 68 69 70 71 |
# File 'lib/torch/tensor.rb', line 66 def numo raise Error, "Numo not found" unless defined?(Numo::NArray) cls = Torch._dtype_to_numo[dtype] raise Error, "Cannot convert #{dtype} to Numo" unless cls cls.cast(_data).reshape(*shape) end |
#requires_grad!(requires_grad = true) ⇒ Object
77 78 79 |
# File 'lib/torch/tensor.rb', line 77 def requires_grad!(requires_grad = true) _requires_grad!(requires_grad) end |
#shape ⇒ Object
42 43 44 |
# File 'lib/torch/tensor.rb', line 42 def shape dim.times.map { |i| size(i) } end |
#size(dim = nil) ⇒ Object
34 35 36 37 38 39 40 |
# File 'lib/torch/tensor.rb', line 34 def size(dim = nil) if dim _size(dim) else shape end end |
#to_a ⇒ Object
30 31 32 |
# File 'lib/torch/tensor.rb', line 30 def to_a reshape_arr(_data, shape) end |
#to_s ⇒ Object
26 27 28 |
# File 'lib/torch/tensor.rb', line 26 def to_s inspect end |
#type(dtype) ⇒ Object
81 82 83 84 85 |
# File 'lib/torch/tensor.rb', line 81 def type(dtype) enum = DTYPE_TO_ENUM[dtype] raise Error, "Unknown type: #{dtype}" unless enum _type(enum) end |
#view(*size) ⇒ Object
46 47 48 |
# File 'lib/torch/tensor.rb', line 46 def view(*size) _view(size) end |