Class: Torch::NN::BatchNorm

Inherits:
Module
  • Object
show all
Defined in:
lib/torch/nn/batch_norm.rb

Direct Known Subclasses

BatchNorm1d, BatchNorm2d, BatchNorm3d, InstanceNorm

Instance Attribute Summary

Attributes inherited from Module

#training

Instance Method Summary collapse

Methods inherited from Module

#_apply, #add_module, #apply, #buffers, #call, #children, #cpu, #cuda, #deep_dup, #double, #eval, #float, #half, #inspect, #load_state_dict, #method_missing, #modules, #named_buffers, #named_children, #named_modules, #named_parameters, #parameters, #register_buffer, #register_parameter, #requires_grad!, #respond_to?, #share_memory, #state_dict, #to, #train, #type, #zero_grad

Methods included from Utils

#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true) ⇒ BatchNorm

Returns a new instance of BatchNorm.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# File 'lib/torch/nn/batch_norm.rb', line 4

def initialize(num_features, eps: 1e-5, momentum: 0.1, affine: true, track_running_stats: true)
  super()
  @num_features = num_features
  @eps = eps
  @momentum = momentum
  @affine = affine
  @track_running_stats = track_running_stats
  if @affine
    @weight = Parameter.new(Torch::Tensor.new(num_features))
    @bias = Parameter.new(Torch::Tensor.new(num_features))
  else
    register_parameter("weight", nil)
    register_parameter("bias", nil)
  end
  if track_running_stats
    register_buffer("running_mean", Torch.zeros(num_features))
    register_buffer("running_var", Torch.ones(num_features))
    register_buffer("num_batches_tracked", Torch.tensor(0, dtype: :long))
  else
    register_parameter("running_mean", nil)
    register_parameter("running_var", nil)
    register_parameter("num_batches_tracked", nil)
  end
  reset_parameters
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method in the class Torch::NN::Module

Instance Method Details

#extra_inspectObject



74
75
76
77
# File 'lib/torch/nn/batch_norm.rb', line 74

def extra_inspect
  s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
  format(s, **dict)
end

#forward(input) ⇒ Object



46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# File 'lib/torch/nn/batch_norm.rb', line 46

def forward(input)
  _check_input_dim(input)

  if @momentum.nil?
    exponential_average_factor = 0.0
  else
    exponential_average_factor = @momentum
  end

  if @training and @track_running_stats
    if @num_batches_tracked.nil?
      @num_batches_tracked += 1
      if @momentum.nil?
        exponential_average_factor = 1.0 / @num_batches_tracked.to_f
      else
        exponential_average_factor = @momentum
      end
    end
  end

  F.batch_norm(
    input, @running_mean, @running_var,
    weight: @weight, bias: @bias,
    training: @training || !@track_running_stats,
    momentum: exponential_average_factor, eps: @eps
  )
end

#reset_parametersObject



38
39
40
41
42
43
44
# File 'lib/torch/nn/batch_norm.rb', line 38

def reset_parameters
  reset_running_stats
  if @affine
    Init.ones!(@weight)
    Init.zeros!(@bias)
  end
end

#reset_running_statsObject



30
31
32
33
34
35
36
# File 'lib/torch/nn/batch_norm.rb', line 30

def reset_running_stats
  if @track_running_stats
    @running_mean.zero!
    @running_var.fill!(1)
    @num_batches_tracked.zero!
  end
end