Class: Secryst::PositionalEncoding

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/secryst/transformer.rb

Instance Method Summary collapse

Constructor Details

#initialize(d_model, dropout: 0.1, max_len: 5000) ⇒ PositionalEncoding

PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.



347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# File 'lib/secryst/transformer.rb', line 347

def initialize(d_model, dropout: 0.1, max_len: 5000)
  super()
  @dropout = Torch::NN::Dropout.new(p: dropout)

  pe = Torch.zeros(max_len, d_model)
  position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1)
  div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model))
  sin = Torch.sin(position * div_term).t
  cos = Torch.cos(position * div_term).t
  pe.t!
  pe.each.with_index do |row, i|
    pe[i] = sin[i / 2] if i % 2 == 0
    pe[i] = cos[(i-1)/2] if i % 2 != 0
  end
  pe.t!
  pe = pe.unsqueeze(0).transpose(0, 1)
  register_buffer('pe', pe)
end

Instance Method Details

#forward(x) ⇒ Object



366
367
368
369
# File 'lib/secryst/transformer.rb', line 366

def forward(x)
  x = x + pe.narrow(0, 0, x.size(0))
  return x
end