Class: DNN::Layers::TrainableLayer
- Defined in:
- lib/dnn/core/layers/basic_layers.rb
Overview
This class is a superclass of all classes with learning parameters.
Direct Known Subclasses
Instance Attribute Summary collapse
-
#trainable ⇒ Boolean
Setting false prevents learning of parameters.
Attributes inherited from Layer
Instance Method Summary collapse
- #clean ⇒ Object
-
#get_params ⇒ Array
The parameters of the layer.
-
#initialize ⇒ TrainableLayer
constructor
A new instance of TrainableLayer.
Methods inherited from Layer
#backward, #build, #built?, #call, call, #forward, from_hash, #load_hash, #output_shape, #to_hash
Constructor Details
#initialize ⇒ TrainableLayer
Returns a new instance of TrainableLayer.
95 96 97 98 |
# File 'lib/dnn/core/layers/basic_layers.rb', line 95 def initialize super() @trainable = true end |
Instance Attribute Details
#trainable ⇒ Boolean
Returns Setting false prevents learning of parameters.
93 94 95 |
# File 'lib/dnn/core/layers/basic_layers.rb', line 93 def trainable @trainable end |
Instance Method Details
#clean ⇒ Object
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# File 'lib/dnn/core/layers/basic_layers.rb', line 105 def clean input_shape = @input_shape hash = to_hash params = get_params instance_variables.each do |ivar| instance_variable_set(ivar, nil) end load_hash(hash) build(input_shape) params.each do |(key, param)| param.data = nil param.grad = Xumo::SFloat[0] if param.grad instance_variable_set("@#{key}", param) end end |
#get_params ⇒ Array
Returns The parameters of the layer.
101 102 103 |
# File 'lib/dnn/core/layers/basic_layers.rb', line 101 def get_params raise NotImplementedError, "Class '#{self.class.name}' has implement method 'get_params'" end |