Class: Gliner::Inference::SessionValidator
- Inherits:
-
Object
- Object
- Gliner::Inference::SessionValidator
- Defined in:
- lib/gliner/inference/session_validator.rb
Constant Summary collapse
- EXPECTED_INPUTS_LOGITS =
%w[ input_ids attention_mask words_mask text_lengths task_type label_positions label_mask ].freeze
- EXPECTED_INPUTS_SPAN_LOGITS =
%w[ input_ids attention_mask ].freeze
Class Method Summary collapse
Class Method Details
.[](session) ⇒ Object
22 |
# File 'lib/gliner/inference/session_validator.rb', line 22 def [](session) = call(session) |
.call(session) ⇒ Object
24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# File 'lib/gliner/inference/session_validator.rb', line 24 def call(session) input_names = session.inputs.map { |input| input[:name] } output_names = session.outputs.map { |output| output[:name] } has_cls_logits = output_names.include?('cls_logits') validation = validation_for_outputs(output_names, input_names) IOValidation.new( input_names: input_names, output_name: validation.fetch(:output_name), label_index_mode: validation.fetch(:label_index_mode), has_cls_logits: has_cls_logits ) end |