Class: Rumale::ModelSelection::CrossValidation

Inherits:
Object
  • Object
show all
Defined in:
lib/rumale/model_selection/cross_validation.rb

Overview

CrossValidation is a class that evaluates a given classifier with cross-validation method.

Examples:

svc = Rumale::LinearModel::SVC.new
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5)
cv = Rumale::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
report = cv.perform(samples, labels)
mean_test_score = report[:test_score].inject(:+) / kf.n_splits

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false) ⇒ CrossValidation

Create a new evaluator with cross-validation method.



48
49
50
51
52
53
54
55
56
57
# File 'lib/rumale/model_selection/cross_validation.rb', line 48

def initialize(estimator: nil, splitter: nil, evaluator: nil, return_train_score: false)
  check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
  check_params_type(Rumale::Base::Splitter, splitter: splitter)
  check_params_type_or_nil(Rumale::Base::Evaluator, evaluator: evaluator)
  check_params_boolean(return_train_score: return_train_score)
  @estimator = estimator
  @splitter = splitter
  @evaluator = evaluator
  @return_train_score = return_train_score
end

Instance Attribute Details

#estimatorClassifier (readonly)

Return the classifier of which performance is evaluated.



28
29
30
# File 'lib/rumale/model_selection/cross_validation.rb', line 28

def estimator
  @estimator
end

#evaluatorEvaluator (readonly)

Return the evaluator that calculates score.



36
37
38
# File 'lib/rumale/model_selection/cross_validation.rb', line 36

def evaluator
  @evaluator
end

#return_train_scoreBoolean (readonly)

Return the flag indicating whether to caculate the score of training dataset.



40
41
42
# File 'lib/rumale/model_selection/cross_validation.rb', line 40

def return_train_score
  @return_train_score
end

#splitterSplitter (readonly)

Return the splitter that divides dataset.



32
33
34
# File 'lib/rumale/model_selection/cross_validation.rb', line 32

def splitter
  @splitter
end

Instance Method Details

#perform(x, y) ⇒ Hash

Perform the evalution of given classifier with cross-validation method.



70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/rumale/model_selection/cross_validation.rb', line 70

def perform(x, y)
  check_sample_array(x)
  if @estimator.is_a?(Rumale::Base::Classifier)
    check_label_array(y)
    check_sample_label_size(x, y)
  end
  if @estimator.is_a?(Rumale::Base::Regressor)
    check_tvalue_array(y)
    check_sample_tvalue_size(x, y)
  end
  # Initialize the report of cross validation.
  report = { test_score: [], train_score: nil, fit_time: [] }
  report[:train_score] = [] if @return_train_score
  # Evaluate the estimator on each split.
  @splitter.split(x, y).each do |train_ids, test_ids|
    # Split dataset into training and testing dataset.
    feature_ids = !kernel_machine? || train_ids
    train_x = x[train_ids, feature_ids]
    train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
    test_x = x[test_ids, feature_ids]
    test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
    # Fit the estimator.
    start_time = Time.now.to_i
    @estimator.fit(train_x, train_y)
    # Calculate scores and prepare the report.
    report[:fit_time].push(Time.now.to_i - start_time)
    if @evaluator.nil?
      report[:test_score].push(@estimator.score(test_x, test_y))
      report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
    elsif log_loss?
      report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
      report[:train_score].push(@evaluator.score(train_y, @estimator.predict_proba(train_x))) if @return_train_score
    else
      report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
      report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
    end
  end
  report
end