Class: DNN::Layers::BatchNormalization

Inherits:
HasParamLayer show all
Defined in:
lib/dnn/core/layers.rb

Instance Attribute Summary collapse

Attributes inherited from HasParamLayer

#params, #trainable

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from HasParamLayer

#build, #update

Methods inherited from Layer

#build, #built?, #prev_layer, #shape

Constructor Details

#initialize(momentum: 0.9) ⇒ BatchNormalization

Returns a new instance of BatchNormalization.



339
340
341
342
# File 'lib/dnn/core/layers.rb', line 339

def initialize(momentum: 0.9)
  super()
  @momentum = momentum
end

Instance Attribute Details

#momentumObject (readonly)

Returns the value of attribute momentum.



333
334
335
# File 'lib/dnn/core/layers.rb', line 333

def momentum
  @momentum
end

Class Method Details

.load_hash(hash) ⇒ Object



335
336
337
# File 'lib/dnn/core/layers.rb', line 335

def self.load_hash(hash)
  self.new(momentum: hash[:momentum])
end

Instance Method Details

#backward(dout) ⇒ Object



361
362
363
364
365
366
367
368
369
370
371
372
# File 'lib/dnn/core/layers.rb', line 361

def backward(dout)
  batch_size = dout.shape[0]
  @beta.grad = dout.sum(0)
  @gamma.grad = (@xn * dout).sum(0)
  dxn = @gamma.data * dout
  dxc = dxn / @std
  dstd = -((dxn * @xc) / (@std**2)).sum(0)
  dvar = 0.5 * dstd / @std
  dxc += (2.0 / batch_size) * @xc * dvar
  dmean = dxc.sum(0)
  dxc - dmean / batch_size
end

#forward(x) ⇒ Object



344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# File 'lib/dnn/core/layers.rb', line 344

def forward(x)
  if @model.training?
    mean = x.mean(0)
    @xc = x - mean
    var = (@xc**2).mean(0)
    @std = Xumo::NMath.sqrt(var + 1e-7)
    xn = @xc / @std
    @xn = xn
    @params[:running_mean] = @momentum * @params[:running_mean] + (1 - @momentum) * mean
    @params[:running_var] = @momentum * @params[:running_var] + (1 - @momentum) * var
  else
    xc = x - @params[:running_mean]
    xn = xc / Xumo::NMath.sqrt(@params[:running_var] + 1e-7)
  end
  @gamma.data * xn + @beta.data
end

#to_hashObject



374
375
376
# File 'lib/dnn/core/layers.rb', line 374

def to_hash
  super({momentum: @momentum})
end