Class: Chainer::Functions::Normalization::BatchNormalizationFunction
- Inherits:
-
Chainer::Function
- Object
- Chainer::Function
- Chainer::Functions::Normalization::BatchNormalizationFunction
- Defined in:
- lib/chainer/functions/normalization/batch_normalization.rb
Instance Attribute Summary collapse
-
#running_mean ⇒ Object
readonly
Returns the value of attribute running_mean.
-
#running_var ⇒ Object
readonly
Returns the value of attribute running_var.
Attributes inherited from Chainer::Function
#inputs, #output_data, #outputs, #rank, #retain_after_backward
Class Method Summary collapse
-
.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5) ⇒ Object
Batch normalization function with fixed statistics.
Instance Method Summary collapse
- #backward(inputs, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) ⇒ BatchNormalizationFunction
constructor
A new instance of BatchNormalizationFunction.
Methods inherited from Chainer::Function
#call, #forward_cpu, #retain_inputs, #retain_outputs
Constructor Details
#initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) ⇒ BatchNormalizationFunction
Returns a new instance of BatchNormalizationFunction.
26 27 28 29 30 31 32 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 26 def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) @running_mean = mean @running_var = var @eps = eps @mean_cache = nil @decay = decay end |
Instance Attribute Details
#running_mean ⇒ Object (readonly)
Returns the value of attribute running_mean.
5 6 7 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 5 def running_mean @running_mean end |
#running_var ⇒ Object (readonly)
Returns the value of attribute running_var.
5 6 7 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 5 def running_var @running_var end |
Class Method Details
.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5) ⇒ Object
Batch normalization function with fixed statistics. This is a variant of batch normalization, where the mean and variance statistics are given by the caller as fixed variables. This is used on testing mode of the batch normalization layer, where batch statistics cannot be used for prediction consistency.
18 19 20 21 22 23 24 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 18 def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5) old_train = Chainer.configuration.train Chainer.configuration.train = false norm = self.new(eps: eps, mean: nil, var: nil, decay: 0.0).(x, gamma, beta, mean, var) Chainer.configuration.train = old_train norm end |
Instance Method Details
#backward(inputs, grad_outputs) ⇒ Object
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 93 def backward(inputs, grad_outputs) x, gamma = inputs[0], inputs[1] gy = grad_outputs[0] head_ndim = gamma.ndim + 1 m = gamma.class[x.size.div(gamma.size)][0] axis = [0] + (head_ndim...(x.ndim)).to_a if inputs.size == 5 mean = inputs[3] var = inputs[4] std = Numo::NMath.sqrt(var) gs = gamma / std gbeta = gy.sum(axis: axis) = [1] + mean.shape + [1] * (x.ndim - head_ndim) x_mu = x - mean.reshape(*) = [1] + std.shape + [1] * (x.ndim - head_ndim) x_mu /= std.reshape(*) x_hat = x_mu ggamma = (gy * x_hat).sum(axis: axis) gmean = -gs * gbeta gvar = -0.5 * gamma / var * ggamma = [1] + gs.shape + [1] * (x.ndim - head_ndim) gx = gs.reshape(*) return [gx, ggamma, gbeta, gmean, gvar] end gbeta = gy.sum(axis: axis) ggamma = (gy * @x_hat).sum(axis: axis) tmp = (gamma / @std) = [1] + tmp.shape + [1] * (x.ndim - head_ndim) tmp = tmp.reshape(*) = [1] + ggamma.shape + [1] * (x.ndim - head_ndim) = [1] + gbeta.shape + [1] * (x.ndim - head_ndim) gx = tmp * (gy - (@x_hat * ggamma.reshape(*) + gbeta.reshape(*)) / m ) [gx, ggamma, gbeta] end |
#forward(inputs) ⇒ Object
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 34 def forward(inputs) x, gamma, beta = inputs[0], inputs[1], inputs[2] if Chainer.configuration.train if @running_mean.nil? @running_mean = Numo::NArray[*gamma].new_zeros @running_var = Numo::NArray[*gamma].new_zeros else @running_mean = Numo::NArray[*@running_mean] @running_var = Numo::NArray[*@running_var] end elsif inputs.size == 5 @fixed_mean = inputs[3] @fixed_var = inputs[4] end head_ndim = gamma.ndim + 1 = [1] + gamma.shape + [1] * (x.ndim - head_ndim) gamma = gamma.reshape(*) = [1] + beta.shape + [1] * (x.ndim - head_ndim) beta = beta.reshape(*) if Chainer.configuration.train axis = [0] + (head_ndim...(x.ndim)).to_a mean = x.mean(axis: axis) # FIXME: numpy.var var = x.var(axis: axis) var += @eps else mean = @fixed_mean var = @fixed_var + @eps end @std = Numo::NMath.sqrt(var) = [1] + mean.shape + [1] * (x.ndim - head_ndim) x_mu = x - mean.reshape(*) = [1] + @std.shape + [1] * (x.ndim - head_ndim) x_mu /= @std.reshape(*) @x_hat = x_mu y = gamma * @x_hat y += beta if Chainer.configuration.train m = x.size.div(gamma.size) adjust = m / [m - 1.0, 1.0].max @running_mean *= @decay temp_ar = Numo::NArray[*mean] temp_ar *= (1 - @decay) @running_mean += temp_ar @running_var *= @decay temp_ar = Numo::NArray[*var] temp_ar *= ((1 - @decay) * adjust) @running_var += temp_ar end [y,] end |