Class: Secryst::TransformerEncoder

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/secryst/transformer.rb

Instance Method Summary collapse

Constructor Details

#initialize(encoder_layers, norm = nil, d_model, vocab_size, dropout) ⇒ TransformerEncoder

TransformerEncoder is a stack of N encoder layers Args:

encoder_layers: an array of instances of the TransformerEncoderLayer class (required).
norm: the layer normalization component (optional).
d_model: the number of expected features in the encoder/decoder inputs.
vocab_size: size of vocabulary (number of different possible tokens).
Examples

>>> encoder_layers = 6.times.map {|i| TransformerEncoderLayer.new(512, 8) } >>> transformer_encoder = nn.TransformerEncoder(encoder_layers, nil, 512, 72, 0.1) >>> src = Torch.rand(10, 32, 512) >>> out = transformer_encoder.call(src)



131
132
133
134
135
136
137
138
139
140
141
142
# File 'lib/secryst/transformer.rb', line 131

def initialize(encoder_layers, norm=nil, d_model, vocab_size, dropout)
  super()
  @d_model = d_model
  encoder_layers.each.with_index do |l, i|
    instance_variable_set("@layer#{i}", l)
  end
  @layers = encoder_layers.length.times.map {|i| instance_variable_get("@layer#{i}") }
  @num_layers = encoder_layers.length
  @embedding = Torch::NN::Embedding.new(vocab_size, d_model)
  @pos_encoder = PositionalEncoding.new(d_model, dropout: dropout)
  @norm = norm
end

Instance Method Details

#forward(src, mask: nil, src_key_padding_mask: nil) ⇒ Object

Pass the input through the encoder layers in turn. Args:

src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).

Shape:

see the docs in Transformer class.


151
152
153
154
155
156
157
158
159
160
161
162
163
164
# File 'lib/secryst/transformer.rb', line 151

def forward(src, mask: nil, src_key_padding_mask: nil)
  output = @embedding.call(src) * Math.sqrt(@d_model)
  output = @pos_encoder.call(output)

  @layers.each { |mod|
    output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
  }

  if @norm
    output = @norm.call(output)
  end

  return output
end