Class: TorchModel

Inherits:
VectorModel show all
Defined in:
lib/rbbt/vector/model/torch.rb

Direct Known Subclasses

HuggingfaceModel, PytorchLightningModel

Instance Attribute Summary collapse

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model_options, #model_path, #names, #post_process, #train_model

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initialize(dir, model_options = {}) ⇒ TorchModel

Returns a new instance of TorchModel.



34
35
36
# File 'lib/rbbt/vector/model/torch.rb', line 34

def initialize(dir, model_options = {})
  super(dir, model_options)
end

Instance Attribute Details

#modelObject

Returns the value of attribute model.



9
10
11
# File 'lib/rbbt/vector/model/torch.rb', line 9

def model
  @model
end

Class Method Details

.freeze(layer) ⇒ Object



19
20
21
22
23
24
25
26
27
# File 'lib/rbbt/vector/model/torch.rb', line 19

def self.freeze(layer)
  begin
    PyCall.getattr(layer, :weight).requires_grad = false
  rescue
  end
  RbbtPython.iterate(layer.children) do |layer|
    freeze(layer)
  end
end

.freeze_layer(model, layer) ⇒ Object



29
30
31
32
# File 'lib/rbbt/vector/model/torch.rb', line 29

def self.freeze_layer(model, layer)
  layer = get_layer(model, layer)
  freeze(layer)
end

.get_layer(model, layer) ⇒ Object



11
12
13
# File 'lib/rbbt/vector/model/torch.rb', line 11

def self.get_layer(model, layer)
  layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
end

.get_weights(model, layer) ⇒ Object



15
16
17
# File 'lib/rbbt/vector/model/torch.rb', line 15

def self.get_weights(model, layer)
  PyCall.getattr(get_layer(model, layer), :weight)
end