Class: DNN::Losses::MeanSquaredError

Inherits:
Loss
  • Object
show all
Defined in:
lib/dnn/core/losses.rb

Instance Method Summary collapse

Methods inherited from Loss

#d_regularize, #regularize, #to_hash

Instance Method Details

#backward(y) ⇒ Object



37
38
39
# File 'lib/dnn/core/losses.rb', line 37

def backward(y)
  @out - y
end

#forward(out, y) ⇒ Object



31
32
33
34
35
# File 'lib/dnn/core/losses.rb', line 31

def forward(out, y)
  @out = out
  batch_size = y.shape[0]
  0.5 * ((out - y)**2).sum / batch_size
end