Class: Informers::TextGenerationPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#initialize

Constructor Details

This class inherits a constructor from Informers::Pipeline

Instance Method Details

#call(texts, **generate_kwargs) ⇒ Object



351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# File 'lib/informers/pipelines.rb', line 351

def call(texts, **generate_kwargs)
  is_batched = false
  is_chat_input = false

  # Normalize inputs
  if texts.is_a?(String)
    texts = [texts]
    inputs = texts
  else
    raise Todo
  end

  # By default, do not add special tokens
  add_special_tokens = generate_kwargs[:add_special_tokens] || false

  # /By default, return full text
  return_full_text =
    if is_chat_input
      false
    else
      generate_kwargs[:return_full_text] || true
    end

  @tokenizer.padding_side = "left"
  input_ids, attention_mask =
    @tokenizer.(inputs, add_special_tokens:, padding: true, truncation: true)
      .values_at(:input_ids, :attention_mask)

  output_token_ids =
    @model.generate(
      input_ids, generate_kwargs, nil, inputs_attention_mask: attention_mask
    )

  decoded = @tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)

  if !return_full_text && Utils.dims(input_ids)[-1] > 0
    prompt_lengths = @tokenizer.batch_decode(input_ids, skip_special_tokens: true).map { |x| x.length }
  end

  to_return = Array.new(texts.length) { [] }
  decoded.length.times do |i|
    text_index = (i / output_token_ids.length.to_i * texts.length).floor

    if prompt_lengths
      raise Todo
    end
    # TODO is_chat_input
    to_return[text_index] << {
      generated_text: decoded[i]
    }
  end
  !is_batched && to_return.length == 1 ? to_return[0] : to_return
end