Class: Secryst::Translator

Inherits:
Object
  • Object
show all
Defined in:
lib/secryst/translator.rb

Instance Method Summary collapse

Constructor Details

#initialize(model:, vocabs_dir:, hyperparameters:, model_file:) ⇒ Translator

Returns a new instance of Translator.



3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/secryst/translator.rb', line 3

def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
  @device = "cpu"
  @vocabs_dir = vocabs_dir

  load_vocabs

  if model == 'transformer'
    @model = Secryst::Transformer.new(hyperparameters.merge({
      input_vocab_size: @input_vocab.length,
      target_vocab_size: @target_vocab.length,
    }))
  else
    raise ArgumentError, 'Only transformer model is currently supported'
  end

  @model.load_state_dict(Torch.load(model_file))
  @model.eval
end

Instance Method Details

#translate(phrase, max_seq_length: 100) ⇒ Object



22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/secryst/translator.rb', line 22

def translate(phrase, max_seq_length: 100)
  input = ['<sos>'] + phrase.chars + ['<eos>']
  input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
  output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
  src_key_padding_mask = input.t.eq(1)

  max_seq_length.times do |i|
    tgt_key_padding_mask = output.t.eq(1)
    tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
    opts = {
      tgt_mask: tgt_mask,
      src_key_padding_mask: src_key_padding_mask,
      tgt_key_padding_mask: tgt_key_padding_mask,
      memory_key_padding_mask: src_key_padding_mask,
    }
    prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
    break if @target_vocab.itos[prediction[i]] == '<eos>'
    output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
  end

  puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
end