Class: RubyZero::Data::DataLoader
- Inherits:
-
Object
- Object
- RubyZero::Data::DataLoader
- Includes:
- Enumerable
- Defined in:
- lib/rubyzero/data/dataloader.rb
Instance Method Summary collapse
- #each ⇒ Object
-
#initialize(dataset, batch_size: 1, shuffle: false) ⇒ DataLoader
constructor
A new instance of DataLoader.
Constructor Details
#initialize(dataset, batch_size: 1, shuffle: false) ⇒ DataLoader
Returns a new instance of DataLoader.
3 4 5 6 7 |
# File 'lib/rubyzero/data/dataloader.rb', line 3 def initialize(dataset, batch_size: 1, shuffle: false) @dataset = dataset @batch_size = batch_size > dataset.size ? dataset.size : batch_size @shuffle = shuffle end |
Instance Method Details
#each ⇒ Object
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# File 'lib/rubyzero/data/dataloader.rb', line 9 def each shuffled_index = nil shuffled_index = (0..(@dataset.length-1)).to_a.shuffle if @shuffle batched_index = shuffled_index.each_slice(@batch_size).to_a batched_index.each do |batch_idx_arr| datas = [] batch_idx_arr.each do |idx| datas << @dataset[idx] end # transpose 2d array args = datas.transpose args = args.map do |arg| if arg.class == Array if arg[0].is_a?(RubyZero::Core::Tensor) next RubyZero::Core::Tensor.stack(arg) elsif arg[0].is_a?(Array) next RubyZero::Core::Tensor.new(arg) end end next arg end yield(*args) end end |