Class: DNN::Layers::LSTM_Dense
- Inherits:
-
Object
- Object
- DNN::Layers::LSTM_Dense
- Defined in:
- lib/dnn/core/rnn_layers.rb
Instance Method Summary collapse
- #backward(dh2, dcell2) ⇒ Object
- #forward(x, h, cell) ⇒ Object
-
#initialize(params, grads) ⇒ LSTM_Dense
constructor
A new instance of LSTM_Dense.
Constructor Details
#initialize(params, grads) ⇒ LSTM_Dense
Returns a new instance of LSTM_Dense.
171 172 173 174 175 176 177 178 179 |
# File 'lib/dnn/core/rnn_layers.rb', line 171 def initialize(params, grads) @params = params @grads = grads @tanh = Tanh.new @g_tanh = Tanh.new @forget_sigmoid = Sigmoid.new @in_sigmoid = Sigmoid.new @out_sigmoid = Sigmoid.new end |
Instance Method Details
#backward(dh2, dcell2) ⇒ Object
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# File 'lib/dnn/core/rnn_layers.rb', line 199 def backward(dh2, dcell2) dh2_tmp = @tanh_cell2 * dh2 dcell2_tmp = @tanh.backward(@out * dh2) + dcell2 dout = @out_sigmoid.backward(dh2_tmp) din = @in_sigmoid.backward(dcell2_tmp * @g) dg = @g_tanh.backward(dcell2_tmp * @in) dforget = @forget_sigmoid.backward(dcell2_tmp * @cell) da = Xumo::SFloat.hstack([dforget, dg, din, dout]) @grads[:weight] += @x.transpose.dot(da) @grads[:weight2] += @h.transpose.dot(da) @grads[:bias] += da.sum(0) dx = da.dot(@params[:weight].transpose) dh = da.dot(@params[:weight2].transpose) dcell = dcell2_tmp * @forget [dx, dh, dcell] end |
#forward(x, h, cell) ⇒ Object
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# File 'lib/dnn/core/rnn_layers.rb', line 181 def forward(x, h, cell) @x = x @h = h @cell = cell num_nodes = h.shape[1] a = x.dot(@params[:weight]) + h.dot(@params[:weight2]) + @params[:bias] @forget = @forget_sigmoid.forward(a[true, 0...num_nodes]) @g = @g_tanh.forward(a[true, num_nodes...(num_nodes * 2)]) @in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)]) @out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1]) cell2 = @forget * cell + @g * @in @tanh_cell2 = @tanh.forward(cell2) h2 = @out * @tanh_cell2 [h2, cell2] end |