Class: DNN::Layers::SimpleRNN

Inherits:
RNN show all
Defined in:
lib/dnn/core/rnn_layers.rb

Instance Attribute Summary

Attributes inherited from RNN

#h, #num_nodes, #stateful, #weight_decay

Attributes inherited from HasParamLayer

#grads, #params, #trainable

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from RNN

#backward, #forward, #ridge, #shape

Methods inherited from HasParamLayer

#build, #update

Methods inherited from Layer

#backward, #build, #built?, #forward, #prev_layer, #shape

Constructor Details

#initialize(num_nodes, stateful: false, return_sequences: true, activation: nil, weight_initializer: nil, bias_initializer: nil, weight_decay: 0) ⇒ SimpleRNN

Returns a new instance of SimpleRNN.



132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# File 'lib/dnn/core/rnn_layers.rb', line 132

def initialize(num_nodes,
               stateful: false,
               return_sequences: true,
               activation: nil,
               weight_initializer: nil,
               bias_initializer: nil,
               weight_decay: 0)
  super(num_nodes,
        stateful: stateful,
        return_sequences: return_sequences,
        weight_initializer: weight_initializer,
        bias_initializer: bias_initializer,
        weight_decay: weight_decay)
  @activation = (activation || Tanh.new)
end

Class Method Details

.load_hash(hash) ⇒ Object



122
123
124
125
126
127
128
129
130
# File 'lib/dnn/core/rnn_layers.rb', line 122

def self.load_hash(hash)
  self.new(hash[:num_nodes],
           stateful: hash[:stateful],
           return_sequences: hash[:return_sequences],
           activation: Util.load_hash(hash[:activation]),
           weight_initializer: Util.load_hash(hash[:weight_initializer]),
           bias_initializer: Util.load_hash(hash[:bias_initializer]),
           weight_decay: hash[:weight_decay])
end

Instance Method Details

#to_hashObject



148
149
150
# File 'lib/dnn/core/rnn_layers.rb', line 148

def to_hash
  super({activation: @activation.to_hash})
end