Class: TorchModel
- Inherits:
-
VectorModel
- Object
- VectorModel
- TorchModel
- Defined in:
- lib/rbbt/vector/model/torch.rb
Direct Known Subclasses
Instance Attribute Summary collapse
-
#model ⇒ Object
Returns the value of attribute model.
Attributes inherited from VectorModel
#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model_options, #model_path, #names, #post_process, #train_model
Class Method Summary collapse
- .freeze(layer) ⇒ Object
- .freeze_layer(model, layer) ⇒ Object
- .get_layer(model, layer) ⇒ Object
- .get_weights(model, layer) ⇒ Object
Instance Method Summary collapse
-
#initialize(dir, model_options = {}) ⇒ TorchModel
constructor
A new instance of TorchModel.
Methods inherited from VectorModel
R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train
Constructor Details
#initialize(dir, model_options = {}) ⇒ TorchModel
Returns a new instance of TorchModel.
34 35 36 |
# File 'lib/rbbt/vector/model/torch.rb', line 34 def initialize(dir, = {}) super(dir, ) end |
Instance Attribute Details
#model ⇒ Object
Returns the value of attribute model.
9 10 11 |
# File 'lib/rbbt/vector/model/torch.rb', line 9 def model @model end |
Class Method Details
.freeze(layer) ⇒ Object
19 20 21 22 23 24 25 26 27 |
# File 'lib/rbbt/vector/model/torch.rb', line 19 def self.freeze(layer) begin PyCall.getattr(layer, :weight).requires_grad = false rescue end RbbtPython.iterate(layer.children) do |layer| freeze(layer) end end |
.freeze_layer(model, layer) ⇒ Object
29 30 31 32 |
# File 'lib/rbbt/vector/model/torch.rb', line 29 def self.freeze_layer(model, layer) layer = get_layer(model, layer) freeze(layer) end |
.get_layer(model, layer) ⇒ Object
11 12 13 |
# File 'lib/rbbt/vector/model/torch.rb', line 11 def self.get_layer(model, layer) layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) } end |
.get_weights(model, layer) ⇒ Object
15 16 17 |
# File 'lib/rbbt/vector/model/torch.rb', line 15 def self.get_weights(model, layer) PyCall.getattr(get_layer(model, layer), :weight) end |