Class: DNN::Layers::GlobalAvgPool2D

Inherits:
Layer
  • Object
show all
Defined in:
lib/dnn/core/layers/cnn_layers.rb

Instance Attribute Summary

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods inherited from Layer

#<<, #built?, call, #call, #clean, #compute_output_shape, from_hash, #initialize, #load_hash, #to_hash

Constructor Details

This class inherits a constructor from DNN::Layers::Layer

Instance Method Details

#build(input_shape) ⇒ Object



423
424
425
426
427
428
# File 'lib/dnn/core/layers/cnn_layers.rb', line 423

def build(input_shape)
  unless input_shape.length == 3
    raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
  end
  super
end

#forward(x) ⇒ Object



430
431
432
# File 'lib/dnn/core/layers/cnn_layers.rb', line 430

def forward(x)
  Flatten.(AvgPool2D.(x, @input_shape[0..1]))
end