Class: Torchrb::Wrapper

Inherits:
Torch show all
Defined in:
lib/torchrb/wrapper.rb

Instance Attribute Summary collapse

Attributes inherited from Torch

#error_rate, #network_loaded, #network_timestamp

Attributes inherited from Lua

#debug, #enable_cuda

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Torch

#cudify, #iteration_callback=, #load_network, #print_results, #store_network

Methods inherited from Lua

#eval

Instance Attribute Details

#modelObject (readonly)

Returns the value of attribute model.



13
14
15
# File 'lib/torchrb/wrapper.rb', line 13

def model
  @model
end

#progressObject (readonly)

Returns the value of attribute progress.



13
14
15
# File 'lib/torchrb/wrapper.rb', line 13

def progress
  @progress
end

Class Method Details

.for(model_class, options = {}) ⇒ Object



4
5
6
7
8
9
10
11
# File 'lib/torchrb/wrapper.rb', line 4

def self.for model_class, options={}
  @@instances[model_class] ||= new model_class, options
  if block_given?
    yield @@instances[model_class]
  else
    @@instances[model_class]
  end
end

Instance Method Details

#load_model_dataObject



26
27
28
29
30
31
# File 'lib/torchrb/wrapper.rb', line 26

def load_model_data
  @progress = 0
  load_dataset :train_set
  load_dataset :test_set
  load_dataset :validation_set
end

#predict(sample) ⇒ Object



44
45
46
# File 'lib/torchrb/wrapper.rb', line 44

def predict sample
  super sample
end

#trainObject



33
34
35
36
37
38
39
40
41
42
# File 'lib/torchrb/wrapper.rb', line 33

def train
  define_nn
  define_trainer

  cudify if enable_cuda
  super
  print_results
  store_network
  error_rate
end