Module: Svm::CrossValidation

Included in:
Problem
Defined in:
lib/svm/cross_validation.rb

Instance Method Summary collapse

Instance Method Details

#cross_validate(n_folds = 5, more_options = nil) ⇒ Object



10
11
12
13
14
15
16
17
18
# File 'lib/svm/cross_validation.rb', line 10

def cross_validate(n_folds = 5, more_options = nil)
  set(more_options) if more_options
  
  predicted_results_pointer = FFI::MemoryPointer.new(:double, num_samples)
  
  Svm.svm_cross_validation(problem_struct, options.parameter_struct, n_folds, predicted_results_pointer)
  
  predicted_results_pointer.read_array_of_double(num_samples)
end

#find_best_parameters(n_folds = 5) ⇒ Object



20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/svm/cross_validation.rb', line 20

def find_best_parameters(n_folds = 5)
  c_exponents = (-1..14).to_a
  gamma_exponents = (-13..-1).to_a
  
  combinations = c_exponents.product(gamma_exponents)
  
  max = combinations.max_by do |comb|
    c = 2 ** comb[0]
    gamma = 2 ** comb[1]
    
    results_for_cross_validation(n_folds, :c => c, :gamma => gamma)
  end
  
  c     = 2**max[0]
  gamma = 2**max[1]
  
  {:c => c, :gamma => gamma}
end

#results_for_cross_validation(n_folds = 5, custom_options = nil) ⇒ Object



4
5
6
7
8
# File 'lib/svm/cross_validation.rb', line 4

def results_for_cross_validation(n_folds = 5, custom_options = nil)
  results = cross_validate(n_folds, custom_options)
  
  num_samples.times.collect { |i| value(i) == results[i] ? weight_for(i) : 0 }.inject(:+)
end