Class: Informers::ZeroShotClassificationPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Constructor Details

#initialize(**options) ⇒ ZeroShotClassificationPipeline

Returns a new instance of ZeroShotClassificationPipeline.



407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# File 'lib/informers/pipelines.rb', line 407

def initialize(**options)
  super(**options)

  @label2id = @model.config[:label2id].transform_keys(&:downcase)

  @entailment_id = @label2id["entailment"]
  if @entailment_id.nil?
    warn "Could not find 'entailment' in label2id mapping. Using 2 as entailment_id."
    @entailment_id = 2
  end

  @contradiction_id = @label2id["contradiction"] || @label2id["not_entailment"]
  if @contradiction_id.nil?
    warn "Could not find 'contradiction' in label2id mapping. Using 0 as contradiction_id."
    @contradiction_id = 0
  end
end

Instance Method Details

#call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false) ⇒ Object



425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# File 'lib/informers/pipelines.rb', line 425

def call(texts, candidate_labels, hypothesis_template: "This example is {}.", multi_label: false)
  is_batched = texts.is_a?(Array)
  if !is_batched
    texts = [texts]
  end
  if !candidate_labels.is_a?(Array)
    candidate_labels = [candidate_labels]
  end

  # Insert labels into hypothesis template
  hypotheses = candidate_labels.map { |x| hypothesis_template.sub("{}", x) }

  # How to perform the softmax over the logits:
  #  - true:  softmax over the entailment vs. contradiction dim for each label independently
  #  - false: softmax the "entailment" logits over all candidate labels
  softmax_each = multi_label || candidate_labels.length == 1

  to_return = []
  texts.each do |premise|
    entails_logits = []

    hypotheses.each do |hypothesis|
      inputs = @tokenizer.(
        premise,
        text_pair: hypothesis,
        padding: true,
        truncation: true
      )
      outputs = @model.(inputs)

      if softmax_each
        entails_logits << [
          outputs.logits[0][@contradiction_id],
          outputs.logits[0][@entailment_id]
        ]
      else
        entails_logits << outputs.logits[0][@entailment_id]
      end
    end

    scores =
      if softmax_each
        entails_logits.map { |x| Utils.softmax(x)[1] }
      else
        Utils.softmax(entails_logits)
      end

    # Sort by scores (desc) and return scores with indices
    scores_sorted = scores.map.with_index { |x, i| [x, i] }.sort_by { |v| -v[0] }

    to_return << {
      sequence: premise,
      labels: scores_sorted.map { |x| candidate_labels[x[1]] },
      scores: scores_sorted.map { |x| x[0] }
    }
  end
  is_batched ? to_return : to_return[0]
end