Class: Torchrb::ModelBase

Inherits:
Object
  • Object
show all
Defined in:
lib/torchrb/model_base.rb

Constant Summary collapse

REQUIRED_OPTIONS =
[:data_model]

Class Method Summary collapse

Class Method Details

.error_rateObject



33
34
35
# File 'lib/torchrb/model_base.rb', line 33

def error_rate
  torch.error_rate
end

.predict(sample) ⇒ Object



59
60
61
# File 'lib/torchrb/model_base.rb', line 59

def predict sample
  torch.predict sample, network_storage_path
end

.progress_callback(progress: nil, message: nil, error_rate: Float::NAN) ⇒ Object

Raises:

  • (NotImplementedError)


5
6
7
# File 'lib/torchrb/model_base.rb', line 5

def progress_callback progress: nil, message: nil, error_rate: Float::NAN
  raise NotImplementedError.new("Implement this method in your Model")
end

.setup_nn(options = {}) ⇒ Object



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/torchrb/model_base.rb', line 9

def setup_nn options={}
  check_options(options)
  {
      net: Torchrb::NN::Basic,
      trainer: Torchrb::NN::TrainerDefault,
      tensor_type: "DoubleTensor",
      dimensions: [0],
      classes: [],
      dataset_split: [80, 10, 10],
      normalize: false,
      enable_cuda: false,
      auto_store_trained_network: true,
      network_storage_path: "tmp/cache/torchrb",
      debug: false,
  }.merge!(options).each do |option, default|
    cattr_reader(option)
    class_variable_set(:"@@#{option}", default)
  end
  cattr_reader(:torch) { Torchrb::Torch.new options }

  @net_options = load_extension(options[:net])
  @trainer_options = load_extension(options[:trainer])
end

.trainObject



37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/torchrb/model_base.rb', line 37

def train
  progress_callback message: 'Loading data'
  load_model_data

  torch.iteration_callback= method(:progress_callback)

  define_nn @net_options
  define_trainer @trainer_options

  torch.cudify if enable_cuda

  progress_callback message: 'Start training'
  torch.train
  progress_callback message: 'Done'

  torch.print_results
  torch.store_network network_storage_path if auto_store_trained_network

  after_training if respond_to?(:after_training)
  torch.error_rate
end