Class: CooCoo::Trainer::Base Abstract
- Includes:
- Singleton
- Defined in:
- lib/coo-coo/trainer/base.rb
Overview
This class is abstract.
Defines and documents the interface for the trainers.
Direct Known Subclasses
Constant Summary collapse
- DEFAULT_OPTIONS =
{ cost: CostFunctions::MeanSquare, learning_rate: 1/3.0, batch_size: 1024 }
Instance Method Summary collapse
-
#name ⇒ Object
Returns a user friendly name, like the class name by default.
-
#options(defaults = DEFAULT_OPTIONS) ⇒ [OptionParser, OpenStruct]
Returns a command line OptionParser to gather the trainer’s options.
-
#train(options) {|BatchStats| ... } ⇒ Object
Trains a network by iterating through a set of target, input pairs.
Instance Method Details
#name ⇒ Object
Returns a user friendly name, like the class name by default.
12 13 14 |
# File 'lib/coo-coo/trainer/base.rb', line 12 def name self.class.name.split('::').last end |
#options(defaults = DEFAULT_OPTIONS) ⇒ [OptionParser, OpenStruct]
Returns a command line OptionParser to gather the trainer’s options.
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# File 'lib/coo-coo/trainer/base.rb', line 25 def (defaults = DEFAULT_OPTIONS) = OpenStruct.new(defaults) parser = OptionParser.new do |o| o. = "#{name} trainer options" o.accept(CostFunctions::Base) do |v| CostFunctions.from_name(v) end o.on('--cost NAME', '--cost-function NAME', "The function to minimize during training. Choices are: #{CostFunctions.named_classes.join(', ')}", CostFunctions::Base) do |v| .cost_function = v end o.on('-r', '--rate FLOAT', '--learning-rate FLOAT', Float, 'Multiplier for the changes the network calculates.') do |n| .learning_rate = n end o.on('-n', '--batch-size INTEGER', Integer, 'Number of examples to train against before yielding.') do |n| .batch_size = n end yield(o, ) if block_given? end [ parser, ] end |
#train(options) {|BatchStats| ... } ⇒ Object
Trains a network by iterating through a set of target, input pairs.
62 63 64 |
# File 'lib/coo-coo/trainer/base.rb', line 62 def train(, &block) raise NotImplementedError.new end |