Class: Rababa::Diacritizer

Inherits:
Object
  • Object
show all
Includes:
Harakats, Reconcile
Defined in:
lib/rababa/diacritizer.rb

Instance Method Summary collapse

Methods included from Reconcile

#build_pivot_map, #reconcile_strings

Methods included from Harakats

#basic_cleaners, #collapse_whitespace, #extract_haraqat, #extract_stack, #remove_diacritics, #valid_arabic_cleaners

Constructor Details

#initialize(onnx_model_path, config) ⇒ Diacritizer

Returns a new instance of Diacritizer.



19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/rababa/diacritizer.rb', line 19

def initialize(onnx_model_path, config)

    # load inference model from model_path
    @onnx_session = OnnxRuntime::InferenceSession.new(onnx_model_path)

    # load config
    @config = config
    @max_length = @config['max_len']
    @batch_size = @config['batch_size']

    # instantiate encoder's class
    @encoder = get_text_encoder
    @start_symbol_id = @encoder.start_symbol_id

end

Instance Method Details

#combine_text_and_haraqat(vec_txt, vec_haraqat, encoding_mode = 'std') ⇒ Object

Combine: text + Haraqats –> diacritised arabic



120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# File 'lib/rababa/diacritizer.rb', line 120

def combine_text_and_haraqat(vec_txt, vec_haraqat, encoding_mode='std')
    if vec_txt.length != vec_haraqat.length
        raise Exception.new('haraqat.len != txt.len in \
                             Harakats::combine_text_and_haraqat')
    end

    text, i = '', 0
    loop do
        txt = vec_txt[i]
        haraq = vec_haraqat[i]
        i += 1
    break if (i == vec_txt.length) or \
                  (txt == @encoder.input_pad_id)

        if encoding_mode == 'std'
            s = @encoder.input_id_to_symbol[txt].to_s + \
                    @encoder.target_id_to_symbol[haraq].to_s

        elsif encoding_mode == 'escaped unicode'
            s = @encoder.input_id_to_symbol[txt].to_s + \
                    @utarget_symbol_to_id.utarget_id_to_symbol[haraq].to_s
        end
        text += s
    end

    text #.reverse
end

#diacritize_file(path) ⇒ Object

download data from relative path and diacritize line by line



74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/rababa/diacritizer.rb', line 74

def diacritize_file(path)
    texts = []
    File.open(path).each do |line|
        texts.push(line.chomp.strip())
    end

    # process batches
    out_texts = []
    idx = 0
    loop do
        break if (idx+@batch_size > texts.length)

        originals = texts[idx..idx+@batch_size-1]
        src = originals.map.each{|t| preprocess_text(t)}
        lengths = src.map.each{|seq| seq.length}
        ort_inputs = {'src' => src,
                      'lengths' => lengths}
        preds = predict_batch(ort_inputs)

        out_texts += (0..@batch_size-1).map.each{|i| \
          reconcile_strings(originals[i],
                            combine_text_and_haraqat(src[i], preds[i]))
        }
        idx += @batch_size
    end

    # process rest of data
    loop do
        break if (idx >= texts.length)
        out_texts += [diacritize_text(texts[idx])]
        idx += 1
    end

    out_texts
end

#diacritize_text(text) ⇒ Object

Diacritize single arabic strings



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# File 'lib/rababa/diacritizer.rb', line 55

def diacritize_text(text)
    """Diacritize single arabic strings"""
    text = text.strip()
    seq = preprocess_text(text)

    # initialize onnx computation
    # redondancy caused by batch processing of nnets
    ort_inputs = {
        'src' => [seq]*@batch_size,
        'lengths' => [seq.length]*@batch_size
    }

    # onnx predictions
    preds = predict_batch(ort_inputs)[0]

    reconcile_strings(text, combine_text_and_haraqat(seq, preds))
end

#get_text_encoderObject

Initialise text encoder from config params



149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# File 'lib/rababa/diacritizer.rb', line 149

def get_text_encoder()
    if not ['basic_cleaners', 'valid_arabic_cleaners', nil].include? \
                                        @config['text_cleaner']
        raise Exception.new( \
                'cleaner is not known: '+@config['text_cleaner'].to_s)
    end

    if @config['text_encoder'] == 'BasicArabicEncoder'
        encoder = Encoders::BasicArabicEncoder.new(@config['text_cleaner'])
    elsif @config['text_encoder'] == 'ArabicEncoderWithStartSymbol'
        encoder = Encoders::ArabicEncoderWithStartSymbol.new(@config['text_cleaner'])
    else
        raise Exception.new(\
            'the text encoder is not found: '+@config['text_encoder'].to_s)
    end

    encoder
end

#predict_batch(batch_data) ⇒ Object

Call ONNX model with data transformed in batches



111
112
113
114
115
116
117
# File 'lib/rababa/diacritizer.rb', line 111

def predict_batch(batch_data)
  # onnx predictions
  predicts = @onnx_session.run(nil, batch_data)
  predicts = predicts[0].map.each{|p| \
                            p.map.each{|r| r.each_with_index.max[1]}}
  return predicts
end

#preprocess_text(text) ⇒ Object

preprocess text into indices



36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# File 'lib/rababa/diacritizer.rb', line 36

def preprocess_text(text)
    # if (text.length > @max_length)
    #     raise ValueError.new('text length larger than max_length')
    # end
    # hack in absence of preprocessing!
    if text.length > @max_length
        text = text[0..@max_length]
        warn('WARNING:: string cut length > #{@max_length},\n')
        warn('text:: '+text)
    end

    text = @encoder.clean(text)
    text = remove_diacritics(text)
    seq = @encoder.input_to_sequence(text)
    # correct expected length for vectors with 0's
    return seq+[0]*(@max_length-seq.length)
end