Class: NN::BatchNorm

Inherits:
Object
  • Object
show all
Includes:
Numo
Defined in:
lib/nn.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(nn, index) ⇒ BatchNorm

Returns a new instance of BatchNorm.



412
413
414
415
# File 'lib/nn.rb', line 412

def initialize(nn, index)
  @nn = nn
  @index = index
end

Instance Attribute Details

#d_betaObject (readonly)

Returns the value of attribute d_beta.



410
411
412
# File 'lib/nn.rb', line 410

def d_beta
  @d_beta
end

#d_gammaObject (readonly)

Returns the value of attribute d_gamma.



409
410
411
# File 'lib/nn.rb', line 409

def d_gamma
  @d_gamma
end

Instance Method Details

#backward(dout) ⇒ Object



428
429
430
431
432
433
434
435
436
437
438
439
# File 'lib/nn.rb', line 428

def backward(dout)
  @d_beta = dout.sum(0).mean
  @d_gamma = (@xn * dout).sum(0).mean
  dxn = @nn.gammas[@index] * dout
  dxc = dxn / @std
  dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
  dvar = 0.5 * dstd / @std
  dxc += (2.0 / @nn.batch_size) * @xc * dvar
  dmean = dxc.sum(0)
  dx = dxc - dmean / @nn.batch_size
  dx.reshape(*@x.shape)
end

#forward(x) ⇒ Object



417
418
419
420
421
422
423
424
425
426
# File 'lib/nn.rb', line 417

def forward(x)
  @x = x
  @mean = x.mean(0)
  @xc = x - @mean
  @var = (@xc ** 2).mean(0)
  @std = NMath.sqrt(@var + 1e-7)
  @xn = @xc / @std
  out = @nn.gammas[@index] * @xn + @nn.betas[@index]
  out.reshape(*@x.shape)
end