Class: DNN::Layers::GRUCell
- Defined in:
- lib/dnn/core/layers/rnn_layers.rb
Instance Attribute Summary
Attributes inherited from RNNCell
Instance Method Summary collapse
- #backward(dh2) ⇒ Object
- #forward(x, h) ⇒ Object
-
#initialize(weight, recurrent_weight, bias) ⇒ GRUCell
constructor
A new instance of GRUCell.
Constructor Details
#initialize(weight, recurrent_weight, bias) ⇒ GRUCell
Returns a new instance of GRUCell.
370 371 372 373 374 375 |
# File 'lib/dnn/core/layers/rnn_layers.rb', line 370 def initialize(weight, recurrent_weight, bias) super(weight, recurrent_weight, bias) @update_sigmoid = Layers::Sigmoid.new @reset_sigmoid = Layers::Sigmoid.new @tanh = Layers::Tanh.new end |
Instance Method Details
#backward(dh2) ⇒ Object
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 |
# File 'lib/dnn/core/layers/rnn_layers.rb', line 400 def backward(dh2) dtanh_h = @tanh.backward_node(dh2 * (1 - @update)) dh = dh2 * @update if @trainable dweight_h = @x.transpose.dot(dtanh_h) dweight2_h = (@h * @reset).transpose.dot(dtanh_h) dbias_h = dtanh_h.sum(0) if @bias end dx = dtanh_h.dot(@weight_h.transpose) dh += dtanh_h.dot(@weight2_h.transpose) * @reset dreset = @reset_sigmoid.backward_node(dtanh_h.dot(@weight2_h.transpose) * @h) dupdate = @update_sigmoid.backward_node(dh2 * @h - dh2 * @tanh_h) da = Xumo::SFloat.hstack([dupdate, dreset]) if @trainable dweight_a = @x.transpose.dot(da) dweight2_a = @h.transpose.dot(da) dbias_a = da.sum(0) if @bias end dx += da.dot(@weight_a.transpose) dh += da.dot(@weight2_a.transpose) if @trainable @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h]) @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h]) @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias end [dx, dh] end |
#forward(x, h) ⇒ Object
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 |
# File 'lib/dnn/core/layers/rnn_layers.rb', line 377 def forward(x, h) @x = x @h = h num_units = h.shape[1] @weight_a = @weight.data[true, 0...(num_units * 2)] @weight2_a = @recurrent_weight.data[true, 0...(num_units * 2)] a = x.dot(@weight_a) + h.dot(@weight2_a) a += @bias.data[0...(num_units * 2)] if @bias @update = @update_sigmoid.forward_node(a[true, 0...num_units]) @reset = @reset_sigmoid.forward_node(a[true, num_units..-1]) @weight_h = @weight.data[true, (num_units * 2)..-1] @weight2_h = @recurrent_weight.data[true, (num_units * 2)..-1] @tanh_h = if @bias bias_h = @bias.data[(num_units * 2)..-1] @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h) else @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h)) end h2 = (1 - @update) * @tanh_h + @update * h h2 end |