Class: CausalModel

Inherits:
HuggingfaceModel show all
Defined in:
lib/scout/model/python/huggingface/causal.rb

Direct Known Subclasses

NextTokenModel

Instance Attribute Summary

Attributes inherited from TorchModel

#criterion, #device, #dtype, #optimizer

Attributes inherited from ScoutModel

#directory, #options, #state

Instance Method Summary collapse

Methods inherited from HuggingfaceModel

#fix_options

Methods inherited from TorchModel

criterion, device, dtype, feature_dataset, feature_tsv, #fix_options, freeze, freeze_layer, #freeze_layer, #get_layer, get_layer, #get_weights, get_weights, init_python, load, load_architecture, load_state, model_architecture, optimizer, #reset_state, save, save_architecture, save_state, tensor, text_dataset

Methods inherited from ScoutModel

#add, #add_list, #eval, #eval_list, #execute, #extract_features, #extract_features_list, #init, #load_method, #load_options, #load_ruby_code, #load_state, #post_process, #post_process_list, #restore, #save, #save_method, #save_options, #save_state, #state_file, #train

Constructor Details

#initializeCausalModel

Returns a new instance of CausalModel.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# File 'lib/scout/model/python/huggingface/causal.rb', line 4

def initialize(...)
  super("CausalLM", ...)

  self.eval do |messages,list|
    model, tokenizer = @state
    ScoutPython.call_method(
      "scout_ai.huggingface.eval", :eval_causal_lm_chat, 
      model, tokenizer, messages, 
      options[:chat_template],
      options[:chat_template_kwargs], 
      options[:generation_kwargs]
    )
  end

  train do |pairs,labels|
    # data: array of [response, reward] or [prompt, response, reward]
    model, tokenizer = @state

    ScoutPython.call_method(
      "scout_ai.huggingface.rlhf", :train_rlhf, 
      self.state_file, tokenizer, pairs, labels, options[:rlhf_config]
    )
    load_state
  end
end