Class: Gliner::Tasks::Classification
Instance Attribute Summary
Attributes inherited from Gliner::Task
#config_parser, #inference, #input_builder
Instance Method Summary
collapse
#normalize_text?
Constructor Details
#initialize(config_parser:, inference:, input_builder:, classifier:) ⇒ Classification
Returns a new instance of Classification.
6
7
8
9
10
|
# File 'lib/gliner/tasks/classification.rb', line 6
def initialize(config_parser:, inference:, input_builder:, classifier:)
super(config_parser: config_parser, inference: inference, input_builder: input_builder)
@classifier = classifier
end
|
Instance Method Details
#build_prompt(parsed) ⇒ Object
28
29
30
|
# File 'lib/gliner/tasks/classification.rb', line 28
def build_prompt(parsed)
config_parser.build_prompt(parsed[:name], parsed[:label_descs])
end
|
#execute_all(pipeline, text, tasks_config, **options) ⇒ Object
55
56
57
58
59
60
61
62
|
# File 'lib/gliner/tasks/classification.rb', line 55
def execute_all(pipeline, text, tasks_config, **options)
raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
tasks_config.each_with_object({}) do |(task_name, task_config), results|
parsed_config = { name: task_name, config: task_config }
results[task_name.to_s] = pipeline.execute(self, text, parsed_config, **options)
end
end
|
#label_prefix ⇒ Object
24
25
26
|
# File 'lib/gliner/tasks/classification.rb', line 24
def label_prefix
'[L]'
end
|
#labels(parsed) ⇒ Object
32
33
34
|
# File 'lib/gliner/tasks/classification.rb', line 32
def labels(parsed)
parsed[:labels]
end
|
#needs_cls_logits? ⇒ Boolean
36
37
38
|
# File 'lib/gliner/tasks/classification.rb', line 36
def needs_cls_logits?
inference.has_cls_logits
end
|
#parse_config(input) ⇒ Object
12
13
14
15
16
17
18
|
# File 'lib/gliner/tasks/classification.rb', line 12
def parse_config(input)
raise Error, 'classification config must be a Hash' unless input.is_a?(Hash)
name, task_config = (input)
parsed = config_parser.parse_classification_task(name, task_config)
parsed.merge(name: name.to_s)
end
|
#process_output(logits, parsed, prepared, options) ⇒ Object
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
# File 'lib/gliner/tasks/classification.rb', line 40
def process_output(logits, parsed, prepared, options)
include_confidence = options.fetch(:include_confidence, false)
threshold_override = options[:threshold]
cls_threshold = threshold_override.nil? ? parsed[:cls_threshold] : threshold_override
scores = classification_scores(logits, parsed, prepared, options)
@classifier.format_classification(
scores,
labels: parsed[:labels],
multi_label: parsed[:multi_label],
include_confidence: include_confidence,
cls_threshold: cls_threshold
)
end
|
#task_type ⇒ Object
20
21
22
|
# File 'lib/gliner/tasks/classification.rb', line 20
def task_type
Inference::TASK_TYPE_CLASSIFICATION
end
|