Module: Rumale::Dataset

Defined in:
lib/rumale/dataset.rb

Overview

Module for loading and saving a dataset file.

Class Method Summary collapse

Class Method Details

.dump_libsvm_file(data, labels, filename, zero_based: false) ⇒ Object

Dump the dataset with the libsvm file format.



37
38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/rumale/dataset.rb', line 37

def dump_libsvm_file(data, labels, filename, zero_based: false)
  n_samples = [data.shape[0], labels.shape[0]].min
  single_label = labels.shape[1].nil?
  label_type = detect_dtype(labels)
  value_type = detect_dtype(data)
  File.open(filename, 'w') do |file|
    n_samples.times do |n|
      label = single_label ? labels[n] : labels[n, true].to_a
      file.puts(dump_libsvm_line(label, data[n, true],
                                 label_type, value_type, zero_based))
    end
  end
end

.load_libsvm_file(filename, zero_based: false, dtype: Numo::DFloat) ⇒ Array<Numo::NArray>

Load a dataset with the libsvm file format into Numo::NArray.



18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/rumale/dataset.rb', line 18

def load_libsvm_file(filename, zero_based: false, dtype: Numo::DFloat)
  ftvecs = []
  labels = []
  n_features = 0
  CSV.foreach(filename, col_sep: "\s", headers: false) do |line|
    label, ftvec, max_idx = parse_libsvm_line(line, zero_based)
    labels.push(label)
    ftvecs.push(ftvec)
    n_features = max_idx if n_features < max_idx
  end
  [convert_to_matrix(ftvecs, n_features, dtype), Numo::NArray.asarray(labels)]
end