Class: RubyZero::Data::DataLoader

Inherits:
Object
  • Object
show all
Includes:
Enumerable
Defined in:
lib/rubyzero/data/dataloader.rb

Instance Method Summary collapse

Constructor Details

#initialize(dataset, batch_size: 1, shuffle: false) ⇒ DataLoader

Returns a new instance of DataLoader.



3
4
5
6
7
# File 'lib/rubyzero/data/dataloader.rb', line 3

def initialize(dataset, batch_size: 1, shuffle: false)
    @dataset = dataset
    @batch_size = batch_size > dataset.size ? dataset.size : batch_size
    @shuffle = shuffle
end

Instance Method Details

#eachObject



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/rubyzero/data/dataloader.rb', line 9

def each
    shuffled_index = nil
    shuffled_index = (0..(@dataset.length-1)).to_a.shuffle if @shuffle
    batched_index = shuffled_index.each_slice(@batch_size).to_a
    batched_index.each do |batch_idx_arr|
        datas = []
        batch_idx_arr.each do |idx|
            datas << @dataset[idx]
        end
        # transpose 2d array
        args = datas.transpose
        args = args.map do |arg|
            if arg.class == Array
                if arg[0].is_a?(RubyZero::Core::Tensor)
                    next RubyZero::Core::Tensor.stack(arg)
                elsif arg[0].is_a?(Array)
                    next RubyZero::Core::Tensor.new(arg)
                end
            end
            next arg
        end
        yield(*args)
    end
end