Class: Secryst::PositionalEncoding
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Secryst::PositionalEncoding
- Defined in:
- lib/secryst/transformer.rb
Instance Method Summary collapse
- #forward(x) ⇒ Object
-
#initialize(d_model, dropout: 0.1, max_len: 5000) ⇒ PositionalEncoding
constructor
PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence.
Constructor Details
#initialize(d_model, dropout: 0.1, max_len: 5000) ⇒ PositionalEncoding
PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
# File 'lib/secryst/transformer.rb', line 347 def initialize(d_model, dropout: 0.1, max_len: 5000) super() @dropout = Torch::NN::Dropout.new(p: dropout) pe = Torch.zeros(max_len, d_model) position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1) div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model)) sin = Torch.sin(position * div_term).t cos = Torch.cos(position * div_term).t pe.t! pe.each.with_index do |row, i| pe[i] = sin[i / 2] if i % 2 == 0 pe[i] = cos[(i-1)/2] if i % 2 != 0 end pe.t! pe = pe.unsqueeze(0).transpose(0, 1) register_buffer('pe', pe) end |
Instance Method Details
#forward(x) ⇒ Object
366 367 368 369 |
# File 'lib/secryst/transformer.rb', line 366 def forward(x) x = x + pe.narrow(0, 0, x.size(0)) return x end |