Class: HuggingfaceModel
- Inherits:
-
TorchModel
- Object
- VectorModel
- TorchModel
- HuggingfaceModel
- Defined in:
- lib/rbbt/vector/model/huggingface.rb
Direct Known Subclasses
Instance Attribute Summary
Attributes inherited from TorchModel
Attributes inherited from VectorModel
#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model
Class Method Summary collapse
Instance Method Summary collapse
-
#initialize(task, checkpoint, dir = nil, model_options = {}) ⇒ HuggingfaceModel
constructor
A new instance of HuggingfaceModel.
- #reset_model ⇒ Object
Methods inherited from TorchModel
freeze, freeze_layer, get_layer, get_weights
Methods inherited from VectorModel
R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train
Constructor Details
#initialize(task, checkpoint, dir = nil, model_options = {}) ⇒ HuggingfaceModel
Returns a new instance of HuggingfaceModel.
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# File 'lib/rbbt/vector/model/huggingface.rb', line 41 def initialize(task, checkpoint, dir = nil, = {}) super(dir, ) @model_options = Misc.add_defaults @model_options, :task => task, :checkpoint => checkpoint init_model do checkpoint = @model_path && File.directory?(@model_path) ? @model_path : @model_options[:checkpoint] model = RbbtPython.call_method("rbbt_dm.huggingface", :load_model, @model_options[:task], checkpoint, **(IndiferentHash.setup([:model_args]) || {})) tokenizer = RbbtPython.call_method("rbbt_dm.huggingface", :load_tokenizer, @model_options[:task], checkpoint, **(IndiferentHash.setup([:tokenizer_args]) || {})) [model, tokenizer] end eval_model do |texts,is_list| model, tokenizer = self.init if is_list || @model_options[:task] == "MaskedLM" texts = [texts] if ! is_list if @model_options.include?(:locate_tokens) locate_tokens = @model_options[:locate_tokens] elsif @model_options[:task] == "MaskedLM" @model_options[:locate_tokens] = locate_tokens = tokenizer.special_tokens_map["mask_token"] end if @directory tsv_file = File.join(@directory, 'dataset.tsv') checkpoint_dir = File.join(@directory, 'checkpoints') else tmpdir = TmpFile.tmp_file Open.mkdir tmpdir tsv_file = File.join(tmpdir, 'dataset.tsv') checkpoint_dir = File.join(tmpdir, 'checkpoints') end dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts) training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args]) begin RbbtPython.call_method("rbbt_dm.huggingface", :predict_model, model, tokenizer, training_args_obj, dataset_file, locate_tokens) ensure Open.rm_rf tmpdir if tmpdir end else RbbtPython.call_method("rbbt_dm.huggingface", :eval_model, model, tokenizer, [texts], locate_tokens) end end train_model do |texts,labels| model, tokenizer = self.init if @directory tsv_file = File.join(@directory, 'dataset.tsv') checkpoint_dir = File.join(@directory, 'checkpoints') else tmpdir = TmpFile.tmp_file Open.mkdir tmpdir tsv_file = File.join(tmpdir, 'dataset.tsv') checkpoint_dir = File.join(tmpdir, 'checkpoints') end training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, checkpoint_dir, @model_options[:training_args]) dataset_file = HuggingfaceModel.tsv_dataset(tsv_file, texts, labels, @model_options[:class_labels]) RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights]) Open.rm_rf tmpdir if tmpdir model.save_pretrained(@model_path) if @model_path tokenizer.save_pretrained(@model_path) if @model_path end post_process do |result,is_list| model, tokenizer = self.init if result.respond_to?(:predictions) single = false predictions = result.predictions elsif result["token_positions"] predictions = result["result"].predictions token_positions = result["token_positions"] else single = true predictions = result["logits"] end task, class_labels, locate_tokens = @model_options.values_at :task, :class_labels, :locate_tokens result = case task when "SequenceClassification" RbbtPython.collect(predictions) do |logits| logits = RbbtPython.numpy2ruby logits best_class = logits.index logits.max best_class = class_labels[best_class] if class_labels best_class end when "MaskedLM" all_token_positions = token_positions.to_a i = 0 RbbtPython.collect(predictions) do |item_logits| item_token_positions = all_token_positions[i] i += 1 item_logits = RbbtPython.numpy2ruby(item_logits) item_masks = item_token_positions.collect do |token_positions| best = item_logits.values_at(*token_positions).collect do |logits| best_token, best_score = nil logits.each_with_index do |v,i| if best_score.nil? || v > best_score best_token, best_score = i, v end end best_token end best.collect{|b| tokenizer.decode(b) } * "|" end Array === locate_tokens ? item_masks : item_masks.first end else predictions end (! is_list || single) && Array === result ? result.first : result end save_models if @model_path end |
Class Method Details
.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/rbbt/vector/model/huggingface.rb', line 5 def self.tsv_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) if labels labels = case class_labels when Array labels.collect{|l| class_labels.index l} when Hash inverse_class_labels = {} class_labels.each{|c,l| inverse_class_labels[l] = c } labels.collect{|l| inverse_class_labels[l]} else labels end Open.write(tsv_dataset_file) do |ffile| ffile.puts ["label", "text"].flatten * "\t" elements.zip(labels).each do |element,label| element = element.gsub("\n", " ") ffile.puts [label, element].flatten * "\t" end ffile.sync end else Open.write(tsv_dataset_file) do |ffile| ffile.puts ["text"].flatten * "\t" elements.each do |element| element = element.gsub("\n", " ") ffile.puts element end ffile.sync end end tsv_dataset_file end |
Instance Method Details
#reset_model ⇒ Object
174 175 176 177 178 |
# File 'lib/rbbt/vector/model/huggingface.rb', line 174 def reset_model @model, @tokenizer = nil Open.rm_rf @model_path init end |