Class: DNN::Layers::Reshape

Inherits:
Layer
  • Object
show all
Defined in:
lib/dnn/core/layers.rb

Instance Attribute Summary

Attributes inherited from Layer

#input_shape

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Layer

#build, #built?

Constructor Details

#initialize(output_shape) ⇒ Reshape

Returns a new instance of Reshape.



245
246
247
248
# File 'lib/dnn/core/layers.rb', line 245

def initialize(output_shape)
  super()
  @output_shape = output_shape
end

Class Method Details

.load_hash(hash) ⇒ Object



241
242
243
# File 'lib/dnn/core/layers.rb', line 241

def self.load_hash(hash)
  self.new(hash[:output_shape])
end

Instance Method Details

#backward(dout) ⇒ Object



254
255
256
# File 'lib/dnn/core/layers.rb', line 254

def backward(dout)
  dout.reshape(dout.shape[0], *@input_shape)
end

#forward(x) ⇒ Object



250
251
252
# File 'lib/dnn/core/layers.rb', line 250

def forward(x)
  x.reshape(x.shape[0], *@output_shape)
end

#output_shapeObject



258
259
260
# File 'lib/dnn/core/layers.rb', line 258

def output_shape
  @output_shape
end

#to_hashObject



262
263
264
# File 'lib/dnn/core/layers.rb', line 262

def to_hash
  super({output_shape: @output_shape})
end