Class: DNN::Layers::BatchNormalization
- Inherits:
-
HasParamLayer
- Object
- Layer
- HasParamLayer
- DNN::Layers::BatchNormalization
- Defined in:
- lib/dnn/core/normalizations.rb
Instance Attribute Summary collapse
-
#axis ⇒ Object
readonly
Returns the value of attribute axis.
-
#beta ⇒ Object
readonly
Returns the value of attribute beta.
-
#eps ⇒ Object
Returns the value of attribute eps.
-
#gamma ⇒ Object
readonly
Returns the value of attribute gamma.
-
#momentum ⇒ Object
Returns the value of attribute momentum.
-
#running_mean ⇒ Object
readonly
Returns the value of attribute running_mean.
-
#running_var ⇒ Object
readonly
Returns the value of attribute running_var.
Attributes inherited from HasParamLayer
Attributes inherited from Layer
Instance Method Summary collapse
- #backward(dy) ⇒ Object
- #build(input_shape) ⇒ Object
- #forward(x) ⇒ Object
- #get_params ⇒ Object
-
#initialize(axis: 0, momentum: 0.9, eps: 1e-7) ⇒ BatchNormalization
constructor
A new instance of BatchNormalization.
- #load_hash(hash) ⇒ Object
- #to_hash ⇒ Object
Methods inherited from Layer
#built?, #call, call, from_hash, #output_shape
Constructor Details
#initialize(axis: 0, momentum: 0.9, eps: 1e-7) ⇒ BatchNormalization
16 17 18 19 20 21 |
# File 'lib/dnn/core/normalizations.rb', line 16 def initialize(axis: 0, momentum: 0.9, eps: 1e-7) super() @axis = axis @momentum = momentum @eps = eps end |
Instance Attribute Details
#axis ⇒ Object (readonly)
Returns the value of attribute axis.
9 10 11 |
# File 'lib/dnn/core/normalizations.rb', line 9 def axis @axis end |
#beta ⇒ Object (readonly)
Returns the value of attribute beta.
6 7 8 |
# File 'lib/dnn/core/normalizations.rb', line 6 def beta @beta end |
#eps ⇒ Object
Returns the value of attribute eps.
11 12 13 |
# File 'lib/dnn/core/normalizations.rb', line 11 def eps @eps end |
#gamma ⇒ Object (readonly)
Returns the value of attribute gamma.
5 6 7 |
# File 'lib/dnn/core/normalizations.rb', line 5 def gamma @gamma end |
#momentum ⇒ Object
Returns the value of attribute momentum.
10 11 12 |
# File 'lib/dnn/core/normalizations.rb', line 10 def momentum @momentum end |
#running_mean ⇒ Object (readonly)
Returns the value of attribute running_mean.
7 8 9 |
# File 'lib/dnn/core/normalizations.rb', line 7 def running_mean @running_mean end |
#running_var ⇒ Object (readonly)
Returns the value of attribute running_var.
8 9 10 |
# File 'lib/dnn/core/normalizations.rb', line 8 def running_var @running_var end |
Instance Method Details
#backward(dy) ⇒ Object
48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# File 'lib/dnn/core/normalizations.rb', line 48 def backward(dy) batch_size = dy.shape[@axis] if @trainable @beta.grad = dy.sum(axis: @axis, keepdims: true) @gamma.grad = (@xn * dy).sum(axis: @axis, keepdims: true) end dxn = @gamma.data * dy dxc = dxn / @std dstd = -((dxn * @xc) / (@std**2)).sum(axis: @axis, keepdims: true) dvar = 0.5 * dstd / @std dxc += (2.0 / batch_size) * @xc * dvar dmean = dxc.sum(axis: @axis, keepdims: true) dxc - dmean / batch_size end |
#build(input_shape) ⇒ Object
23 24 25 26 27 28 29 |
# File 'lib/dnn/core/normalizations.rb', line 23 def build(input_shape) super @gamma = Param.new(Xumo::SFloat.ones(*output_shape), Xumo::SFloat[0]) @beta = Param.new(Xumo::SFloat.zeros(*output_shape), Xumo::SFloat[0]) @running_mean = Param.new(Xumo::SFloat.zeros(*output_shape)) @running_var = Param.new(Xumo::SFloat.zeros(*output_shape)) end |
#forward(x) ⇒ Object
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# File 'lib/dnn/core/normalizations.rb', line 31 def forward(x) if DNN.learning_phase mean = x.mean(axis: @axis, keepdims: true) @xc = x - mean var = (@xc**2).mean(axis: @axis, keepdims: true) @std = Xumo::NMath.sqrt(var + @eps) xn = @xc / @std @xn = xn @running_mean.data = @momentum * @running_mean.data + (1 - @momentum) * mean @running_var.data = @momentum * @running_var.data + (1 - @momentum) * var else xc = x - @running_mean.data xn = xc / Xumo::NMath.sqrt(@running_var.data + @eps) end @gamma.data * xn + @beta.data end |
#get_params ⇒ Object
71 72 73 |
# File 'lib/dnn/core/normalizations.rb', line 71 def get_params { gamma: @gamma, beta: @beta, running_mean: @running_mean, running_var: @running_var } end |
#load_hash(hash) ⇒ Object
67 68 69 |
# File 'lib/dnn/core/normalizations.rb', line 67 def load_hash(hash) initialize(axis: hash[:axis], momentum: hash[:momentum]) end |
#to_hash ⇒ Object
63 64 65 |
# File 'lib/dnn/core/normalizations.rb', line 63 def to_hash super(axis: @axis, momentum: @momentum, eps: @eps) end |