Class: Rumale::Optimizer::RMSProp

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/optimizer/rmsprop.rb

Overview

RMSProp is a class that implements RMSProp optimizer.

Reference

    1. Sutskever, J. Martens, G. Dahl, and G. Hinton, “On the importance of initialization and momentum in deep learning,” Proc. ICML’ 13, pp. 1139–1147, 2013.

    1. Hinton, N. Srivastava, and K. Swersky, “Lecture 6e rmsprop,” Neural Networks for Machine Learning, 2012.

Examples:

optimizer = Rumale::Optimizer::RMSProp.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
estimator.fit(samples, values)

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9) ⇒ RMSProp

Create a new optimizer with RMSProp.

Parameters:

  • learning_rate (Float) (defaults to: 0.01)

    The initial value of learning rate.

  • momentum (Float) (defaults to: 0.9)

    The initial value of momentum.

  • decay (Float) (defaults to: 0.9)

    The smooting parameter.



27
28
29
30
31
32
33
34
35
36
# File 'lib/rumale/optimizer/rmsprop.rb', line 27

def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
  check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
  check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
  @params = {}
  @params[:learning_rate] = learning_rate
  @params[:momentum] = momentum
  @params[:decay] = decay
  @moment = nil
  @update = nil
end

Instance Method Details

#call(weight, gradient) ⇒ Numo::DFloat

Calculate the updated weight with RMSProp adaptive learning rate.

Parameters:

  • weight (Numo::DFloat)

    (shape: [n_features]) The weight to be updated.

  • gradient (Numo::DFloat)

    (shape: [n_features]) The gradient for updating the weight.

Returns:

  • (Numo::DFloat)

    (shape: [n_feautres]) The updated weight.



43
44
45
46
47
48
49
# File 'lib/rumale/optimizer/rmsprop.rb', line 43

def call(weight, gradient)
  @moment ||= Numo::DFloat.zeros(weight.shape[0])
  @update ||= Numo::DFloat.zeros(weight.shape[0])
  @moment = @params[:decay] * @moment + (1.0 - @params[:decay]) * gradient**2
  @update = @params[:momentum] * @update - (@params[:learning_rate] / (@moment**0.5 + 1.0e-8)) * gradient
  weight + @update
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data.



53
54
55
56
57
# File 'lib/rumale/optimizer/rmsprop.rb', line 53

def marshal_dump
  { params: @params,
    moment: @moment,
    update: @update }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


61
62
63
64
65
66
# File 'lib/rumale/optimizer/rmsprop.rb', line 61

def marshal_load(obj)
  @params = obj[:params]
  @moment = obj[:moment]
  @update = obj[:update]
  nil
end