Class: DNN::Iterator

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/iterator.rb

Overview

This class manages input datas and output datas together.

Instance Method Summary collapse

Constructor Details

#initialize(x_datas, y_datas, random: true) ⇒ Iterator

Returns a new instance of Iterator.

Parameters:

  • x_datas (Numo::SFloat)

    input datas.

  • y_datas (Numo::SFloat)

    output datas.

  • random (Boolean) (defaults to: true)

    Set true to return batches randomly. Setting false returns batches in order of index.



7
8
9
10
11
12
13
# File 'lib/dnn/core/iterator.rb', line 7

def initialize(x_datas, y_datas, random: true)
  @x_datas = x_datas
  @y_datas = y_datas
  @random = random
  @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
  reset
end

Instance Method Details

#foreach(batch_size, &block) ⇒ Object



50
51
52
53
54
55
56
57
58
# File 'lib/dnn/core/iterator.rb', line 50

def foreach(batch_size, &block)
  step = 0
  while has_next?
    x_batch, y_batch = next_batch(batch_size)
    block.call(x_batch, y_batch, step)
    step += 1
  end
  reset
end

#has_next?Boolean

Return the true if has next batch.

Returns:

  • (Boolean)


46
47
48
# File 'lib/dnn/core/iterator.rb', line 46

def has_next?
  @has_next
end

#next_batch(batch_size) ⇒ Object

Return the next batch.

Parameters:

  • batch_size (Integer)

    Required batch size.

Raises:



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/dnn/core/iterator.rb', line 17

def next_batch(batch_size)
  raise DNN_Error.new("This iterator has not next batch. Please call reset.") unless has_next?
  if @indexes.length <= batch_size
    batch_indexes = @indexes
    @has_next = false
  else
    batch_indexes = @indexes.shift(batch_size)
  end
  x_batch = if @x_datas.is_a?(Array)
    @x_datas.map { |datas| datas[batch_indexes, false] }
  else
    @x_datas[batch_indexes, false]
  end
  y_batch = if @y_datas.is_a?(Array)
    @y_datas.map { |datas| datas[batch_indexes, false] }
  else
    @y_datas[batch_indexes, false]
  end
  [x_batch, y_batch]
end

#resetObject

Reset input datas and output datas.



39
40
41
42
43
# File 'lib/dnn/core/iterator.rb', line 39

def reset
  @has_next = true
  @indexes = @num_datas.times.to_a
  @indexes.shuffle! if @random
end