Class: DNN::Iterator

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/iterator.rb

Overview

This class manages input datas and output datas together.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(x_datas, y_datas, random: true, last_round_down: false) ⇒ Iterator

Returns a new instance of Iterator.

Parameters:

  • x_datas (Numo::NArray | Array)

    input datas.

  • y_datas (Numo::NArray | Array)

    output datas.

  • random (Boolean) (defaults to: true)

    Set true to return batches randomly. Setting false returns batches in order of index.

  • last_round_down (Boolean) (defaults to: false)

    Set true to round down for last batch data when call foreach.



11
12
13
14
15
16
17
18
19
20
# File 'lib/dnn/core/iterator.rb', line 11

def initialize(x_datas, y_datas, random: true, last_round_down: false)
  Utils.check_input_data_type("x_datas", x_datas, Xumo::NArray)
  Utils.check_input_data_type("y_datas", y_datas, Xumo::NArray)
  @x_datas = x_datas
  @y_datas = y_datas
  @random = random
  @last_round_down = last_round_down
  @num_datas = x_datas.is_a?(Array) ? x_datas[0].shape[0] : x_datas.shape[0]
  reset
end

Instance Attribute Details

#last_round_downObject (readonly)

Returns the value of attribute last_round_down.



5
6
7
# File 'lib/dnn/core/iterator.rb', line 5

def last_round_down
  @last_round_down
end

#num_datasObject (readonly)

Returns the value of attribute num_datas.



4
5
6
# File 'lib/dnn/core/iterator.rb', line 4

def num_datas
  @num_datas
end

Instance Method Details

#foreach(batch_size) { ... } ⇒ Object

Run a loop with all data separated by batch

Parameters:

  • batch_size (Integer)

    Batch size.

Yields:

  • Executes block by receiving the specified arguments (x_batch, y_batch).



68
69
70
71
72
73
74
# File 'lib/dnn/core/iterator.rb', line 68

def foreach(batch_size, &block)
  max_steps(batch_size).times do |step|
    x_batch, y_batch = next_batch(batch_size)
    block.call(x_batch, y_batch, step)
  end
  reset
end

#has_next?Boolean

Return the true if has next batch.

Returns:

  • (Boolean)


61
62
63
# File 'lib/dnn/core/iterator.rb', line 61

def has_next?
  @has_next
end

#max_steps(batch_size) ⇒ Object

Get max steps for iteration.

Parameters:

  • batch_size (Integer)

    Batch size.



87
88
89
# File 'lib/dnn/core/iterator.rb', line 87

def max_steps(batch_size)
  @last_round_down ? @num_datas / batch_size : (@num_datas.to_f / batch_size).ceil
end

#next_batch(batch_size) ⇒ Array

Return the next batch.

Parameters:

  • batch_size (Integer)

    Required batch size.

Returns:

  • (Array)

    Returns the mini batch in the form [x_batch, y_batch].

Raises:



25
26
27
28
29
30
31
32
33
34
# File 'lib/dnn/core/iterator.rb', line 25

def next_batch(batch_size)
  raise DNNError, "This iterator has not next batch. Please call reset." unless has_next?
  if @indexes.length <= batch_size
    batch_indexes = @indexes
    @has_next = false
  else
    batch_indexes = @indexes.shift(batch_size)
  end
  get_batch(batch_indexes)
end

#num_usable_datas(batch_size) ⇒ Object

Return the number of available data considering last_round_down.



77
78
79
80
81
82
83
# File 'lib/dnn/core/iterator.rb', line 77

def num_usable_datas(batch_size)
  if @last_round_down
    max_steps(batch_size) * batch_size
  else
    @num_datas
  end
end

#resetObject

Reset input datas and output datas.



54
55
56
57
58
# File 'lib/dnn/core/iterator.rb', line 54

def reset
  @has_next = true
  @indexes = @num_datas.times.to_a
  @indexes.shuffle! if @random
end