Module: DNN::CIFAR100

Defined in:
lib/dnn/datasets/cifar100.rb,
ext/cifar_loader/cifar_loader.c

Defined Under Namespace

Classes: DNN_CIFAR100_LoadError

Class Method Summary collapse

Class Method Details

.downloadsObject



13
14
15
16
17
18
19
20
21
22
23
24
# File 'lib/dnn/datasets/cifar100.rb', line 13

def self.downloads
  return if Dir.exist?(DOWNLOADS_PATH + "/downloads/" + DIR_CIFAR100)
  Downloader.download(URL_CIFAR100)
  cifar100_binary_file_name = DOWNLOADS_PATH + "/downloads/" + URL_CIFAR100.match(%r`.+/(.+)`)[1]
  begin
    Zlib::GzipReader.open(cifar100_binary_file_name) do |gz|
      Archive::Tar::Minitar.unpack(gz, DOWNLOADS_PATH + "/downloads")
    end
  ensure
    File.unlink(cifar100_binary_file_name)
  end
end

.load_binary(rb_bin, rb_num_datas) ⇒ Object



39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# File 'ext/cifar_loader/cifar_loader.c', line 39

static VALUE cifar100_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
  uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
  int32_t num_datas = FIX2INT(rb_num_datas);
  VALUE rb_x_bin;
  VALUE rb_y_bin;
  int32_t i;
  int32_t j = 0;
  int32_t k = 0;
  int32_t size = CIFAR_WIDTH * CIFAR_HEIGHT * CIFAR_CHANNEL;
  int32_t x_bin_size = num_datas * size;
  int32_t y_bin_size = num_datas * 2;
  uint8_t* x_bin;
  uint8_t* y_bin;

  x_bin = (uint8_t*)malloc(x_bin_size);
  y_bin = (uint8_t*)malloc(y_bin_size);
  for (i = 0; i < num_datas; i++) {
    y_bin[i * 2] = bin[j];
    y_bin[i * 2 + 1] = bin[j + 1];
    j += 2;
    memcpy(&x_bin[k], &bin[j], size);
    j += size;
    k += size;
  }
  rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
  rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
  free(x_bin);
  free(y_bin);
  return rb_ary_new3(2, rb_x_bin, rb_y_bin);
}

.load_testObject



38
39
40
41
42
43
44
45
46
47
# File 'lib/dnn/datasets/cifar100.rb', line 38

def self.load_test
  downloads
  fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/test.bin"
  raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
  bin = File.binread(fname)
  x_bin, y_bin = CIFAR100.load_binary(bin, 10000)
  x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_test = Numo::UInt8.from_binary(y_bin).reshape(10000, 2)
  [x_test, y_test]
end

.load_trainObject



26
27
28
29
30
31
32
33
34
35
36
# File 'lib/dnn/datasets/cifar100.rb', line 26

def self.load_train
  downloads
  bin = ""
  fname = DOWNLOADS_PATH + "/downloads/#{DIR_CIFAR100}/train.bin"
  raise DNN_CIFAR100_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
  bin << File.binread(fname)
  x_bin, y_bin = CIFAR100.load_binary(bin, 50000)
  x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_train = Numo::UInt8.from_binary(y_bin).reshape(50000, 2)
  [x_train, y_train]
end