Class: Torchrb::Wrapper
Instance Attribute Summary collapse
-
#model ⇒ Object
readonly
Returns the value of attribute model.
-
#progress ⇒ Object
readonly
Returns the value of attribute progress.
Attributes inherited from Torch
#error_rate, #network_loaded, #network_timestamp
Attributes inherited from Lua
Class Method Summary collapse
Instance Method Summary collapse
Methods inherited from Torch
#cudify, #iteration_callback=, #load_network, #print_results, #store_network
Methods inherited from Lua
Instance Attribute Details
#model ⇒ Object (readonly)
Returns the value of attribute model.
13 14 15 |
# File 'lib/torchrb/wrapper.rb', line 13 def model @model end |
#progress ⇒ Object (readonly)
Returns the value of attribute progress.
13 14 15 |
# File 'lib/torchrb/wrapper.rb', line 13 def progress @progress end |
Class Method Details
.for(model_class, options = {}) ⇒ Object
4 5 6 7 8 9 10 11 |
# File 'lib/torchrb/wrapper.rb', line 4 def self.for model_class, ={} @@instances[model_class] ||= new model_class, if block_given? yield @@instances[model_class] else @@instances[model_class] end end |
Instance Method Details
#load_model_data ⇒ Object
26 27 28 29 30 31 |
# File 'lib/torchrb/wrapper.rb', line 26 def load_model_data @progress = 0 load_dataset :train_set load_dataset :test_set load_dataset :validation_set end |
#predict(sample) ⇒ Object
44 45 46 |
# File 'lib/torchrb/wrapper.rb', line 44 def predict sample super sample end |
#train ⇒ Object
33 34 35 36 37 38 39 40 41 42 |
# File 'lib/torchrb/wrapper.rb', line 33 def train define_nn define_trainer cudify if enable_cuda super print_results store_network error_rate end |