Class: Tensorflow::Datasets::Images::Mnist

Inherits:
Object
  • Object
show all
Extended by:
Tensorflow::Decorator
Defined in:
lib/datasets/images/mnist.rb

Constant Summary collapse

BASE_URL =
'https://storage.googleapis.com/cvdf-datasets/mnist'

Instance Method Summary collapse

Methods included from Tensorflow::Decorator

extended, function, method_added, singleton_method_added, wrap_method

Instance Method Details

#dataset(images_file, labels_file) ⇒ Object



32
33
34
35
36
37
38
39
40
41
42
# File 'lib/datasets/images/mnist.rb', line 32

def dataset(images_file, labels_file)
  download_manager = Datasets::DownloadManager.new
  urls = ["#{BASE_URL}/#{images_file}",
          "#{BASE_URL}/#{labels_file}"]

  resources = download_manager.download(urls)

  images = Data::FixedLengthRecordDataset.new(resources.first.path, 28 * 28, header_bytes: 16, compression_type: 'GZIP').map_func(self.decode_image)
  labels = Data::FixedLengthRecordDataset.new(resources.last.path, 1, header_bytes: 8, compression_type: 'GZIP').map_func(self.decode_label)
  Data::ZipDataset.new(images, labels)
end

#decode_image(image) ⇒ Object



16
17
18
19
20
21
22
# File 'lib/datasets/images/mnist.rb', line 16

def decode_image(image)
  image = IO.decode_raw(image, Tf.uint8)
  image = Tf.cast(image, Tf.float32)
  image = Tf.reshape(image, [784])
  # Normalize from [0, 255] to [0.0, 1.0]
  image / 255.0
end

#decode_label(label) ⇒ Object



25
26
27
28
29
30
# File 'lib/datasets/images/mnist.rb', line 25

def decode_label(label)
  # tf.string -> [Tf.uint8]
  label = Tf::IO.decode_raw(label, Tf.uint8)
  label = Tf.reshape(label, [])  # label is a scalar
  Tf.cast(label, Tf.int32)
end

#testObject



48
49
50
# File 'lib/datasets/images/mnist.rb', line 48

def test
  dataset('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
end

#trainObject



44
45
46
# File 'lib/datasets/images/mnist.rb', line 44

def train
  dataset('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz')
end