Class: DNN::Layers::LSTM

Inherits:
RNN show all
Defined in:
lib/dnn/core/layers/rnn_layers.rb

Instance Attribute Summary collapse

Attributes inherited from RNN

#hidden, #num_units, #recurrent_weight, #recurrent_weight_initializer, #recurrent_weight_regularizer, #return_sequences, #stateful

Attributes inherited from Connection

#bias, #bias_initializer, #bias_regularizer, #weight, #weight_initializer, #weight_regularizer

Attributes inherited from TrainableLayer

#trainable

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods inherited from RNN

#compute_output_shape, #load_hash, #regularizers, #to_hash

Methods included from LayerNode

#forward

Methods inherited from Connection

#regularizers, #to_hash, #use_bias

Methods inherited from TrainableLayer

#clean

Methods inherited from Layer

#<<, #built?, #call, call, #clean, #compute_output_shape, #forward, from_hash, #load_hash, #to_hash

Constructor Details

#initialize(num_units, stateful: false, return_sequences: true, weight_initializer: Initializers::RandomNormal.new, recurrent_weight_initializer: Initializers::RandomNormal.new, bias_initializer: Initializers::Zeros.new, weight_regularizer: nil, recurrent_weight_regularizer: nil, bias_regularizer: nil, use_bias: true) ⇒ LSTM

Returns a new instance of LSTM.



292
293
294
295
296
297
298
299
300
301
302
303
304
# File 'lib/dnn/core/layers/rnn_layers.rb', line 292

def initialize(num_units,
               stateful: false,
               return_sequences: true,
               weight_initializer: Initializers::RandomNormal.new,
               recurrent_weight_initializer: Initializers::RandomNormal.new,
               bias_initializer: Initializers::Zeros.new,
               weight_regularizer: nil,
               recurrent_weight_regularizer: nil,
               bias_regularizer: nil,
               use_bias: true)
  super
  @cell = Param.new
end

Instance Attribute Details

#cellObject (readonly)

Returns the value of attribute cell.



290
291
292
# File 'lib/dnn/core/layers/rnn_layers.rb', line 290

def cell
  @cell
end

Instance Method Details

#backward_node(dh2s) ⇒ Object



342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# File 'lib/dnn/core/layers/rnn_layers.rb', line 342

def backward_node(dh2s)
  unless @return_sequences
    dh = dh2s
    dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
    dh2s[true, -1, false] = dh
  end
  dxs = Xumo::SFloat.zeros(@xs_shape)
  dh = 0
  dc = 0
  (dh2s.shape[1] - 1).downto(0) do |t|
    dh2 = dh2s[true, t, false]
    dx, dh, dc = @hidden_layers[t].backward(dh2 + dh, dc)
    dxs[true, t, false] = dx
  end
  dxs
end

#build(input_shape) ⇒ Object



306
307
308
309
310
311
312
313
# File 'lib/dnn/core/layers/rnn_layers.rb', line 306

def build(input_shape)
  super
  num_prev_units = input_shape[1]
  @weight.data = Xumo::SFloat.new(num_prev_units, @num_units * 4)
  @recurrent_weight.data = Xumo::SFloat.new(@num_units, @num_units * 4)
  @bias.data = Xumo::SFloat.new(@num_units * 4) if @bias
  init_weight_and_bias
end

#create_hidden_layerObject



315
316
317
# File 'lib/dnn/core/layers/rnn_layers.rb', line 315

def create_hidden_layer
  @hidden_layers = Array.new(@time_length) { LSTMCell.new(@weight, @recurrent_weight, @bias) }
end

#forward_node(xs) ⇒ Object



319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# File 'lib/dnn/core/layers/rnn_layers.rb', line 319

def forward_node(xs)
  create_hidden_layer
  @xs_shape = xs.shape
  hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_units)
  h = nil
  c = nil
  if @stateful
    h = @hidden.data if @hidden.data
    c = @cell.data if @cell.data
  end
  h ||= Xumo::SFloat.zeros(xs.shape[0], @num_units)
  c ||= Xumo::SFloat.zeros(xs.shape[0], @num_units)
  xs.shape[1].times do |t|
    x = xs[true, t, false]
    @hidden_layers[t].trainable = @trainable
    h, c = @hidden_layers[t].forward(x, h, c)
    hs[true, t, false] = h
  end
  @hidden.data = h
  @cell.data = c
  @return_sequences ? hs : h
end

#get_paramsObject



364
365
366
# File 'lib/dnn/core/layers/rnn_layers.rb', line 364

def get_params
  { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden, cell: @cell }
end

#reset_stateObject



359
360
361
362
# File 'lib/dnn/core/layers/rnn_layers.rb', line 359

def reset_state
  super()
  @cell.data = @cell.data.fill(0) if @cell.data
end