Class: TorchVision::Datasets::CIFAR10

Inherits:
VisionDataset
  • Object
show all
Defined in:
lib/torchvision/datasets/cifar10.rb

Direct Known Subclasses

CIFAR100

Instance Attribute Summary

Attributes inherited from VisionDataset

#data, #targets

Instance Method Summary collapse

Constructor Details

#initialize(root, train: true, download: false, transform: nil, target_transform: nil) ⇒ CIFAR10



6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/torchvision/datasets/cifar10.rb', line 6

def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
  super(root, transform: transform, target_transform: target_transform)
  @train = train

  self.download if download

  if !_check_integrity
    raise Error, "Dataset not found or corrupted. You can use download=True to download it"
  end

  downloaded_list = @train ? train_list : test_list

  @data = String.new
  @targets = String.new

  downloaded_list.each do |file|
    file_path = File.join(@root, base_folder, file[:filename])
    File.open(file_path, "rb") do |f|
      while !f.eof?
        f.read(1) if multiple_labels?
        @targets << f.read(1)
        @data << f.read(3072)
      end
    end
  end

  @targets = @targets.unpack("C*")
  # TODO switch i to -1 when Numo supports it
  @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
  @data = @data.transpose(0, 2, 3, 1)
end

Instance Method Details

#[](index) ⇒ Object



42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/torchvision/datasets/cifar10.rb', line 42

def [](index)
  # TODO remove trues when Numo supports it
  img, target = @data[index, true, true, true], @targets[index]

  img = Utils.image_from_array(img)

  img = @transform.call(img) if @transform

  target = @target_transform.call(target) if @target_transform

  [img, target]
end

#_check_integrityObject



55
56
57
58
59
60
61
62
# File 'lib/torchvision/datasets/cifar10.rb', line 55

def _check_integrity
  root = @root
  (train_list + test_list).each do |fentry|
    fpath = File.join(root, base_folder, fentry[:filename])
    return false unless check_integrity(fpath, fentry[:sha256])
  end
  true
end

#downloadObject



64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/torchvision/datasets/cifar10.rb', line 64

def download
  if _check_integrity
    puts "Files already downloaded and verified"
    return
  end

  download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)

  path = File.join(@root, filename)
  File.open(path, "rb") do |io|
    Gem::Package.new("").extract_tar_gz(io, @root)
  end
end

#sizeObject



38
39
40
# File 'lib/torchvision/datasets/cifar10.rb', line 38

def size
  @data.shape[0]
end