from __future__ import annotations import heapq import logging import functools from typing import TYPE_CHECKING, Literal, Any import rtree import numpy import shapely from inire.geometry.components import Bend90, SBend, Straight, SEARCH_GRID_SNAP_UM, snap_search_grid from inire.geometry.primitives import Port from inire.router.config import RouterConfig from inire.router.visibility import VisibilityManager if TYPE_CHECKING: from inire.geometry.components import ComponentResult from inire.router.cost import CostEvaluator logger = logging.getLogger(__name__) class AStarNode: """ A node in the A* search tree. """ __slots__ = ('port', 'g_cost', 'h_cost', 'f_cost', 'parent', 'component_result') def __init__( self, port: Port, g_cost: float, h_cost: float, parent: AStarNode | None = None, component_result: ComponentResult | None = None, ) -> None: self.port = port self.g_cost = g_cost self.h_cost = h_cost self.f_cost = g_cost + h_cost self.parent = parent self.component_result = component_result def __lt__(self, other: AStarNode) -> bool: if self.f_cost < other.f_cost - 1e-6: return True if self.f_cost > other.f_cost + 1e-6: return False return self.h_cost < other.h_cost class AStarRouter: """ Waveguide router based on sparse A* search. """ __slots__ = ('cost_evaluator', 'config', 'node_limit', 'visibility_manager', '_hard_collision_set', '_congestion_cache', '_static_safe_cache', '_move_cache', 'total_nodes_expanded', 'last_expanded_nodes', 'metrics', '_self_collision_check') def __init__(self, cost_evaluator: CostEvaluator, node_limit: int | None = None, **kwargs) -> None: self.cost_evaluator = cost_evaluator self.config = RouterConfig(sbend_radii=[5.0, 10.0, 50.0, 100.0]) if node_limit is not None: self.config.node_limit = node_limit for k, v in kwargs.items(): if hasattr(self.config, k): setattr(self.config, k, v) self.node_limit = self.config.node_limit # Visibility Manager for sparse jumps self.visibility_manager = VisibilityManager(self.cost_evaluator.collision_engine) self._hard_collision_set: set[tuple] = set() self._congestion_cache: dict[tuple, int] = {} self._static_safe_cache: set[tuple] = set() self._move_cache: dict[tuple, ComponentResult] = {} self.total_nodes_expanded = 0 self.last_expanded_nodes: list[tuple[float, float, float]] = [] self.metrics = { 'nodes_expanded': 0, 'moves_generated': 0, 'moves_added': 0, 'pruned_closed_set': 0, 'pruned_hard_collision': 0, 'pruned_cost': 0 } def reset_metrics(self) -> None: """ Reset all performance counters. """ for k in self.metrics: self.metrics[k] = 0 self.cost_evaluator.collision_engine.reset_metrics() def get_metrics_summary(self) -> str: """ Return a human-readable summary of search performance. """ m = self.metrics c = self.cost_evaluator.collision_engine.get_metrics_summary() return (f"Search Performance: \n" f" Nodes Expanded: {m['nodes_expanded']}\n" f" Moves: Generated={m['moves_generated']}, Added={m['moves_added']}\n" f" Pruning: ClosedSet={m['pruned_closed_set']}, HardColl={m['pruned_hard_collision']}, Cost={m['pruned_cost']}\n" f" {c}") @property def _self_dilation(self) -> float: return self.cost_evaluator.collision_engine.clearance / 2.0 def route( self, start: Port, target: Port, net_width: float, net_id: str = 'default', bend_collision_type: Literal['arc', 'bbox', 'clipped_bbox'] | None = None, return_partial: bool = False, store_expanded: bool = False, skip_congestion: bool = False, max_cost: float | None = None, self_collision_check: bool = False, ) -> list[ComponentResult] | None: """ Route a single net using A*. Args: start: Starting port. target: Target port. net_width: Waveguide width. net_id: Identifier for the net. bend_collision_type: Type of collision model to use for bends. return_partial: If True, returns the best-effort path if target not reached. store_expanded: If True, keep track of all expanded nodes for visualization. skip_congestion: If True, ignore other nets' paths (greedy mode). max_cost: Hard limit on f_cost to prune search. self_collision_check: If True, prevent the net from crossing its own path. """ self._self_collision_check = self_collision_check self._congestion_cache.clear() if store_expanded: self.last_expanded_nodes = [] if bend_collision_type is not None: self.config.bend_collision_type = bend_collision_type self.cost_evaluator.set_target(target) open_set: list[AStarNode] = [] snap = self.config.snap_size inv_snap = 1.0 / snap # (x_grid, y_grid, orientation_grid) -> min_g_cost closed_set: dict[tuple[int, int, int], float] = {} start_node = AStarNode(start, 0.0, self.cost_evaluator.h_manhattan(start, target)) heapq.heappush(open_set, start_node) best_node = start_node nodes_expanded = 0 node_limit = self.node_limit while open_set: if nodes_expanded >= node_limit: return self._reconstruct_path(best_node) if return_partial else None current = heapq.heappop(open_set) # Cost Pruning (Fail Fast) if max_cost is not None and current.f_cost > max_cost: self.metrics['pruned_cost'] += 1 continue if current.h_cost < best_node.h_cost: best_node = current state = (int(round(current.port.x / snap)), int(round(current.port.y / snap)), int(round(current.port.orientation / 1.0))) if state in closed_set and closed_set[state] <= current.g_cost + 1e-6: continue closed_set[state] = current.g_cost if store_expanded: self.last_expanded_nodes.append((current.port.x, current.port.y, current.port.orientation)) nodes_expanded += 1 self.total_nodes_expanded += 1 self.metrics['nodes_expanded'] += 1 # Check if we reached the target exactly if (abs(current.port.x - target.x) < 1e-6 and abs(current.port.y - target.y) < 1e-6 and abs(current.port.orientation - target.orientation) < 0.1): return self._reconstruct_path(current) # Expansion self._expand_moves(current, target, net_width, net_id, open_set, closed_set, snap, nodes_expanded, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=state, max_cost=max_cost) return self._reconstruct_path(best_node) if return_partial else None def _expand_moves( self, current: AStarNode, target: Port, net_width: float, net_id: str, open_set: list[AStarNode], closed_set: dict[tuple[int, int, int], float], snap: float = 1.0, nodes_expanded: int = 0, skip_congestion: bool = False, inv_snap: float | None = None, parent_state: tuple[int, int, int] | None = None, max_cost: float | None = None ) -> None: cp = current.port if inv_snap is None: inv_snap = 1.0 / snap if parent_state is None: parent_state = (int(round(cp.x / snap)), int(round(cp.y / snap)), int(round(cp.orientation / 1.0))) dx_t = target.x - cp.x dy_t = target.y - cp.y dist_sq = dx_t*dx_t + dy_t*dy_t rad = numpy.radians(cp.orientation) cos_v, sin_v = numpy.cos(rad), numpy.sin(rad) # 1. DIRECT JUMP TO TARGET proj_t = dx_t * cos_v + dy_t * sin_v perp_t = -dx_t * sin_v + dy_t * cos_v # A. Straight Jump if proj_t > 0 and abs(perp_t) < 1e-3 and abs(cp.orientation - target.orientation) < 0.1: max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, proj_t + 1.0) if max_reach >= proj_t - 0.01: self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{proj_t}', 'S', (proj_t,), skip_congestion, inv_snap=inv_snap, snap_to_grid=False, parent_state=parent_state, max_cost=max_cost) # 2. VISIBILITY JUMPS & MAX REACH max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, self.config.max_straight_length) straight_lengths = set() if max_reach > self.config.min_straight_length: straight_lengths.add(snap_search_grid(max_reach, snap)) for radius in self.config.bend_radii: if max_reach > radius + self.config.min_straight_length: straight_lengths.add(snap_search_grid(max_reach - radius, snap)) if max_reach > self.config.min_straight_length + 5.0: straight_lengths.add(snap_search_grid(max_reach - 5.0, snap)) visible_corners = self.visibility_manager.get_visible_corners(cp, max_dist=max_reach) for cx, cy, dist in visible_corners: proj = (cx - cp.x) * cos_v + (cy - cp.y) * sin_v if proj > self.config.min_straight_length: straight_lengths.add(snap_search_grid(proj, snap)) straight_lengths.add(self.config.min_straight_length) if max_reach > self.config.min_straight_length * 4: straight_lengths.add(snap_search_grid(max_reach / 2.0, snap)) if abs(cp.orientation % 180) < 0.1: # Horizontal target_dist = abs(target.x - cp.x) if target_dist <= max_reach and target_dist > self.config.min_straight_length: sl = snap_search_grid(target_dist, snap) if sl > 0.1: straight_lengths.add(sl) for radius in self.config.bend_radii: for l in [target_dist - radius, target_dist - 2*radius]: if l > self.config.min_straight_length: s_l = snap_search_grid(l, snap) if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l) else: # Vertical target_dist = abs(target.y - cp.y) if target_dist <= max_reach and target_dist > self.config.min_straight_length: sl = snap_search_grid(target_dist, snap) if sl > 0.1: straight_lengths.add(sl) for radius in self.config.bend_radii: for l in [target_dist - radius, target_dist - 2*radius]: if l > self.config.min_straight_length: s_l = snap_search_grid(l, snap) if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l) for length in sorted(straight_lengths, reverse=True): self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{length}', 'S', (length,), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) # 3. BENDS & SBENDS angle_to_target = numpy.degrees(numpy.arctan2(target.y - cp.y, target.x - cp.x)) allow_backwards = (dist_sq < 150*150) for radius in self.config.bend_radii: for direction in ['CW', 'CCW']: if not allow_backwards: turn = 90 if direction == 'CCW' else -90 new_ori = (cp.orientation + turn) % 360 new_diff = (angle_to_target - new_ori + 180) % 360 - 180 if abs(new_diff) > 135: continue self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'B{radius}{direction}', 'B', (radius, direction), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) # 4. SBENDS max_sbend_r = max(self.config.sbend_radii) if self.config.sbend_radii else 0 if max_sbend_r > 0: user_offsets = self.config.sbend_offsets offsets: set[float] = set(user_offsets) if user_offsets is not None else set() dx_local = (target.x - cp.x) * cos_v + (target.y - cp.y) * sin_v dy_local = -(target.x - cp.x) * sin_v + (target.y - cp.y) * cos_v if dx_local > 0 and abs(dy_local) < 2 * max_sbend_r: min_d = numpy.sqrt(max(0, 4 * (abs(dy_local)/2.0) * abs(dy_local) - dy_local**2)) if dx_local >= min_d: offsets.add(dy_local) if user_offsets is None: for sign in [-1, 1]: for i in [0.1, 0.2, 0.5, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144]: o = sign * i * snap if abs(o) < 2 * max_sbend_r: offsets.add(o) for offset in sorted(offsets): for radius in self.config.sbend_radii: if abs(offset) >= 2 * radius: continue self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'SB{offset}R{radius}', 'SB', (offset, radius), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) def _process_move( self, parent: AStarNode, target: Port, net_width: float, net_id: str, open_set: list[AStarNode], closed_set: dict[tuple[int, int, int], float], snap: float, move_type: str, move_class: Literal['S', 'B', 'SB'], params: tuple, skip_congestion: bool, inv_snap: float | None = None, snap_to_grid: bool = True, parent_state: tuple[int, int, int] | None = None, max_cost: float | None = None ) -> None: cp = parent.port if inv_snap is None: inv_snap = 1.0 / snap base_ori = float(int(cp.orientation + 0.5)) if parent_state is None: gx = int(round(cp.x / snap)) gy = int(round(cp.y / snap)) go = int(round(cp.orientation / 1.0)) parent_state = (gx, gy, go) else: gx, gy, go = parent_state state_key = parent_state abs_key = (state_key, move_class, params, net_width, self.config.bend_collision_type, snap_to_grid) if abs_key in self._move_cache: res = self._move_cache[abs_key] move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None) self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) return rel_key = (base_ori, move_class, params, net_width, self.config.bend_collision_type, self._self_dilation, snap_to_grid) cache_key = (gx, gy, go, move_type, net_width) if cache_key in self._hard_collision_set: return if rel_key in self._move_cache: res_rel = self._move_cache[rel_key] else: try: p0 = Port(0, 0, base_ori) if move_class == 'S': res_rel = Straight.generate(p0, params[0], net_width, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) elif move_class == 'B': res_rel = Bend90.generate(p0, params[0], net_width, params[1], collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) elif move_class == 'SB': res_rel = SBend.generate(p0, params[0], params[1], net_width, collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap) else: return self._move_cache[rel_key] = res_rel except (ValueError, ZeroDivisionError): return res = res_rel.translate(cp.x, cp.y, rel_gx=res_rel.rel_gx + gx, rel_gy=res_rel.rel_gy + gy, rel_go=res_rel.rel_go) self._move_cache[abs_key] = res move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None) self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost) def _add_node( self, parent: AStarNode, result: ComponentResult, target: Port, net_width: float, net_id: str, open_set: list[AStarNode], closed_set: dict[tuple[int, int, int], float], move_type: str, move_radius: float | None = None, snap: float = 1.0, skip_congestion: bool = False, inv_snap: float | None = None, parent_state: tuple[int, int, int] | None = None, max_cost: float | None = None ) -> None: self.metrics['moves_generated'] += 1 state = (result.rel_gx, result.rel_gy, result.rel_go) if state in closed_set and closed_set[state] <= parent.g_cost + 1e-6: self.metrics['pruned_closed_set'] += 1 return parent_p = parent.port end_p = result.end_port if parent_state is None: pgx, pgy, pgo = int(round(parent_p.x / snap)), int(round(parent_p.y / snap)), int(round(parent_p.orientation / 1.0)) else: pgx, pgy, pgo = parent_state cache_key = (pgx, pgy, pgo, move_type, net_width) if cache_key in self._hard_collision_set: self.metrics['pruned_hard_collision'] += 1 return new_g_cost = parent.g_cost + result.length # Pre-check cost pruning before evaluation (using heuristic) if max_cost is not None: new_h_cost = self.cost_evaluator.h_manhattan(end_p, target) if new_g_cost + new_h_cost > max_cost: self.metrics['pruned_cost'] += 1 return is_static_safe = (cache_key in self._static_safe_cache) if not is_static_safe: ce = self.cost_evaluator.collision_engine if 'S' in move_type and 'SB' not in move_type: if ce.check_move_straight_static(parent_p, result.length): self._hard_collision_set.add(cache_key) self.metrics['pruned_hard_collision'] += 1 return is_static_safe = True if not is_static_safe: if ce.check_move_static(result, start_port=parent_p, end_port=end_p): self._hard_collision_set.add(cache_key) self.metrics['pruned_hard_collision'] += 1 return else: self._static_safe_cache.add(cache_key) total_overlaps = 0 if not skip_congestion: if cache_key in self._congestion_cache: total_overlaps = self._congestion_cache[cache_key] else: total_overlaps = self.cost_evaluator.collision_engine.check_move_congestion(result, net_id) self._congestion_cache[cache_key] = total_overlaps # SELF-COLLISION CHECK (Optional for performance) if getattr(self, '_self_collision_check', False): curr_p = parent new_tb = result.total_bounds while curr_p and curr_p.parent: ancestor_res = curr_p.component_result if ancestor_res: anc_tb = ancestor_res.total_bounds if (new_tb[0] < anc_tb[2] and new_tb[2] > anc_tb[0] and new_tb[1] < anc_tb[3] and new_tb[3] > anc_tb[1]): for p_anc in ancestor_res.geometry: for p_new in result.geometry: if p_new.intersects(p_anc) and not p_new.touches(p_anc): return curr_p = curr_p.parent penalty = 0.0 if 'SB' in move_type: penalty = self.config.sbend_penalty elif 'B' in move_type: penalty = self.config.bend_penalty if move_radius is not None and move_radius > 1e-6: penalty *= (10.0 / move_radius)**0.5 move_cost = self.cost_evaluator.evaluate_move( None, result.end_port, net_width, net_id, start_port=parent_p, length=result.length, dilated_geometry=None, penalty=penalty, skip_static=True, skip_congestion=True ) move_cost += total_overlaps * self.cost_evaluator.congestion_penalty if move_cost > 1e12: self.metrics['pruned_cost'] += 1 return g_cost = parent.g_cost + move_cost if state in closed_set and closed_set[state] <= g_cost + 1e-6: self.metrics['pruned_closed_set'] += 1 return h_cost = self.cost_evaluator.h_manhattan(result.end_port, target) heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result)) self.metrics['moves_added'] += 1 def _reconstruct_path(self, end_node: AStarNode) -> list[ComponentResult]: path = [] curr: AStarNode | None = end_node while curr and curr.component_result: path.append(curr.component_result) curr = curr.parent return path[::-1]