Class: DNN::Layers::LSTM_Dense
- Inherits:
-
Object
- Object
- DNN::Layers::LSTM_Dense
- Defined in:
- lib/dnn/core/rnn_layers.rb
Instance Method Summary collapse
- #backward(dh2, dc2) ⇒ Object
- #forward(x, h, c) ⇒ Object
-
#initialize(rnn) ⇒ LSTM_Dense
constructor
A new instance of LSTM_Dense.
Constructor Details
#initialize(rnn) ⇒ LSTM_Dense
Returns a new instance of LSTM_Dense.
213 214 215 216 217 218 219 220 |
# File 'lib/dnn/core/rnn_layers.rb', line 213 def initialize(rnn) @rnn = rnn @tanh = Tanh.new @g_tanh = Tanh.new @forget_sigmoid = Sigmoid.new @in_sigmoid = Sigmoid.new @out_sigmoid = Sigmoid.new end |
Instance Method Details
#backward(dh2, dc2) ⇒ Object
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
# File 'lib/dnn/core/rnn_layers.rb', line 240 def backward(dh2, dc2) dh2_tmp = @tanh_c2 * dh2 dc2_tmp = @tanh.backward(@out * dh2) + dc2 dout = @out_sigmoid.backward(dh2_tmp) din = @in_sigmoid.backward(dc2_tmp * @g) dg = @g_tanh.backward(dc2_tmp * @in) dforget = @forget_sigmoid.backward(dc2_tmp * @c) da = Xumo::SFloat.hstack([dforget, dg, din, dout]) @rnn.weight.grad += @x.transpose.dot(da) @rnn.weight2.grad += @h.transpose.dot(da) if @rnn.l1_lambda > 0 @rnn.weight.grad += dlasso @rnn.weight2.grad += dlasso2 elsif @rnn.l2_lambda > 0 @rnn.weight.grad += dridge @rnn.weight2.grad += dridge2 end @rnn.bias.grad += da.sum(0) dx = da.dot(@rnn.weight.data.transpose) dh = da.dot(@rnn.weight2.data.transpose) dc = dc2_tmp * @forget [dx, dh, dc] end |
#forward(x, h, c) ⇒ Object
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# File 'lib/dnn/core/rnn_layers.rb', line 222 def forward(x, h, c) @x = x @h = h @c = c num_nodes = h.shape[1] a = x.dot(@rnn.weight.data) + h.dot(@rnn.weight2.data) + @rnn.bias.data @forget = @forget_sigmoid.forward(a[true, 0...num_nodes]) @g = @g_tanh.forward(a[true, num_nodes...(num_nodes * 2)]) @in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)]) @out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1]) c2 = @forget * c + @g * @in @tanh_c2 = @tanh.forward(c2) h2 = @out * @tanh_c2 [h2, c2] end |