Class: TorchVision::Datasets::MNIST
- Inherits:
-
VisionDataset
- Object
- VisionDataset
- TorchVision::Datasets::MNIST
- Defined in:
- lib/torchvision/datasets/mnist.rb
Direct Known Subclasses
Instance Method Summary collapse
- #[](index) ⇒ Object
- #check_exists ⇒ Object
- #download ⇒ Object
- #initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ MNIST constructor
- #processed_folder ⇒ Object
- #raw_folder ⇒ Object
- #size ⇒ Object
Constructor Details
#initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ MNIST
6 7 8 9 10 11 12 13 14 15 16 17 18 |
# File 'lib/torchvision/datasets/mnist.rb', line 6 def initialize(root, train: true, download: false, transform: nil, target_transform: nil) super(root, transform: transform, target_transform: target_transform) @train = train self.download if download if !check_exists raise Error, "Dataset not found. You can use download: true to download it" end data_file = @train ? training_file : test_file @data, @targets = Torch.load(File.join(processed_folder, data_file)) end |
Instance Method Details
#[](index) ⇒ Object
24 25 26 27 28 29 30 31 32 33 |
# File 'lib/torchvision/datasets/mnist.rb', line 24 def [](index) img, target = @data[index], @targets[index].item # TODO convert to image img = @transform.call(img) if @transform target = @target_transform.call(target) if @target_transform [img, target] end |
#check_exists ⇒ Object
43 44 45 46 |
# File 'lib/torchvision/datasets/mnist.rb', line 43 def check_exists File.exist?(File.join(processed_folder, training_file)) && File.exist?(File.join(processed_folder, test_file)) end |
#download ⇒ Object
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# File 'lib/torchvision/datasets/mnist.rb', line 48 def download return if check_exists FileUtils.mkdir_p(raw_folder) FileUtils.mkdir_p(processed_folder) resources.each do |resource| filename = resource[:url].split("/").last download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256]) end puts "Processing..." training_set = [ unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]), unpack_mnist("train-labels-idx1-ubyte", 8, [60000]) ] test_set = [ unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]), unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000]) ] Torch.save(training_set, File.join(processed_folder, training_file)) Torch.save(test_set, File.join(processed_folder, test_file)) puts "Done!" end |
#processed_folder ⇒ Object
39 40 41 |
# File 'lib/torchvision/datasets/mnist.rb', line 39 def processed_folder File.join(@root, self.class.name.split("::").last, "processed") end |
#raw_folder ⇒ Object
35 36 37 |
# File 'lib/torchvision/datasets/mnist.rb', line 35 def raw_folder File.join(@root, self.class.name.split("::").last, "raw") end |
#size ⇒ Object
20 21 22 |
# File 'lib/torchvision/datasets/mnist.rb', line 20 def size @data.size(0) end |