Class: DNN::Layers::UnPool2D

Inherits:
Layer
  • Object
show all
Defined in:
lib/dnn/core/cnn_layers.rb

Instance Attribute Summary collapse

Attributes inherited from Layer

#input_shape

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Layer

#built?

Constructor Details

#initialize(unpool_size) ⇒ UnPool2D

Returns a new instance of UnPool2D.



245
246
247
248
# File 'lib/dnn/core/cnn_layers.rb', line 245

def initialize(unpool_size)
  super()
  @unpool_size = unpool_size.is_a?(Integer) ? [unpool_size, unpool_size] : unpool_size
end

Instance Attribute Details

#unpool_sizeObject (readonly)

Returns the value of attribute unpool_size.



243
244
245
# File 'lib/dnn/core/cnn_layers.rb', line 243

def unpool_size
  @unpool_size
end

Class Method Details

.load_hash(hash) ⇒ Object



250
251
252
# File 'lib/dnn/core/cnn_layers.rb', line 250

def self.load_hash(hash)
  UnPool2D.new(hash[:unpool_size])
end

Instance Method Details

#backward(dout) ⇒ Object



272
273
274
275
276
# File 'lib/dnn/core/cnn_layers.rb', line 272

def backward(dout)
  unpool_h, unpool_w = @unpool_size
  dout = dout.reshape(dout.shape[0], @x_shape[1], unpool_h, @x_shape[2], unpool_w, @num_channel)
  dout[true, true, 0, true, 0, true].clone
end

#build(input_shape) ⇒ Object



254
255
256
257
258
259
260
261
262
# File 'lib/dnn/core/cnn_layers.rb', line 254

def build(input_shape)
  super
  prev_h, prev_w = input_shape[0..1]
  unpool_h, unpool_w = @unpool_size
  out_h = prev_h * unpool_h
  out_w = prev_w * unpool_w
  @out_size = [out_h, out_w]
  @num_channel = input_shape[2]
end

#forward(x) ⇒ Object



264
265
266
267
268
269
270
# File 'lib/dnn/core/cnn_layers.rb', line 264

def forward(x)
  @x_shape = x.shape
  unpool_h, unpool_w = @unpool_size
  x2 = Xumo::SFloat.zeros(x.shape[0], x.shape[1], unpool_h, x.shape[2], unpool_w, @num_channel)
  x2[true, true, 0, true, 0, true] = x
  x2.reshape(x.shape[0], *@out_size, x.shape[3])
end

#output_shapeObject



278
279
280
# File 'lib/dnn/core/cnn_layers.rb', line 278

def output_shape
  [*@out_size, @num_channel]
end

#to_hashObject



282
283
284
# File 'lib/dnn/core/cnn_layers.rb', line 282

def to_hash
  super({unpool_size: @unpool_size})
end