Class: SequenceClassificationModel

Inherits:
HuggingfaceModel show all
Defined in:
lib/scout/model/python/huggingface/classification.rb

Instance Attribute Summary

Attributes inherited from TorchModel

#criterion, #device, #dtype, #optimizer

Attributes inherited from ScoutModel

#directory, #options, #state

Instance Method Summary collapse

Methods inherited from HuggingfaceModel

#fix_options

Methods inherited from TorchModel

criterion, device, dtype, feature_dataset, feature_tsv, #fix_options, freeze, freeze_layer, #freeze_layer, #get_layer, get_layer, #get_weights, get_weights, init_python, load, load_architecture, load_state, model_architecture, optimizer, #reset_state, save, save_architecture, save_state, tensor, text_dataset

Methods inherited from ScoutModel

#add, #add_list, #eval, #eval_list, #execute, #extract_features, #extract_features_list, #init, #load_method, #load_options, #load_ruby_code, #load_state, #post_process, #post_process_list, #restore, #save, #save_method, #save_options, #save_state, #state_file, #train

Constructor Details

#initializeSequenceClassificationModel

Returns a new instance of SequenceClassificationModel.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/scout/model/python/huggingface/classification.rb', line 4

def initialize(...)
  super("SequenceClassification", ...)

  self.eval do |features,list|
    model, tokenizer = @state
    texts = list ? list : [features]
    res = ScoutPython.call_method("scout_ai.huggingface.eval", :eval_model, model, tokenizer, texts, options[:locate_tokens])
    list ? res : res[0]
  end

  post_process do |result,list|
    model, tokenizer = @state

    logit_list = list ? list.logits : result

    res = ScoutPython.collect(logit_list) do |logits|
      logits = ScoutPython.numpy2ruby logits
      best_class = logits.index logits.max
      best_class = options[:class_labels][best_class] if options[:class_labels]
      best_class
    end

    list ? res : res[0]
  end

  train do |texts,labels| 
    model, tokenizer = @state
    
    if directory
      tsv_file = File.join(directory, 'dataset.tsv')
      checkpoint_dir = File.join(directory, 'checkpoints')
    else
      tmpdir = TmpFile.tmp_file
      Open.mkdir tmpdir
      tsv_file = File.join(tmpdir, 'dataset.tsv')
      checkpoint_dir = File.join(tmpdir, 'checkpoints')
    end

    training_args_obj = ScoutPython.call_method("scout_ai.huggingface.train", :training_args, checkpoint_dir, options[:training_args])
    dataset_file = HuggingfaceModel.text_dataset(tsv_file, texts, labels, options[:class_labels])

    ScoutPython.call_method("scout_ai.huggingface.train", :train_model, model, tokenizer, training_args_obj, dataset_file, options[:class_weights])

    Open.rm_rf tmpdir if tmpdir
  end
end