Class: GEPA::Proposer::MergeProposer

Inherits:
Object
  • Object
show all
Extended by:
T::Sig
Includes:
ProposeNewCandidate
Defined in:
lib/gepa/proposer/merge_proposer.rb

Overview

Port of the Python GEPA merge proposer. It fuses two descendants that share a common ancestor by recombining their component instructions and then evaluates the merged program on a Pareto-informed subsample.

Constant Summary collapse

CandidateTriplet =
T.type_alias { [Integer, Integer, Integer] }
MergeAttempt =
T.type_alias { [Integer, Integer, T::Array[Integer]] }

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(logger:, valset:, evaluator:, use_merge:, max_merge_invocations:, rng: nil, telemetry: nil) ⇒ MergeProposer



34
35
36
37
38
39
40
41
42
43
44
45
46
47
# File 'lib/gepa/proposer/merge_proposer.rb', line 34

def initialize(logger:, valset:, evaluator:, use_merge:, max_merge_invocations:, rng: nil, telemetry: nil)
  @logger = logger
  @valset = valset
  @evaluator = evaluator
  @use_merge = use_merge
  @max_merge_invocations = max_merge_invocations
  @rng = rng || Random.new(0)
  @telemetry = telemetry || GEPA::Telemetry

  @merges_due = 0
  @total_merges_tested = 0
  @last_iter_found_new_program = false
  @merges_performed = [[], []]
end

Instance Attribute Details

#last_iter_found_new_programObject

Returns the value of attribute last_iter_found_new_program.



56
57
58
# File 'lib/gepa/proposer/merge_proposer.rb', line 56

def last_iter_found_new_program
  @last_iter_found_new_program
end

#max_merge_invocationsObject (readonly)

Returns the value of attribute max_merge_invocations.



59
60
61
# File 'lib/gepa/proposer/merge_proposer.rb', line 59

def max_merge_invocations
  @max_merge_invocations
end

#merges_dueObject

Returns the value of attribute merges_due.



50
51
52
# File 'lib/gepa/proposer/merge_proposer.rb', line 50

def merges_due
  @merges_due
end

#total_merges_testedObject

Returns the value of attribute total_merges_tested.



53
54
55
# File 'lib/gepa/proposer/merge_proposer.rb', line 53

def total_merges_tested
  @total_merges_tested
end

#use_mergeObject (readonly)

Returns the value of attribute use_merge.



62
63
64
# File 'lib/gepa/proposer/merge_proposer.rb', line 62

def use_merge
  @use_merge
end

Instance Method Details

#propose(state) ⇒ Object



119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# File 'lib/gepa/proposer/merge_proposer.rb', line 119

def propose(state)
  iteration = state.i + 1
  ensure_trace_slot(state)
  state.full_program_trace.last[:invoked_merge] = true

  unless eligible_for_proposal?
    @logger.log("Iteration #{iteration}: No merge candidates scheduled")
    return nil
  end

  merge_candidates = GEPA::Utils::Pareto.find_dominator_programs(
    state.program_at_pareto_front_valset,
    state.per_program_tracked_scores.each_with_index.to_h { |score, idx| [idx, score] }
  )

  success, new_program, id1, id2, ancestor = sample_and_attempt_merge_programs_by_common_predictors(
    state,
    merge_candidates
  )

  unless success
    @logger.log("Iteration #{iteration}: No merge candidates found")
    return nil
  end

  state.full_program_trace.last[:merged] = true
  state.full_program_trace.last[:merged_entities] = [id1, id2, ancestor]
  @merges_performed[0] << [id1, id2, ancestor]

  @logger.log("Iteration #{iteration}: Merged programs #{id1} and #{id2} via ancestor #{ancestor}")

  subsample_ids = select_eval_subsample_for_merged_program(
    state.prog_candidate_val_subscores[id1],
    state.prog_candidate_val_subscores[id2]
  )

  mini_valset = subsample_ids.map { |idx| @valset[idx] }
  id1_sub_scores = subsample_ids.map { |idx| state.prog_candidate_val_subscores[id1][idx] }
  id2_sub_scores = subsample_ids.map { |idx| state.prog_candidate_val_subscores[id2][idx] }

  state.full_program_trace.last[:subsample_ids] = subsample_ids
  state.full_program_trace.last[:id1_subsample_scores] = id1_sub_scores
  state.full_program_trace.last[:id2_subsample_scores] = id2_sub_scores

  _, new_sub_scores = @evaluator.call(mini_valset, new_program)
  state.full_program_trace.last[:new_program_subsample_scores] = new_sub_scores

  state.total_num_evals += subsample_ids.length

  CandidateProposal.new(
    candidate: new_program,
    parent_program_ids: [id1, id2],
    subsample_indices: subsample_ids,
    subsample_scores_before: [id1_sub_scores.sum, id2_sub_scores.sum],
    subsample_scores_after: new_sub_scores,
    tag: 'merge',
    metadata: { ancestor: ancestor }
  )
end

#schedule_if_neededObject



65
66
67
68
69
70
# File 'lib/gepa/proposer/merge_proposer.rb', line 65

def schedule_if_needed
  return unless @use_merge
  return unless @total_merges_tested < @max_merge_invocations

  @merges_due += 1
end

#select_eval_subsample_for_merged_program(scores1, scores2, num_subsample_ids: 5) ⇒ Object



79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# File 'lib/gepa/proposer/merge_proposer.rb', line 79

def select_eval_subsample_for_merged_program(scores1, scores2, num_subsample_ids: 5)
  all_indices = (0...[scores1.length, scores2.length].min).to_a
  p1 = []
  p2 = []
  p3 = []

  all_indices.each do |index|
    s1 = scores1[index]
    s2 = scores2[index]
    if s1 > s2
      p1 << index
    elsif s2 > s1
      p2 << index
    else
      p3 << index
    end
  end

  n_each = (num_subsample_ids / 3.0).ceil
  selected = []
  selected.concat(sample_from(p1, [n_each, p1.length].min))
  selected.concat(sample_from(p2, [n_each, p2.length].min))

  remaining_slots = num_subsample_ids - selected.length
  selected.concat(sample_from(p3, [remaining_slots, p3.length].min))

  remaining_slots = num_subsample_ids - selected.length
  unused = all_indices - selected
  if remaining_slots.positive?
    if unused.length >= remaining_slots
      selected.concat(sample_from(unused, remaining_slots))
    else
      selected.concat(sample_with_replacement(all_indices, remaining_slots))
    end
  end

  selected.take(num_subsample_ids)
end