Class: MaskedLMModel

Inherits:
HuggingfaceModel show all
Defined in:
lib/rbbt/vector/model/huggingface/masked_lm.rb

Instance Attribute Summary

Attributes inherited from TorchModel

#model

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model

Instance Method Summary collapse

Methods inherited from HuggingfaceModel

#reset_model, tsv_dataset

Methods inherited from TorchModel

freeze, freeze_layer, get_layer, get_weights

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initialize(checkpoint, dir = nil, model_options = {}) ⇒ MaskedLMModel

Returns a new instance of MaskedLMModel.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/rbbt/vector/model/huggingface/masked_lm.rb', line 4

def initialize(checkpoint, dir = nil, model_options = {})
  
  model_options = Misc.add_defaults model_options, :max_length => 128
  super("MaskedLM", checkpoint, dir, model_options)

  train_model do |texts,labels|
    model, tokenizer = self.init
    max_length = @model_options[:max_length]
    mask_id = tokenizer.mask_token_id

    dataset = []
    texts.zip(labels).each do |text,label_values|
      fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
      label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
      label_tokens.each do |ids|
        ids = [ids] unless Array === ids
        fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
      end

      tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
      input_ids = tokenized_text["input_ids"].to_a
      attention_mask = tokenized_text["attention_mask"].to_a

      all_label_tokens = label_tokens.flatten
      label_ids = input_ids.collect do |id|
        if id == mask_id
          all_label_tokens.shift
        else
          -100
        end
      end
      dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
    end

    dataset_file = File.join(@directory, 'dataset.json')
    Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")

    training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
    data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {}) 
    RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)

    model.save_pretrained(@model_path) if @model_path
    tokenizer.save_pretrained(@model_path) if @model_path
  end

end