22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
# File 'lib/coo-coo/trainer/batch.rb', line 22
def train(options, &block)
options = options.to_h
network = options.fetch(:network)
training_data = options.fetch(:data)
learning_rate = options.fetch(:learning_rate, 0.3)
batch_size = options.fetch(:batch_size, 1024)
cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare)
processes = options.fetch(:processes, Parallel.processor_count)
t = Time.now
training_data.each_slice(batch_size).with_index do |batch, i|
deltas_errors = in_parallel(processes, batch) do |(expecting, input)|
output, hidden_state = network.forward(input, Hash.new)
target = network.prep_output_target(expecting)
final_output = network.final_output(output)
errors = cost_function.derivative(target, final_output)
new_deltas, hidden_state = network.backprop(input, output, errors, hidden_state)
new_deltas = network.weight_deltas(input, output, new_deltas * learning_rate)
[ new_deltas, cost_function.call(target, final_output) ]
end
deltas, total_errors = deltas_errors.transpose
network.adjust_weights!(accumulate_deltas(deltas))
if block
block.call(BatchStats.new(self, i, batch_size, Time.now - t, CooCoo::Sequence[total_errors].sum))
end
t = Time.now
end
end
|