Method: Transformers::Bert::BertPreTrainedModel#_init_weights

Defined in:
lib/transformers/models/bert/modeling_bert.rb

#_init_weights(mod) ⇒ Object



519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
# File 'lib/transformers/models/bert/modeling_bert.rb', line 519

def _init_weights(mod)
  if mod.is_a?(Torch::NN::Linear)
    mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
    if !mod.bias.nil?
      mod.bias.data.zero!
    end
  elsif mod.is_a?(Torch::NN::Embedding)
    mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
    if !mod.instance_variable_get(:@padding_idx).nil?
      mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
    end
  elsif mod.is_a?(Torch::NN::LayerNorm)
    mod.bias.data.zero!
    mod.weight.data.fill!(1.0)
  end
end