Class: Rumale::ModelSelection::KFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
lib/rumale/model_selection/k_fold.rb

Overview

KFold is a class that generates the set of data indices for K-fold cross-validation.

Examples:

kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
kf.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold

Create a new data splitter for K-fold cross validation.

Parameters:

  • n_splits (Integer) (defaults to: 3)

    The number of folds.

  • shuffle (Boolean) (defaults to: false)

    The flag indicating whether to shuffle the dataset.

  • random_seed (Integer) (defaults to: nil)

    The seed value using to initialize the random generator.



38
39
40
41
42
43
44
45
46
47
48
# File 'lib/rumale/model_selection/k_fold.rb', line 38

def initialize(n_splits: 3, shuffle: false, random_seed: nil)
  check_params_integer(n_splits: n_splits)
  check_params_boolean(shuffle: shuffle)
  check_params_type_or_nil(Integer, random_seed: random_seed)
  check_params_positive(n_splits: n_splits)
  @n_splits = n_splits
  @shuffle = shuffle
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.

Returns:

  • (Integer)


23
24
25
# File 'lib/rumale/model_selection/k_fold.rb', line 23

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.

Returns:

  • (Random)


31
32
33
# File 'lib/rumale/model_selection/k_fold.rb', line 31

def rng
  @rng
end

#shuffleBoolean (readonly)

Return the flag indicating whether to shuffle the dataset.

Returns:

  • (Boolean)


27
28
29
# File 'lib/rumale/model_selection/k_fold.rb', line 27

def shuffle
  @shuffle
end

Instance Method Details

#split(x, _y = nil) ⇒ Array

Generate data indices for K-fold cross validation.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The dataset to be used to generate data indices for K-fold cross validation.

Returns:

  • (Array)

    The set of data indices for constructing the training and testing dataset in each fold.



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/rumale/model_selection/k_fold.rb', line 55

def split(x, _y = nil)
  check_sample_array(x)
  # Initialize and check some variables.
  n_samples, = x.shape
  unless @n_splits.between?(2, n_samples)
    raise ArgumentError,
          'The value of n_splits must be not less than 2 and not more than the number of samples.'
  end
  sub_rng = @rng.dup
  # Splits dataset ids to each fold.
  dataset_ids = [*0...n_samples]
  dataset_ids.shuffle!(random: sub_rng) if @shuffle
  fold_sets = Array.new(@n_splits) do |n|
    n_fold_samples = n_samples / @n_splits
    n_fold_samples += 1 if n < n_samples % @n_splits
    dataset_ids.shift(n_fold_samples)
  end
  # Returns array consisting of the training and testing ids for each fold.
  Array.new(@n_splits) do |n|
    train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
    test_ids = fold_sets[n]
    [train_ids, test_ids]
  end
end