Class: NeuralNetwork::Trainer

Inherits:
Object
  • Object
show all
Defined in:
lib/neural_network/trainer.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(network, data) ⇒ Trainer

Returns a new instance of Trainer.



5
6
7
8
# File 'lib/neural_network/trainer.rb', line 5

def initialize(network, data)
  @network  = network
  @data     = data
end

Instance Attribute Details

#dataObject

Returns the value of attribute data.



3
4
5
# File 'lib/neural_network/trainer.rb', line 3

def data
  @data
end

#networkObject

Returns the value of attribute network.



3
4
5
# File 'lib/neural_network/trainer.rb', line 3

def network
  @network
end

Instance Method Details

#train(options = {}) ⇒ Object



10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# File 'lib/neural_network/trainer.rb', line 10

def train(options = {})
  epochs    = options[:epochs]
  log_freqs = options[:log_freqs]

  epochs.times do |epoch|
    average_error = data.reduce(0) do |sum, sample|
      network.activate(sample[:input])
      network.train(sample[:output])
      sum + network.error/data.length
    end

    if epoch % log_freqs == 0 || epoch + 1 == epochs
      puts "epoch: #{epoch}  error: #{average_error}"
    end
  end
end