Module: Chainer::Datasets::MNIST

Defined in:
lib/chainer/datasets/mnist.rb

Class Method Summary collapse

Class Method Details

.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil) ⇒ Object



6
7
8
9
10
11
12
13
14
15
16
17
# File 'lib/chainer/datasets/mnist.rb', line 6

def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
  xm = Chainer::Device.default.xm
  dtype ||= xm::SFloat
  label_dtype ||= xm::Int32

  train_raw = retrieve_mnist(type: :train)
  train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype)

  test_raw = retrieve_mnist(type: :test)
  test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype)
  [train, test]
end

.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype) ⇒ Object



19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/chainer/datasets/mnist.rb', line 19

def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
  images = raw[:x]
  if ndim == 2
    images = images.reshape(true, 28, 28)
  elsif ndim == 3
    images = images.reshape(true, 1, 28, 28)
  elsif ndim != 1
    raise "invalid ndim for MNIST dataset"
  end

  images = images.cast_to(image_dtype)
  images *= scale / 255.0

  if withlabel
    labels = raw[:y].cast_to(label_dtype)
    TupleDataset.new(images, labels)
  else
    images
  end
end

.retrieve_mnist(type:) ⇒ Object



40
41
42
43
44
45
# File 'lib/chainer/datasets/mnist.rb', line 40

def self.retrieve_mnist(type:)
  train_table = ::Datasets::MNIST.new(type: type).to_table

  xm = Chainer::Device.default.xm
  { x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] }
end