Class: Gliner::Inference::SessionValidator

Inherits:
Object
  • Object
show all
Defined in:
lib/gliner/inference/session_validator.rb

Constant Summary collapse

EXPECTED_INPUTS_LOGITS =
%w[
  input_ids
  attention_mask
  words_mask
  text_lengths
  task_type
  label_positions
  label_mask
].freeze
EXPECTED_INPUTS_SPAN_LOGITS =
%w[
  input_ids
  attention_mask
].freeze

Class Method Summary collapse

Class Method Details

.[](session) ⇒ Object



22
# File 'lib/gliner/inference/session_validator.rb', line 22

def [](session) = call(session)

.call(session) ⇒ Object



24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/gliner/inference/session_validator.rb', line 24

def call(session)
  input_names = session.inputs.map { |input| input[:name] }
  output_names = session.outputs.map { |output| output[:name] }
  has_cls_logits = output_names.include?('cls_logits')

  validation = validation_for_outputs(output_names, input_names)

  IOValidation.new(
    input_names: input_names,
    output_name: validation.fetch(:output_name),
    label_index_mode: validation.fetch(:label_index_mode),
    has_cls_logits: has_cls_logits
  )
end