Module: DNN::CIFAR10
- Defined in:
- lib/dnn/lib/cifar10.rb,
ext/cifar10_loader/cifar10_loader.c
Defined Under Namespace
Classes: DNN_CIFAR10_LoadError
Class Method Summary
collapse
Class Method Details
.downloads ⇒ Object
15
16
17
18
19
20
21
22
23
24
25
26
|
# File 'lib/dnn/lib/cifar10.rb', line 15
def self.downloads
return if Dir.exist?(__dir__ + "/" + CIFAR10_DIR)
Downloader.download(URL_CIFAR10)
cifar10_binary_file_name = __dir__ + "/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
begin
Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
Archive::Tar::Minitar::unpack(gz, __dir__)
end
ensure
File.unlink(cifar10_binary_file_name)
end
end
|
.load_binary(rb_bin, rb_num_datas) ⇒ Object
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
# File 'ext/cifar10_loader/cifar10_loader.c', line 9
static VALUE cifar10_load_binary(VALUE self, VALUE rb_bin, VALUE rb_num_datas) {
uint8_t* bin = (uint8_t*)StringValuePtr(rb_bin);
int32_t num_datas = FIX2INT(rb_num_datas);
VALUE rb_x_bin;
VALUE rb_y_bin;
int32_t i;
int32_t j = 0;
int32_t k = 0;
int32_t size = CIFAR10_WIDTH * CIFAR10_HEIGHT * CIFAR10_CHANNEL;
int32_t x_bin_size = num_datas * size;
int32_t y_bin_size = num_datas;
uint8_t* x_bin;
uint8_t* y_bin;
x_bin = (uint8_t*)malloc(x_bin_size);
y_bin = (uint8_t*)malloc(y_bin_size);
for (i = 0; i < num_datas; i++) {
y_bin[i] = bin[j];
j++;
memcpy(&x_bin[k], &bin[j], size);
j += size;
k += size;
}
rb_x_bin = rb_str_new((char*)x_bin, x_bin_size);
rb_y_bin = rb_str_new((char*)y_bin, y_bin_size);
free(x_bin);
free(y_bin);
return rb_ary_new3(2, rb_x_bin, rb_y_bin);
}
|
.load_test ⇒ Object
42
43
44
45
46
47
48
49
50
51
|
# File 'lib/dnn/lib/cifar10.rb', line 42
def self.load_test
downloads
fname = __dir__ + "/#{CIFAR10_DIR}/test_batch.bin"
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
bin = File.binread(fname)
x_bin, y_bin = load_binary(bin, 10000)
x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
y_test = Numo::UInt8.from_binary(y_bin)
[x_test, y_test]
end
|
.load_train ⇒ Object
28
29
30
31
32
33
34
35
36
37
38
39
40
|
# File 'lib/dnn/lib/cifar10.rb', line 28
def self.load_train
downloads
bin = ""
(1..5).each do |i|
fname = __dir__ + "/#{CIFAR10_DIR}/data_batch_#{i}.bin"
raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
bin << File.binread(fname)
end
x_bin, y_bin = load_binary(bin, 50000)
x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
y_train = Numo::UInt8.from_binary(y_bin)
[x_train, y_train]
end
|