Class: Informers::TextGenerationPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#initialize

Constructor Details

This class inherits a constructor from Informers::Pipeline

Instance Method Details

#call(texts, **generate_kwargs) ⇒ Object



365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
# File 'lib/informers/pipelines.rb', line 365

def call(texts, **generate_kwargs)
  is_batched = false
  is_chat_input = false

  # Normalize inputs
  if texts.is_a?(String)
    texts = [texts]
    inputs = texts
  else
    raise Todo
  end

  # By default, do not add special tokens
  add_special_tokens = generate_kwargs[:add_special_tokens] || false

  # /By default, return full text
  return_full_text =
    if is_chat_input
      false
    else
      generate_kwargs[:return_full_text] || true
    end

  @tokenizer.padding_side = "left"
  input_ids, attention_mask =
    @tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
      .values_at(:input_ids, :attention_mask)

  output_token_ids =
    @model.generate(
      input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
    )

  decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)

  if !return_full_text && Utils.dims(input_ids)[-1] > 0
    prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
  end

  to_return = Array.new(texts.length) { [] }
  decoded.length.times do |i|
    text_index = (i / output_token_ids.length.to_i * texts.length).floor

    if prompt_lengths
      raise Todo
    end
    # TODO is_chat_input
    to_return[text_index] << {
      generated_text: decoded[i]
    }
  end
  !is_batched && to_return.length == 1 ? to_return[0] : to_return
end