Class: Informers::FillMaskPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#initialize

Constructor Details

This class inherits a constructor from Informers::Pipeline

Instance Method Details

#call(texts, top_k: 5) ⇒ Object



267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# File 'lib/informers/pipelines.rb', line 267

def call(texts, top_k: 5)
  model_inputs = @tokenizer.(texts, padding: true, truncation: true)
  outputs = @model.(model_inputs)

  to_return = []
  model_inputs[:input_ids].each_with_index do |ids, i|
    mask_token_index = ids.index(@tokenizer.mask_token_id)

    if mask_token_index.nil?
      raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
    end
    logits = outputs.logits[i]
    item_logits = logits[mask_token_index]

    scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)

    to_return <<
      scores.map do |x|
        sequence = ids.dup
        sequence[mask_token_index] = x[0]

        {
          score: x[1],
          token: x[0],
          token_str: @tokenizer.id_to_token(x[0]),
          sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
        }
      end
  end
  texts.is_a?(Array) ? to_return : to_return[0]
end