Class: RubyZero::Core::Functions::Repeat

Inherits:
Function
  • Object
show all
Defined in:
lib/rubyzero/core/functions/tensor_functions.rb

Instance Attribute Summary

Attributes inherited from Function

#inputs, #output

Instance Method Summary collapse

Methods inherited from Function

#call, #inspect, plot

Constructor Details

#initialize(axis, repeats) ⇒ Repeat

Returns a new instance of Repeat.



32
33
34
35
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 32

def initialize(axis, repeats)
    @axis = axis
    @repeats = repeats
end

Instance Method Details

#backward(dy) ⇒ Object



44
45
46
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 44

def backward(dy)
    return [ dy.sum(axis: @axis) ]
end

#forward(x1) ⇒ Object



36
37
38
39
40
41
42
43
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 36

def forward(x1)
    arr = x1.data
    arr = arr.reshape(*([1] + arr.shape))
    arr = arr.repeat(@repeats, axis:0)
    arr = arr.swapaxes(0, @axis)
    new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
    return new_t
end