Class: Informers::QuestionAnswering

Inherits:
Object
  • Object
show all
Defined in:
lib/informers/question_answering.rb

Instance Method Summary collapse

Constructor Details

#initialize(model_path) ⇒ QuestionAnswering

Returns a new instance of QuestionAnswering.



18
19
20
21
22
23
24
25
# File 'lib/informers/question_answering.rb', line 18

def initialize(model_path)
  # make sure Numo is available
  require "numo/narray"

  tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
  @tokenizer = BlingFire.load_model(tokenizer_path)
  @model = OnnxRuntime::Model.new(model_path)
end

Instance Method Details

#predict(questions) ⇒ Object



27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# File 'lib/informers/question_answering.rb', line 27

def predict(questions)
  singular = !questions.is_a?(Array)
  questions = [questions] if singular

  topk = 1
  max_answer_len = 15

  sep_pos = []
  cls_pos = []
  context_offsets = []

  # tokenize
  input_ids =
    questions.map do |question|
      tokens = @tokenizer.text_to_ids(question[:question], nil, 100) # unk token
      sep_pos << tokens.size
      tokens << 102 # sep token
      context_tokens, offsets = @tokenizer.text_to_ids_with_offsets(question[:context], nil, 100) # unk token
      tokens.concat(context_tokens)
      context_offsets << offsets
      cls_pos << tokens.size
      tokens.unshift(101) # cls token
      tokens << 102 # sep token
      tokens
    end

  max_tokens = 384
  raise "Large text not supported yet" if input_ids.map(&:size).max > max_tokens

  attention_mask = []
  input_ids.each do |ids|
    zeros = [0] * (max_tokens - ids.size)

    mask = ([1] * ids.size) + zeros
    attention_mask << mask

    ids.concat(zeros)
  end

  # infer
  input = {
    input_ids: input_ids,
    attention_mask: attention_mask
  }
  output = @model.predict(input)

  start = output["output_0"]
  stop = output["output_1"]

  # transform
  answers = []
  start.zip(stop).each_with_index do |(start_, end_), i|
    start_ = Numo::DFloat.cast(start_)
    end_ = Numo::DFloat.cast(end_)

    # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
    feature_p_mask = Numo::Int64.new(max_tokens).fill(0)
    feature_p_mask[1..sep_pos[i] + 1] = 1
    feature_p_mask[cls_pos[i][1]] = 1
    feature_attention_mask = Numo::Int64.cast(attention_mask[i])
    undesired_tokens = (feature_p_mask - 1).abs & feature_attention_mask

    # Generate mask
    undesired_tokens_mask = undesired_tokens.eq(0)

    # Make sure non-context indexes in the tensor cannot contribute to the softmax
    start_[undesired_tokens_mask] = -10000
    end_[undesired_tokens_mask] = -10000

    # Normalize logits and spans to retrieve the answer
    start_ = Numo::DFloat::Math.exp(start_ - Numo::DFloat::Math.log(Numo::DFloat::Math.exp(start_).sum(axis: -1)))
    end_ = Numo::DFloat::Math.exp(end_ - Numo::DFloat::Math.log(Numo::DFloat::Math.exp(end_).sum(axis: -1)))

    # Mask CLS
    start_[0] = end_[0] = 0.0

    starts, ends, scores = decode(start_, end_, topk, max_answer_len)

    # char_to_word
    doc_tokens, char_to_word_offset = send(:doc_tokens, questions[i][:context])
    char_to_word = Numo::Int64.cast(char_to_word_offset)

    # token_to_orig_map
    token_to_orig_map = {}
    map_pos = sep_pos[i] + 2
    context_offsets[i].each do |offset|
      token_to_orig_map[map_pos] = char_to_word_offset[offset]
      map_pos += 1
    end

    # Convert the answer (tokens) back to the original text
    starts.to_a.zip(ends.to_a, scores) do |s, e, score|
      answers << {
        answer: doc_tokens[token_to_orig_map[s]..token_to_orig_map[e]].join(" "),
        score: score,
        start: (char_to_word.eq(token_to_orig_map[s])).where[0],
        end: (char_to_word.eq(token_to_orig_map[e])).where[-1]
      }
    end
  end

  singular ? answers.first : answers
end