Class: Rumale::Optimizer::Nadam

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/optimizer/nadam.rb

Overview

Nadam is a class that implements Nadam optimizer.

Reference

    1. Dozat, “Incorporating Nesterov Momentum into Adam,” Tech. Repo. Stanford University, 2015.

Examples:

optimizer = Rumale::Optimizer::Nadam.new(learning_rate: 0.01, decay1: 0.9, decay2: 0.999)
estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
estimator.fit(samples, values)

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(learning_rate: 0.01, decay1: 0.9, decay2: 0.999) ⇒ Nadam

Create a new optimizer with Nadam

Parameters:

  • learning_rate (Float) (defaults to: 0.01)

    The initial value of learning rate.

  • decay1 (Float) (defaults to: 0.9)

    The smoothing parameter for the first moment.

  • decay2 (Float) (defaults to: 0.999)

    The smoothing parameter for the second moment.



27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/rumale/optimizer/nadam.rb', line 27

def initialize(learning_rate: 0.01, decay1: 0.9, decay2: 0.999)
  check_params_float(learning_rate: learning_rate, decay1: decay1, decay2: decay2)
  check_params_positive(learning_rate: learning_rate, decay1: decay1, decay2: decay2)
  @params = {}
  @params[:learning_rate] = learning_rate
  @params[:decay1] = decay1
  @params[:decay2] = decay2
  @fst_moment = nil
  @sec_moment = nil
  @decay1_prod = 1.0
  @iter = 0
end

Instance Method Details

#call(weight, gradient) ⇒ Numo::DFloat

Calculate the updated weight with Nadam adaptive learning rate.

Parameters:

  • weight (Numo::DFloat)

    (shape: [n_features]) The weight to be updated.

  • gradient (Numo::DFloat)

    (shape: [n_features]) The gradient for updating the weight.

Returns:

  • (Numo::DFloat)

    (shape: [n_feautres]) The updated weight.



45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/rumale/optimizer/nadam.rb', line 45

def call(weight, gradient)
  @fst_moment ||= Numo::DFloat.zeros(weight.shape[0])
  @sec_moment ||= Numo::DFloat.zeros(weight.shape[0])

  @iter += 1

  decay1_curr = @params[:decay1] * (1.0 - 0.5 * 0.96**(@iter * 0.004))
  decay1_next = @params[:decay1] * (1.0 - 0.5 * 0.96**((@iter + 1) * 0.004))
  decay1_prod_curr = @decay1_prod * decay1_curr
  decay1_prod_next = @decay1_prod * decay1_curr * decay1_next
  @decay1_prod = decay1_prod_curr

  @fst_moment = @params[:decay1] * @fst_moment + (1.0 - @params[:decay1]) * gradient
  @sec_moment = @params[:decay2] * @sec_moment + (1.0 - @params[:decay2]) * gradient**2
  nm_gradient = gradient / (1.0 - decay1_prod_curr)
  nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next)
  nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter)

  weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment)
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data.



68
69
70
71
72
73
74
# File 'lib/rumale/optimizer/nadam.rb', line 68

def marshal_dump
  { params: @params,
    fst_moment: @fst_moment,
    sec_moment: @sec_moment,
    decay1_prod: @decay1_prod,
    iter: @iter }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


78
79
80
81
82
83
84
85
# File 'lib/rumale/optimizer/nadam.rb', line 78

def marshal_load(obj)
  @params = obj[:params]
  @fst_moment = obj[:fst_moment]
  @sec_moment = obj[:sec_moment]
  @decay1_prod = obj[:decay1_prod]
  @iter = obj[:iter]
  nil
end