Class: PytorchLightningModel

Inherits:
TorchModel show all
Defined in:
lib/rbbt/vector/model/pytorch_lightning.rb

Instance Attribute Summary collapse

Attributes inherited from TorchModel

#criterion, #model, #optimizer, #training_args

Attributes inherited from PythonModel

#python_class, #python_module

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model

Instance Method Summary collapse

Methods inherited from TorchModel

device, dtype, feature_dataset, feature_tsv, freeze, #freeze_layer, freeze_layer, get_layer, #get_layer, #get_weights, get_weights, init_python, load_architecture, load_state, model_architecture, optimizer, save_architecture, save_state, tensor, text_dataset

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initializePytorchLightningModel

Returns a new instance of PytorchLightningModel.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 5

def initialize(...)
  super(...)

  train_model do |features,labels|
    model = init
    loader = self.loader
    val_loader = self.val_loader
    if (features && features.any?) && loader.nil?
      TmpFile.with_file do |tsv_dataset_file|
        TorchModel.feature_dataset(tsv_dataset_file, features, labels)
        RbbtPython.pyimport :rbbt_dm
        loader = RbbtPython.rbbt_dm.tsv(tsv_dataset_file)
      end
    end
    trainer.fit(model, loader, val_loader)
    TorchModel.save_architecture(model, model_path) if @directory
    TorchModel.save_state(model, model_path) if @directory
  end
end

Instance Attribute Details

#loaderObject

Returns the value of attribute loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def loader
  @loader
end

#trainerObject

Returns the value of attribute trainer.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def trainer
  @trainer
end

#val_loaderObject

Returns the value of attribute val_loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def val_loader
  @val_loader
end