Class: Secryst::ClipGradNorm

Inherits:
Torch::NN::F
  • Object
show all
Defined in:
lib/secryst/clip_grad_norm.rb

Class Method Summary collapse

Class Method Details

.clip_grad_norm(parameters, max_norm:, norm_type: 2) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/secryst/clip_grad_norm.rb', line 5

def self.clip_grad_norm(parameters, max_norm:, norm_type:2)
  parameters = parameters.select {|p| p.grad }
  max_norm = max_norm.to_f
  if parameters.length == 0
    return Torch.tensor(0.0)
  end
  device = parameters[0].grad.device
  if norm_type == Float::INFINITY
    # ... TODO
  else
    total_norm = Numo::Linalg.norm(Numo::NArray.concatenate(parameters.map {|p| Numo::Linalg.norm(p.grad.detach.numo, norm_type)}), norm_type)
  end
  clip_coef = max_norm / (total_norm + 1e-6)
  if clip_coef < 1
    parameters.each {|p| p.grad = p.grad.detach * clip_coef}
  end

  return total_norm
end