Class: Secryst::TransformerDecoder

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/secryst/transformer.rb

Instance Method Summary collapse

Constructor Details

#initialize(decoder_layers, norm = nil, d_model, vocab_size, dropout) ⇒ TransformerDecoder

TransformerDecoder is a stack of N decoder layers Args:

decoder_layers: an array of instances of the TransformerDecoderLayer class (required).
norm: the layer normalization component (optional).
d_model: the number of expected features in the encoder/decoder inputs.
vocab_size: size of vocabulary (number of different possible tokens).
Examples

>>> decoder_layers = 6.times.map {|i| TransformerDecoderLayer.new(512, 8) } >>> transformer_decoder = TransformerDecoder.new(encoder_layers, nil, 512, 72, 0.1) >>> memory = Torch.rand(10, 32, 512) >>> tgt = Torch.rand(20, 32, 512) >>> out = transformer_decoder.call(tgt, memory)



234
235
236
237
238
239
240
241
242
243
244
245
# File 'lib/secryst/transformer.rb', line 234

def initialize(decoder_layers, norm=nil, d_model, vocab_size, dropout)
  super()
  @d_model = d_model
  decoder_layers.each.with_index do |l, i|
    instance_variable_set("@layer#{i}", l)
  end
  @layers = decoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
  @num_layers = decoder_layers.length
  @embedding = Torch::NN::Embedding.new(vocab_size, d_model)
  @pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
  @norm = norm
end

Instance Method Details

#forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) ⇒ Object

Pass the inputs (and mask) through the decoder layer in turn. Args:

tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).

Shape:

see the docs in Transformer class.


257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# File 'lib/secryst/transformer.rb', line 257

def forward(tgt, memory, tgt_mask: nil,
          memory_mask: nil, tgt_key_padding_mask: nil,
          memory_key_padding_mask: nil)

  output = @embedding.call(tgt) * Math.sqrt(@d_model)
  output = @pos_encoder.call(output)

  @layers.each { |mod|
    output = mod.call(output, memory, tgt_mask: tgt_mask,
                   memory_mask: memory_mask,
                   tgt_key_padding_mask: tgt_key_padding_mask,
                   memory_key_padding_mask: memory_key_padding_mask)
  }

  if @norm
    output = @norm.call(output)
  end

  return output
end