Class: DNN::Layers::BatchNormalization

Inherits:
TrainableLayer show all
Includes:
LayerNode
Defined in:
lib/dnn/core/layers/normalizations.rb

Instance Attribute Summary collapse

Attributes inherited from TrainableLayer

#trainable

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods included from LayerNode

#forward

Methods inherited from TrainableLayer

#clean

Methods inherited from Layer

#<<, #built?, #call, call, #clean, #compute_output_shape, #forward, from_hash

Constructor Details

#initialize(axis: 0, momentum: 0.9, eps: 1e-7) ⇒ BatchNormalization

Returns a new instance of BatchNormalization.

Parameters:

  • axis (Integer) (defaults to: 0)

    The axis to normalization.

  • momentum (Float) (defaults to: 0.9)

    Exponential moving average of mean and variance.

  • eps (Float) (defaults to: 1e-7)

    Value to avoid division by zero.



18
19
20
21
22
23
24
25
26
27
# File 'lib/dnn/core/layers/normalizations.rb', line 18

def initialize(axis: 0, momentum: 0.9, eps: 1e-7)
  super()
  @axis = axis
  @momentum = momentum
  @eps = eps
  @gamma = Param.new(nil, Xumo::SFloat[0])
  @beta = Param.new(nil, Xumo::SFloat[0])
  @running_mean = Param.new
  @running_var = Param.new
end

Instance Attribute Details

#axisObject (readonly)

Returns the value of attribute axis.



11
12
13
# File 'lib/dnn/core/layers/normalizations.rb', line 11

def axis
  @axis
end

#betaObject (readonly)

Returns the value of attribute beta.



8
9
10
# File 'lib/dnn/core/layers/normalizations.rb', line 8

def beta
  @beta
end

#epsObject

Returns the value of attribute eps.



13
14
15
# File 'lib/dnn/core/layers/normalizations.rb', line 13

def eps
  @eps
end

#gammaObject (readonly)

Returns the value of attribute gamma.



7
8
9
# File 'lib/dnn/core/layers/normalizations.rb', line 7

def gamma
  @gamma
end

#momentumObject

Returns the value of attribute momentum.



12
13
14
# File 'lib/dnn/core/layers/normalizations.rb', line 12

def momentum
  @momentum
end

#running_meanObject (readonly)

Returns the value of attribute running_mean.



9
10
11
# File 'lib/dnn/core/layers/normalizations.rb', line 9

def running_mean
  @running_mean
end

#running_varObject (readonly)

Returns the value of attribute running_var.



10
11
12
# File 'lib/dnn/core/layers/normalizations.rb', line 10

def running_var
  @running_var
end

Instance Method Details

#backward_node(dy) ⇒ Object



54
55
56
57
58
59
60
61
62
63
64
65
66
67
# File 'lib/dnn/core/layers/normalizations.rb', line 54

def backward_node(dy)
  batch_size = dy.shape[@axis]
  if @trainable
    @beta.grad = dy.sum(axis: @axis, keepdims: true)
    @gamma.grad = (@xn * dy).sum(axis: @axis, keepdims: true)
  end
  dxn = @gamma.data * dy
  dxc = dxn / @std
  dstd = -((dxn * @xc) / (@std**2)).sum(axis: @axis, keepdims: true)
  dvar = 0.5 * dstd / @std
  dxc += (2.0 / batch_size) * @xc * dvar
  dmean = dxc.sum(axis: @axis, keepdims: true)
  dxc - dmean / batch_size
end

#build(input_shape) ⇒ Object



29
30
31
32
33
34
35
# File 'lib/dnn/core/layers/normalizations.rb', line 29

def build(input_shape)
  super
  @gamma.data = Xumo::SFloat.ones(*@output_shape)
  @beta.data = Xumo::SFloat.zeros(*@output_shape)
  @running_mean.data = Xumo::SFloat.zeros(*@output_shape)
  @running_var.data = Xumo::SFloat.zeros(*@output_shape)
end

#forward_node(x) ⇒ Object



37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# File 'lib/dnn/core/layers/normalizations.rb', line 37

def forward_node(x)
  if DNN.learning_phase
    mean = x.mean(axis: @axis, keepdims: true)
    @xc = x - mean
    var = (@xc**2).mean(axis: @axis, keepdims: true)
    @std = Xumo::NMath.sqrt(var + @eps)
    xn = @xc / @std
    @xn = xn
    @running_mean.data = @momentum * @running_mean.data + (1 - @momentum) * mean
    @running_var.data = @momentum * @running_var.data + (1 - @momentum) * var
  else
    xc = x - @running_mean.data
    xn = xc / Xumo::NMath.sqrt(@running_var.data + @eps)
  end
  @gamma.data * xn + @beta.data
end

#get_paramsObject



77
78
79
# File 'lib/dnn/core/layers/normalizations.rb', line 77

def get_params
  { gamma: @gamma, beta: @beta, running_mean: @running_mean, running_var: @running_var }
end

#load_hash(hash) ⇒ Object



73
74
75
# File 'lib/dnn/core/layers/normalizations.rb', line 73

def load_hash(hash)
  initialize(axis: hash[:axis], momentum: hash[:momentum])
end

#to_hashObject



69
70
71
# File 'lib/dnn/core/layers/normalizations.rb', line 69

def to_hash
  super(axis: @axis, momentum: @momentum, eps: @eps)
end