Class: CooCoo::Trainer::Batch

Inherits:
Base show all
Defined in:
lib/coo-coo/trainer/batch.rb

Overview

Trains a network by only adjusting the network once a batch. This opens up parallelism during learning as more examples can be ran at one time.

Constant Summary collapse

DEFAULT_OPTIONS =
Base::DEFAULT_OPTIONS.merge(processes: Parallel.processor_count)

Instance Method Summary collapse

Methods inherited from Base

#name

Instance Method Details

#optionsObject



13
14
15
16
17
18
19
# File 'lib/coo-coo/trainer/batch.rb', line 13

def options
  super(DEFAULT_OPTIONS) do |o, options|
    o.on('--processes INTEGER', Integer, 'Number of threads or processes to use for the batch.') do |n|
      options.processes = n
    end
  end
end

#train(options, &block) ⇒ Object

Parameters:

  • options (Hash)

    a customizable set of options

Options Hash (options):

  • :processes (Integer)

    How many threads or processes to use for the batch. Defaults to the processor count, Parallel#processor_count.



22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/coo-coo/trainer/batch.rb', line 22

def train(options, &block)
  options = options.to_h
  network = options.fetch(:network)
  training_data = options.fetch(:data)
  learning_rate = options.fetch(:learning_rate, 0.3)
  batch_size = options.fetch(:batch_size, 1024)
  cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare)
  processes = options.fetch(:processes, Parallel.processor_count)

  t = Time.now
  
  training_data.each_slice(batch_size).with_index do |batch, i|
    deltas_errors = in_parallel(processes, batch) do |(expecting, input)|
      output, hidden_state = network.forward(input, Hash.new)
      target = network.prep_output_target(expecting)
      final_output = network.final_output(output)
      errors = cost_function.derivative(target, final_output)
      new_deltas, hidden_state = network.backprop(input, output, errors, hidden_state)
      new_deltas = network.weight_deltas(input, output, new_deltas * learning_rate)

      [ new_deltas, cost_function.call(target, final_output) ]
    end

    deltas, total_errors = deltas_errors.transpose
    network.adjust_weights!(accumulate_deltas(deltas))

    if block
      block.call(BatchStats.new(self, i, batch_size, Time.now - t, CooCoo::Sequence[total_errors].sum))
    end
    
    t = Time.now
  end
end