inire/inire/router/astar.py

506 lines
22 KiB
Python

from __future__ import annotations
import heapq
import logging
import functools
from typing import TYPE_CHECKING, Literal, Any
import rtree
import numpy
import shapely
from inire.geometry.components import Bend90, SBend, Straight, SEARCH_GRID_SNAP_UM, snap_search_grid
from inire.geometry.primitives import Port
from inire.router.config import RouterConfig
from inire.router.visibility import VisibilityManager
if TYPE_CHECKING:
from inire.geometry.components import ComponentResult
from inire.router.cost import CostEvaluator
logger = logging.getLogger(__name__)
class AStarNode:
"""
A node in the A* search tree.
"""
__slots__ = ('port', 'g_cost', 'h_cost', 'f_cost', 'parent', 'component_result')
def __init__(
self,
port: Port,
g_cost: float,
h_cost: float,
parent: AStarNode | None = None,
component_result: ComponentResult | None = None,
) -> None:
self.port = port
self.g_cost = g_cost
self.h_cost = h_cost
self.f_cost = g_cost + h_cost
self.parent = parent
self.component_result = component_result
def __lt__(self, other: AStarNode) -> bool:
if self.f_cost < other.f_cost - 1e-6:
return True
if self.f_cost > other.f_cost + 1e-6:
return False
return self.h_cost < other.h_cost
class AStarRouter:
"""
Waveguide router based on sparse A* search.
"""
__slots__ = ('cost_evaluator', 'config', 'node_limit', 'visibility_manager',
'_hard_collision_set', '_congestion_cache', '_static_safe_cache',
'_move_cache', 'total_nodes_expanded', 'last_expanded_nodes', 'metrics',
'_self_collision_check')
def __init__(self, cost_evaluator: CostEvaluator, node_limit: int | None = None, **kwargs) -> None:
self.cost_evaluator = cost_evaluator
self.config = RouterConfig(sbend_radii=[5.0, 10.0, 50.0, 100.0])
if node_limit is not None:
self.config.node_limit = node_limit
for k, v in kwargs.items():
if hasattr(self.config, k):
setattr(self.config, k, v)
self.node_limit = self.config.node_limit
# Visibility Manager for sparse jumps
self.visibility_manager = VisibilityManager(self.cost_evaluator.collision_engine)
self._hard_collision_set: set[tuple] = set()
self._congestion_cache: dict[tuple, int] = {}
self._static_safe_cache: set[tuple] = set()
self._move_cache: dict[tuple, ComponentResult] = {}
self.total_nodes_expanded = 0
self.last_expanded_nodes: list[tuple[float, float, float]] = []
self.metrics = {
'nodes_expanded': 0,
'moves_generated': 0,
'moves_added': 0,
'pruned_closed_set': 0,
'pruned_hard_collision': 0,
'pruned_cost': 0
}
def reset_metrics(self) -> None:
""" Reset all performance counters. """
for k in self.metrics:
self.metrics[k] = 0
self.cost_evaluator.collision_engine.reset_metrics()
def get_metrics_summary(self) -> str:
""" Return a human-readable summary of search performance. """
m = self.metrics
c = self.cost_evaluator.collision_engine.get_metrics_summary()
return (f"Search Performance: \n"
f" Nodes Expanded: {m['nodes_expanded']}\n"
f" Moves: Generated={m['moves_generated']}, Added={m['moves_added']}\n"
f" Pruning: ClosedSet={m['pruned_closed_set']}, HardColl={m['pruned_hard_collision']}, Cost={m['pruned_cost']}\n"
f" {c}")
@property
def _self_dilation(self) -> float:
return self.cost_evaluator.collision_engine.clearance / 2.0
def route(
self,
start: Port,
target: Port,
net_width: float,
net_id: str = 'default',
bend_collision_type: Literal['arc', 'bbox', 'clipped_bbox'] | None = None,
return_partial: bool = False,
store_expanded: bool = False,
skip_congestion: bool = False,
max_cost: float | None = None,
self_collision_check: bool = False,
) -> list[ComponentResult] | None:
"""
Route a single net using A*.
Args:
start: Starting port.
target: Target port.
net_width: Waveguide width.
net_id: Identifier for the net.
bend_collision_type: Type of collision model to use for bends.
return_partial: If True, returns the best-effort path if target not reached.
store_expanded: If True, keep track of all expanded nodes for visualization.
skip_congestion: If True, ignore other nets' paths (greedy mode).
max_cost: Hard limit on f_cost to prune search.
self_collision_check: If True, prevent the net from crossing its own path.
"""
self._self_collision_check = self_collision_check
self._congestion_cache.clear()
if store_expanded:
self.last_expanded_nodes = []
if bend_collision_type is not None:
self.config.bend_collision_type = bend_collision_type
self.cost_evaluator.set_target(target)
open_set: list[AStarNode] = []
snap = self.config.snap_size
inv_snap = 1.0 / snap
# (x_grid, y_grid, orientation_grid) -> min_g_cost
closed_set: dict[tuple[int, int, int], float] = {}
start_node = AStarNode(start, 0.0, self.cost_evaluator.h_manhattan(start, target))
heapq.heappush(open_set, start_node)
best_node = start_node
nodes_expanded = 0
node_limit = self.node_limit
while open_set:
if nodes_expanded >= node_limit:
return self._reconstruct_path(best_node) if return_partial else None
current = heapq.heappop(open_set)
# Cost Pruning (Fail Fast)
if max_cost is not None and current.f_cost > max_cost:
self.metrics['pruned_cost'] += 1
continue
if current.h_cost < best_node.h_cost:
best_node = current
state = (int(round(current.port.x / snap)), int(round(current.port.y / snap)), int(round(current.port.orientation / 1.0)))
if state in closed_set and closed_set[state] <= current.g_cost + 1e-6:
continue
closed_set[state] = current.g_cost
if store_expanded:
self.last_expanded_nodes.append((current.port.x, current.port.y, current.port.orientation))
nodes_expanded += 1
self.total_nodes_expanded += 1
self.metrics['nodes_expanded'] += 1
# Check if we reached the target exactly
if (abs(current.port.x - target.x) < 1e-6 and
abs(current.port.y - target.y) < 1e-6 and
abs(current.port.orientation - target.orientation) < 0.1):
return self._reconstruct_path(current)
# Expansion
self._expand_moves(current, target, net_width, net_id, open_set, closed_set, snap, nodes_expanded, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=state, max_cost=max_cost)
return self._reconstruct_path(best_node) if return_partial else None
def _expand_moves(
self,
current: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
snap: float = 1.0,
nodes_expanded: int = 0,
skip_congestion: bool = False,
inv_snap: float | None = None,
parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None
) -> None:
cp = current.port
if inv_snap is None: inv_snap = 1.0 / snap
if parent_state is None:
parent_state = (int(round(cp.x / snap)), int(round(cp.y / snap)), int(round(cp.orientation / 1.0)))
dx_t = target.x - cp.x
dy_t = target.y - cp.y
dist_sq = dx_t*dx_t + dy_t*dy_t
rad = numpy.radians(cp.orientation)
cos_v, sin_v = numpy.cos(rad), numpy.sin(rad)
# 1. DIRECT JUMP TO TARGET
proj_t = dx_t * cos_v + dy_t * sin_v
perp_t = -dx_t * sin_v + dy_t * cos_v
# A. Straight Jump
if proj_t > 0 and abs(perp_t) < 1e-3 and abs(cp.orientation - target.orientation) < 0.1:
max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, proj_t + 1.0)
if max_reach >= proj_t - 0.01:
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{proj_t}', 'S', (proj_t,), skip_congestion, inv_snap=inv_snap, snap_to_grid=False, parent_state=parent_state, max_cost=max_cost)
# 2. VISIBILITY JUMPS & MAX REACH
max_reach = self.cost_evaluator.collision_engine.ray_cast(cp, cp.orientation, self.config.max_straight_length)
straight_lengths = set()
if max_reach > self.config.min_straight_length:
straight_lengths.add(snap_search_grid(max_reach, snap))
for radius in self.config.bend_radii:
if max_reach > radius + self.config.min_straight_length:
straight_lengths.add(snap_search_grid(max_reach - radius, snap))
if max_reach > self.config.min_straight_length + 5.0:
straight_lengths.add(snap_search_grid(max_reach - 5.0, snap))
visible_corners = self.visibility_manager.get_visible_corners(cp, max_dist=max_reach)
for cx, cy, dist in visible_corners:
proj = (cx - cp.x) * cos_v + (cy - cp.y) * sin_v
if proj > self.config.min_straight_length:
straight_lengths.add(snap_search_grid(proj, snap))
straight_lengths.add(self.config.min_straight_length)
if max_reach > self.config.min_straight_length * 4:
straight_lengths.add(snap_search_grid(max_reach / 2.0, snap))
if abs(cp.orientation % 180) < 0.1: # Horizontal
target_dist = abs(target.x - cp.x)
if target_dist <= max_reach and target_dist > self.config.min_straight_length:
sl = snap_search_grid(target_dist, snap)
if sl > 0.1: straight_lengths.add(sl)
for radius in self.config.bend_radii:
for l in [target_dist - radius, target_dist - 2*radius]:
if l > self.config.min_straight_length:
s_l = snap_search_grid(l, snap)
if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l)
else: # Vertical
target_dist = abs(target.y - cp.y)
if target_dist <= max_reach and target_dist > self.config.min_straight_length:
sl = snap_search_grid(target_dist, snap)
if sl > 0.1: straight_lengths.add(sl)
for radius in self.config.bend_radii:
for l in [target_dist - radius, target_dist - 2*radius]:
if l > self.config.min_straight_length:
s_l = snap_search_grid(l, snap)
if s_l <= max_reach and s_l > 0.1: straight_lengths.add(s_l)
for length in sorted(straight_lengths, reverse=True):
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'S{length}', 'S', (length,), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost)
# 3. BENDS & SBENDS
angle_to_target = numpy.degrees(numpy.arctan2(target.y - cp.y, target.x - cp.x))
allow_backwards = (dist_sq < 150*150)
for radius in self.config.bend_radii:
for direction in ['CW', 'CCW']:
if not allow_backwards:
turn = 90 if direction == 'CCW' else -90
new_ori = (cp.orientation + turn) % 360
new_diff = (angle_to_target - new_ori + 180) % 360 - 180
if abs(new_diff) > 135:
continue
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'B{radius}{direction}', 'B', (radius, direction), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost)
# 4. SBENDS
max_sbend_r = max(self.config.sbend_radii) if self.config.sbend_radii else 0
if max_sbend_r > 0:
user_offsets = self.config.sbend_offsets
offsets: set[float] = set(user_offsets) if user_offsets is not None else set()
dx_local = (target.x - cp.x) * cos_v + (target.y - cp.y) * sin_v
dy_local = -(target.x - cp.x) * sin_v + (target.y - cp.y) * cos_v
if dx_local > 0 and abs(dy_local) < 2 * max_sbend_r:
min_d = numpy.sqrt(max(0, 4 * (abs(dy_local)/2.0) * abs(dy_local) - dy_local**2))
if dx_local >= min_d: offsets.add(dy_local)
if user_offsets is None:
for sign in [-1, 1]:
for i in [0.1, 0.2, 0.5, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144]:
o = sign * i * snap
if abs(o) < 2 * max_sbend_r: offsets.add(o)
for offset in sorted(offsets):
for radius in self.config.sbend_radii:
if abs(offset) >= 2 * radius: continue
self._process_move(current, target, net_width, net_id, open_set, closed_set, snap, f'SB{offset}R{radius}', 'SB', (offset, radius), skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost)
def _process_move(
self,
parent: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
snap: float,
move_type: str,
move_class: Literal['S', 'B', 'SB'],
params: tuple,
skip_congestion: bool,
inv_snap: float | None = None,
snap_to_grid: bool = True,
parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None
) -> None:
cp = parent.port
if inv_snap is None: inv_snap = 1.0 / snap
base_ori = float(int(cp.orientation + 0.5))
if parent_state is None:
gx = int(round(cp.x / snap))
gy = int(round(cp.y / snap))
go = int(round(cp.orientation / 1.0))
parent_state = (gx, gy, go)
else:
gx, gy, go = parent_state
state_key = parent_state
abs_key = (state_key, move_class, params, net_width, self.config.bend_collision_type, snap_to_grid)
if abs_key in self._move_cache:
res = self._move_cache[abs_key]
move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None)
self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost)
return
rel_key = (base_ori, move_class, params, net_width, self.config.bend_collision_type, self._self_dilation, snap_to_grid)
cache_key = (gx, gy, go, move_type, net_width)
if cache_key in self._hard_collision_set:
return
if rel_key in self._move_cache:
res_rel = self._move_cache[rel_key]
else:
try:
p0 = Port(0, 0, base_ori)
if move_class == 'S':
res_rel = Straight.generate(p0, params[0], net_width, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
elif move_class == 'B':
res_rel = Bend90.generate(p0, params[0], net_width, params[1], collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
elif move_class == 'SB':
res_rel = SBend.generate(p0, params[0], params[1], net_width, collision_type=self.config.bend_collision_type, clip_margin=self.config.bend_clip_margin, dilation=self._self_dilation, snap_to_grid=snap_to_grid, snap_size=snap)
else:
return
self._move_cache[rel_key] = res_rel
except (ValueError, ZeroDivisionError):
return
res = res_rel.translate(cp.x, cp.y, rel_gx=res_rel.rel_gx + gx, rel_gy=res_rel.rel_gy + gy, rel_go=res_rel.rel_go)
self._move_cache[abs_key] = res
move_radius = params[0] if move_class == 'B' else (params[1] if move_class == 'SB' else None)
self._add_node(parent, res, target, net_width, net_id, open_set, closed_set, move_type, move_radius=move_radius, snap=snap, skip_congestion=skip_congestion, inv_snap=inv_snap, parent_state=parent_state, max_cost=max_cost)
def _add_node(
self,
parent: AStarNode,
result: ComponentResult,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
move_type: str,
move_radius: float | None = None,
snap: float = 1.0,
skip_congestion: bool = False,
inv_snap: float | None = None,
parent_state: tuple[int, int, int] | None = None,
max_cost: float | None = None
) -> None:
self.metrics['moves_generated'] += 1
state = (result.rel_gx, result.rel_gy, result.rel_go)
if state in closed_set and closed_set[state] <= parent.g_cost + 1e-6:
self.metrics['pruned_closed_set'] += 1
return
parent_p = parent.port
end_p = result.end_port
if parent_state is None:
pgx, pgy, pgo = int(round(parent_p.x / snap)), int(round(parent_p.y / snap)), int(round(parent_p.orientation / 1.0))
else:
pgx, pgy, pgo = parent_state
cache_key = (pgx, pgy, pgo, move_type, net_width)
if cache_key in self._hard_collision_set:
self.metrics['pruned_hard_collision'] += 1
return
new_g_cost = parent.g_cost + result.length
# Pre-check cost pruning before evaluation (using heuristic)
if max_cost is not None:
new_h_cost = self.cost_evaluator.h_manhattan(end_p, target)
if new_g_cost + new_h_cost > max_cost:
self.metrics['pruned_cost'] += 1
return
is_static_safe = (cache_key in self._static_safe_cache)
if not is_static_safe:
ce = self.cost_evaluator.collision_engine
if 'S' in move_type and 'SB' not in move_type:
if ce.check_move_straight_static(parent_p, result.length):
self._hard_collision_set.add(cache_key)
self.metrics['pruned_hard_collision'] += 1
return
is_static_safe = True
if not is_static_safe:
if ce.check_move_static(result, start_port=parent_p, end_port=end_p):
self._hard_collision_set.add(cache_key)
self.metrics['pruned_hard_collision'] += 1
return
else: self._static_safe_cache.add(cache_key)
total_overlaps = 0
if not skip_congestion:
if cache_key in self._congestion_cache: total_overlaps = self._congestion_cache[cache_key]
else:
total_overlaps = self.cost_evaluator.collision_engine.check_move_congestion(result, net_id)
self._congestion_cache[cache_key] = total_overlaps
# SELF-COLLISION CHECK (Optional for performance)
if getattr(self, '_self_collision_check', False):
curr_p = parent
new_tb = result.total_bounds
while curr_p and curr_p.parent:
ancestor_res = curr_p.component_result
if ancestor_res:
anc_tb = ancestor_res.total_bounds
if (new_tb[0] < anc_tb[2] and new_tb[2] > anc_tb[0] and
new_tb[1] < anc_tb[3] and new_tb[3] > anc_tb[1]):
for p_anc in ancestor_res.geometry:
for p_new in result.geometry:
if p_new.intersects(p_anc) and not p_new.touches(p_anc):
return
curr_p = curr_p.parent
penalty = 0.0
if 'SB' in move_type: penalty = self.config.sbend_penalty
elif 'B' in move_type: penalty = self.config.bend_penalty
if move_radius is not None and move_radius > 1e-6: penalty *= (10.0 / move_radius)**0.5
move_cost = self.cost_evaluator.evaluate_move(
None, result.end_port, net_width, net_id,
start_port=parent_p, length=result.length,
dilated_geometry=None, penalty=penalty,
skip_static=True, skip_congestion=True
)
move_cost += total_overlaps * self.cost_evaluator.congestion_penalty
if move_cost > 1e12:
self.metrics['pruned_cost'] += 1
return
g_cost = parent.g_cost + move_cost
if state in closed_set and closed_set[state] <= g_cost + 1e-6:
self.metrics['pruned_closed_set'] += 1
return
h_cost = self.cost_evaluator.h_manhattan(result.end_port, target)
heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result))
self.metrics['moves_added'] += 1
def _reconstruct_path(self, end_node: AStarNode) -> list[ComponentResult]:
path = []
curr: AStarNode | None = end_node
while curr and curr.component_result:
path.append(curr.component_result)
curr = curr.parent
return path[::-1]