Class: RubyZero::Core::Functions::DotProduct

Inherits:
Function
  • Object
show all
Defined in:
lib/rubyzero/core/functions/tensor_functions.rb

Instance Attribute Summary

Attributes inherited from Function

#inputs, #output

Instance Method Summary collapse

Methods inherited from Function

#call, #inspect, plot

Constructor Details

#initializeDotProduct

Returns a new instance of DotProduct.



82
83
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 82

def initialize()
end

Instance Method Details

#backward(dy) ⇒ Object



89
90
91
92
93
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 89

def backward(dy)
    x1, x2 = @inputs
    dx, dy = [dy.dot(x2.swapaxes(0,1)), x1.swapaxes(0,1).dot(dy)]
    return dx, dy
end

#forward(x1, x2) ⇒ Object



84
85
86
87
88
# File 'lib/rubyzero/core/functions/tensor_functions.rb', line 84

def forward(x1, x2)
    arr = x1.data.dot(x2.data)
    new_t = RubyZero::Core::Tensor.new(arr, device: x1.device)
    return new_t
end