Class: Gliner::Inference

Inherits:
Object
  • Object
show all
Defined in:
lib/gliner/inference.rb,
lib/gliner/inference/session_validator.rb

Defined Under Namespace

Classes: IOValidation, Request, SessionValidator

Constant Summary collapse

TASK_TYPE_ENTITIES =
0
TASK_TYPE_CLASSIFICATION =
1
TASK_TYPE_JSON =
2
SCHEMA_PREFIX_LENGTH =
4
LABEL_SPACING =
2

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(session) ⇒ Inference

Returns a new instance of Inference.



29
30
31
32
33
34
35
36
37
38
# File 'lib/gliner/inference.rb', line 29

def initialize(session)
  @session = session

  validation = SessionValidator[session]

  @input_names = validation.input_names
  @output_name = validation.output_name
  @label_index_mode = validation.label_index_mode
  @has_cls_logits = validation.has_cls_logits
end

Instance Attribute Details

#has_cls_logitsObject (readonly)

Returns the value of attribute has_cls_logits.



27
28
29
# File 'lib/gliner/inference.rb', line 27

def has_cls_logits
  @has_cls_logits
end

#label_index_modeObject (readonly)

Returns the value of attribute label_index_mode.



27
28
29
# File 'lib/gliner/inference.rb', line 27

def label_index_mode
  @label_index_mode
end

Instance Method Details

#label_logit(logits, pos, width, label_index, label_positions) ⇒ Object



57
58
59
60
61
62
63
64
65
66
# File 'lib/gliner/inference.rb', line 57

def label_logit(logits, pos, width, label_index, label_positions)
  if @label_index_mode == :label_position
    raise Error, 'Label positions required for span_logits output' if label_positions.nil?

    label_pos = label_positions.fetch(label_index)
    logits[0][pos][width][label_pos]
  else
    logits[0][pos][width][label_index]
  end
end

#label_positions_for(word_ids, label_count) ⇒ Object



46
47
48
49
50
51
52
53
54
55
# File 'lib/gliner/inference.rb', line 46

def label_positions_for(word_ids, label_count)
  label_count.times.map do |i|
    combined_idx = SCHEMA_PREFIX_LENGTH + (i * LABEL_SPACING)
    pos = word_ids.index(combined_idx)

    raise Error, "Could not locate label position at combined index #{combined_idx}" if pos.nil?

    pos
  end
end

#run(request) ⇒ Object



40
41
42
43
44
# File 'lib/gliner/inference.rb', line 40

def run(request)
  outputs = output_names_for(request)
  out = @session.run(outputs, build_inputs(request))
  format_outputs(out, outputs)
end

#sigmoid(value) ⇒ Object



68
69
70
# File 'lib/gliner/inference.rb', line 68

def sigmoid(value)
  1.0 / (1.0 + Math.exp(-value))
end

#softmax(values) ⇒ Object



72
73
74
75
76
77
# File 'lib/gliner/inference.rb', line 72

def softmax(values)
  max_value = values.max
  exps = values.map { |value| Math.exp(value - max_value) }
  sum = exps.sum
  exps.map { |value| value / sum }
end