Module: DNN::CIFAR10

Defined in:
lib/dnn/lib/cifar10.rb

Defined Under Namespace

Classes: DNN_CIFAR10_DownloadError, DNN_CIFAR10_LoadError

Class Method Summary collapse

Class Method Details

.downloadsObject



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# File 'lib/dnn/lib/cifar10.rb', line 17

def self.downloads
  return if Dir.exist?(__dir__ + "/" + CIFAR10_DIR)
  cifar10_binary_file_name = __dir__ + "/" + URL_CIFAR10.match(%r`.+/(.+)`)[1]
  puts "Now downloading..."
  open(URL_CIFAR10, "rb") do |f|
    File.binwrite(cifar10_binary_file_name, f.read)
    begin
      Zlib::GzipReader.open(cifar10_binary_file_name) do |gz|
        Archive::Tar::Minitar::unpack(gz, __dir__)
      end
    ensure
      File.unlink(cifar10_binary_file_name)
    end
  end
  puts "The download has ended."
rescue => ex
  raise DNN_CIFAR10_DownloadError.new(ex.message)
end

.load_testObject



50
51
52
53
54
55
56
57
58
59
# File 'lib/dnn/lib/cifar10.rb', line 50

def self.load_test
  downloads
  fname = __dir__ + "/#{CIFAR10_DIR}/test_batch.bin"
  raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
  bin = File.binread(fname)
  x_bin, y_bin = load_binary(bin, 10000)
  x_test = Numo::UInt8.from_binary(x_bin).reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_test = Numo::UInt8.from_binary(y_bin)
  [x_test, y_test]
end

.load_trainObject



36
37
38
39
40
41
42
43
44
45
46
47
48
# File 'lib/dnn/lib/cifar10.rb', line 36

def self.load_train
  downloads
  bin = ""
  (1..5).each do |i|
    fname = __dir__ + "/#{CIFAR10_DIR}/data_batch_#{i}.bin"
    raise DNN_CIFAR10_LoadError.new(%`file "#{fname}" is not found.`) unless File.exist?(fname)
    bin << File.binread(fname)
  end
  x_bin, y_bin = load_binary(bin, 50000)
  x_train = Numo::UInt8.from_binary(x_bin).reshape(50000, 3, 32, 32).transpose(0, 2, 3, 1).clone
  y_train = Numo::UInt8.from_binary(y_bin)
  [x_train, y_train]
end