Class: DNN::Layers::TrainableLayer

Inherits:
Layer
  • Object
show all
Defined in:
lib/dnn/core/layers/basic_layers.rb

Overview

This class is a superclass of all classes with learning parameters.

Direct Known Subclasses

BatchNormalization, Connection, Embedding

Instance Attribute Summary collapse

Attributes inherited from Layer

#input_shape

Instance Method Summary collapse

Methods inherited from Layer

#backward, #build, #built?, #call, call, #forward, from_hash, #load_hash, #output_shape, #to_hash

Constructor Details

#initializeTrainableLayer

Returns a new instance of TrainableLayer.



95
96
97
98
# File 'lib/dnn/core/layers/basic_layers.rb', line 95

def initialize
  super()
  @trainable = true
end

Instance Attribute Details

#trainableBoolean

Returns Setting false prevents learning of parameters.

Returns:

  • (Boolean)

    Setting false prevents learning of parameters.



93
94
95
# File 'lib/dnn/core/layers/basic_layers.rb', line 93

def trainable
  @trainable
end

Instance Method Details

#cleanObject



105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# File 'lib/dnn/core/layers/basic_layers.rb', line 105

def clean
  input_shape = @input_shape
  hash = to_hash
  params = get_params
  instance_variables.each do |ivar|
    instance_variable_set(ivar, nil)
  end
  load_hash(hash)
  build(input_shape)
  params.each do |(key, param)|
    param.data = nil
    param.grad = Xumo::SFloat[0] if param.grad
    instance_variable_set("@#{key}", param)
  end
end

#get_paramsArray

Returns The parameters of the layer.

Returns:

  • (Array)

    The parameters of the layer.

Raises:

  • (NotImplementedError)


101
102
103
# File 'lib/dnn/core/layers/basic_layers.rb', line 101

def get_params
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'get_params'"
end