Class: NerRuby::Models::Onnx

Inherits:
Base
  • Object
show all
Defined in:
lib/ner_ruby/models/onnx.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(model_path:) ⇒ Onnx

Returns a new instance of Onnx.

Raises:



10
11
12
13
14
15
16
17
# File 'lib/ner_ruby/models/onnx.rb', line 10

def initialize(model_path:)
  require "onnx_ruby"
  @model_path = model_path
  raise ModelNotFoundError, "Model not found: #{model_path}" unless File.exist?(model_path)

  @session = OnnxRuby::Session.new(model_path)
  @label_map = load_config_label_map
end

Instance Attribute Details

#label_mapObject (readonly)

Returns the value of attribute label_map.



8
9
10
# File 'lib/ner_ruby/models/onnx.rb', line 8

def label_map
  @label_map
end

Instance Method Details

#predict(input_ids) ⇒ Object



19
20
21
22
23
24
25
26
27
28
29
30
# File 'lib/ner_ruby/models/onnx.rb', line 19

def predict(input_ids)
  attention_mask = Array.new(input_ids.length, 1)
  token_type_ids = Array.new(input_ids.length, 0)

  outputs = @session.run(
    input_ids: [input_ids],
    attention_mask: [attention_mask],
    token_type_ids: [token_type_ids]
  )

  outputs[0][0]
end