Class: Transformers::RerankingPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/transformers/pipelines/reranking.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#check_model_type, #get_iterator, #initialize, #torch_dtype

Constructor Details

This class inherits a constructor from Transformers::Pipeline

Instance Method Details

#_forward(model_inputs) ⇒ Object



15
16
17
18
# File 'lib/transformers/pipelines/reranking.rb', line 15

def _forward(model_inputs)
  model_outputs = @model.(**model_inputs)
  model_outputs
end

#_sanitize_parameters(**kwargs) ⇒ Object



3
4
5
# File 'lib/transformers/pipelines/reranking.rb', line 3

def _sanitize_parameters(**kwargs)
  [{}, {}, kwargs]
end

#call(query, documents) ⇒ Object



20
21
22
# File 'lib/transformers/pipelines/reranking.rb', line 20

def call(query, documents)
  super({query: query, documents: documents})
end

#postprocess(model_outputs) ⇒ Object



24
25
26
27
28
29
30
31
# File 'lib/transformers/pipelines/reranking.rb', line 24

def postprocess(model_outputs)
   model_outputs[0]
    .sigmoid
    .squeeze
    .to_a
    .map.with_index { |s, i| {index: i, score: s} }
    .sort_by { |v| -v[:score] }
end

#preprocess(inputs) ⇒ Object



7
8
9
10
11
12
13
# File 'lib/transformers/pipelines/reranking.rb', line 7

def preprocess(inputs)
  @tokenizer.(
    [inputs[:query]] * inputs[:documents].length,
    text_pair: inputs[:documents],
    return_tensors: @framework
  )
end