Class: Rumale::Optimizer::YellowFin

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/optimizer/yellow_fin.rb

Overview

YellowFin is a class that implements YellowFin optimizer.

Reference

    1. Zhang and I. Mitliagkas, “YellowFin and the Art of Momentum Tuning,” CoRR abs/1706.03471, 2017.

Examples:

optimizer = Rumale::Optimizer::YellowFin.new(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
estimator.fit(samples, values)

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20) ⇒ YellowFin

Create a new optimizer with YellowFin.

Parameters:

  • learning_rate (Float) (defaults to: 0.01)

    The initial value of learning rate.

  • momentum (Float) (defaults to: 0.9)

    The initial value of momentum.

  • decay (Float) (defaults to: 0.999)

    The smooting parameter.

  • window_width (Integer) (defaults to: 20)

    The sliding window width for searching curvature range.



27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# File 'lib/rumale/optimizer/yellow_fin.rb', line 27

def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20)
  check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
  check_params_integer(window_width: window_width)
  check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay, window_width: window_width)
  @params = {}
  @params[:learning_rate] = learning_rate
  @params[:momentum] = momentum
  @params[:decay] = decay
  @params[:window_width] = window_width
  @smth_learning_rate = learning_rate
  @smth_momentum = momentum
  @grad_norms = nil
  @grad_norm_min = 0.0
  @grad_norm_max = 0.0
  @grad_mean_sqr = 0.0
  @grad_mean = 0.0
  @grad_var = 0.0
  @grad_norm_mean = 0.0
  @curve_mean = 0.0
  @distance_mean = 0.0
  @update = nil
end

Instance Method Details

#call(weight, gradient) ⇒ Numo::DFloat

Calculate the updated weight with adaptive momentum coefficient and learning rate.

Parameters:

  • weight (Numo::DFloat)

    (shape: [n_features]) The weight to be updated.

  • gradient (Numo::DFloat)

    (shape: [n_features]) The gradient for updating the weight.

Returns:

  • (Numo::DFloat)

    (shape: [n_feautres]) The updated weight.



55
56
57
58
59
60
61
62
63
64
# File 'lib/rumale/optimizer/yellow_fin.rb', line 55

def call(weight, gradient)
  @update ||= Numo::DFloat.zeros(weight.shape[0])
  curvature_range(gradient)
  gradient_variance(gradient)
  distance_to_optimum(gradient)
  @smth_momentum = @params[:decay] * @smth_momentum + (1 - @params[:decay]) * current_momentum
  @smth_learning_rate = @params[:decay] * @smth_learning_rate + (1 - @params[:decay]) * current_learning_rate
  @update = @smth_momentum * @update - @smth_learning_rate * gradient
  weight + @update
end