Class: Torch::Utils::Data::DataLoader

Inherits:
Object
  • Object
show all
Includes:
Enumerable
Defined in:
lib/torch/utils/data/data_loader.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(dataset, batch_size: 1, shuffle: false) ⇒ DataLoader

Returns a new instance of DataLoader.



9
10
11
12
13
# File 'lib/torch/utils/data/data_loader.rb', line 9

def initialize(dataset, batch_size: 1, shuffle: false)
  @dataset = dataset
  @batch_size = batch_size
  @shuffle = shuffle
end

Instance Attribute Details

#datasetObject (readonly)

Returns the value of attribute dataset.



7
8
9
# File 'lib/torch/utils/data/data_loader.rb', line 7

def dataset
  @dataset
end

Instance Method Details

#eachObject



15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/torch/utils/data/data_loader.rb', line 15

def each
  # try to keep the random number generator in sync with Python
  # this makes it easy to compare results
  base_seed = Torch.empty([], dtype: :int64).random!.item

  indexes =
    if @shuffle
      Torch.randperm(@dataset.size).to_a
    else
      @dataset.size.times
    end

  indexes.each_slice(@batch_size) do |idx|
    batch = idx.map { |i| @dataset[i] }
    yield collate(batch)
  end
end

#sizeObject



33
34
35
# File 'lib/torch/utils/data/data_loader.rb', line 33

def size
  (@dataset.size / @batch_size.to_f).ceil
end