Method: Secryst::TransformerDecoderLayer#forward

Defined in:
lib/secryst/transformer.rb

#forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) ⇒ Object

Pass the inputs (and mask) through the decoder layer. Args:

tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).

Shape:

see the docs in Transformer class.


327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# File 'lib/secryst/transformer.rb', line 327

def forward(tgt, memory, tgt_mask: nil, memory_mask: nil,
          tgt_key_padding_mask: nil, memory_key_padding_mask: nil)

  tgt2 = @self_attn.call(tgt, tgt, tgt, attn_mask: tgt_mask,
                        key_padding_mask: tgt_key_padding_mask)[0]
  tgt = tgt + @dropout1.call(tgt2)
  tgt = @norm1.call(tgt)
  tgt2 = @multihead_attn.call(tgt, memory, memory, attn_mask: memory_mask,
                             key_padding_mask: memory_key_padding_mask)[0]
  tgt = tgt + @dropout2.call(tgt2)
  tgt = @norm2.call(tgt)
  tgt2 = @linear2.call(@dropout.call(@activation.call(@linear1.call(tgt))))
  tgt = tgt + @dropout3.call(tgt2)
  tgt = @norm3.call(tgt)
  return tgt
end