Class: PytorchLightningModel
- Inherits:
-
TorchModel
- Object
- VectorModel
- TorchModel
- PytorchLightningModel
- Defined in:
- lib/rbbt/vector/model/pytorch_lightning.rb
Instance Attribute Summary collapse
-
#loader ⇒ Object
Returns the value of attribute loader.
-
#trainer ⇒ Object
Returns the value of attribute trainer.
-
#val_loader ⇒ Object
Returns the value of attribute val_loader.
Attributes inherited from TorchModel
Attributes inherited from VectorModel
#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model
Instance Method Summary collapse
-
#initialize(module_name, class_name, dir = nil, model_options = {}) ⇒ PytorchLightningModel
constructor
A new instance of PytorchLightningModel.
Methods inherited from TorchModel
freeze, freeze_layer, get_layer, get_weights
Methods inherited from VectorModel
R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train
Constructor Details
#initialize(module_name, class_name, dir = nil, model_options = {}) ⇒ PytorchLightningModel
Returns a new instance of PytorchLightningModel.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 5 def initialize(module_name, class_name, dir = nil, = {}) super(dir, ) @module_name = module_name @class_name = class_name init_model do RbbtPython.pyimport @module_name RbbtPython.class_new_obj(@module_name, @class_name, [:model_args] || {}) end train_model do |features,labels| model = init raise "Use the loader" if @loader.nil? raise "Use the trainer" if @trainer.nil? trainer.fit(model, @loader, @val_loader) end eval_model do |features,list| if list model.call(RbbtPython.call_method(:torch, :tensor, features)) else model.call(RbbtPython.call_method(:torch, :tensor, [features])) end end end |
Instance Attribute Details
#loader ⇒ Object
Returns the value of attribute loader.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def loader @loader end |
#trainer ⇒ Object
Returns the value of attribute trainer.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def trainer @trainer end |
#val_loader ⇒ Object
Returns the value of attribute val_loader.
4 5 6 |
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4 def val_loader @val_loader end |