Class: PytorchLightningModel

Inherits:
TorchModel show all
Defined in:
lib/rbbt/vector/model/pytorch_lightning.rb

Instance Attribute Summary collapse

Attributes inherited from TorchModel

#model

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model, #model_options, #model_path, #names, #post_process, #train_model

Instance Method Summary collapse

Methods inherited from TorchModel

freeze, freeze_layer, get_layer, get_weights

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initialize(module_name, class_name, dir = nil, model_options = {}) ⇒ PytorchLightningModel

Returns a new instance of PytorchLightningModel.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 5

def initialize(module_name, class_name, dir = nil, model_options = {})
  super(dir, model_options)
  @module_name = module_name
  @class_name = class_name

  init_model do 
    RbbtPython.pyimport @module_name
    RbbtPython.class_new_obj(@module_name, @class_name, @model_options[:model_args] || {})
  end

  train_model do |features,labels|
    model = init
    raise "Use the loader" if @loader.nil?
    raise "Use the trainer" if @trainer.nil?
    
    trainer.fit(model, @loader, @val_loader)
  end

  eval_model do |features,list|
    if list
      model.call(RbbtPython.call_method(:torch, :tensor, features))
    else
      model.call(RbbtPython.call_method(:torch, :tensor, [features]))
    end
  end

end

Instance Attribute Details

#loaderObject

Returns the value of attribute loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def loader
  @loader
end

#trainerObject

Returns the value of attribute trainer.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def trainer
  @trainer
end

#val_loaderObject

Returns the value of attribute val_loader.



4
5
6
# File 'lib/rbbt/vector/model/pytorch_lightning.rb', line 4

def val_loader
  @val_loader
end