47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
# File 'lib/informers/pipelines.rb', line 47
def call(texts, top_k: 1)
model_inputs = @tokenizer.(texts,
padding: true,
truncation: true
)
outputs = @model.(model_inputs)
function_to_apply =
if @model.config[:problem_type] == "multi_label_classification"
->(batch) { Utils.sigmoid(batch) }
else
->(batch) { Utils.softmax(batch) }
end
id2label = @model.config[:id2label]
to_return = []
outputs.logits.each do |batch|
output = function_to_apply.(batch)
scores = Utils.get_top_items(output, top_k)
vals = scores.map do |x|
{
label: id2label[x[0].to_s],
score: x[1]
}
end
if top_k == 1
to_return.concat(vals)
else
to_return << vals
end
end
texts.is_a?(Array) ? to_return : to_return[0]
end
|