Class: CooCoo::Trainer::Base Abstract

Inherits:
Object
  • Object
show all
Includes:
Singleton
Defined in:
lib/coo-coo/trainer/base.rb

Overview

This class is abstract.

Defines and documents the interface for the trainers.

Direct Known Subclasses

Batch, MomentumStochastic, Stochastic

Constant Summary collapse

DEFAULT_OPTIONS =
{
  cost: CostFunctions::MeanSquare,
  learning_rate: 1/3.0,
  batch_size: 1024
}

Instance Method Summary collapse

Instance Method Details

#nameObject

Returns a user friendly name, like the class name by default.



12
13
14
# File 'lib/coo-coo/trainer/base.rb', line 12

def name
  self.class.name.split('::').last
end

#options(defaults = DEFAULT_OPTIONS) ⇒ [OptionParser, OpenStruct]

Returns a command line OptionParser to gather the trainer’s options.

Returns:



25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# File 'lib/coo-coo/trainer/base.rb', line 25

def options(defaults = DEFAULT_OPTIONS)
  options = OpenStruct.new(defaults)
  
  parser = OptionParser.new do |o|
    o.banner = "#{name} trainer options"

    o.accept(CostFunctions::Base) do |v|
      CostFunctions.from_name(v)
    end
    
    o.on('--cost NAME', '--cost-function NAME', "The function to minimize during training. Choices are: #{CostFunctions.named_classes.join(', ')}", CostFunctions::Base) do |v|
      options.cost_function = v
    end
    
    o.on('-r', '--rate FLOAT', '--learning-rate FLOAT', Float, 'Multiplier for the changes the network calculates.') do |n|
      options.learning_rate = n
    end
    
    o.on('-n', '--batch-size INTEGER', Integer, 'Number of examples to train against before yielding.') do |n|
      options.batch_size = n
    end
    
    yield(o, options) if block_given?
  end

  [ parser, options ]
end

#train(options) {|BatchStats| ... } ⇒ Object

Trains a network by iterating through a set of target, input pairs.

Parameters:

  • options (Hash, OpenStruct)

    Options hash

Options Hash (options):

  • :network (Network, TemporalNetwork)

    The network to train.

  • :data (Array<Array<Vector, Vector>>, Enumerator<Vector, Vector>)

    An array of [ target, input ] pairs to be used for the training.

  • :learning_rate (Float)

    The multiplier of change in the network’s weights.

  • :batch_size (Integer)

    How many examples to pull from the training data in each batch

  • :cost_function (CostFunctions::Base)

    The function to use to calculate the loss and how to change the network from bad outputs.

Yields:

Raises:

  • (NotImplementedError)


62
63
64
# File 'lib/coo-coo/trainer/base.rb', line 62

def train(options, &block)
  raise NotImplementedError.new
end