Class: Secryst::TransformerDecoder
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Secryst::TransformerDecoder
- Defined in:
- lib/secryst/transformer.rb
Instance Method Summary collapse
-
#forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) ⇒ Object
Pass the inputs (and mask) through the decoder layer in turn.
-
#initialize(decoder_layers, norm = nil, d_model, vocab_size, dropout) ⇒ TransformerDecoder
constructor
TransformerDecoder is a stack of N decoder layers Args: decoder_layers: an array of instances of the TransformerDecoderLayer class (required).
Constructor Details
#initialize(decoder_layers, norm = nil, d_model, vocab_size, dropout) ⇒ TransformerDecoder
TransformerDecoder is a stack of N decoder layers Args:
decoder_layers: an array of instances of the TransformerDecoderLayer class (required).
norm: the layer normalization component (optional).
d_model: the number of expected features in the encoder/decoder inputs.
vocab_size: size of vocabulary (number of different possible tokens).
- Examples
-
>>> decoder_layers = 6.times.map {|i| TransformerDecoderLayer.new(512, 8) } >>> transformer_decoder = TransformerDecoder.new(encoder_layers, nil, 512, 72, 0.1) >>> memory = Torch.rand(10, 32, 512) >>> tgt = Torch.rand(20, 32, 512) >>> out = transformer_decoder.call(tgt, memory)
234 235 236 237 238 239 240 241 242 243 244 245 |
# File 'lib/secryst/transformer.rb', line 234 def initialize(decoder_layers, norm=nil, d_model, vocab_size, dropout) super() @d_model = d_model decoder_layers.each.with_index do |l, i| instance_variable_set("@layer#{i}", l) end @layers = decoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") } @num_layers = decoder_layers.length @embedding = Torch::NN::Embedding.new(vocab_size, d_model) @pos_encoder = PositionalEncoding.new(d_model, dropout: dropout) @norm = norm end |
Instance Method Details
#forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) ⇒ Object
Pass the inputs (and mask) through the decoder layer in turn. Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
# File 'lib/secryst/transformer.rb', line 257 def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) output = @embedding.call(tgt) * Math.sqrt(@d_model) output = @pos_encoder.call(output) @layers.each { |mod| output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask) } if @norm output = @norm.call(output) end return output end |