Class: DNN::Layers::Pool2D

Inherits:
Layer
  • Object
show all
Includes:
Conv2DModule
Defined in:
lib/dnn/core/cnn_layers.rb

Overview

Super class of all pooling2D class.

Direct Known Subclasses

AvgPool2D, MaxPool2D

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Layer

#backward, #built?, #forward, #prev_layer

Constructor Details

#initialize(pool_size, strides: nil, padding: false) ⇒ Pool2D

Returns a new instance of Pool2D.



156
157
158
159
160
161
162
163
164
165
# File 'lib/dnn/core/cnn_layers.rb', line 156

def initialize(pool_size, strides: nil, padding: false)
  super()
  @pool_size = pool_size.is_a?(Integer) ? [pool_size, pool_size] : pool_size
  @strides = if strides
    strides.is_a?(Integer) ? [strides, strides] : strides
  else
    @pool_size.clone
  end
  @padding = padding
end

Instance Attribute Details

#pool_sizeObject (readonly)

Returns the value of attribute pool_size.



149
150
151
# File 'lib/dnn/core/cnn_layers.rb', line 149

def pool_size
  @pool_size
end

#stridesObject (readonly)

Returns the value of attribute strides.



150
151
152
# File 'lib/dnn/core/cnn_layers.rb', line 150

def strides
  @strides
end

Class Method Details

.load_hash(pool2d_class, hash) ⇒ Object



152
153
154
# File 'lib/dnn/core/cnn_layers.rb', line 152

def self.load_hash(pool2d_class, hash)
  pool2d_class.new(hash[:pool_size], strides: hash[:strides], padding: hash[:padding])
end

Instance Method Details

#build(model) ⇒ Object



167
168
169
170
171
172
173
174
175
176
177
# File 'lib/dnn/core/cnn_layers.rb', line 167

def build(model)
  super
  prev_w, prev_h = prev_layer.shape[0..1]
  @num_channel = prev_layer.shape[2]
  @out_size = out_size(prev_h, prev_w, *@pool_size, @strides)
  out_w, out_h = @out_size
  if @padding
    @pad = [prev_h - out_h, prev_w - out_w]
    @out_size = [prev_h, prev_w]
  end
end

#shapeObject



179
180
181
# File 'lib/dnn/core/cnn_layers.rb', line 179

def shape
  [*@out_size, @num_channel]
end

#to_hashObject



183
184
185
186
187
# File 'lib/dnn/core/cnn_layers.rb', line 183

def to_hash
  super({pool_size: @pool_size,
         strides: @strides,
         padding: @padding})
end