Class: SVMKit::ModelSelection::KFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
lib/svmkit/model_selection/k_fold.rb

Overview

KFold is a class that generates the set of data indices for K-fold cross-validation.

Examples:

kf = SVMKit::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
kf.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Attributes included from Base::Splitter

#n_splits

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold

Create a new data splitter for K-fold cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • shuffle (Boolean) (defaults to: false)

    The flag indicating whether to shuffle the dataset.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



35
36
37
38
39
40
41
42
43
44
45
# File 'lib/svmkit/model_selection/k_fold.rb', line 35

def initialize(n_splits: 3, shuffle: false, random_seed: nil)
  SVMKit::Validation.check_params_integer(n_splits: n_splits)
  SVMKit::Validation.check_params_boolean(shuffle: shuffle)
  SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
  SVMKit::Validation.check_params_positive(n_splits: n_splits)
  @n_splits = n_splits
  @shuffle = shuffle
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


28
29
30
# File 'lib/svmkit/model_selection/k_fold.rb', line 28

def rng
  @rng
end

#shuffleBoolean (readonly)

Return the flag indicating whether to shuffle the dataset.

Returns:

  • (Boolean)


24
25
26
# File 'lib/svmkit/model_selection/k_fold.rb', line 24

def shuffle
  @shuffle
end

Instance Method Details

#split(x, _y = nil) ⇒ Array

Generate data indices for K-fold cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for K-fold cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# File 'lib/svmkit/model_selection/k_fold.rb', line 52

def split(x, _y = nil)
  SVMKit::Validation.check_sample_array(x)
  # Initialize and check some variables.
  n_samples, = x.shape
  unless @n_splits.between?(2, n_samples)
    raise ArgumentError,
          'The value of n_splits must be not less than 2 and not more than the number of samples.'
  end
  # Splits dataset ids to each fold.
  dataset_ids = [*0...n_samples]
  dataset_ids.shuffle!(random: @rng) if @shuffle
  fold_sets = Array.new(@n_splits) do |n|
    n_fold_samples = n_samples / @n_splits
    n_fold_samples += 1 if n < n_samples % @n_splits
    dataset_ids.shift(n_fold_samples)
  end
  # Returns array consisting of the training and testing ids for each fold.
  Array.new(@n_splits) do |n|
    train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
    test_ids = fold_sets[n]
    [train_ids, test_ids]
  end
end