Class: Gliner::Inference
- Inherits:
-
Object
show all
- Defined in:
- lib/gliner/inference.rb,
lib/gliner/inference/session_validator.rb
Defined Under Namespace
Classes: IOValidation, Request, SessionValidator
Constant Summary
collapse
- TASK_TYPE_ENTITIES =
0
- TASK_TYPE_CLASSIFICATION =
1
- TASK_TYPE_JSON =
2
- SCHEMA_PREFIX_LENGTH =
4
- LABEL_SPACING =
2
Instance Attribute Summary collapse
Instance Method Summary
collapse
Constructor Details
#initialize(session) ⇒ Inference
Returns a new instance of Inference.
29
30
31
32
33
34
35
36
37
38
|
# File 'lib/gliner/inference.rb', line 29
def initialize(session)
@session = session
validation = SessionValidator[session]
@input_names = validation.input_names
@output_name = validation.output_name
@label_index_mode = validation.label_index_mode
@has_cls_logits = validation.has_cls_logits
end
|
Instance Attribute Details
#has_cls_logits ⇒ Object
Returns the value of attribute has_cls_logits.
27
28
29
|
# File 'lib/gliner/inference.rb', line 27
def has_cls_logits
@has_cls_logits
end
|
#label_index_mode ⇒ Object
Returns the value of attribute label_index_mode.
27
28
29
|
# File 'lib/gliner/inference.rb', line 27
def label_index_mode
@label_index_mode
end
|
Instance Method Details
#label_logit(logits, pos, width, label_index, label_positions) ⇒ Object
57
58
59
60
61
62
63
64
65
66
|
# File 'lib/gliner/inference.rb', line 57
def label_logit(logits, pos, width, label_index, label_positions)
if @label_index_mode == :label_position
raise Error, 'Label positions required for span_logits output' if label_positions.nil?
label_pos = label_positions.fetch(label_index)
logits[0][pos][width][label_pos]
else
logits[0][pos][width][label_index]
end
end
|
#label_positions_for(word_ids, label_count) ⇒ Object
46
47
48
49
50
51
52
53
54
55
|
# File 'lib/gliner/inference.rb', line 46
def label_positions_for(word_ids, label_count)
label_count.times.map do |i|
combined_idx = SCHEMA_PREFIX_LENGTH + (i * LABEL_SPACING)
pos = word_ids.index(combined_idx)
raise Error, "Could not locate label position at combined index #{combined_idx}" if pos.nil?
pos
end
end
|
#run(request) ⇒ Object
40
41
42
43
44
|
# File 'lib/gliner/inference.rb', line 40
def run(request)
outputs = output_names_for(request)
out = @session.run(outputs, build_inputs(request))
format_outputs(out, outputs)
end
|
#sigmoid(value) ⇒ Object
68
69
70
|
# File 'lib/gliner/inference.rb', line 68
def sigmoid(value)
1.0 / (1.0 + Math.exp(-value))
end
|
#softmax(values) ⇒ Object
72
73
74
75
76
77
|
# File 'lib/gliner/inference.rb', line 72
def softmax(values)
max_value = values.max
exps = values.map { |value| Math.exp(value - max_value) }
sum = exps.sum
exps.map { |value| value / sum }
end
|