Class: DNN::Layers::TrainableLayer

Inherits:
Layer
  • Object
show all
Defined in:
lib/dnn/core/layers/basic_layers.rb

Overview

This class is a superclass of all classes with learning parameters.

Direct Known Subclasses

BatchNormalization, Connection, Embedding

Instance Attribute Summary collapse

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods inherited from Layer

#<<, #build, #built?, #call, call, #compute_output_shape, #forward, from_hash, #load_hash, #to_hash

Constructor Details

#initializeTrainableLayer

Returns a new instance of TrainableLayer.



114
115
116
117
# File 'lib/dnn/core/layers/basic_layers.rb', line 114

def initialize
  super()
  @trainable = true
end

Instance Attribute Details

#trainableBoolean

Returns Setting false prevents learning of parameters.

Returns:

  • (Boolean)

    Setting false prevents learning of parameters.



112
113
114
# File 'lib/dnn/core/layers/basic_layers.rb', line 112

def trainable
  @trainable
end

Instance Method Details

#cleanObject



124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# File 'lib/dnn/core/layers/basic_layers.rb', line 124

def clean
  input_shape = @input_shape
  hash = to_hash
  params = get_params
  instance_variables.each do |ivar|
    instance_variable_set(ivar, nil)
  end
  load_hash(hash)
  build(input_shape)
  params.each do |(key, param)|
    param.data = nil
    param.grad = Xumo::SFloat[0] if param.grad
    instance_variable_set("@#{key}", param)
  end
end

#get_paramsArray

Returns The parameters of the layer.

Returns:

  • (Array)

    The parameters of the layer.

Raises:

  • (NotImplementedError)


120
121
122
# File 'lib/dnn/core/layers/basic_layers.rb', line 120

def get_params
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'get_params'"
end