astar refactor

This commit is contained in:
Jan Petykiewicz 2026-03-21 12:57:55 -07:00
commit 62d357c147

View file

@ -2,14 +2,12 @@ from __future__ import annotations
import heapq import heapq
import logging import logging
import functools
from typing import TYPE_CHECKING, Literal, Any from typing import TYPE_CHECKING, Literal, Any
import rtree
import numpy import numpy
import shapely import shapely
from inire.geometry.components import Bend90, SBend, Straight, SEARCH_GRID_SNAP_UM, snap_search_grid from inire.geometry.components import Bend90, SBend, Straight, snap_search_grid
from inire.geometry.primitives import Port from inire.geometry.primitives import Port
from inire.router.config import RouterConfig from inire.router.config import RouterConfig
from inire.router.visibility import VisibilityManager from inire.router.visibility import VisibilityManager
@ -50,73 +48,70 @@ class AStarNode:
return self.h_cost < other.h_cost return self.h_cost < other.h_cost
class AStarRouter: class AStarMetrics:
""" """
Waveguide router based on sparse A* search. Performance metrics and instrumentation for A* search.
""" """
__slots__ = ('cost_evaluator', 'config', 'node_limit', 'visibility_manager', __slots__ = ('total_nodes_expanded', 'last_expanded_nodes', 'nodes_expanded',
'_hard_collision_set', '_congestion_cache', '_static_safe_cache', 'moves_generated', 'moves_added', 'pruned_closed_set',
'_move_cache', 'total_nodes_expanded', 'last_expanded_nodes', 'metrics', 'pruned_hard_collision', 'pruned_cost')
'_self_collision_check')
def __init__(self, cost_evaluator: CostEvaluator, node_limit: int | None = None, **kwargs) -> None:
self.cost_evaluator = cost_evaluator
self.config = RouterConfig(sbend_radii=[5.0, 10.0, 50.0, 100.0])
if node_limit is not None:
self.config.node_limit = node_limit
for k, v in kwargs.items():
if hasattr(self.config, k):
setattr(self.config, k, v)
self.node_limit = self.config.node_limit
# Visibility Manager for sparse jumps
self.visibility_manager = VisibilityManager(self.cost_evaluator.collision_engine)
self._hard_collision_set: set[tuple] = set()
self._congestion_cache: dict[tuple, int] = {}
self._static_safe_cache: set[tuple] = set()
self._move_cache: dict[tuple, ComponentResult] = {}
def __init__(self) -> None:
self.total_nodes_expanded = 0 self.total_nodes_expanded = 0
self.last_expanded_nodes: list[tuple[float, float, float]] = [] self.last_expanded_nodes: list[tuple[float, float, float]] = []
self.nodes_expanded = 0
self.moves_generated = 0
self.moves_added = 0
self.pruned_closed_set = 0
self.pruned_hard_collision = 0
self.pruned_cost = 0
self.metrics = { def reset_per_route(self) -> None:
'nodes_expanded': 0, """ Reset metrics that are specific to a single route() call. """
'moves_generated': 0, self.nodes_expanded = 0
'moves_added': 0, self.moves_generated = 0
'pruned_closed_set': 0, self.moves_added = 0
'pruned_hard_collision': 0, self.pruned_closed_set = 0
'pruned_cost': 0 self.pruned_hard_collision = 0
self.pruned_cost = 0
self.last_expanded_nodes = []
def get_summary_dict(self) -> dict[str, int]:
""" Return a dictionary of current metrics. """
return {
'nodes_expanded': self.nodes_expanded,
'moves_generated': self.moves_generated,
'moves_added': self.moves_added,
'pruned_closed_set': self.pruned_closed_set,
'pruned_hard_collision': self.pruned_hard_collision,
'pruned_cost': self.pruned_cost
} }
def reset_metrics(self) -> None:
""" Reset all performance counters. """
for k in self.metrics:
self.metrics[k] = 0
self.cost_evaluator.collision_engine.reset_metrics()
def get_metrics_summary(self) -> str: class AStarContext:
""" Return a human-readable summary of search performance. """ """
m = self.metrics Persistent state for A* search, decoupled from search logic.
c = self.cost_evaluator.collision_engine.get_metrics_summary() """
return (f"Search Performance: \n" __slots__ = ('cost_evaluator', 'config', 'visibility_manager',
f" Nodes Expanded: {m['nodes_expanded']}\n" 'move_cache', 'hard_collision_set', 'static_safe_cache')
f" Moves: Generated={m['moves_generated']}, Added={m['moves_added']}\n"
f" Pruning: ClosedSet={m['pruned_closed_set']}, HardColl={m['pruned_hard_collision']}, Cost={m['pruned_cost']}\n"
f" {c}")
@property def __init__(self, cost_evaluator: CostEvaluator, config: RouterConfig | None = None) -> None:
def _self_dilation(self) -> float: self.cost_evaluator = cost_evaluator
return self.cost_evaluator.collision_engine.clearance / 2.0 self.config = config if config is not None else RouterConfig()
self.visibility_manager = VisibilityManager(self.cost_evaluator.collision_engine)
def route( # Long-lived caches (shared across multiple route calls)
self, self.move_cache: dict[tuple, ComponentResult] = {}
self.hard_collision_set: set[tuple] = set()
self.static_safe_cache: set[tuple] = set()
def route_astar(
start: Port, start: Port,
target: Port, target: Port,
net_width: float, net_width: float,
context: AStarContext,
metrics: AStarMetrics | None = None,
net_id: str = 'default', net_id: str = 'default',
bend_collision_type: Literal['arc', 'bbox', 'clipped_bbox'] | None = None, bend_collision_type: Literal['arc', 'bbox', 'clipped_bbox'] | None = None,
return_partial: bool = False, return_partial: bool = False,
@ -124,56 +119,48 @@ class AStarRouter:
skip_congestion: bool = False, skip_congestion: bool = False,
max_cost: float | None = None, max_cost: float | None = None,
self_collision_check: bool = False, self_collision_check: bool = False,
node_limit: int | None = None,
) -> list[ComponentResult] | None: ) -> list[ComponentResult] | None:
""" """
Route a single net using A*. Functional implementation of A* routing.
Args:
start: Starting port.
target: Target port.
net_width: Waveguide width.
net_id: Identifier for the net.
bend_collision_type: Type of collision model to use for bends.
return_partial: If True, returns the best-effort path if target not reached.
store_expanded: If True, keep track of all expanded nodes for visualization.
skip_congestion: If True, ignore other nets' paths (greedy mode).
max_cost: Hard limit on f_cost to prune search.
self_collision_check: If True, prevent the net from crossing its own path.
""" """
self._self_collision_check = self_collision_check if metrics is None:
self._congestion_cache.clear() metrics = AStarMetrics()
if store_expanded:
self.last_expanded_nodes = [] metrics.reset_per_route()
# Per-route congestion cache (not shared across different routes)
congestion_cache: dict[tuple, int] = {}
if bend_collision_type is not None: if bend_collision_type is not None:
self.config.bend_collision_type = bend_collision_type context.config.bend_collision_type = bend_collision_type
self.cost_evaluator.set_target(target) context.cost_evaluator.set_target(target)
open_set: list[AStarNode] = [] open_set: list[AStarNode] = []
snap = self.config.snap_size snap = context.config.snap_size
inv_snap = 1.0 / snap inv_snap = 1.0 / snap
# (x_grid, y_grid, orientation_grid) -> min_g_cost # (x_grid, y_grid, orientation_grid) -> min_g_cost
closed_set: dict[tuple[int, int, int], float] = {} closed_set: dict[tuple[int, int, int], float] = {}
start_node = AStarNode(start, 0.0, self.cost_evaluator.h_manhattan(start, target)) start_node = AStarNode(start, 0.0, context.cost_evaluator.h_manhattan(start, target))
heapq.heappush(open_set, start_node) heapq.heappush(open_set, start_node)
best_node = start_node best_node = start_node
nodes_expanded = 0 nodes_expanded = 0
node_limit = self.node_limit effective_node_limit = node_limit if node_limit is not None else context.config.node_limit
while open_set: while open_set:
if nodes_expanded >= node_limit: if nodes_expanded >= effective_node_limit:
return self._reconstruct_path(best_node) if return_partial else None return reconstruct_path(best_node) if return_partial else None
current = heapq.heappop(open_set) current = heapq.heappop(open_set)
# Cost Pruning (Fail Fast) # Cost Pruning (Fail Fast)
if max_cost is not None and current.f_cost > max_cost: if max_cost is not None and current.f_cost > max_cost:
self.metrics['pruned_cost'] += 1 metrics.pruned_cost += 1
continue continue
if current.h_cost < best_node.h_cost: if current.h_cost < best_node.h_cost:
@ -185,38 +172,50 @@ class AStarRouter:
closed_set[state] = current.g_cost closed_set[state] = current.g_cost
if store_expanded: if store_expanded:
self.last_expanded_nodes.append((current.port.x, current.port.y, current.port.orientation)) metrics.last_expanded_nodes.append((current.port.x, current.port.y, current.port.orientation))
nodes_expanded += 1 nodes_expanded += 1
self.total_nodes_expanded += 1 metrics.total_nodes_expanded += 1
self.metrics['nodes_expanded'] += 1 metrics.nodes_expanded += 1
# Check if we reached the target exactly # Check if we reached the target exactly
if (abs(current.port.x - target.x) < 1e-6 and if (abs(current.port.x - target.x) < 1e-6 and
abs(current.port.y - target.y) < 1e-6 and abs(current.port.y - target.y) < 1e-6 and
abs(current.port.orientation - target.orientation) < 0.1): abs(current.port.orientation - target.orientation) < 0.1):
return self._reconstruct_path(current) return reconstruct_path(current)
# Expansion # Expansion
self._expand_moves(current, target, net_width, net_id, open_set, closed_set, snap, nodes_expanded, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=state, max_cost=max_cost) expand_moves(
current, target, net_width, net_id, open_set, closed_set,
context, metrics, congestion_cache,
snap=snap, inv_snap=inv_snap, parent_state=state,
max_cost=max_cost, skip_congestion=skip_congestion,
self_collision_check=self_collision_check
)
return self._reconstruct_path(best_node) if return_partial else None return reconstruct_path(best_node) if return_partial else None
def _expand_moves(
self, def expand_moves(
current: AStarNode, current: AStarNode,
target: Port, target: Port,
net_width: float, net_width: float,
net_id: str, net_id: str,
open_set: list[AStarNode], open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float], closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
snap: float = 1.0, snap: float = 1.0,
nodes_expanded: int = 0,
skip_congestion: bool = False,
inv_snap: float | None = None, inv_snap: float | None = None,
parent_state: tuple[int, int, int] | None = None, parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None max_cost: float | None = None,
skip_congestion: bool = False,
self_collision_check: bool = False,
) -> None: ) -> None:
"""
Extract moves and add valid successors to the open set.
"""
cp = current.port cp = current.port
if inv_snap is None: inv_snap = 1.0 / snap if inv_snap is None: inv_snap = 1.0 / snap
if parent_state is None: if parent_state is None:
@ -228,68 +227,77 @@ class AStarRouter:
rad = numpy.radians(cp.orientation) rad = numpy.radians(cp.orientation)
cos_v, sin_v = numpy.cos(rad), numpy.sin(rad) cos_v, sin_v = numpy.cos(rad), numpy.sin(rad)
# 1. DIRECT JUMP TO TARGET # 1. DIRECT JUMP TO TARGET
proj_t = dx_t * cos_v + dy_t * sin_v proj_t = dx_t * cos_v + dy_t * sin_v
perp_t = -dx_t * sin_v + dy_t * cos_v perp_t = -dx_t * sin_v + dy_t * cos_v
# A. Straight Jump # A. Straight Jump
if proj_t > 0 and abs(perp_t) < 1e-3 and abs(cp.orientation - target.orientation) < 0.1: if proj_t > 0 and abs(perp_t) < 1e-3 and abs(cp.orientation - target.orientation) < 0.1:
max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, proj_t + 1.0) max_reach = context.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, proj_t + 1.0)
if max_reach >= proj_t - 0.01: if max_reach >= proj_t - 0.01:
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{proj_t}', 'S', (proj_t,), skip_congestion, inv_snap=inv_snap, snap_to_grid=False, parent_state=parent_state, max_cost=max_cost) process_move(
current, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
f'S{proj_t}', 'S', (proj_t,), skip_congestion, inv_snap=inv_snap, snap_to_grid=False,
parent_state=parent_state, max_cost=max_cost, snap=snap, self_collision_check=self_collision_check
)
# 2. VISIBILITY JUMPS & MAX REACH # 2. VISIBILITY JUMPS & MAX REACH
max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, self.config.max_straight_length) max_reach = context.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, context.config.max_straight_length)
straight_lengths = set() straight_lengths = set()
if max_reach > self.config.min_straight_length: if max_reach > context.config.min_straight_length:
straight_lengths.add(snap_search_grid(max_reach, snap)) straight_lengths.add(snap_search_grid(max_reach, snap))
for radius in self.config.bend_radii: for radius in context.config.bend_radii:
if max_reach > radius + self.config.min_straight_length: if max_reach > radius + context.config.min_straight_length:
straight_lengths.add(snap_search_grid(max_reach - radius, snap)) straight_lengths.add(snap_search_grid(max_reach - radius, snap))
if max_reach > self.config.min_straight_length + 5.0: if max_reach > context.config.min_straight_length + 5.0:
straight_lengths.add(snap_search_grid(max_reach - 5.0, snap)) straight_lengths.add(snap_search_grid(max_reach - 5.0, snap))
visible_corners = self.visibility_manager.get_visible_corners(cp, max_dist=max_reach) visible_corners = context.visibility_manager.get_visible_corners(cp, max_dist=max_reach)
for cx, cy, dist in visible_corners: for cx, cy, dist in visible_corners:
proj = (cx - cp.x) * cos_v + (cy - cp.y) * sin_v proj = (cx - cp.x) * cos_v + (cy - cp.y) * sin_v
if proj > self.config.min_straight_length: if proj > context.config.min_straight_length:
straight_lengths.add(snap_search_grid(proj, snap)) straight_lengths.add(snap_search_grid(proj, snap))
straight_lengths.add(self.config.min_straight_length) straight_lengths.add(context.config.min_straight_length)
if max_reach > self.config.min_straight_length * 4: if max_reach > context.config.min_straight_length * 4:
straight_lengths.add(snap_search_grid(max_reach / 2.0, snap)) straight_lengths.add(snap_search_grid(max_reach / 2.0, snap))
if abs(cp.orientation % 180) < 0.1: # Horizontal if abs(cp.orientation % 180) < 0.1: # Horizontal
target_dist = abs(target.x - cp.x) target_dist = abs(target.x - cp.x)
if target_dist <= max_reach and target_dist > self.config.min_straight_length: if target_dist <= max_reach and target_dist > context.config.min_straight_length:
sl = snap_search_grid(target_dist, snap) sl = snap_search_grid(target_dist, snap)
if sl > 0.1: straight_lengths.add(sl) if sl > 0.1: straight_lengths.add(sl)
for radius in self.config.bend_radii: for radius in context.config.bend_radii:
for l in [target_dist - radius, target_dist - 2*radius]: for l in [target_dist - radius, target_dist - 2*radius]:
if l > self.config.min_straight_length: if l > context.config.min_straight_length:
s_l = snap_search_grid(l, snap) s_l = snap_search_grid(l, snap)
if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l) if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l)
else: # Vertical else: # Vertical
target_dist = abs(target.y - cp.y) target_dist = abs(target.y - cp.y)
if target_dist <= max_reach and target_dist > self.config.min_straight_length: if target_dist <= max_reach and target_dist > context.config.min_straight_length:
sl = snap_search_grid(target_dist, snap) sl = snap_search_grid(target_dist, snap)
if sl > 0.1: straight_lengths.add(sl) if sl > 0.1: straight_lengths.add(sl)
for radius in self.config.bend_radii: for radius in context.config.bend_radii:
for l in [target_dist - radius, target_dist - 2*radius]: for l in [target_dist - radius, target_dist - 2*radius]:
if l > self.config.min_straight_length: if l > context.config.min_straight_length:
s_l = snap_search_grid(l, snap) s_l = snap_search_grid(l, snap)
if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l) if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l)
for length in sorted(straight_lengths, reverse=True): for length in sorted(straight_lengths, reverse=True):
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{length}', 'S', (length,), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) process_move(
current, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
f'S{length}', 'S', (length,), skip_congestion, inv_snap=inv_snap, parent_state=parent_state,
max_cost=max_cost, snap=snap, self_collision_check=self_collision_check
)
# 3. BENDS & SBENDS # 3. BENDS & SBENDS
angle_to_target = numpy.degrees(numpy.arctan2(target.y - cp.y, target.x - cp.x)) angle_to_target = numpy.degrees(numpy.arctan2(target.y - cp.y, target.x - cp.x))
allow_backwards = (dist_sq < 150*150) allow_backwards = (dist_sq < 150*150)
for radius in self.config.bend_radii: for radius in context.config.bend_radii:
for direction in ['CW', 'CCW']: for direction in ['CW', 'CCW']:
if not allow_backwards: if not allow_backwards:
turn = 90 if direction == 'CCW' else -90 turn = 90 if direction == 'CCW' else -90
@ -297,12 +305,16 @@ class AStarRouter:
new_diff = (angle_to_target - new_ori + 180) % 360 - 180 new_diff = (angle_to_target - new_ori + 180) % 360 - 180
if abs(new_diff) > 135: if abs(new_diff) > 135:
continue continue
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'B{radius}{direction}', 'B', (radius, direction), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) process_move(
current, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
f'B{radius}{direction}', 'B', (radius, direction), skip_congestion, inv_snap=inv_snap,
parent_state=parent_state, max_cost=max_cost, snap=snap, self_collision_check=self_collision_check
)
# 4. SBENDS # 4. SBENDS
max_sbend_r = max(self.config.sbend_radii) if self.config.sbend_radii else 0 max_sbend_r = max(context.config.sbend_radii) if context.config.sbend_radii else 0
if max_sbend_r > 0: if max_sbend_r > 0:
user_offsets = self.config.sbend_offsets user_offsets = context.config.sbend_offsets
offsets: set[float] = set(user_offsets) if user_offsets is not None else set() offsets: set[float] = set(user_offsets) if user_offsets is not None else set()
dx_local = (target.x - cp.x) * cos_v + (target.y - cp.y) * sin_v dx_local = (target.x - cp.x) * cos_v + (target.y - cp.y) * sin_v
dy_local = -(target.x - cp.x) * sin_v + (target.y - cp.y) * cos_v dy_local = -(target.x - cp.x) * sin_v + (target.y - cp.y) * cos_v
@ -318,19 +330,25 @@ class AStarRouter:
if abs(o) < 2 * max_sbend_r: offsets.add(o) if abs(o) < 2 * max_sbend_r: offsets.add(o)
for offset in sorted(offsets): for offset in sorted(offsets):
for radius in self.config.sbend_radii: for radius in context.config.sbend_radii:
if abs(offset) >= 2 * radius: continue if abs(offset) >= 2 * radius: continue
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'SB{offset}R{radius}', 'SB', (offset, radius), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) process_move(
current, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
f'SB{offset}R{radius}', 'SB', (offset, radius), skip_congestion, inv_snap=inv_snap,
parent_state=parent_state, max_cost=max_cost, snap=snap, self_collision_check=self_collision_check
)
def _process_move(
self, def process_move(
parent: AStarNode, parent: AStarNode,
target: Port, target: Port,
net_width: float, net_width: float,
net_id: str, net_id: str,
open_set: list[AStarNode], open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float], closed_set: dict[tuple[int, int, int], float],
snap: float, context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
move_type: str, move_type: str,
move_class: Literal['S', 'B', 'SB'], move_class: Literal['S', 'B', 'SB'],
params: tuple, params: tuple,
@ -338,8 +356,13 @@ class AStarRouter:
inv_snap: float | None = None, inv_snap: float | None = None,
snap_to_grid: bool = True, snap_to_grid: bool = True,
parent_state: tuple[int, int, int] | None = None, parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None max_cost: float | None = None,
snap: float = 1.0,
self_collision_check: bool = False,
) -> None: ) -> None:
"""
Generate or retrieve geometry and delegate to add_node.
"""
cp = parent.port cp = parent.port
if inv_snap is None: inv_snap = 1.0 / snap if inv_snap is None: inv_snap = 1.0 / snap
base_ori = float(int(cp.orientation + 0.5)) base_ori = float(int(cp.orientation + 0.5))
@ -350,45 +373,55 @@ class AStarRouter:
parent_state = (gx, gy, go) parent_state = (gx, gy, go)
else: else:
gx, gy, go = parent_state gx, gy, go = parent_state
state_key = parent_state
abs_key = (state_key, move_class, params, net_width, self.config.bend_collision_type, snap_to_grid) abs_key = (parent_state, move_class, params, net_width, context.config.bend_collision_type, snap_to_grid)
if abs_key in self._move_cache: if abs_key in context.move_cache:
res = self._move_cache[abs_key] res = context.move_cache[abs_key]
move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None) move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None)
self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) add_node(
parent, res, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion,
inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost,
self_collision_check=self_collision_check
)
return return
rel_key = (base_ori, move_class, params, net_width, self.config.bend_collision_type, self._self_dilation, snap_to_grid) self_dilation = context.cost_evaluator.collision_engine.clearance / 2.0
rel_key = (base_ori, move_class, params, net_width, context.config.bend_collision_type, self_dilation, snap_to_grid)
cache_key = (gx, gy, go, move_type, net_width) cache_key = (gx, gy, go, move_type, net_width)
if cache_key in self._hard_collision_set: if cache_key in context.hard_collision_set:
return return
if rel_key in self._move_cache: if rel_key in context.move_cache:
res_rel = self._move_cache[rel_key] res_rel = context.move_cache[rel_key]
else: else:
try: try:
p0 = Port(0, 0, base_ori) p0 = Port(0, 0, base_ori)
if move_class == 'S': if move_class == 'S':
res_rel = Straight.generate(p0, params[0], net_width, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) res_rel = Straight.generate(p0, params[0], net_width, dilation=self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
elif move_class == 'B': elif move_class == 'B':
res_rel = Bend90.generate(p0, params[0], net_width, params[1], collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) res_rel = Bend90.generate(p0, params[0], net_width, params[1], collision_type=context.config.bend_collision_type, clip_margin=context.config.bend_clip_margin, dilation=self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
elif move_class == 'SB': elif move_class == 'SB':
res_rel = SBend.generate(p0, params[0], params[1], net_width, collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) res_rel = SBend.generate(p0, params[0], params[1], net_width, collision_type=context.config.bend_collision_type, clip_margin=context.config.bend_clip_margin, dilation=self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
else: else:
return return
self._move_cache[rel_key] = res_rel context.move_cache[rel_key] = res_rel
except (ValueError, ZeroDivisionError): except (ValueError, ZeroDivisionError):
return return
res = res_rel.translate(cp.x, cp.y, rel_gx=res_rel.rel_gx + gx, rel_gy=res_rel.rel_gy + gy, rel_go=res_rel.rel_go) res = res_rel.translate(cp.x, cp.y, rel_gx=res_rel.rel_gx + gx, rel_gy=res_rel.rel_gy + gy, rel_go=res_rel.rel_go)
self._move_cache[abs_key] = res context.move_cache[abs_key] = res
move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None) move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None)
self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) add_node(
parent, res, target, net_width, net_id, open_set, closed_set, context, metrics, congestion_cache,
move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion,
inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost,
self_collision_check=self_collision_check
)
def _add_node(
self, def add_node(
parent: AStarNode, parent: AStarNode,
result: ComponentResult, result: ComponentResult,
target: Port, target: Port,
@ -396,19 +429,26 @@ class AStarRouter:
net_id: str, net_id: str,
open_set: list[AStarNode], open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float], closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
move_type: str, move_type: str,
move_radius: float | None = None, move_radius: float | None = None,
snap: float = 1.0, snap: float = 1.0,
skip_congestion: bool = False, skip_congestion: bool = False,
inv_snap: float | None = None, inv_snap: float | None = None,
parent_state: tuple[int, int, int] | None = None, parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None max_cost: float | None = None,
self_collision_check: bool = False,
) -> None: ) -> None:
self.metrics['moves_generated'] += 1 """
Check collisions and costs, and add node to the open set.
"""
metrics.moves_generated += 1
state = (result.rel_gx, result.rel_gy, result.rel_go) state = (result.rel_gx, result.rel_gy, result.rel_go)
if state in closed_set and closed_set[state] <= parent.g_cost + 1e-6: if state in closed_set and closed_set[state] <= parent.g_cost + 1e-6:
self.metrics['pruned_closed_set'] += 1 metrics.pruned_closed_set += 1
return return
parent_p = parent.port parent_p = parent.port
@ -419,44 +459,45 @@ class AStarRouter:
pgx, pgy, pgo = parent_state pgx, pgy, pgo = parent_state
cache_key = (pgx, pgy, pgo, move_type, net_width) cache_key = (pgx, pgy, pgo, move_type, net_width)
if cache_key in self._hard_collision_set: if cache_key in context.hard_collision_set:
self.metrics['pruned_hard_collision'] += 1 metrics.pruned_hard_collision += 1
return return
new_g_cost = parent.g_cost + result.length new_g_cost = parent.g_cost + result.length
# Pre-check cost pruning before evaluation (using heuristic) # Pre-check cost pruning before evaluation (using heuristic)
if max_cost is not None: if max_cost is not None:
new_h_cost = self.cost_evaluator.h_manhattan(end_p, target) new_h_cost = context.cost_evaluator.h_manhattan(end_p, target)
if new_g_cost + new_h_cost > max_cost: if new_g_cost + new_h_cost > max_cost:
self.metrics['pruned_cost'] += 1 metrics.pruned_cost += 1
return return
is_static_safe = (cache_key in self._static_safe_cache) is_static_safe = (cache_key in context.static_safe_cache)
if not is_static_safe: if not is_static_safe:
ce = self.cost_evaluator.collision_engine ce = context.cost_evaluator.collision_engine
if 'S' in move_type and 'SB' not in move_type: if 'S' in move_type and 'SB' not in move_type:
if ce.check_move_straight_static(parent_p, result.length): if ce.check_move_straight_static(parent_p, result.length):
self._hard_collision_set.add(cache_key) context.hard_collision_set.add(cache_key)
self.metrics['pruned_hard_collision'] += 1 metrics.pruned_hard_collision += 1
return return
is_static_safe = True is_static_safe = True
if not is_static_safe: if not is_static_safe:
if ce.check_move_static(result, start_port=parent_p, end_port=end_p): if ce.check_move_static(result, start_port=parent_p, end_port=end_p):
self._hard_collision_set.add(cache_key) context.hard_collision_set.add(cache_key)
self.metrics['pruned_hard_collision'] += 1 metrics.pruned_hard_collision += 1
return return
else: self._static_safe_cache.add(cache_key) else: context.static_safe_cache.add(cache_key)
total_overlaps = 0 total_overlaps = 0
if not skip_congestion: if not skip_congestion:
if cache_key in self._congestion_cache: total_overlaps = self._congestion_cache[cache_key] if cache_key in congestion_cache:
total_overlaps = congestion_cache[cache_key]
else: else:
total_overlaps = self.cost_evaluator.collision_engine.check_move_congestion(result, net_id) total_overlaps = context.cost_evaluator.collision_engine.check_move_congestion(result, net_id)
self._congestion_cache[cache_key] = total_overlaps congestion_cache[cache_key] = total_overlaps
# SELF-COLLISION CHECK (Optional for performance) # SELF-COLLISION CHECK (Optional for performance)
if getattr(self, '_self_collision_check', False): if self_collision_check:
curr_p = parent curr_p = parent
new_tb = result.total_bounds new_tb = result.total_bounds
while curr_p and curr_p.parent: while curr_p and curr_p.parent:
@ -472,35 +513,115 @@ class AStarRouter:
curr_p = curr_p.parent curr_p = curr_p.parent
penalty = 0.0 penalty = 0.0
if 'SB' in move_type: penalty = self.config.sbend_penalty if 'SB' in move_type: penalty = context.config.sbend_penalty
elif 'B' in move_type: penalty = self.config.bend_penalty elif 'B' in move_type: penalty = context.config.bend_penalty
if move_radius is not None and move_radius > 1e-6: penalty *= (10.0 / move_radius)**0.5 if move_radius is not None and move_radius > 1e-6: penalty *= (10.0 / move_radius)**0.5
move_cost = self.cost_evaluator.evaluate_move( move_cost = context.cost_evaluator.evaluate_move(
None, result.end_port, net_width, net_id, None, result.end_port, net_width, net_id,
start_port=parent_p, length=result.length, start_port=parent_p, length=result.length,
dilated_geometry=None, penalty=penalty, dilated_geometry=None, penalty=penalty,
skip_static=True, skip_congestion=True skip_static=True, skip_congestion=True
) )
move_cost += total_overlaps * self.cost_evaluator.congestion_penalty move_cost += total_overlaps * context.cost_evaluator.congestion_penalty
if move_cost > 1e12: if move_cost > 1e12:
self.metrics['pruned_cost'] += 1 metrics.pruned_cost += 1
return return
g_cost = parent.g_cost + move_cost g_cost = parent.g_cost + move_cost
if state in closed_set and closed_set[state] <= g_cost + 1e-6: if state in closed_set and closed_set[state] <= g_cost + 1e-6:
self.metrics['pruned_closed_set'] += 1 metrics.pruned_closed_set += 1
return return
h_cost = self.cost_evaluator.h_manhattan(result.end_port, target) h_cost = context.cost_evaluator.h_manhattan(result.end_port, target)
heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result)) heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result))
self.metrics['moves_added'] += 1 metrics.moves_added += 1
def _reconstruct_path(self, end_node: AStarNode) -> list[ComponentResult]:
def reconstruct_path(end_node: AStarNode) -> list[ComponentResult]:
""" Trace back from end node to start node to get the path. """
path = [] path = []
curr: AStarNode | None = end_node curr: AStarNode | None = end_node
while curr and curr.component_result: while curr and curr.component_result:
path.append(curr.component_result) path.append(curr.component_result)
curr = curr.parent curr = curr.parent
return path[::-1] return path[::-1]
class AStarRouter:
"""
Waveguide router based on sparse A* search.
Wrapper around functional core.
"""
__slots__ = ('context', 'metrics')
def __init__(self, cost_evaluator: CostEvaluator, node_limit: int | None = None, **kwargs) -> None:
config = RouterConfig(sbend_radii=[5.0, 10.0, 50.0, 100.0])
if node_limit is not None:
config.node_limit = node_limit
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
self.context = AStarContext(cost_evaluator, config)
self.metrics = AStarMetrics()
@property
def cost_evaluator(self): return self.context.cost_evaluator
@property
def config(self): return self.context.config
@property
def visibility_manager(self): return self.context.visibility_manager
@property
def node_limit(self): return self.context.config.node_limit
@node_limit.setter
def node_limit(self, value): self.context.config.node_limit = value
@property
def total_nodes_expanded(self): return self.metrics.total_nodes_expanded
@total_nodes_expanded.setter
def total_nodes_expanded(self, value): self.metrics.total_nodes_expanded = value
@property
def last_expanded_nodes(self): return self.metrics.last_expanded_nodes
@property
def metrics_dict(self): return self.metrics.get_summary_dict()
def reset_metrics(self) -> None:
""" Reset all performance counters. """
self.metrics.reset_per_route()
self.context.cost_evaluator.collision_engine.reset_metrics()
def get_metrics_summary(self) -> str:
""" Return a human-readable summary of search performance. """
m = self.metrics
c = self.context.cost_evaluator.collision_engine.get_metrics_summary()
return (f"Search Performance: \n"
f" Nodes Expanded: {m.nodes_expanded}\n"
f" Moves: Generated={m.moves_generated}, Added={m.moves_added}\n"
f" Pruning: ClosedSet={m.pruned_closed_set}, HardColl={m.pruned_hard_collision}, Cost={m.pruned_cost}\n"
f" {c}")
def route(
self,
start: Port,
target: Port,
net_width: float,
net_id: str = 'default',
bend_collision_type: Literal['arc', 'bbox', 'clipped_bbox'] | None = None,
return_partial: bool = False,
store_expanded: bool = False,
skip_congestion: bool = False,
max_cost: float | None = None,
self_collision_check: bool = False,
) -> list[ComponentResult] | None:
"""
Route a single net using A*. Delegates to route_astar.
"""
return route_astar(
start, target, net_width, self.context, self.metrics,
net_id=net_id, bend_collision_type=bend_collision_type,
return_partial=return_partial, store_expanded=store_expanded,
skip_congestion=skip_congestion, max_cost=max_cost,
self_collision_check=self_collision_check,
node_limit=self.context.config.node_limit
)