Class: Torchrb::ModelBase
- Inherits:
-
Object
- Object
- Torchrb::ModelBase
- Defined in:
- lib/torchrb/model_base.rb
Constant Summary collapse
- REQUIRED_OPTIONS =
[:data_model]
Class Method Summary collapse
- .error_rate ⇒ Object
- .predict(sample) ⇒ Object
- .progress_callback(progress: nil, message: nil, error_rate: Float::NAN) ⇒ Object
- .setup_nn(options = {}) ⇒ Object
- .train ⇒ Object
Class Method Details
.error_rate ⇒ Object
33 34 35 |
# File 'lib/torchrb/model_base.rb', line 33 def error_rate torch.error_rate end |
.predict(sample) ⇒ Object
59 60 61 |
# File 'lib/torchrb/model_base.rb', line 59 def predict sample torch.predict sample, network_storage_path end |
.progress_callback(progress: nil, message: nil, error_rate: Float::NAN) ⇒ Object
5 6 7 |
# File 'lib/torchrb/model_base.rb', line 5 def progress_callback progress: nil, message: nil, error_rate: Float::NAN raise NotImplementedError.new("Implement this method in your Model") end |
.setup_nn(options = {}) ⇒ Object
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# File 'lib/torchrb/model_base.rb', line 9 def setup_nn ={} () { net: Torchrb::NN::Basic, trainer: Torchrb::NN::TrainerDefault, tensor_type: "DoubleTensor", dimensions: [0], classes: [], dataset_split: [80, 10, 10], normalize: false, enable_cuda: false, auto_store_trained_network: true, network_storage_path: "tmp/cache/torchrb", debug: false, }.merge!().each do |option, default| cattr_reader(option) class_variable_set(:"@@#{option}", default) end cattr_reader(:torch) { Torchrb::Torch.new } @net_options = load_extension([:net]) @trainer_options = load_extension([:trainer]) end |
.train ⇒ Object
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# File 'lib/torchrb/model_base.rb', line 37 def train progress_callback message: 'Loading data' load_model_data torch.iteration_callback= method(:progress_callback) define_nn @net_options define_trainer @trainer_options torch.cudify if enable_cuda progress_callback message: 'Start training' torch.train progress_callback message: 'Done' torch.print_results torch.store_network network_storage_path if auto_store_trained_network after_training if respond_to?(:after_training) torch.error_rate end |