Module: DNN::STL10

Defined in:
lib/dnn/datasets/stl-10.rb

Defined Under Namespace

Classes: DNN_STL10_LoadError

Class Method Summary collapse

Class Method Details

.downloadsObject



12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/dnn/datasets/stl-10.rb', line 12

def self.downloads
  return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_STL10)
  Downloader.download(URL_STL10)
  stl10_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_STL10.match(%r`.+/(.+)`)[1]
  begin
    Zlib::GzipReader.open(stl10_binary_file_name) do |gz|
      Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
    end
  ensure
    File.unlink(stl10_binary_file_name)
  end
end

.load_testObject



38
39
40
41
42
43
44
45
46
47
48
49
# File 'lib/dnn/datasets/stl-10.rb', line 38

def self.load_test
  downloads
  x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_X.bin"
  raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
  y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/test_y.bin"
  raise DNN_STL10_LoadError, %`file "#{y_fname}" is not found.` unless File.exist?(y_fname)
  x_bin = File.binread(x_fname)
  y_bin = File.binread(y_fname)
  x_test = Numo::UInt8.from_binary(x_bin).reshape(8000, 3, 96, 96).transpose(0, 3, 2, 1).clone
  y_test = Numo::UInt8.from_binary(y_bin)
  [x_test, y_test]
end

.load_trainObject



25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/dnn/datasets/stl-10.rb', line 25

def self.load_train
  downloads
  x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_X.bin"
  raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
  y_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/train_y.bin"
  raise DNN_STL10_LoadError, %`file "#{y_fname}" is not found.` unless File.exist?(y_fname)
  x_bin = File.binread(x_fname)
  y_bin = File.binread(y_fname)
  x_train = Numo::UInt8.from_binary(x_bin).reshape(5000, 3, 96, 96).transpose(0, 3, 2, 1).clone
  y_train = Numo::UInt8.from_binary(y_bin)
  [x_train, y_train]
end

.load_unlabeled(range = 0...100000) ⇒ Object

Raises:



51
52
53
54
55
56
57
58
59
60
61
# File 'lib/dnn/datasets/stl-10.rb', line 51

def self.load_unlabeled(range = 0...100000)
  raise DNNError, "Range must between 0 and 100000. (But the end is excluded)" unless range.begin >= 0 && range.end <= 100000
  downloads
  x_fname = DOWNLOADS_PATH + "/downloads/#{DIR_STL10}/unlabeled_X.bin"
  raise DNN_STL10_LoadError, %`file "#{x_fname}" is not found.` unless File.exist?(x_fname)
  num_datas = range.end - range.begin
  length = num_datas * 3 * 96 * 96
  ofs = range.begin * 3 * 96 * 96
  x_bin = File.binread(x_fname, length, ofs)
  Numo::UInt8.from_binary(x_bin).reshape(num_datas, 3, 96, 96).transpose(0, 3, 2, 1).clone
end