Module: DNN::MNIST

Defined in:
lib/dnn/datasets/mnist.rb

Defined Under Namespace

Classes: DNN_MNIST_LoadError

Constant Summary collapse

URL_BASE =
"http://yann.lecun.com/exdb/mnist/"
TRAIN_IMAGES_FILE_NAME =
"train-images-idx3-ubyte.gz"
TRAIN_LABELS_FILE_NAME =
"train-labels-idx1-ubyte.gz"
TEST_IMAGES_FILE_NAME =
"t10k-images-idx3-ubyte.gz"
TEST_LABELS_FILE_NAME =
"t10k-labels-idx1-ubyte.gz"
URL_TRAIN_IMAGES =
URL_BASE + TRAIN_IMAGES_FILE_NAME
URL_TRAIN_LABELS =
URL_BASE + TRAIN_LABELS_FILE_NAME
URL_TEST_IMAGES =
URL_BASE + TEST_IMAGES_FILE_NAME
URL_TEST_LABELS =
URL_BASE + TEST_LABELS_FILE_NAME

Class Method Summary collapse

Class Method Details

.downloadsObject



21
22
23
24
25
26
27
28
# File 'lib/dnn/datasets/mnist.rb', line 21

def self.downloads
  Dir.mkdir("#{DOWNLOADS_PATH}/downloads") unless Dir.exist?("#{DOWNLOADS_PATH}/downloads")
  Dir.mkdir(mnist_dir) unless Dir.exist?(mnist_dir)
  Downloader.download(URL_TRAIN_IMAGES, mnist_dir) unless File.exist?(get_file_path(TRAIN_IMAGES_FILE_NAME))
  Downloader.download(URL_TRAIN_LABELS, mnist_dir) unless File.exist?(get_file_path(TRAIN_LABELS_FILE_NAME))
  Downloader.download(URL_TEST_IMAGES, mnist_dir) unless File.exist?(get_file_path(TEST_IMAGES_FILE_NAME))
  Downloader.download(URL_TEST_LABELS, mnist_dir) unless File.exist?(get_file_path(TEST_LABELS_FILE_NAME))
end

.load_testObject



41
42
43
44
45
46
47
48
49
50
# File 'lib/dnn/datasets/mnist.rb', line 41

def self.load_test
  downloads
  test_images_file_path = get_file_path(TEST_IMAGES_FILE_NAME)
  test_labels_file_path = get_file_path(TEST_LABELS_FILE_NAME)
  raise DNN_MNIST_LoadError, %`file "#{test_images_file_path}" is not found.` unless File.exist?(test_images_file_path)
  raise DNN_MNIST_LoadError, %`file "#{test_labels_file_path}" is not found.` unless File.exist?(test_labels_file_path)
  images = load_images(test_images_file_path)
  labels = load_labels(test_labels_file_path)
  [images, labels]
end

.load_trainObject



30
31
32
33
34
35
36
37
38
39
# File 'lib/dnn/datasets/mnist.rb', line 30

def self.load_train
  downloads
  train_images_file_path = get_file_path(TRAIN_IMAGES_FILE_NAME)
  train_labels_file_path = get_file_path(TRAIN_LABELS_FILE_NAME)
  raise DNN_MNIST_LoadError, %`file "#{train_images_file_path}" is not found.` unless File.exist?(train_images_file_path)
  raise DNN_MNIST_LoadError, %`file "#{train_labels_file_path}" is not found.` unless File.exist?(train_labels_file_path)
  images = load_images(train_images_file_path)
  labels = load_labels(train_labels_file_path)
  [images, labels]
end