Class: Torch::Optim::Optimizer

Inherits:
Object
  • Object
show all
Defined in:
lib/torch/optim/optimizer.rb

Direct Known Subclasses

ASGD, Adadelta, Adagrad, Adam, AdamW, Adamax, RMSprop, Rprop, SGD

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(params, defaults) ⇒ Optimizer

Returns a new instance of Optimizer.



7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/torch/optim/optimizer.rb', line 7

def initialize(params, defaults)
  @defaults = defaults
  @state = Hash.new { |hash, key| hash[key] = {} }
  @param_groups = []

  param_groups = params
  if param_groups.empty?
    raise ArgumentError, "optimizer got an empty parameter list"
  end
  if !param_groups[0].is_a?(Hash)
    param_groups = [{params: param_groups}]
  end

  param_groups.each do |param_group|
    add_param_group(param_group)
  end
end

Instance Attribute Details

#param_groupsObject (readonly)

Returns the value of attribute param_groups.



5
6
7
# File 'lib/torch/optim/optimizer.rb', line 5

def param_groups
  @param_groups
end

Instance Method Details

#add_param_group(param_group) ⇒ Object



25
26
27
28
# File 'lib/torch/optim/optimizer.rb', line 25

def add_param_group(param_group)
  # TODO more advanced logic
  @param_groups << @defaults.merge(param_group)
end

#load_state_dict(state_dict) ⇒ Object

Raises:



30
31
32
# File 'lib/torch/optim/optimizer.rb', line 30

def load_state_dict(state_dict)
  raise NotImplementedYet
end

#state_dictObject

Raises:



34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/torch/optim/optimizer.rb', line 34

def state_dict
  raise NotImplementedYet

  pack_group = lambda do |group|
    packed = group.select { |k, _| k != :params }.map { |k, v| [k.to_s, v] }.to_h
    packed["params"] = group[:params].map { |p| p.object_id }
    packed
  end

  param_groups = @param_groups.map { |g| pack_group.call(g) }
  packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h

  {
    "state" => packed_state,
    "param_groups" => param_groups
  }
end

#zero_gradObject



52
53
54
55
56
57
58
59
60
61
# File 'lib/torch/optim/optimizer.rb', line 52

def zero_grad
  @param_groups.each do |group|
    group[:params].each do |p|
      if p.grad
        p.grad.detach!
        p.grad.zero!
      end
    end
  end
end