Class: Chainer::Functions::Normalization::BatchNormalizationFunction

Inherits:
Chainer::Function
  • Object
show all
Defined in:
lib/chainer/functions/normalization/batch_normalization.rb

Instance Attribute Summary collapse

Attributes inherited from Chainer::Function

#inputs, #output_data, #outputs, #rank, #retain_after_backward

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::Function

#call, #forward_cpu, #retain_inputs, #retain_outputs

Constructor Details

#initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) ⇒ BatchNormalizationFunction

Returns a new instance of BatchNormalizationFunction.



26
27
28
29
30
31
32
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 26

def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) 
  @running_mean = mean
  @running_var = var
  @eps = eps
  @mean_cache = nil
  @decay = decay
end

Instance Attribute Details

#running_meanObject (readonly)

Returns the value of attribute running_mean.



5
6
7
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 5

def running_mean
  @running_mean
end

#running_varObject (readonly)

Returns the value of attribute running_var.



5
6
7
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 5

def running_var
  @running_var
end

Class Method Details

.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5) ⇒ Object

Batch normalization function with fixed statistics. This is a variant of batch normalization, where the mean and variance statistics are given by the caller as fixed variables. This is used on testing mode of the batch normalization layer, where batch statistics cannot be used for prediction consistency.

Parameters:



18
19
20
21
22
23
24
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 18

def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
  old_train = Chainer.configuration.train
  Chainer.configuration.train = false
  norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var)
  Chainer.configuration.train = old_train
  norm
end

Instance Method Details

#backward(inputs, grad_outputs) ⇒ Object



93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 93

def backward(inputs, grad_outputs)
  x, gamma = inputs[0], inputs[1]
  gy = grad_outputs[0]
  head_ndim = gamma.ndim + 1
  m = gamma.class[x.size.div(gamma.size)][0]
  axis = [0] + (head_ndim...(x.ndim)).to_a

  if inputs.size == 5
    mean = inputs[3]
    var = inputs[4]
    std = Numo::NMath.sqrt(var)
    gs = gamma / std
    gbeta = gy.sum(axis: axis)

    mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
    x_mu = x - mean.reshape(*mean_expander)
    std_expander = [1] + std.shape + [1] * (x.ndim - head_ndim)
    x_mu /= std.reshape(*std_expander)
    x_hat = x_mu
    ggamma = (gy * x_hat).sum(axis: axis)
    gmean = -gs * gbeta
    gvar = -0.5 * gamma / var * ggamma
    gs_expander = [1] + gs.shape + [1] * (x.ndim - head_ndim)
    gx = gs.reshape(*gs_expander)
    return [gx, ggamma, gbeta, gmean, gvar]
  end

  gbeta = gy.sum(axis: axis)
  ggamma = (gy * @x_hat).sum(axis: axis)
  tmp = (gamma / @std)
  tmp_expander = [1] + tmp.shape + [1] * (x.ndim - head_ndim)
  tmp = tmp.reshape(*tmp_expander)

  ggamma_expander = [1] + ggamma.shape + [1] * (x.ndim - head_ndim)
  gbeta_expander = [1] + gbeta.shape + [1] * (x.ndim - head_ndim)
  
  gx = tmp * (gy - (@x_hat * ggamma.reshape(*ggamma_expander) + gbeta.reshape(*gbeta_expander)) / m )

  [gx, ggamma, gbeta]
end

#forward(inputs) ⇒ Object



34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 34

def forward(inputs)
  x, gamma, beta = inputs[0], inputs[1], inputs[2]
  if Chainer.configuration.train
    if @running_mean.nil?
      @running_mean = Numo::NArray[*gamma].new_zeros
      @running_var = Numo::NArray[*gamma].new_zeros
    else
      @running_mean = Numo::NArray[*@running_mean]
      @running_var = Numo::NArray[*@running_var]
    end
  elsif inputs.size == 5
    @fixed_mean = inputs[3]
    @fixed_var = inputs[4]
  end

  head_ndim = gamma.ndim + 1
  gamma_expander = [1] + gamma.shape + [1] * (x.ndim - head_ndim)
  gamma = gamma.reshape(*gamma_expander)
  beta_expander = [1] + beta.shape + [1] * (x.ndim - head_ndim)
  beta = beta.reshape(*beta_expander)

  if Chainer.configuration.train
    axis = [0] + (head_ndim...(x.ndim)).to_a
    mean = x.mean(axis: axis)
    # FIXME: numpy.var
    var = x.var(axis: axis)
    var += @eps
  else
    mean = @fixed_mean
    var = @fixed_var + @eps
  end

  @std = Numo::NMath.sqrt(var)

  mean_expander = [1] + mean.shape + [1] * (x.ndim - head_ndim)
  x_mu = x - mean.reshape(*mean_expander)
  std_expander = [1] + @std.shape + [1] * (x.ndim - head_ndim)
  x_mu /= @std.reshape(*std_expander)
  @x_hat = x_mu
  y = gamma * @x_hat
  y += beta

  if Chainer.configuration.train
    m = x.size.div(gamma.size)
    adjust = m / [m - 1.0, 1.0].max
    @running_mean *= @decay
    temp_ar = Numo::NArray[*mean]
    temp_ar *= (1 - @decay)
    @running_mean += temp_ar
    
    @running_var *= @decay
    temp_ar = Numo::NArray[*var]
    temp_ar *= ((1 - @decay) * adjust)
    @running_var += temp_ar
  end

  [y,]
end