Class: Rumale::Optimizer::SGD

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/optimizer/sgd.rb

Overview

SGD is a class that implements SGD optimizer.

Examples:

optimizer = Rumale::Optimizer::SGD.new(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
estimator = Rumale::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
estimator.fit(samples, values)

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0) ⇒ SGD

Create a new optimizer with SGD.

Parameters:

  • learning_rate (Float) (defaults to: 0.01)

    The initial value of learning rate.

  • momentum (Float) (defaults to: 0.0)

    The initial value of momentum.

  • decay (Float) (defaults to: 0.0)

    The smooting parameter.



23
24
25
26
27
28
29
30
31
32
# File 'lib/rumale/optimizer/sgd.rb', line 23

def initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0)
  check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay)
  check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
  @params = {}
  @params[:learning_rate] = learning_rate
  @params[:momentum] = momentum
  @params[:decay] = decay
  @iter = 0
  @update = nil
end

Instance Method Details

#call(weight, gradient) ⇒ Numo::DFloat

Calculate the updated weight with SGD.

Parameters:

  • weight (Numo::DFloat)

    (shape: [n_features]) The weight to be updated.

  • gradient (Numo::DFloat)

    (shape: [n_features]) The gradient for updating the weight.

Returns:

  • (Numo::DFloat)

    (shape: [n_feautres]) The updated weight.



39
40
41
42
43
44
45
# File 'lib/rumale/optimizer/sgd.rb', line 39

def call(weight, gradient)
  @update ||= Numo::DFloat.zeros(weight.shape[0])
  current_learning_rate = @params[:learning_rate] / (1.0 + @params[:decay] * @iter)
  @iter += 1
  @update = @params[:momentum] * @update - current_learning_rate * gradient
  weight + @update
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data.



49
50
51
52
53
# File 'lib/rumale/optimizer/sgd.rb', line 49

def marshal_dump
  { params: @params,
    iter: @iter,
    update: @update }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


57
58
59
60
61
62
# File 'lib/rumale/optimizer/sgd.rb', line 57

def marshal_load(obj)
  @params = obj[:params]
  @iter = obj[:iter]
  @update = obj[:update]
  nil
end