Class: RubyZero::Core::Functions::Sum

Inherits:
Function
  • Object
show all
Defined in:
lib/rubyzero/core/functions/tensor_functions.rb

Instance Attribute Summary

Attributes inherited from Function

#inputs, #output

Instance Method Summary collapse

Methods inherited from Function

#call, #inspect, plot

Constructor Details

#initialize(axis) ⇒ Sum

Returns a new instance of Sum.



50
51
52
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 50

def initialize(axis)
    @axis = axis
end

Instance Method Details

#backward(dy) ⇒ Object



60
61
62
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 60

def backward(dy)
    return [dy.repeat(@repeats, axis: @axis) / @repeats]
end

#forward(x1) ⇒ Object



53
54
55
56
57
58
59
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 53

def forward(x1)
    @repeats = x1.shape[@axis]
    arr = x1.data
    arr = arr.sum(axis: @axis)
    new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
    return new_t
end