Class: RubyZero::Core::Functions::SwapAxes

Inherits:
Function
  • Object
show all
Defined in:
lib/rubyzero/core/functions/tensor_functions.rb

Instance Attribute Summary

Attributes inherited from Function

#inputs, #output

Instance Method Summary collapse

Methods inherited from Function

#call, #inspect, plot

Constructor Details

#initialize(axis1, axis2) ⇒ SwapAxes

Returns a new instance of SwapAxes.



17
18
19
20
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 17

def initialize(axis1, axis2)
    @axis1 = axis1
    @axis2 = axis2
end

Instance Method Details

#backward(dy) ⇒ Object



27
28
29
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 27

def backward(dy)
    return [dy.swapaxes(@axis1, @axis2)]
end

#forward(x1) ⇒ Object



21
22
23
24
25
26
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 21

def forward(x1)
    new_arr = x1.data.swapaxes(@axis1, @axis2)
    @prev_shape = x1.shape
    new_t = RubyZero::Core::Tensor.new(new_arr, device: x1.device)
    return new_t
end