Class: Informers::ZeroShotClassificationPipeline
- Defined in:
- lib/informers/pipelines.rb
Instance Method Summary collapse
- #call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object
-
#initialize(**options) ⇒ ZeroShotClassificationPipeline
constructor
A new instance of ZeroShotClassificationPipeline.
Constructor Details
#initialize(**options) ⇒ ZeroShotClassificationPipeline
Returns a new instance of ZeroShotClassificationPipeline.
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 |
# File 'lib/informers/pipelines.rb', line 421 def initialize(**) super(**) @label2id = @model.config[:label2id].transform_keys(&:downcase) @entailment_id = @label2id["entailment"] if @entailment_id.nil? warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id." @entailment_id = 2 end @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"] if @contradiction_id.nil? warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id." @contradiction_id = 0 end end |
Instance Method Details
#call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 |
# File 'lib/informers/pipelines.rb', line 439 def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) is_batched = texts.is_a?(Array) if !is_batched texts = [texts] end if !candidate_labels.is_a?(Array) candidate_labels = [candidate_labels] end # Insert labels into hypothesis template hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) } # How to perform the softmax over the logits: # - true: softmax over the entailment vs. contradiction dim for each label independently # - false: softmax the "entailment" logits over all candidate labels softmax_each = multi_label || candidate_labels.length == 1 to_return = [] texts.each do |premise| entails_logits = [] hypotheses.each do |hypothesis| inputs = @tokenizer.( premise, text_pair: hypothesis, padding: true, truncation: true ) outputs = @model.(inputs) if softmax_each entails_logits << [ outputs.logits[0][@contradiction_id], outputs.logits[0][@entailment_id] ] else entails_logits << outputs.logits[0][@entailment_id] end end scores = if softmax_each entails_logits.map { |x| Utils.softmax(x)[1] } else Utils.softmax(entails_logits) end # Sort by scores (desc) and return scores with indices scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] } to_return << { sequence: premise, labels: scores_sorted.map { |x| candidate_labels[x[1]] }, scores: scores_sorted.map { |x| x[0] } } end is_batched ? to_return : to_return[0] end |