Module: TensorFlow::Keras::Datasets::MNIST

Defined in:
lib/tensorflow/keras/datasets/mnist.rb

Class Method Summary collapse

Class Method Details

.load_data(path: "mnist.npz") ⇒ Object



5
6
7
8
9
10
11
12
13
14
# File 'lib/tensorflow/keras/datasets/mnist.rb', line 5

def self.load_data(path: "mnist.npz")
  file = Utils.get_file(
    path,
    "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
    file_hash: "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
  )

  data = Npy.load_npz(file)
  [[data["x_train"], data["y_train"]], [data["x_test"], data["y_test"]]]
end