Class: DNN::Optimizers::Optimizer

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/optimizers.rb

Overview

Super class of all optimizer classes.

Direct Known Subclasses

AdaDelta, AdaGrad, Adam, RMSProp, RMSPropGraves, SGD

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(clip_norm: nil) ⇒ Optimizer



30
31
32
# File 'lib/dnn/core/optimizers.rb', line 30

def initialize(clip_norm: nil)
  @clip_norm = clip_norm
end

Instance Attribute Details

#clip_normObject

Returns the value of attribute clip_norm.



7
8
9
# File 'lib/dnn/core/optimizers.rb', line 7

def clip_norm
  @clip_norm
end

#statusObject (readonly)

Returns the value of attribute status.



6
7
8
# File 'lib/dnn/core/optimizers.rb', line 6

def status
  @status
end

Class Method Details

.from_hash(hash) ⇒ Object

Raises:



9
10
11
12
13
14
15
16
# File 'lib/dnn/core/optimizers.rb', line 9

def self.from_hash(hash)
  return nil unless hash
  optimizer_class = DNN.const_get(hash[:class])
  optimizer = optimizer_class.allocate
  raise DNN_Error, "#{optimizer.class} is not an instance of #{self} class." unless optimizer.is_a?(self)
  optimizer.load_hash(hash)
  optimizer
end

.load(dumped) ⇒ Object



18
19
20
21
22
23
24
25
26
27
# File 'lib/dnn/core/optimizers.rb', line 18

def self.load(dumped)
  opt = from_hash(dumped[:hash])
  return opt unless dumped[:status]
  dumped[:status].each do |key, state|
    state = state.clone
    opt.status[key] = state
    opt.instance_variable_set("@#{key}", state)
  end
  opt
end

Instance Method Details

#dump(require_status = true) ⇒ Object



46
47
48
49
# File 'lib/dnn/core/optimizers.rb', line 46

def dump(require_status = true)
  status = require_status ? @status : nil
  { hash: to_hash, status: status }
end

#load_hash(hash) ⇒ Object



72
73
74
# File 'lib/dnn/core/optimizers.rb', line 72

def load_hash(hash)
  initialize(clip_norm: hash[:clip_norm])
end

#to_hash(merge_hash = nil) ⇒ Object



51
52
53
54
55
# File 'lib/dnn/core/optimizers.rb', line 51

def to_hash(merge_hash = nil)
  hash = { class: self.class.name, clip_norm: @clip_norm }
  hash.merge!(merge_hash) if merge_hash
  hash
end

#update(layers) ⇒ Object

Update layers has params.



35
36
37
38
39
40
41
42
43
44
# File 'lib/dnn/core/optimizers.rb', line 35

def update(layers)
  target_params = layers.select { |layer| layer.is_a?(Layers::HasParamLayer) && layer.trainable }
                        .map { |layer| layer.get_params.values }.flatten.compact
                        .select(&:grad)
  clip_grads(target_params) if @clip_norm
  update_params(target_params)
  target_params.each do |param|
    param.grad = Xumo::SFloat[0]
  end
end