5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
# File 'lib/tensorflow/keras/datasets/fashion_mnist.rb', line 5
def self.load_data
base_url = "https://storage.googleapis.com/tensorflow/tf-keras-datasets"
files = [
"train-labels-idx1-ubyte.gz", "train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz"
]
paths = []
files.each do |file|
paths << Utils.get_file(file, "#{base_url}/#{file}", cache_subdir: "datasets/fashion-mnist")
end
x_train, y_train, x_test, y_test = nil
Zlib::GzipReader.open(paths[0]) do |gz|
gz.read(8) y_train = Numo::UInt8.from_string(gz.read)
end
Zlib::GzipReader.open(paths[1]) do |gz|
gz.read(16) x_train = Numo::UInt8.from_string(gz.read, [y_train.shape[0], 28, 28])
end
Zlib::GzipReader.open(paths[2]) do |gz|
gz.read(8) y_test = Numo::UInt8.from_string(gz.read)
end
Zlib::GzipReader.open(paths[3]) do |gz|
gz.read(16) x_test = Numo::UInt8.from_string(gz.read, [y_test.shape[0], 28, 28])
end
[[x_train, y_train], [x_test, y_test]]
end
|