Class: SVMKit::Optimizer::Nadam

Inherits:
Object
  • Object
show all
Defined in:
lib/svmkit/optimizer/nadam.rb

Overview

Nadam is a class that implements Nadam optimizer. This class is used for internal processes.

Reference

    1. Dozat, “Incorporating Nesterov Momentum into Adam,” Tech. Repo. Stanford University, 2015.

Instance Method Summary collapse

Constructor Details

#initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999) ⇒ Nadam

Create a new optimizer with Nadam

Parameters:

  • learning_rate (Float) (defaults to: 0.01)

    The initial value of learning rate.

  • momentum (Float) (defaults to: 0.9)

    The initial value of momentum.

  • decay1 (Float) (defaults to: 0.9)

    The smoothing parameter for the first moment.

  • decay2 (Float) (defaults to: 0.999)

    The smoothing parameter for the second moment.

  • schedule_decay (Float)

    The smooting parameter.



23
24
25
26
27
28
29
30
31
32
33
34
35
# File 'lib/svmkit/optimizer/nadam.rb', line 23

def initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
  check_params_float(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
  check_params_positive(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
  @params = {}
  @params[:learning_rate] = learning_rate
  @params[:momentum] = momentum
  @params[:decay1] = decay1
  @params[:decay2] = decay2
  @fst_moment = nil
  @sec_moment = nil
  @decay1_prod = 1.0
  @iter = 0
end

Instance Method Details

#call(weight, gradient) ⇒ Numo::DFloat

Calculate the updated weight with Nadam adaptive learning rate.

Parameters:

  • weight (Numo::DFloat)

    (shape: [n_features]) The weight to be updated.

  • gradient (Numo::DFloat)

    (shape: [n_features]) The gradient for updating the weight.

Returns:

  • (Numo::DFloat)

    (shape: [n_feautres]) The updated weight.



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# File 'lib/svmkit/optimizer/nadam.rb', line 42

def call(weight, gradient)
  @fst_moment ||= Numo::DFloat.zeros(weight.shape[0])
  @sec_moment ||= Numo::DFloat.zeros(weight.shape[0])

  @iter += 1

  decay1_curr = @params[:decay1] * (1.0 - 0.5 * 0.96**(@iter * 0.004))
  decay1_next = @params[:decay1] * (1.0 - 0.5 * 0.96**((@iter + 1) * 0.004))
  decay1_prod_curr = @decay1_prod * decay1_curr
  decay1_prod_next = @decay1_prod * decay1_curr * decay1_next
  @decay1_prod = decay1_prod_curr

  @fst_moment = @params[:decay1] * @fst_moment + (1.0 - @params[:decay1]) * gradient
  @sec_moment = @params[:decay2] * @sec_moment + (1.0 - @params[:decay2]) * gradient**2
  nm_gradient = gradient / (1.0 - decay1_prod_curr)
  nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next)
  nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter)

  weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment)
end