Class: TorchVision::Datasets::MNIST

Inherits:
VisionDataset show all
Defined in:
lib/torchvision/datasets/mnist.rb

Direct Known Subclasses

FashionMNIST, KMNIST

Instance Method Summary collapse

Constructor Details

#initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ MNIST



6
7
8
9
10
11
12
13
14
15
16
17
18
# File 'lib/torchvision/datasets/mnist.rb', line 6

def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
  super(root, transform: transform, target_transform: target_transform)
  @train = train

  self.download if download

  if !check_exists
    raise Error, "Dataset not found. You can use download: true to download it"
  end

  data_file = @train ? training_file : test_file
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
end

Instance Method Details

#[](index) ⇒ Object



24
25
26
27
28
29
30
31
32
33
# File 'lib/torchvision/datasets/mnist.rb', line 24

def [](index)
  img, target = @data[index], @targets[index].item

  # TODO convert to image
  img = @transform.call(img) if @transform

  target = @target_transform.call(target) if @target_transform

  [img, target]
end

#check_existsObject



43
44
45
46
# File 'lib/torchvision/datasets/mnist.rb', line 43

def check_exists
  File.exist?(File.join(processed_folder, training_file)) &&
    File.exist?(File.join(processed_folder, test_file))
end

#downloadObject



48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# File 'lib/torchvision/datasets/mnist.rb', line 48

def download
  return if check_exists

  FileUtils.mkdir_p(raw_folder)
  FileUtils.mkdir_p(processed_folder)

  resources.each do |resource|
    filename = resource[:url].split("/").last
    download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
  end

  puts "Processing..."

  training_set = [
    unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]),
    unpack_mnist("train-labels-idx1-ubyte", 8, [60000])
  ]
  test_set = [
    unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]),
    unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
  ]

  Torch.save(training_set, File.join(processed_folder, training_file))
  Torch.save(test_set, File.join(processed_folder, test_file))

  puts "Done!"
end

#processed_folderObject



39
40
41
# File 'lib/torchvision/datasets/mnist.rb', line 39

def processed_folder
  File.join(@root, self.class.name.split("::").last, "processed")
end

#raw_folderObject



35
36
37
# File 'lib/torchvision/datasets/mnist.rb', line 35

def raw_folder
  File.join(@root, self.class.name.split("::").last, "raw")
end

#sizeObject



20
21
22
# File 'lib/torchvision/datasets/mnist.rb', line 20

def size
  @data.size(0)
end