Class: CooCoo::Trainer::MomentumStochastic
- Defined in:
- lib/coo-coo/trainer/momentum_stochastic.rb
Constant Summary collapse
- DEFAULT_OPTIONS =
Base::DEFAULT_OPTIONS.merge(momentum: 1/30.0)
Instance Method Summary collapse
- #learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state) ⇒ Object
- #options ⇒ Object
- #train(options, &block) ⇒ Object
Methods inherited from Base
Instance Method Details
#learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state) ⇒ Object
46 47 48 49 50 51 52 53 54 55 56 |
# File 'lib/coo-coo/trainer/momentum_stochastic.rb', line 46 def learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state) output, hidden_state = network.forward(input, hidden_state) target = expecting target = network.prep_output_target(expecting) final_output = network.final_output(output) errors = cost_function.derivative(target, final_output) deltas, hidden_state = network.backprop(input, output, errors, hidden_state) deltas = CooCoo::Sequence[deltas] * rate network.update_weights!(input, output, deltas - last_deltas * momentum) return cost_function.call(target, final_output), hidden_state, deltas end |
#options ⇒ Object
11 12 13 14 15 16 17 |
# File 'lib/coo-coo/trainer/momentum_stochastic.rb', line 11 def super(DEFAULT_OPTIONS) do |o, | o.on('--momentum FLOAT', Float, 'Multiplier for the accumulated changes.') do |n| .momentum = n end end end |
#train(options, &block) ⇒ Object
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# File 'lib/coo-coo/trainer/momentum_stochastic.rb', line 20 def train(, &block) = .to_h network = .fetch(:network) training_data = .fetch(:data) learning_rate = .fetch(:learning_rate, 1/3.0) batch_size = .fetch(:batch_size, 1024) cost_function = .fetch(:cost_function, CostFunctions::MeanSquare) momentum = .fetch(:momentum, 1/30.0) t = Time.now training_data.each_slice(batch_size).with_index do |batch, i| last_delta = 0.0 total_errs = batch.inject(nil) do |acc, (expecting, input)| errs, hidden_state, last_delta = learn(network, input, expecting, learning_rate, last_delta, momentum, cost_function, Hash.new) errs + (acc || 0) end if block block.call(BatchStats.new(self, i, batch_size, Time.now - t, total_errs)) end t = Time.now end end |