Class: Rumale::ModelSelection::KFold

Inherits:
Object
  • Object
show all
Includes:
Base::Splitter
Defined in:
lib/rumale/model_selection/k_fold.rb

Overview

KFold is a class that generates the set of data indices for K-fold cross-validation.

Examples:

kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
kf.split(samples, labels).each do |train_ids, test_ids|
  train_samples = samples[train_ids, true]
  test_samples = samples[test_ids, true]
  ...
end

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(n_splits: 3, shuffle: false, random_seed: nil) ⇒ KFold

Create a new data splitter for K-fold cross validation.



38
39
40
41
42
43
44
45
46
47
48
# File 'lib/rumale/model_selection/k_fold.rb', line 38

def initialize(n_splits: 3, shuffle: false, random_seed: nil)
  check_params_integer(n_splits: n_splits)
  check_params_boolean(shuffle: shuffle)
  check_params_type_or_nil(Integer, random_seed: random_seed)
  check_params_positive(n_splits: n_splits)
  @n_splits = n_splits
  @shuffle = shuffle
  @random_seed = random_seed
  @random_seed ||= srand
  @rng = Random.new(@random_seed)
end

Instance Attribute Details

#n_splitsInteger (readonly)

Return the number of folds.



23
24
25
# File 'lib/rumale/model_selection/k_fold.rb', line 23

def n_splits
  @n_splits
end

#rngRandom (readonly)

Return the random generator for shuffling the dataset.



31
32
33
# File 'lib/rumale/model_selection/k_fold.rb', line 31

def rng
  @rng
end

#shuffleBoolean (readonly)

Return the flag indicating whether to shuffle the dataset.



27
28
29
# File 'lib/rumale/model_selection/k_fold.rb', line 27

def shuffle
  @shuffle
end

Instance Method Details

#split(x, _y = nil) ⇒ Array

Generate data indices for K-fold cross validation.



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/rumale/model_selection/k_fold.rb', line 55

def split(x, _y = nil)
  check_sample_array(x)
  # Initialize and check some variables.
  n_samples, = x.shape
  unless @n_splits.between?(2, n_samples)
    raise ArgumentError,
          'The value of n_splits must be not less than 2 and not more than the number of samples.'
  end
  # Splits dataset ids to each fold.
  dataset_ids = [*0...n_samples]
  dataset_ids.shuffle!(random: @rng) if @shuffle
  fold_sets = Array.new(@n_splits) do |n|
    n_fold_samples = n_samples / @n_splits
    n_fold_samples += 1 if n < n_samples % @n_splits
    dataset_ids.shift(n_fold_samples)
  end
  # Returns array consisting of the training and testing ids for each fold.
  Array.new(@n_splits) do |n|
    train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
    test_ids = fold_sets[n]
    [train_ids, test_ids]
  end
end