Class: Gliner::Runners::PreparedTask
- Inherits:
-
Object
- Object
- Gliner::Runners::PreparedTask
- Defined in:
- lib/gliner/runners/prepared_task.rb
Instance Method Summary collapse
- #call(text, **options) ⇒ Object
-
#initialize(task, parsed) ⇒ PreparedTask
constructor
A new instance of PreparedTask.
Constructor Details
#initialize(task, parsed) ⇒ PreparedTask
Returns a new instance of PreparedTask.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# File 'lib/gliner/runners/prepared_task.rb', line 6 def initialize(task, parsed) @task = task @parsed = parsed @labels = task.labels(parsed) @schema_tokens = task.input_builder.schema_tokens_for( prompt: task.build_prompt(parsed), labels: @labels, label_prefix: task.label_prefix ) @label_mask = Array.new(@labels.length, 1) @label_positions_template = precompute_label_positions end |
Instance Method Details
#call(text, **options) ⇒ Object
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# File 'lib/gliner/runners/prepared_task.rb', line 21 def call(text, **) prepared = @task.input_builder.prepare(text, @schema_tokens) label_positions = @label_positions_template if label_positions.any? { |pos| pos.nil? || pos >= prepared.input_ids.length } label_positions = @task.inference.label_positions_for(prepared.word_ids, @labels.length) end logits = @task.inference.run( Inference::Request.new( input_ids: prepared.input_ids, attention_mask: prepared.attention_mask, words_mask: prepared.words_mask, text_lengths: [prepared.text_len], task_type: @task.task_type, label_positions: label_positions, label_mask: @label_mask, want_cls: @task.needs_cls_logits? ) ) @task.process_output(logits, @parsed, prepared, .merge(label_positions: label_positions)) end |