Class: Secryst::Translator
- Inherits:
-
Object
- Object
- Secryst::Translator
- Defined in:
- lib/secryst/translator.rb
Instance Method Summary collapse
-
#initialize(model:, vocabs_dir:, hyperparameters:, model_file:) ⇒ Translator
constructor
A new instance of Translator.
- #translate(phrase, max_seq_length: 100) ⇒ Object
Constructor Details
#initialize(model:, vocabs_dir:, hyperparameters:, model_file:) ⇒ Translator
Returns a new instance of Translator.
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# File 'lib/secryst/translator.rb', line 3 def initialize(model:, vocabs_dir:, hyperparameters:, model_file:) @device = "cpu" @vocabs_dir = vocabs_dir load_vocabs if model == 'transformer' @model = Secryst::Transformer.new(hyperparameters.merge({ input_vocab_size: @input_vocab.length, target_vocab_size: @target_vocab.length, })) else raise ArgumentError, 'Only transformer model is currently supported' end @model.load_state_dict(Torch.load(model_file)) @model.eval end |
Instance Method Details
#translate(phrase, max_seq_length: 100) ⇒ Object
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# File 'lib/secryst/translator.rb', line 22 def translate(phrase, max_seq_length: 100) input = ['<sos>'] + phrase.chars + ['<eos>'] input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t output = Torch.tensor([[@target_vocab.stoi['<sos>']]]) src_key_padding_mask = input.t.eq(1) max_seq_length.times do |i| tgt_key_padding_mask = output.t.eq(1) tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1) opts = { tgt_mask: tgt_mask, src_key_padding_mask: src_key_padding_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: src_key_padding_mask, } prediction = @model.call(input, output, opts).map {|i| i.argmax.item } break if @target_vocab.itos[prediction[i]] == '<eos>' output = Torch.cat([output, Torch.tensor([[prediction[i]]])]) end puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}" end |