Class: Torch::Utils::Data::DataLoader

Inherits:
Object
  • Object
show all
Includes:
Enumerable
Defined in:
lib/torch/utils/data/data_loader.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil) ⇒ DataLoader

Returns a new instance of DataLoader.



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# File 'lib/torch/utils/data/data_loader.rb', line 9

def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
  @dataset = dataset
  @batch_size = batch_size
  @shuffle = shuffle

  @batch_sampler = nil

  if collate_fn.nil?
    if auto_collation?
      collate_fn = method(:default_collate)
    else
      collate_fn = method(:default_convert)
    end
  end

  @collate_fn = collate_fn
end

Instance Attribute Details

#datasetObject (readonly)

Returns the value of attribute dataset.



7
8
9
# File 'lib/torch/utils/data/data_loader.rb', line 7

def dataset
  @dataset
end

Instance Method Details

#eachObject



27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/torch/utils/data/data_loader.rb', line 27

def each
  return to_enum(:each) unless block_given?

  # try to keep the random number generator in sync with Python
  # this makes it easy to compare results
  _base_seed = Torch.empty([], dtype: :int64).random!.item

  indexes =
    if @shuffle
      Torch.randperm(@dataset.size).to_a
    else
      @dataset.size.times
    end

  indexes.each_slice(@batch_size) do |idx|
    # TODO improve performance
    yield @collate_fn.call(idx.map { |i| @dataset[i] })
  end
end

#sizeObject Also known as: length, count



47
48
49
# File 'lib/torch/utils/data/data_loader.rb', line 47

def size
  (@dataset.size / @batch_size.to_f).ceil
end