Class: DNN::Layers::AvgPool2D

Inherits:
Pool2D show all
Defined in:
lib/dnn/core/cnn_layers.rb

Instance Attribute Summary

Attributes inherited from Pool2D

#pool_size, #strides

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Pool2D

#build, #initialize, #shape, #to_hash

Methods inherited from Layer

#build, #built?, #initialize, #prev_layer, #shape, #to_hash

Constructor Details

This class inherits a constructor from DNN::Layers::Pool2D

Class Method Details

.load_hash(hash) ⇒ Object



242
243
244
# File 'lib/dnn/core/cnn_layers.rb', line 242

def self.load_hash(hash)
  Pool2D.load_hash(self, hash)
end

Instance Method Details

#backward(dout) ⇒ Object



251
252
253
254
255
256
257
258
259
260
# File 'lib/dnn/core/cnn_layers.rb', line 251

def backward(dout)
  row_length = @pool_size.reduce(:*)
  dout /= row_length
  davg = Xumo::SFloat.zeros(dout.size, row_length)
  row_length.times do |i|
    davg[true, i] = dout.flatten
  end
  dcol = davg.reshape(dout.shape[0..2].reduce(:*), dout.shape[3] * @pool_size.reduce(:*))
  super(dcol)
end

#forward(x) ⇒ Object



246
247
248
249
# File 'lib/dnn/core/cnn_layers.rb', line 246

def forward(x)
  col = super(x)
  col.mean(1).reshape(x.shape[0], *@out_size, x.shape[3])
end