Class: Informers::ZeroShotClassificationPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Constructor Details

#initialize(**options) ⇒ ZeroShotClassificationPipeline

Returns a new instance of ZeroShotClassificationPipeline.



421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# File 'lib/informers/pipelines.rb', line 421

def initialize(**options)
  super(**options)

  @label2id = @model.config[:label2id].transform_keys(&:downcase)

  @entailment_id = @label2id["entailment"]
  if @entailment_id.nil?
    warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
    @entailment_id = 2
  end

  @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
  if @contradiction_id.nil?
    warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
    @contradiction_id = 0
  end
end

Instance Method Details

#call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object



439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
# File 'lib/informers/pipelines.rb', line 439

def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
  is_batched = texts.is_a?(Array)
  if !is_batched
    texts = [texts]
  end
  if !candidate_labels.is_a?(Array)
    candidate_labels = [candidate_labels]
  end

  # Insert labels into hypothesis template
  hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }

  # How to perform the softmax over the logits:
  #  - true:  softmax over the entailment vs. contradiction dim for each label independently
  #  - false: softmax the "entailment" logits over all candidate labels
  softmax_each = multi_label || candidate_labels.length == 1

  to_return = []
  texts.each do |premise|
    entails_logits = []

    hypotheses.each do |hypothesis|
      inputs = @tokenizer.(
        premise,
        text_pair: hypothesis,
        padding: true,
        truncation: true
      )
      outputs = @model.(inputs)

      if softmax_each
        entails_logits << [
          outputs.logits[0][@contradiction_id],
          outputs.logits[0][@entailment_id]
        ]
      else
        entails_logits << outputs.logits[0][@entailment_id]
      end
    end

    scores =
      if softmax_each
        entails_logits.map { |x| Utils.softmax(x)[1] }
      else
        Utils.softmax(entails_logits)
      end

    # Sort by scores (desc) and return scores with indices
    scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }

    to_return << {
      sequence: premise,
      labels: scores_sorted.map { |x| candidate_labels[x[1]] },
      scores: scores_sorted.map { |x| x[0] }
    }
  end
  is_batched ? to_return : to_return[0]
end