Class: Rumale::ModelSelection::GridSearchCV

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/model_selection/grid_search_cv.rb

Overview

GridSearchCV is a class that performs hyperparameter optimization with grid search method.

Examples:

rfc = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params
rbf = Rumale::KernelApproximation::RBF.new(random_seed: 1)
svc = Rumale::LinearModel::SVC.new(random_seed: 1)
pipe = Rumale::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = Rumale::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
gs.fit(samples, labels)
p gs.cv_results
p gs.best_params

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true) ⇒ GridSearchCV

Create a new grid search method.

Parameters:

  • estimator (Classifier/Regresor) (defaults to: nil)

    The estimator to be searched for optimal parameters with grid search method.

  • param_grid (Array<Hash>) (defaults to: nil)

    The parameter sets is represented with array of hash that consists of parameter names as keys and array of parameter values as values.

  • splitter (Splitter) (defaults to: nil)

    The splitter that divides dataset to training and testing dataset on cross validation.

  • evaluator (Evaluator) (defaults to: nil)

    The evaluator that calculates score of estimator results on cross validation. If nil is given, the score method of estimator is used to evaluation.

  • greater_is_better (Boolean) (defaults to: true)

    The flag that indicates whether the estimator is better as evaluation score is larger.



67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 67

def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
  check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
  check_params_type(Rumale::Base::Splitter, splitter: splitter)
  check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator)
  check_params_boolean(greater_is_better: greater_is_better)
  @params = {}
  @params[:param_grid] = valid_param_grid(param_grid)
  @params[:estimator] = Marshal.load(Marshal.dump(estimator))
  @params[:splitter] = Marshal.load(Marshal.dump(splitter))
  @params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
  @params[:greater_is_better] = greater_is_better
  @cv_results = nil
  @best_score = nil
  @best_params = nil
  @best_index = nil
  @best_estimator = nil
end

Instance Attribute Details

#best_estimatorEstimator (readonly)

Return the estimator learned with the best parameter.

Returns:

  • (Estimator)


55
56
57
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 55

def best_estimator
  @best_estimator
end

#best_indexInteger (readonly)

Return the index of the best parameter.

Returns:

  • (Integer)


51
52
53
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 51

def best_index
  @best_index
end

#best_paramsHash (readonly)

Return the best parameter set.

Returns:

  • (Hash)


47
48
49
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 47

def best_params
  @best_params
end

#best_scoreFloat (readonly)

Return the score of the estimator learned with the best parameter.

Returns:

  • (Float)


43
44
45
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 43

def best_score
  @best_score
end

#cv_resultsHash (readonly)

Return the result of cross validation for each parameter.

Returns:

  • (Hash)


39
40
41
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 39

def cv_results
  @cv_results
end

Instance Method Details

#decision_function(x) ⇒ Numo::DFloat

Call the decision_function method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to compute the scores.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples]) Confidence score per sample.



113
114
115
116
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 113

def decision_function(x)
  check_sample_array(x)
  @best_estimator.decision_function(x)
end

#fit(x, y) ⇒ GridSearchCV

Fit the model with given training data and all sets of parameters.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::NArray)

    (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.

Returns:

  • (GridSearchCV)

    The learned estimator with grid search.



90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 90

def fit(x, y)
  check_sample_array(x)

  init_attrs

  param_combinations.each do |prm_set|
    prm_set.each do |prms|
      report = perform_cross_validation(x, y, prms)
      store_cv_result(prms, report)
    end
  end

  find_best_params

  @best_estimator = configurated_estimator(@best_params)
  @best_estimator.fit(x, y)
  self
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data about GridSearchCV.



157
158
159
160
161
162
163
164
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 157

def marshal_dump
  { params: @params,
    cv_results: @cv_results,
    best_score: @best_score,
    best_params: @best_params,
    best_index: @best_index,
    best_estimator: @best_estimator }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


168
169
170
171
172
173
174
175
176
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 168

def marshal_load(obj)
  @params = obj[:params]
  @cv_results = obj[:cv_results]
  @best_score = obj[:best_score]
  @best_params = obj[:best_params]
  @best_index = obj[:best_index]
  @best_estimator = obj[:best_estimator]
  nil
end

#predict(x) ⇒ Numo::NArray

Call the predict method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to obtain prediction result.

Returns:

  • (Numo::NArray)

    Predicted results.



122
123
124
125
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 122

def predict(x)
  check_sample_array(x)
  @best_estimator.predict(x)
end

#predict_log_proba(x) ⇒ Numo::DFloat

Call the predict_log_proba method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the log-probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.



131
132
133
134
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 131

def predict_log_proba(x)
  check_sample_array(x)
  @best_estimator.predict_log_proba(x)
end

#predict_proba(x) ⇒ Numo::DFloat

Call the predict_proba method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the probailities.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Predicted probability of each class per sample.



140
141
142
143
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 140

def predict_proba(x)
  check_sample_array(x)
  @best_estimator.predict_proba(x)
end

#score(x, y) ⇒ Float

Call the score method of learned estimator with the best parameter.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) Testing data.

  • y (Numo::NArray)

    (shape: [n_samples, n_outputs]) True target values or labels for testing data.

Returns:

  • (Float)

    The score of estimator.



150
151
152
153
# File 'lib/rumale/model_selection/grid_search_cv.rb', line 150

def score(x, y)
  check_sample_array(x)
  @best_estimator.score(x, y)
end