Class: Informers::ZeroShotClassificationPipeline
- Defined in:
- lib/informers/pipelines.rb
Instance Method Summary collapse
- #call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object
-
#initialize(**options) ⇒ ZeroShotClassificationPipeline
constructor
A new instance of ZeroShotClassificationPipeline.
Constructor Details
#initialize(**options) ⇒ ZeroShotClassificationPipeline
Returns a new instance of ZeroShotClassificationPipeline.
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 |
# File 'lib/informers/pipelines.rb', line 407 def initialize(**) super(**) @label2id = @model.config[:label2id].transform_keys(&:downcase) @entailment_id = @label2id["entailment"] if @entailment_id.nil? warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id." @entailment_id = 2 end @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"] if @contradiction_id.nil? warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id." @contradiction_id = 0 end end |
Instance Method Details
#call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 |
# File 'lib/informers/pipelines.rb', line 425 def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) is_batched = texts.is_a?(Array) if !is_batched texts = [texts] end if !candidate_labels.is_a?(Array) candidate_labels = [candidate_labels] end # Insert labels into hypothesis template hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) } # How to perform the softmax over the logits: # - true: softmax over the entailment vs. contradiction dim for each label independently # - false: softmax the "entailment" logits over all candidate labels softmax_each = multi_label || candidate_labels.length == 1 to_return = [] texts.each do |premise| entails_logits = [] hypotheses.each do |hypothesis| inputs = @tokenizer.( premise, text_pair: hypothesis, padding: true, truncation: true ) outputs = @model.(inputs) if softmax_each entails_logits << [ outputs.logits[0][@contradiction_id], outputs.logits[0][@entailment_id] ] else entails_logits << outputs.logits[0][@entailment_id] end end scores = if softmax_each entails_logits.map { |x| Utils.softmax(x)[1] } else Utils.softmax(entails_logits) end # Sort by scores (desc) and return scores with indices scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] } to_return << { sequence: premise, labels: scores_sorted.map { |x| candidate_labels[x[1]] }, scores: scores_sorted.map { |x| x[0] } } end is_batched ? to_return : to_return[0] end |