Class: RubyZero::Core::Functions::Repeat
- Defined in:
- lib/rubyzero/core/functions/tensor_functions.rb
Instance Attribute Summary
Attributes inherited from Function
Instance Method Summary collapse
- #backward(dy) ⇒ Object
- #forward(x1) ⇒ Object
-
#initialize(axis, repeats) ⇒ Repeat
constructor
A new instance of Repeat.
Methods inherited from Function
Constructor Details
#initialize(axis, repeats) ⇒ Repeat
Returns a new instance of Repeat.
32 33 34 35 |
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 32 def initialize(axis, repeats) @axis = axis @repeats = repeats end |
Instance Method Details
#backward(dy) ⇒ Object
44 45 46 |
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 44 def backward(dy) return [ dy.sum(axis: @axis) ] end |
#forward(x1) ⇒ Object
36 37 38 39 40 41 42 43 |
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 36 def forward(x1) arr = x1.data arr = arr.reshape(*([1] + arr.shape)) arr = arr.repeat(@repeats, axis:0) arr = arr.swapaxes(0, @axis) new_t = RubyZero::Core::Tensor.new(arr, device: x1.device) return new_t end |