Class: DNN::Iterator

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/iterator.rb

Overview

This class manages input datas and output datas together.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(x_datas, y_datas, random: true, last_round_down: false) ⇒ Iterator

Returns a new instance of Iterator.

Parameters:

  • x_datas (Numo::SFloat)

    input datas.

  • y_datas (Numo::SFloat)

    output datas.

  • random (Boolean) (defaults to: true)

    Set true to return batches randomly. Setting false returns batches in order of index.

  • last_round_down (Boolean) (defaults to: false)

    Set true to round down for last batch data when call foreach.



11
12
13
14
15
16
17
18
# File 'lib/dnn/core/iterator.rb', line 11

def initialize(x_datas, y_datas, random: true, last_round_down: false)
  @x_datas = x_datas
  @y_datas = y_datas
  @random = random
  @last_round_down = last_round_down
  @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
  reset
end

Instance Attribute Details

#last_round_downObject (readonly)

Returns the value of attribute last_round_down.



5
6
7
# File 'lib/dnn/core/iterator.rb', line 5

def last_round_down
  @last_round_down
end

#num_datasObject (readonly)

Returns the value of attribute num_datas.



4
5
6
# File 'lib/dnn/core/iterator.rb', line 4

def num_datas
  @num_datas
end

Instance Method Details

#foreach(batch_size, &block) ⇒ Object



55
56
57
58
59
60
61
62
# File 'lib/dnn/core/iterator.rb', line 55

def foreach(batch_size, &block)
  steps = @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
  steps.times do |step|
    x_batch, y_batch = next_batch(batch_size)
    block.call(x_batch, y_batch, step)
  end
  reset
end

#has_next?Boolean

Return the true if has next batch.

Returns:

  • (Boolean)


51
52
53
# File 'lib/dnn/core/iterator.rb', line 51

def has_next?
  @has_next
end

#next_batch(batch_size) ⇒ Object

Return the next batch.

Parameters:

  • batch_size (Integer)

    Required batch size.

Raises:



22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# File 'lib/dnn/core/iterator.rb', line 22

def next_batch(batch_size)
  raise DNN_Error, "This iterator has not next batch. Please call reset." unless has_next?
  if @indexes.length <= batch_size
    batch_indexes = @indexes
    @has_next = false
  else
    batch_indexes = @indexes.shift(batch_size)
  end
  x_batch = if @x_datas.is_a?(Array)
              @x_datas.map { |datas| datas[batch_indexes, false] }
            else
              @x_datas[batch_indexes, false]
            end
  y_batch = if @y_datas.is_a?(Array)
              @y_datas.map { |datas| datas[batch_indexes, false] }
            else
              @y_datas[batch_indexes, false]
            end
  [x_batch, y_batch]
end

#resetObject

Reset input datas and output datas.



44
45
46
47
48
# File 'lib/dnn/core/iterator.rb', line 44

def reset
  @has_next = true
  @indexes = @num_datas.times.to_a
  @indexes.shuffle! if @random
end