Class: Rumale::ModelSelection::GridSearchCV

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/model_selection/grid_search_cv.rb

Overview

GridSearchCV is a class that performs hyperparameter optimization with grid search method.

Examples:

rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params
rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
svc = Rumale::LinearModel::SVC.new(random_seed: 1)
pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV

Create a new grid search method.



67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 67

def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
  check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
  check_params_type(Rumale::Base::Splitter, splitter: splitter)
  check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator)
  check_params_boolean(greater_is_better: greater_is_better)
  @params = {}
  @params[:param_grid] = valid_param_grid(param_grid)
  @params[:estimator] = Marshal.load(Marshal.dump(estimator))
  @params[:splitter] = Marshal.load(Marshal.dump(splitter))
  @params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
  @params[:greater_is_better] = greater_is_better
  @cv_results = nil
  @best_score = nil
  @best_params = nil
  @best_index = nil
  @best_estimator = nil
end

Instance Attribute Details

#best_estimatorEstimator (readonly)

Return the estimator learned with the best parameter.



55
56
57
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 55

def best_estimator
  @best_estimator
end

#best_indexInteger (readonly)

Return the index of the best parameter.



51
52
53
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 51

def best_index
  @best_index
end

#best_paramsHash (readonly)

Return the best parameter set.



47
48
49
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 47

def best_params
  @best_params
end

#best_scoreFloat (readonly)

Return the score of the estimator learned with the best parameter.



43
44
45
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 43

def best_score
  @best_score
end

#cv_resultsHash (readonly)

Return the result of cross validation for each parameter.



39
40
41
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 39

def cv_results
  @cv_results
end

Instance Method Details

#decision_function(x) ⇒ Numo::DFloat

Call the decision_function method of learned estimator with the best parameter.



113
114
115
116
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 113

def decision_function(x)
  check_sample_array(x)
  @best_estimator.decision_function(x)
end

#fit(x, y) ⇒ GridSearchCV

Fit the model with given training data and all sets of parameters.



90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 90

def fit(x, y)
  check_sample_array(x)

  init_attrs

  param_combinations.each do |prm_set|
    prm_set.each do |prms|
      report = perform_cross_validation(x, y, prms)
      store_cv_result(prms, report)
    end
  end

  find_best_params

  @best_estimator = configurated_estimator(@best_params)
  @best_estimator.fit(x, y)
  self
end

#marshal_dumpHash

Dump marshal data.



157
158
159
160
161
162
163
164
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 157

def marshal_dump
  { params: @params,
    cv_results: @cv_results,
    best_score: @best_score,
    best_params: @best_params,
    best_index: @best_index,
    best_estimator: @best_estimator }
end

#marshal_load(obj) ⇒ nil

Load marshal data.



168
169
170
171
172
173
174
175
176
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 168

def marshal_load(obj)
  @params = obj[:params]
  @cv_results = obj[:cv_results]
  @best_score = obj[:best_score]
  @best_params = obj[:best_params]
  @best_index = obj[:best_index]
  @best_estimator = obj[:best_estimator]
  nil
end

#predict(x) ⇒ Numo::NArray

Call the predict method of learned estimator with the best parameter.



122
123
124
125
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 122

def predict(x)
  check_sample_array(x)
  @best_estimator.predict(x)
end

#predict_log_proba(x) ⇒ Numo::DFloat

Call the predict_log_proba method of learned estimator with the best parameter.



131
132
133
134
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 131

def predict_log_proba(x)
  check_sample_array(x)
  @best_estimator.predict_log_proba(x)
end

#predict_proba(x) ⇒ Numo::DFloat

Call the predict_proba method of learned estimator with the best parameter.



140
141
142
143
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 140

def predict_proba(x)
  check_sample_array(x)
  @best_estimator.predict_proba(x)
end

#score(x, y) ⇒ Float

Call the score method of learned estimator with the best parameter.



150
151
152
153
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 150

def score(x, y)
  check_sample_array(x)
  @best_estimator.score(x, y)
end