Class: Informers::FillMaskPipeline
- Defined in:
- lib/informers/pipelines.rb
Instance Method Summary collapse
Methods inherited from Pipeline
Constructor Details
This class inherits a constructor from Informers::Pipeline
Instance Method Details
#call(texts, top_k: 5) ⇒ Object
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
# File 'lib/informers/pipelines.rb', line 267 def call(texts, top_k: 5) model_inputs = @tokenizer.(texts, padding: true, truncation: true) outputs = @model.(model_inputs) to_return = [] model_inputs[:input_ids].each_with_index do |ids, i| mask_token_index = ids.index(@tokenizer.mask_token_id) if mask_token_index.nil? raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text." end logits = outputs.logits[i] item_logits = logits[mask_token_index] scores = Utils.get_top_items(Utils.softmax(item_logits), top_k) to_return << scores.map do |x| sequence = ids.dup sequence[mask_token_index] = x[0] { score: x[1], token: x[0], token_str: @tokenizer.id_to_token(x[0]), sequence: @tokenizer.decode(sequence, skip_special_tokens: true) } end end texts.is_a?(Array) ? to_return : to_return[0] end |