Class: Informers::Text2TextGenerationPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Direct Known Subclasses

SummarizationPipeline, TranslationPipeline

Constant Summary collapse

KEY =
:generated_text

Instance Method Summary collapse

Methods inherited from Pipeline

#initialize

Constructor Details

This class inherits a constructor from Informers::Pipeline

Instance Method Details

#call(texts, **generate_kwargs) ⇒ Object



303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
# File 'lib/informers/pipelines.rb', line 303

def call(texts, **generate_kwargs)
  if !texts.is_a?(Array)
    texts = [texts]
  end

  # Add global prefix, if present
  if @model.config[:prefix]
    texts = texts.map { |x| @model.config[:prefix] + x }
  end

  # Handle task specific params:
  task_specific_params = @model.config[:task_specific_params]
  if task_specific_params && task_specific_params[@task]
    # Add prefixes, if present
    if task_specific_params[@task]["prefix"]
      texts = texts.map { |x| task_specific_params[@task]["prefix"] + x }
    end

    # TODO update generation config
  end

  tokenizer = @tokenizer
  tokenizer_options = {
    padding: true,
    truncation: true
  }
  if is_a?(TranslationPipeline) && tokenizer.respond_to?(:_build_translation_inputs)
    input_ids = tokenizer._build_translation_inputs(texts, tokenizer_options, generate_kwargs)[:input_ids]
  else
    input_ids = tokenizer.(texts, **tokenizer_options)[:input_ids]
  end

  output_token_ids = @model.generate(input_ids, generate_kwargs)

  tokenizer.batch_decode(output_token_ids, skip_special_tokens: true)
    .map { |text| {self.class.const_get(:KEY) => text} }
end