Class: DNN::Layers::LSTMDense

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/rnn_layers.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(weight, recurrent_weight, bias) ⇒ LSTMDense

Returns a new instance of LSTMDense.



226
227
228
229
230
231
232
233
234
235
236
# File 'lib/dnn/core/rnn_layers.rb', line 226

def initialize(weight, recurrent_weight, bias)
  @weight = weight
  @recurrent_weight = recurrent_weight
  @bias = bias
  @tanh = Layers::Tanh.new
  @g_tanh = Layers::Tanh.new
  @forget_sigmoid = Layers::Sigmoid.new
  @in_sigmoid = Layers::Sigmoid.new
  @out_sigmoid = Layers::Sigmoid.new
  @trainable = true
end

Instance Attribute Details

#trainableObject

Returns the value of attribute trainable.



224
225
226
# File 'lib/dnn/core/rnn_layers.rb', line 224

def trainable
  @trainable
end

Instance Method Details

#backward(dh2, dc2) ⇒ Object



257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# File 'lib/dnn/core/rnn_layers.rb', line 257

def backward(dh2, dc2)
  dh2_tmp = @tanh_c2 * dh2
  dc2_tmp = @tanh.backward(@out * dh2) + dc2

  dout = @out_sigmoid.backward(dh2_tmp)
  din = @in_sigmoid.backward(dc2_tmp * @g)
  dg = @g_tanh.backward(dc2_tmp * @in)
  dforget = @forget_sigmoid.backward(dc2_tmp * @c)

  da = Xumo::SFloat.hstack([dforget, dg, din, dout])

  if @trainable
    @weight.grad += @x.transpose.dot(da)
    @recurrent_weight.grad += @h.transpose.dot(da)
    @bias.grad += da.sum(0) if @bias
  end
  dx = da.dot(@weight.data.transpose)
  dh = da.dot(@recurrent_weight.data.transpose)
  dc = dc2_tmp * @forget
  [dx, dh, dc]
end

#forward(x, h, c) ⇒ Object



238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# File 'lib/dnn/core/rnn_layers.rb', line 238

def forward(x, h, c)
  @x = x
  @h = h
  @c = c
  num_nodes = h.shape[1]
  a = x.dot(@weight.data) + h.dot(@recurrent_weight.data)
  a += @bias.data if @bias

  @forget = @forget_sigmoid.forward(a[true, 0...num_nodes])
  @g = @g_tanh.forward(a[true, num_nodes...(num_nodes * 2)])
  @in = @in_sigmoid.forward(a[true, (num_nodes * 2)...(num_nodes * 3)])
  @out = @out_sigmoid.forward(a[true, (num_nodes * 3)..-1])

  c2 = @forget * c + @g * @in
  @tanh_c2 = @tanh.forward(c2)
  h2 = @out * @tanh_c2
  [h2, c2]
end