Module: DNN::MNIST

Defined in:
lib/dnn/lib/mnist.rb

Defined Under Namespace

Classes: DNN_MNIST_DownloadError, DNN_MNIST_LoadError

Constant Summary collapse

URL_TRAIN_IMAGES =
"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
URL_TRAIN_LABELS =
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
URL_TEST_IMAGES =
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
URL_TEST_LABELS =
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"

Class Method Summary collapse

Class Method Details

.download(url) ⇒ Object



63
64
65
66
67
68
69
# File 'lib/dnn/lib/mnist.rb', line 63

def self.download(url)
  open(url, "rb") do |f|
    File.binwrite(url_to_file_name(url), f.read)
  end
rescue => ex
  raise DNN_MNIST_DownloadError.new(ex.message)
end

.downloadsObject



20
21
22
23
24
25
26
27
28
29
# File 'lib/dnn/lib/mnist.rb', line 20

def self.downloads
  return if Dir.exist?(mnist_dir)
  Dir.mkdir(mnist_dir)
  puts "Now downloading..."
  download(URL_TRAIN_IMAGES)
  download(URL_TRAIN_LABELS)
  download(URL_TEST_IMAGES)
  download(URL_TEST_LABELS)
  puts "The download has ended."
end

.load_images(file_name) ⇒ Object



71
72
73
74
75
76
77
78
79
# File 'lib/dnn/lib/mnist.rb', line 71

def self.load_images(file_name)
  images = nil
  Zlib::GzipReader.open(file_name) do |f|
    magic, num_images = f.read(8).unpack("N2")
    rows, cols = f.read(8).unpack("N2")
    images = _mnist_load_images(f.read, num_images, cols, rows)
  end
  images
end

.load_labels(file_name) ⇒ Object



81
82
83
84
85
86
87
88
# File 'lib/dnn/lib/mnist.rb', line 81

def self.load_labels(file_name)
  labels = nil
  Zlib::GzipReader.open(file_name) do |f|
    magic, num_labels = f.read(8).unpack("N2")
    labels = _mnist_load_labels(f.read, num_labels)
  end
  labels
end

.load_testObject



46
47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/dnn/lib/mnist.rb', line 46

def self.load_test
  downloads
  test_images_file_name = url_to_file_name(URL_TEST_IMAGES)
  test_labels_file_name = url_to_file_name(URL_TEST_LABELS)
  unless File.exist?(test_images_file_name)
    raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
  end
  unless File.exist?(test_labels_file_name)
    raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
  end
  images = load_images(test_images_file_name)
  labels = load_labels(test_labels_file_name)
  [images, labels]
end

.load_trainObject



31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/dnn/lib/mnist.rb', line 31

def self.load_train
  downloads
  train_images_file_name = url_to_file_name(URL_TRAIN_IMAGES)
  train_labels_file_name = url_to_file_name(URL_TRAIN_LABELS)
  unless File.exist?(train_images_file_name)
    raise DNN_MNIST_LoadError.new(%`file "#{train_images_file_name}" is not found.`)
  end
  unless File.exist?(train_labels_file_name)
    raise DNN_MNIST_LoadError.new(%`file "#{train_labels_file_name}" is not found.`)
  end
  images = load_images(train_images_file_name)
  labels = load_labels(train_labels_file_name)
  [images, labels]
end

.mnist_dirObject



90
91
92
# File 'lib/dnn/lib/mnist.rb', line 90

def self.mnist_dir
  __dir__ + "/mnist"
end

.url_to_file_name(url) ⇒ Object



94
95
96
# File 'lib/dnn/lib/mnist.rb', line 94

def self.url_to_file_name(url)
  mnist_dir + "/" + url.match(%r`.+/(.+)$`)[1]
end