Class: Desiru::Optimizers::Base
- Inherits:
-
Object
- Object
- Desiru::Optimizers::Base
- Defined in:
- lib/desiru/optimizers/base.rb
Overview
Base class for all optimizers
Direct Known Subclasses
Instance Attribute Summary collapse
-
#config ⇒ Object
readonly
Returns the value of attribute config.
-
#metric ⇒ Object
readonly
Returns the value of attribute metric.
Instance Method Summary collapse
- #compile(program, trainset:, valset: nil) ⇒ Object
- #evaluate(program, dataset) ⇒ Object
-
#initialize(metric: :exact_match, **config) ⇒ Base
constructor
A new instance of Base.
- #optimize_module(module_instance, examples) ⇒ Object
Constructor Details
#initialize(metric: :exact_match, **config) ⇒ Base
Returns a new instance of Base.
9 10 11 12 13 |
# File 'lib/desiru/optimizers/base.rb', line 9 def initialize(metric: :exact_match, **config) @metric = normalize_metric(metric) @config = default_config.merge(config) @optimization_trace = [] end |
Instance Attribute Details
#config ⇒ Object (readonly)
Returns the value of attribute config.
7 8 9 |
# File 'lib/desiru/optimizers/base.rb', line 7 def config @config end |
#metric ⇒ Object (readonly)
Returns the value of attribute metric.
7 8 9 |
# File 'lib/desiru/optimizers/base.rb', line 7 def metric @metric end |
Instance Method Details
#compile(program, trainset:, valset: nil) ⇒ Object
15 16 17 |
# File 'lib/desiru/optimizers/base.rb', line 15 def compile(program, trainset:, valset: nil) raise NotImplementedError, 'Subclasses must implement #compile' end |
#evaluate(program, dataset) ⇒ Object
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/desiru/optimizers/base.rb', line 23 def evaluate(program, dataset) scores = dataset.map do |example| # Extract inputs (exclude answer/output fields) inputs = {} if example.respond_to?(:to_h) example.to_h.each do |k, v| inputs[k] = v unless %i[answer output].include?(k) end elsif example.is_a?(Hash) example.each do |k, v| inputs[k] = v unless %i[answer output].include?(k.to_sym) end else inputs = example end prediction = program.call(inputs) score_prediction(prediction, example) end { average_score: scores.sum.to_f / scores.size, scores: scores, total: scores.size } end |
#optimize_module(module_instance, examples) ⇒ Object
19 20 21 |
# File 'lib/desiru/optimizers/base.rb', line 19 def optimize_module(module_instance, examples) raise NotImplementedError, 'Subclasses must implement #optimize_module' end |