Module: Chainer::Datasets::MNIST
- Defined in:
- lib/chainer/datasets/mnist.rb
Class Method Summary collapse
- .get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil) ⇒ Object
- .preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype) ⇒ Object
- .retrieve_mnist(type:) ⇒ Object
Class Method Details
.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil) ⇒ Object
6 7 8 9 10 11 12 13 14 15 16 17 |
# File 'lib/chainer/datasets/mnist.rb', line 6 def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil) xm = Chainer::Device.default.xm dtype ||= xm::SFloat label_dtype ||= xm::Int32 train_raw = retrieve_mnist(type: :train) train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype) test_raw = retrieve_mnist(type: :test) test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype) [train, test] end |
.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype) ⇒ Object
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# File 'lib/chainer/datasets/mnist.rb', line 19 def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype) images = raw[:x] if ndim == 2 images = images.reshape(true, 28, 28) elsif ndim == 3 images = images.reshape(true, 1, 28, 28) elsif ndim != 1 raise "invalid ndim for MNIST dataset" end images = images.cast_to(image_dtype) images *= scale / 255.0 if withlabel labels = raw[:y].cast_to(label_dtype) TupleDataset.new(images, labels) else images end end |