Module: Rumale::Dataset

Defined in:
lib/rumale/dataset.rb

Overview

Module for loading and saving a dataset file.

Class Method Summary collapse

Class Method Details

.dump_libsvm_file(data, labels, filename, zero_based: false) ⇒ Object

Dump the dataset with the libsvm file format.

Parameters:

  • data (Numo::NArray)

    (shape: [n_samples, n_features]) matrix consisting of feature vectors.

  • labels (Numo::NArray)

    (shape: [n_samples]) matrix consisting of labels or target values.

  • filename (String)

    A path to the output libsvm file.

  • zero_based (Boolean) (defaults to: false)

    Whether the column index starts from 0 (true) or 1 (false).



37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/rumale/dataset.rb', line 37

def dump_libsvm_file(data, labels, filename, zero_based: false)
  n_samples = [data.shape[0], labels.shape[0]].min
  single_label = labels.shape[1].nil?
  label_type = detect_dtype(labels)
  value_type = detect_dtype(data)
  File.open(filename, 'w') do |file|
    n_samples.times do |n|
      label = single_label ? labels[n] : labels[n, true].to_a
      file.puts(dump_libsvm_line(label, data[n, true],
                                 label_type, value_type, zero_based))
    end
  end
end

.load_libsvm_file(filename, zero_based: false, dtype: Numo::DFloat) ⇒ Array<Numo::NArray>

Load a dataset with the libsvm file format into Numo::NArray.

Parameters:

  • filename (String)

    A path to a dataset file.

  • zero_based (Boolean) (defaults to: false)

    Whether the column index starts from 0 (true) or 1 (false).

  • dtype (Numo::NArray) (defaults to: Numo::DFloat)

    Data type of Numo::NArray for features to be loaded.

Returns:

  • (Array<Numo::NArray>)

    Returns array containing the (n_samples x n_features) matrix for feature vectors and (n_samples) vector for labels or target values.



18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/rumale/dataset.rb', line 18

def load_libsvm_file(filename, zero_based: false, dtype: Numo::DFloat)
  ftvecs = []
  labels = []
  n_features = 0
  CSV.foreach(filename, col_sep: "\s", headers: false) do |line|
    label, ftvec, max_idx = parse_libsvm_line(line, zero_based)
    labels.push(label)
    ftvecs.push(ftvec)
    n_features = max_idx if n_features < max_idx
  end
  [convert_to_matrix(ftvecs, n_features, dtype), Numo::NArray.asarray(labels)]
end