Class: DNN::Layers::RNN
- Inherits:
-
HasParamLayer
- Object
- Layer
- HasParamLayer
- DNN::Layers::RNN
- Includes:
- Activations, Initializers
- Defined in:
- lib/dnn/core/rnn_layers.rb
Overview
Super class of all RNN classes.
Instance Attribute Summary collapse
-
#h ⇒ Object
Returns the value of attribute h.
-
#num_nodes ⇒ Object
readonly
Returns the value of attribute num_nodes.
-
#stateful ⇒ Object
readonly
Returns the value of attribute stateful.
-
#weight_decay ⇒ Object
readonly
Returns the value of attribute weight_decay.
Attributes inherited from HasParamLayer
Instance Method Summary collapse
- #backward(dh2s) ⇒ Object
- #forward(xs) ⇒ Object
- #init_params ⇒ Object
-
#initialize(num_nodes, stateful: false, return_sequences: true, weight_initializer: nil, bias_initializer: nil, weight_decay: 0) ⇒ RNN
constructor
A new instance of RNN.
- #ridge ⇒ Object
- #shape ⇒ Object
- #to_hash(merge_hash = nil) ⇒ Object
Methods inherited from HasParamLayer
Methods inherited from Layer
Constructor Details
#initialize(num_nodes, stateful: false, return_sequences: true, weight_initializer: nil, bias_initializer: nil, weight_decay: 0) ⇒ RNN
Returns a new instance of RNN.
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# File 'lib/dnn/core/rnn_layers.rb', line 14 def initialize(num_nodes, stateful: false, return_sequences: true, weight_initializer: nil, bias_initializer: nil, weight_decay: 0) super() @num_nodes = num_nodes @stateful = stateful @return_sequences = return_sequences @weight_initializer = (weight_initializer || RandomNormal.new) @bias_initializer = (bias_initializer || Zeros.new) @weight_decay = weight_decay @layers = [] @h = nil end |
Instance Attribute Details
#h ⇒ Object
Returns the value of attribute h.
9 10 11 |
# File 'lib/dnn/core/rnn_layers.rb', line 9 def h @h end |
#num_nodes ⇒ Object (readonly)
Returns the value of attribute num_nodes.
10 11 12 |
# File 'lib/dnn/core/rnn_layers.rb', line 10 def num_nodes @num_nodes end |
#stateful ⇒ Object (readonly)
Returns the value of attribute stateful.
11 12 13 |
# File 'lib/dnn/core/rnn_layers.rb', line 11 def stateful @stateful end |
#weight_decay ⇒ Object (readonly)
Returns the value of attribute weight_decay.
12 13 14 |
# File 'lib/dnn/core/rnn_layers.rb', line 12 def weight_decay @weight_decay end |
Instance Method Details
#backward(dh2s) ⇒ Object
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# File 'lib/dnn/core/rnn_layers.rb', line 44 def backward(dh2s) @grads[:weight] = Xumo::SFloat.zeros(*@params[:weight].shape) @grads[:weight2] = Xumo::SFloat.zeros(*@params[:weight2].shape) @grads[:bias] = Xumo::SFloat.zeros(*@params[:bias].shape) unless @return_sequences dh = dh2s dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1]) dh2s[true, -1, false] = dh end dxs = Xumo::SFloat.zeros(@xs_shape) dh = 0 (0...dh2s.shape[1]).to_a.reverse.each do |t| dh2 = dh2s[true, t, false] dx, dh = @layers[t].backward(dh2 + dh) dxs[true, t, false] = dx end dxs end |
#forward(xs) ⇒ Object
31 32 33 34 35 36 37 38 39 40 41 42 |
# File 'lib/dnn/core/rnn_layers.rb', line 31 def forward(xs) @xs_shape = xs.shape hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_nodes) h = (@stateful && @h) ? @h : Xumo::SFloat.zeros(xs.shape[0], @num_nodes) xs.shape[1].times do |t| x = xs[true, t, false] h = @layers[t].forward(x, h) hs[true, t, false] = h end @h = h @return_sequences ? hs : h end |
#init_params ⇒ Object
89 90 91 |
# File 'lib/dnn/core/rnn_layers.rb', line 89 def init_params @time_length = prev_layer.shape[0] end |
#ridge ⇒ Object
81 82 83 84 85 86 87 |
# File 'lib/dnn/core/rnn_layers.rb', line 81 def ridge if @weight_decay > 0 0.5 * (@weight_decay * ((@params[:weight]**2).sum + (@params[:weight2]**2).sum)) else 0 end end |
#shape ⇒ Object
77 78 79 |
# File 'lib/dnn/core/rnn_layers.rb', line 77 def shape @return_sequences ? [@time_length, @num_nodes] : [@num_nodes] end |
#to_hash(merge_hash = nil) ⇒ Object
63 64 65 66 67 68 69 70 71 72 73 74 75 |
# File 'lib/dnn/core/rnn_layers.rb', line 63 def to_hash(merge_hash = nil) hash = { class: self.class.name, num_nodes: @num_nodes, stateful: @stateful, return_sequences: @return_sequences, weight_initializer: @weight_initializer.to_hash, bias_initializer: @bias_initializer.to_hash, weight_decay: @weight_decay, } hash.merge!(merge_hash) if merge_hash hash end |