Class: Gliner::Runners::PreparedTask

Inherits:
Object
  • Object
show all
Defined in:
lib/gliner/runners/prepared_task.rb

Instance Method Summary collapse

Constructor Details

#initialize(task, parsed) ⇒ PreparedTask

Returns a new instance of PreparedTask.



6
7
8
9
10
11
12
13
14
15
16
17
18
19
# File 'lib/gliner/runners/prepared_task.rb', line 6

def initialize(task, parsed)
  @task = task
  @parsed = parsed
  @labels = task.labels(parsed)

  @schema_tokens = task.input_builder.schema_tokens_for(
    prompt: task.build_prompt(parsed),
    labels: @labels,
    label_prefix: task.label_prefix
  )

  @label_mask = Array.new(@labels.length, 1)
  @label_positions_template = precompute_label_positions
end

Instance Method Details

#call(text, **options) ⇒ Object



21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/gliner/runners/prepared_task.rb', line 21

def call(text, **options)
  prepared = @task.input_builder.prepare(text, @schema_tokens)
  label_positions = @label_positions_template

  if label_positions.any? { |pos| pos.nil? || pos >= prepared.input_ids.length }
    label_positions = @task.inference.label_positions_for(prepared.word_ids, @labels.length)
  end

  logits = @task.inference.run(
    Inference::Request.new(
      input_ids: prepared.input_ids,
      attention_mask: prepared.attention_mask,
      words_mask: prepared.words_mask,
      text_lengths: [prepared.text_len],
      task_type: @task.task_type,
      label_positions: label_positions,
      label_mask: @label_mask,
      want_cls: @task.needs_cls_logits?
    )
  )

  @task.process_output(logits, @parsed, prepared, options.merge(label_positions: label_positions))
end