Module: Tensorflow::Keras::Datasets::FashionMNIST

Defined in:
lib/tensorflow/keras/datasets/fashion_mnist.rb

Class Method Summary collapse

Class Method Details

.load_dataObject



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# File 'lib/tensorflow/keras/datasets/fashion_mnist.rb', line 5

def self.load_data
  base_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets"
  files = [
    "train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
    "t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"
  ]

  paths = []
  files.each do |file|
    paths << Utils.get_file(file, "#{base_url}/#{file}", cache_subdir: "datasets/fashion-mnist")
  end

  x_train, y_train, x_test, y_test = nil

  Zlib::GzipReader.open(paths[0]) do |gz|
    gz.read(8) # move to offset
    y_train = Numo::UInt8.from_string(gz.read)
  end

  Zlib::GzipReader.open(paths[1]) do |gz|
    gz.read(16) # move to offset
    x_train = Numo::UInt8.from_string(gz.read, [y_train.shape[0], 28, 28])
  end

  Zlib::GzipReader.open(paths[2]) do |gz|
    gz.read(8) # move to offset
    y_test = Numo::UInt8.from_string(gz.read)
  end

  Zlib::GzipReader.open(paths[3]) do |gz|
    gz.read(16) # move to offset
    x_test = Numo::UInt8.from_string(gz.read, [y_test.shape[0], 28, 28])
  end

  [[x_train, y_train], [x_test, y_test]]
end