Class: Gliner::Tasks::Classification

Inherits:
Gliner::Task show all
Defined in:
lib/gliner/tasks/classification.rb

Instance Attribute Summary

Attributes inherited from Gliner::Task

#config_parser, #inference, #input_builder

Instance Method Summary collapse

Methods inherited from Gliner::Task

#normalize_text?

Constructor Details

#initialize(config_parser:, inference:, input_builder:, classifier:) ⇒ Classification

Returns a new instance of Classification.



6
7
8
9
10
# File 'lib/gliner/tasks/classification.rb', line 6

def initialize(config_parser:, inference:, input_builder:, classifier:)
  super(config_parser: config_parser, inference: inference, input_builder: input_builder)

  @classifier = classifier
end

Instance Method Details

#build_prompt(parsed) ⇒ Object



28
29
30
# File 'lib/gliner/tasks/classification.rb', line 28

def build_prompt(parsed)
  config_parser.build_prompt(parsed[:name], parsed[:label_descs])
end

#execute_all(pipeline, text, tasks_config, **options) ⇒ Object

Raises:



55
56
57
58
59
60
61
62
# File 'lib/gliner/tasks/classification.rb', line 55

def execute_all(pipeline, text, tasks_config, **options)
  raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)

  tasks_config.each_with_object({}) do |(task_name, task_config), results|
    parsed_config = { name: task_name, config: task_config }
    results[task_name.to_s] = pipeline.execute(self, text, parsed_config, **options)
  end
end

#label_prefixObject



24
25
26
# File 'lib/gliner/tasks/classification.rb', line 24

def label_prefix
  '[L]'
end

#labels(parsed) ⇒ Object



32
33
34
# File 'lib/gliner/tasks/classification.rb', line 32

def labels(parsed)
  parsed[:labels]
end

#needs_cls_logits?Boolean

Returns:

  • (Boolean)


36
37
38
# File 'lib/gliner/tasks/classification.rb', line 36

def needs_cls_logits?
  inference.has_cls_logits
end

#parse_config(input) ⇒ Object

Raises:



12
13
14
15
16
17
18
# File 'lib/gliner/tasks/classification.rb', line 12

def parse_config(input)
  raise Error, 'classification config must be a Hash' unless input.is_a?(Hash)

  name, task_config = extract_task_config(input)
  parsed = config_parser.parse_classification_task(name, task_config)
  parsed.merge(name: name.to_s)
end

#process_output(logits, parsed, prepared, options) ⇒ Object



40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/gliner/tasks/classification.rb', line 40

def process_output(logits, parsed, prepared, options)
  include_confidence = options.fetch(:include_confidence, false)
  threshold_override = options[:threshold]
  cls_threshold = threshold_override.nil? ? parsed[:cls_threshold] : threshold_override

  scores = classification_scores(logits, parsed, prepared, options)
  @classifier.format_classification(
    scores,
    labels: parsed[:labels],
    multi_label: parsed[:multi_label],
    include_confidence: include_confidence,
    cls_threshold: cls_threshold
  )
end

#task_typeObject



20
21
22
# File 'lib/gliner/tasks/classification.rb', line 20

def task_type
  Inference::TASK_TYPE_CLASSIFICATION
end