Module: DNN::CIFAR10

Defined in:
lib/dnn/lib/cifar10.rb,
ext/cifar10_loader/cifar10_loader.c

Defined Under Namespace

Classes: DNN_CIFAR10_LoadError

Class Method Summary collapse

Class Method Details

.downloadsObject



15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/dnn/lib/cifar10.rb', line 15

def self.downloads
  return if Dir.exist?(__dir__ + "/" + CIFAR10_DIR)
  Downloader.download(URL_CIFAR10)
  cifar10_binary_file_name = __dir__ + "/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
  begin
    Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
      Archive::Tar::Minitar::unpack(gz, __dir__)
    end
  ensure
    File.unlink(cifar10_binary_file_name)
  end
end

.load_binary(rb_bin, rb_num_datas) ⇒ Object



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'ext/cifar10_loader/cifar10_loader.c', line 9

static VALUE cifar10_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
  uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
  int32_t num_datas = FIX2INT(rb_num_datas);
  VALUE rb_x_bin;
  VALUE rb_y_bin;
  int32_t i;
  int32_t j = 0;
  int32_t k = 0;
  int32_t size = CIFAR10_WIDTH * CIFAR10_HEIGHT * CIFAR10_CHANNEL;
  int32_t x_bin_size = num_datas * size;
  int32_t y_bin_size = num_datas;
  uint8_t* x_bin;
  uint8_t* y_bin;

  x_bin = (uint8_t*)malloc(x_bin_size);
  y_bin = (uint8_t*)malloc(y_bin_size);
  for (i = 0; i < num_datas; i++) {
    y_bin[i] = bin[j];
    j++;
    memcpy(&x_bin[k], &bin[j], size);
    j += size;
    k += size;
  }
  rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
  rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
  free(x_bin);
  free(y_bin);
  return rb_ary_new3(2, rb_x_bin, rb_y_bin);
}

.load_testObject



42
43
44
45
46
47
48
49
50
51
# File 'lib/dnn/lib/cifar10.rb', line 42

def self.load_test
  downloads
  fname = __dir__ + "/#{CIFAR10_DIR}/test_batch.bin"
  raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
  bin = File.binread(fname)
  x_bin, y_bin = load_binary(bin, 10000)
  x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_test = Numo::UInt8.from_binary(y_bin)
  [x_test, y_test]
end

.load_trainObject



28
29
30
31
32
33
34
35
36
37
38
39
40
# File 'lib/dnn/lib/cifar10.rb', line 28

def self.load_train
  downloads
  bin = ""
  (1..5).each do |i|
    fname = __dir__ + "/#{CIFAR10_DIR}/data_batch_#{i}.bin"
    raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
    bin << File.binread(fname)
  end
  x_bin, y_bin = load_binary(bin, 50000)
  x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_train = Numo::UInt8.from_binary(y_bin)
  [x_train, y_train]
end