Refactor router late-phase control flow

This commit is contained in:
Jan Petykiewicz 2026-04-02 20:17:03 -07:00
commit 7e0d96f987

View file

@ -135,6 +135,37 @@ class PathFinder:
def _metric_deltas(self, before: dict[str, int], after: dict[str, int]) -> dict[str, int]:
return {metric_name: after[metric_name] - before[metric_name] for metric_name in before}
def _results_all_reached_target(self, state: _RoutingState) -> bool:
return (
len(state.results) == len(state.ordered_net_ids)
and all(result.reached_target for result in state.results.values())
)
def _has_incumbent_fallback(self, result: RoutingResult | None) -> bool:
return bool(result and result.reached_target and result.path)
def _restore_incumbent_fallback(
self,
net_id: str,
result: RoutingResult,
guidance_seed_present: bool,
) -> tuple[RoutingResult, bool]:
self.metrics.total_late_phase_capped_fallbacks += 1
self._install_path(net_id, result.path)
return result, guidance_seed_present
def _guidance_for_result(
self,
result: RoutingResult | None,
) -> tuple[Sequence[ComponentResult] | None, float, bool]:
if result is None or not result.reached_target or not result.path:
return None, 0.0, False
return (
result.as_seed().segments,
max(10.0, self.context.options.objective.bend_penalty * 0.25),
True,
)
def _install_path(self, net_id: str, path: Sequence[ComponentResult]) -> None:
all_geoms: list[Polygon] = []
all_dilated: list[Polygon] = []
@ -272,6 +303,38 @@ class PathFinder:
if result and result.path:
self._install_path(net_id, result.path)
def _analyze_restored_best(
self,
state: _RoutingState,
) -> tuple[dict[str, PathVerificationDetail], _IterationReview]:
capture_component_conflicts = (
self.context.options.diagnostics.capture_conflict_trace
or self.context.options.diagnostics.capture_pre_pair_frontier_trace
)
state.results, details_by_net, review = self._analyze_results(
state.ordered_net_ids,
state.results,
capture_component_conflicts=capture_component_conflicts,
count_iteration_metrics=False,
)
if self.context.options.diagnostics.capture_conflict_trace:
self._capture_conflict_trace_entry(
state,
stage="restored_best",
iteration=None,
results=state.results,
details_by_net=details_by_net,
review=review,
)
if self.context.options.diagnostics.capture_pre_pair_frontier_trace:
self.pre_pair_frontier_trace = self._materialize_pre_pair_frontier_trace(
state,
state.results,
details_by_net,
review,
)
return details_by_net, review
def _update_best_iteration(self, state: _RoutingState, review: _IterationReview) -> bool:
completed_nets = len(review.completed_net_ids)
conflict_edges = len(review.conflict_edges)
@ -635,6 +698,9 @@ class PathFinder:
return candidate_length < incumbent_length
return False
def _pair_local_attempt_orders(self, target: _PairLocalTarget) -> tuple[tuple[str, str], tuple[str, str]]:
return target.net_ids, target.net_ids[::-1]
def _collect_pair_local_targets(
self,
state: _RoutingState,
@ -756,6 +822,97 @@ class PathFinder:
nets=tuple(nets),
)
def _build_iteration_reroute_plan(
self,
state: _RoutingState,
reroute_net_ids: set[str],
) -> tuple[list[str], set[str]]:
routed_net_ids = [net_id for net_id in state.ordered_net_ids if net_id in reroute_net_ids]
capped_net_ids: set[str] = set()
if len(reroute_net_ids) >= len(state.ordered_net_ids) or not state.recent_attempt_work:
return routed_net_ids, capped_net_ids
order_index = {net_id: idx for idx, net_id in enumerate(state.ordered_net_ids)}
routed_net_ids.sort(key=lambda net_id: (state.recent_attempt_work.get(net_id, 0), order_index[net_id]))
if (
len(routed_net_ids) == 4
and state.best_conflict_edges <= 2
and self._results_all_reached_target(state)
):
heavy_net_ids = sorted(
routed_net_ids,
key=lambda net_id: (-state.recent_attempt_work.get(net_id, 0), order_index[net_id]),
)[:2]
capped_net_ids = {
net_id for net_id in heavy_net_ids if state.recent_attempt_work.get(net_id, 0) >= 200
}
return routed_net_ids, capped_net_ids
def _update_pre_pair_candidate(
self,
state: _RoutingState,
*,
iteration: int,
reroute_net_ids: set[str],
routed_net_ids: list[str],
attempt_traces: list[IterationNetAttemptTrace],
review: _IterationReview,
) -> None:
if self._results_all_reached_target(state) and len(reroute_net_ids) < len(state.ordered_net_ids) and review.conflict_edges:
state.pre_pair_candidate = _PrePairCandidate(
iteration=iteration,
routed_net_ids=tuple(routed_net_ids),
conflict_edges=tuple(sorted(review.conflict_edges)),
net_attempts=tuple(attempt_traces),
)
return
state.pre_pair_candidate = None
def _next_reroute_net_ids(
self,
state: _RoutingState,
review: _IterationReview,
) -> set[str]:
if self._results_all_reached_target(state) and 0 < len(review.conflict_edges) <= 3:
return set(review.conflicting_nets)
return set(state.ordered_net_ids)
def _should_stop_for_pair_local_plateau(
self,
state: _RoutingState,
*,
improved: bool,
) -> bool:
if improved:
state.pair_local_plateau_count = 0
return False
if self._results_all_reached_target(state) and state.best_conflict_edges <= 2:
# Once the run is fully reached-target and already in the final <=2-edge
# basin, another non-improving negotiated iteration is just churn before
# the bounded pair-local repair.
state.pair_local_plateau_count += 1
return state.pair_local_plateau_count >= 1
state.pair_local_plateau_count = 0
return False
def _update_repeated_conflict_state(
self,
state: _RoutingState,
review: _IterationReview,
) -> bool:
current_signature = tuple(sorted(review.conflict_edges))
repeated = (
bool(current_signature)
and (
current_signature == state.last_conflict_signature
or len(current_signature) == state.last_conflict_edge_count
)
)
state.repeated_conflict_count = state.repeated_conflict_count + 1 if repeated else 0
state.last_conflict_signature = current_signature
state.last_conflict_edge_count = len(current_signature)
return state.repeated_conflict_count >= 2
def _run_pair_local_attempt(
self,
state: _RoutingState,
@ -767,12 +924,7 @@ class PathFinder:
for net_id in pair_order:
net = state.net_specs[net_id]
guidance_result = incumbent_results.get(net_id)
guidance_seed = None
guidance_bonus = 0.0
if guidance_result and guidance_result.reached_target and guidance_result.path:
guidance_seed = guidance_result.as_seed().segments
guidance_bonus = max(10.0, self.context.options.objective.bend_penalty * 0.25)
guidance_seed, guidance_bonus, _ = self._guidance_for_result(incumbent_results.get(net_id))
run_config = SearchRunConfig.from_options(
self.context.options,
@ -811,29 +963,13 @@ class PathFinder:
return local_results, local_context.metrics.total_nodes_expanded
def _run_pair_local_search(self, state: _RoutingState) -> None:
state.results, _details_by_net, review = self._analyze_results(
state.ordered_net_ids,
state.results,
capture_component_conflicts=True,
count_iteration_metrics=False,
)
targets = self._collect_pair_local_targets(state, state.results, review)
if not targets:
return
for target in targets[:2]:
self.metrics.total_pair_local_search_pairs_considered += 1
incumbent_results = dict(state.results)
incumbent_review = review
accepted = False
for pair_order in (target.net_ids, target.net_ids[::-1]):
self.metrics.total_pair_local_search_attempts += 1
candidate = self._run_pair_local_attempt(state, incumbent_results, pair_order)
if candidate is None:
continue
candidate_results, nodes_expanded = candidate
self.metrics.total_pair_local_search_nodes_expanded += nodes_expanded
def _apply_pair_local_candidate(
self,
state: _RoutingState,
candidate_results: dict[str, RoutingResult],
incumbent_results: dict[str, RoutingResult],
incumbent_review: _IterationReview,
) -> tuple[bool, _IterationReview]:
self._replace_installed_paths(state, candidate_results)
candidate_results, _candidate_details_by_net, candidate_review = self._analyze_results(
state.ordered_net_ids,
@ -849,14 +985,53 @@ class PathFinder:
):
self.metrics.total_pair_local_search_accepts += 1
state.results = candidate_results
review = candidate_review
accepted = True
break
self._replace_installed_paths(state, incumbent_results)
return True, candidate_review
self._replace_installed_paths(state, incumbent_results)
return False, incumbent_review
def _run_pair_local_target(
self,
state: _RoutingState,
target: _PairLocalTarget,
review: _IterationReview,
) -> _IterationReview:
incumbent_results = dict(state.results)
incumbent_review = review
self.metrics.total_pair_local_search_pairs_considered += 1
for pair_order in self._pair_local_attempt_orders(target):
self.metrics.total_pair_local_search_attempts += 1
candidate = self._run_pair_local_attempt(state, incumbent_results, pair_order)
if candidate is None:
continue
candidate_results, nodes_expanded = candidate
self.metrics.total_pair_local_search_nodes_expanded += nodes_expanded
accepted, next_review = self._apply_pair_local_candidate(
state,
candidate_results,
incumbent_results,
incumbent_review,
)
if accepted:
return next_review
if not accepted:
state.results = incumbent_results
self._replace_installed_paths(state, incumbent_results)
return incumbent_review
def _run_pair_local_search(self, state: _RoutingState) -> None:
state.results, _details_by_net, review = self._analyze_results(
state.ordered_net_ids,
state.results,
capture_component_conflicts=True,
count_iteration_metrics=False,
)
targets = self._collect_pair_local_targets(state, state.results, review)
if not targets:
return
for target in targets[:2]:
review = self._run_pair_local_target(state, target, review)
def _route_net_once(
self,
@ -883,28 +1058,18 @@ class PathFinder:
else:
coll_model, _ = resolve_bend_geometry(search)
skip_congestion = False
guidance_seed = None
guidance_bonus = 0.0
guidance_seed, guidance_bonus, guidance_seed_present = (None, 0.0, False)
if congestion.use_tiered_strategy and iteration == 0:
skip_congestion = True
if coll_model == "arc":
coll_model = "clipped_bbox"
elif iteration > 0:
guidance_result = state.results.get(net_id)
if guidance_result and guidance_result.reached_target and guidance_result.path:
guidance_seed = guidance_result.as_seed().segments
guidance_bonus = max(10.0, self.context.options.objective.bend_penalty * 0.25)
guidance_seed_present = True
guidance_seed, guidance_bonus, guidance_seed_present = self._guidance_for_result(
state.results.get(net_id)
)
if (
node_limit_override is not None
and incumbent_fallback is not None
and incumbent_fallback.reached_target
and incumbent_fallback.path
):
self.metrics.total_late_phase_capped_fallbacks += 1
self._install_path(net_id, incumbent_fallback.path)
return incumbent_fallback, guidance_seed_present
if node_limit_override is not None and self._has_incumbent_fallback(incumbent_fallback):
return self._restore_incumbent_fallback(net_id, incumbent_fallback, guidance_seed_present)
run_config = SearchRunConfig.from_options(
self.context.options,
@ -931,17 +1096,13 @@ class PathFinder:
state.accumulated_expanded_nodes.extend(self.metrics.last_expanded_nodes)
if not path:
if incumbent_fallback is not None and incumbent_fallback.reached_target and incumbent_fallback.path:
self.metrics.total_late_phase_capped_fallbacks += 1
self._install_path(net_id, incumbent_fallback.path)
return incumbent_fallback, guidance_seed_present
if self._has_incumbent_fallback(incumbent_fallback):
return self._restore_incumbent_fallback(net_id, incumbent_fallback, guidance_seed_present)
return RoutingResult(net_id=net_id, path=(), reached_target=False), guidance_seed_present
reached_target = path[-1].end_port == net.target
if not reached_target and incumbent_fallback is not None and incumbent_fallback.reached_target and incumbent_fallback.path:
self.metrics.total_late_phase_capped_fallbacks += 1
self._install_path(net_id, incumbent_fallback.path)
return incumbent_fallback, guidance_seed_present
if not reached_target and self._has_incumbent_fallback(incumbent_fallback):
return self._restore_incumbent_fallback(net_id, incumbent_fallback, guidance_seed_present)
if reached_target:
self.metrics.total_nets_reached_target += 1
report = None
@ -977,24 +1138,7 @@ class PathFinder:
random.Random(iteration_seed).shuffle(state.ordered_net_ids)
iteration_penalty = self.context.congestion_penalty
routed_net_ids = [net_id for net_id in state.ordered_net_ids if net_id in reroute_net_ids]
capped_net_ids: set[str] = set()
if len(reroute_net_ids) < len(state.ordered_net_ids) and state.recent_attempt_work:
order_index = {net_id: idx for idx, net_id in enumerate(state.ordered_net_ids)}
routed_net_ids.sort(key=lambda net_id: (state.recent_attempt_work.get(net_id, 0), order_index[net_id]))
if (
len(routed_net_ids) == 4
and state.best_conflict_edges <= 2
and len(state.results) == len(state.ordered_net_ids)
and all(result.reached_target for result in state.results.values())
):
heavy_net_ids = sorted(
routed_net_ids,
key=lambda net_id: (-state.recent_attempt_work.get(net_id, 0), order_index[net_id]),
)[:2]
capped_net_ids = {
net_id for net_id in heavy_net_ids if state.recent_attempt_work.get(net_id, 0) >= 200
}
routed_net_ids, capped_net_ids = self._build_iteration_reroute_plan(state, reroute_net_ids)
self.metrics.total_nets_carried_forward += len(state.ordered_net_ids) - len(routed_net_ids)
iteration_before = {}
attempt_traces: list[IterationNetAttemptTrace] = []
@ -1039,19 +1183,14 @@ class PathFinder:
state.recent_attempt_work = attempt_work
review = self._reverify_iteration_results(state)
all_reached_target = (
len(state.results) == len(state.ordered_net_ids)
and all(result.reached_target for result in state.results.values())
)
if all_reached_target and len(reroute_net_ids) < len(state.ordered_net_ids) and review.conflict_edges:
state.pre_pair_candidate = _PrePairCandidate(
self._update_pre_pair_candidate(
state,
iteration=iteration,
routed_net_ids=tuple(routed_net_ids),
conflict_edges=tuple(sorted(review.conflict_edges)),
net_attempts=tuple(attempt_traces),
reroute_net_ids=reroute_net_ids,
routed_net_ids=routed_net_ids,
attempt_traces=attempt_traces,
review=review,
)
else:
state.pre_pair_candidate = None
if diagnostics.capture_iteration_trace:
iteration_after = self._capture_metric_totals(_ITERATION_TRACE_TOTALS)
deltas = self._metric_deltas(iteration_before, iteration_after)
@ -1115,38 +1254,11 @@ class PathFinder:
):
return False
all_reached_target = (
len(state.results) == len(state.ordered_net_ids)
and all(result.reached_target for result in state.results.values())
)
reroute_net_ids = set(state.ordered_net_ids)
if all_reached_target and 0 < len(review.conflict_edges) <= 3:
reroute_net_ids = set(review.conflicting_nets)
if improved:
state.pair_local_plateau_count = 0
elif all_reached_target and state.best_conflict_edges <= 2:
# Once all nets reach target and the best snapshot is already in the
# final <=2-edge basin, later negotiated reroutes tend to churn.
# Hand off to the bounded pair-local repair instead of exploring
# additional late iterations that are not improving the best state.
state.pair_local_plateau_count += 1
if state.pair_local_plateau_count >= 1:
reroute_net_ids = self._next_reroute_net_ids(state, review)
if self._should_stop_for_pair_local_plateau(state, improved=improved):
return False
else:
state.pair_local_plateau_count = 0
current_signature = tuple(sorted(review.conflict_edges))
repeated = (
bool(current_signature)
and (
current_signature == state.last_conflict_signature
or len(current_signature) == state.last_conflict_edge_count
)
)
state.repeated_conflict_count = state.repeated_conflict_count + 1 if repeated else 0
state.last_conflict_signature = current_signature
state.last_conflict_edge_count = len(current_signature)
if state.repeated_conflict_count >= 2:
if self._update_repeated_conflict_state(state, review):
return False
self.context.congestion_penalty *= congestion.multiplier
return False
@ -1228,32 +1340,7 @@ class PathFinder:
timed_out = self._run_iterations(state, iteration_callback)
self.accumulated_expanded_nodes = list(state.accumulated_expanded_nodes)
self._restore_best_iteration(state)
capture_component_conflicts = (
self.context.options.diagnostics.capture_conflict_trace
or self.context.options.diagnostics.capture_pre_pair_frontier_trace
)
state.results, details_by_net, review = self._analyze_results(
state.ordered_net_ids,
state.results,
capture_component_conflicts=capture_component_conflicts,
count_iteration_metrics=False,
)
if self.context.options.diagnostics.capture_conflict_trace:
self._capture_conflict_trace_entry(
state,
stage="restored_best",
iteration=None,
results=state.results,
details_by_net=details_by_net,
review=review,
)
if self.context.options.diagnostics.capture_pre_pair_frontier_trace:
self.pre_pair_frontier_trace = self._materialize_pre_pair_frontier_trace(
state,
state.results,
details_by_net,
review,
)
self._analyze_restored_best(state)
if timed_out:
final_results = self._verify_results(state)