Class: Tensorflow::Datasets::Images::Mnist
Constant Summary
collapse
- BASE_URL =
'https://storage.googleapis.com/cvdf-datasets/mnist'
Instance Method Summary
collapse
extended, function, method_added, singleton_method_added, wrap_method
Instance Method Details
#dataset(images_file, labels_file) ⇒ Object
32
33
34
35
36
37
38
39
40
41
42
|
# File 'lib/datasets/images/mnist.rb', line 32
def dataset(images_file, labels_file)
download_manager = Datasets::DownloadManager.new
urls = ["#{BASE_URL}/#{images_file}",
"#{BASE_URL}/#{labels_file}"]
resources = download_manager.download(urls)
images = Data::FixedLengthRecordDataset.new(resources.first.path, 28 * 28, header_bytes: 16, compression_type: 'GZIP').map_func(self.decode_image)
labels = Data::FixedLengthRecordDataset.new(resources.last.path, 1, header_bytes: 8, compression_type: 'GZIP').map_func(self.decode_label)
Data::ZipDataset.new(images, labels)
end
|
#decode_image(image) ⇒ Object
16
17
18
19
20
21
22
|
# File 'lib/datasets/images/mnist.rb', line 16
def decode_image(image)
image = IO.decode_raw(image, Tf.uint8)
image = Tf.cast(image, Tf.float32)
image = Tf.reshape(image, [784])
image / 255.0
end
|
#decode_label(label) ⇒ Object
25
26
27
28
29
30
|
# File 'lib/datasets/images/mnist.rb', line 25
def decode_label(label)
label = Tf::IO.decode_raw(label, Tf.uint8)
label = Tf.reshape(label, []) Tf.cast(label, Tf.int32)
end
|
#test ⇒ Object
48
49
50
|
# File 'lib/datasets/images/mnist.rb', line 48
def test
dataset('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
end
|
#train ⇒ Object
44
45
46
|
# File 'lib/datasets/images/mnist.rb', line 44
def train
dataset('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz')
end
|