Class: SvmToolkit::Svm::CrossValidationSearch

Inherits:
RecursiveTask
  • Object
show all
Defined in:
lib/svm_toolkit/svm.rb

Overview

Set up the cross validation search across a cost/gamma pair

Instance Method Summary collapse

Constructor Details

#initialize(gammas, costs, training_set, cross_valn_set, evaluator) ⇒ CrossValidationSearch

Creates an instance of the CrossValidationSearch.

gammas

array of gamma values to search over

costs

array of cost values to search over

training_set

for building the model

cross_valn_set

for testing the model

evaluator

name of Evaluator class, used for evaluating the model



70
71
72
73
74
75
76
77
78
# File 'lib/svm_toolkit/svm.rb', line 70

def initialize gammas, costs, training_set, cross_valn_set, evaluator
  super() 

  @gammas = gammas
  @costs = costs
  @training_set = training_set
  @cross_valn_set = cross_valn_set  
  @evaluator = evaluator
end

Instance Method Details

#computeObject

perform actual computation, return results/best_model



81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# File 'lib/svm_toolkit/svm.rb', line 81

def compute
  tasks = []
  # create one task per gamma/cost pair
  @gammas.each do |gamma|
    @costs.each do |cost|
      tasks << SvmTrainer.new(@training_set, Parameter.new(
        :svm_type => Parameter::C_SVC,
        :kernel_type => Parameter::RBF,
        :cost => cost,
        :gamma => gamma
      ), @cross_valn_set, @evaluator)
    end
  end

  # set off all the tasks
  tasks.each do |task|
    task.fork
  end

  # collect the results
  results = []
  best_model = nil
  lowest_error = nil

  @gammas.each do |gamma|
    results_row = []
    @costs.each do |cost|
      task = tasks.shift
      model, result = task.join

      if result.better_than? lowest_error
        best_model = model
        lowest_error = result
      end
      puts "Result for cost = #{cost}  gamma = #{gamma} is #{result.value}"
      results_row << result.value
    end
    results << results_row
  end

  return results, best_model
end