inire/inire/router/astar.py

721 lines
24 KiB
Python

from __future__ import annotations
import heapq
import logging
import math
from typing import TYPE_CHECKING, Any, Literal
import shapely
from inire.constants import TOLERANCE_LINEAR
from inire.geometry.components import Bend90, SBend, Straight
from inire.geometry.primitives import Port
from inire.router.config import RouterConfig, VisibilityGuidanceMode
from inire.router.visibility import VisibilityManager
if TYPE_CHECKING:
from inire.geometry.components import ComponentResult
from inire.router.cost import CostEvaluator
logger = logging.getLogger(__name__)
class AStarNode:
__slots__ = ("port", "g_cost", "h_cost", "fh_cost", "parent", "component_result")
def __init__(
self,
port: Port,
g_cost: float,
h_cost: float,
parent: AStarNode | None = None,
component_result: ComponentResult | None = None,
) -> None:
self.port = port
self.g_cost = g_cost
self.h_cost = h_cost
self.fh_cost = (g_cost + h_cost, h_cost)
self.parent = parent
self.component_result = component_result
def __lt__(self, other: AStarNode) -> bool:
return self.fh_cost < other.fh_cost
class AStarMetrics:
__slots__ = (
"total_nodes_expanded",
"last_expanded_nodes",
"nodes_expanded",
"moves_generated",
"moves_added",
"pruned_closed_set",
"pruned_hard_collision",
"pruned_cost",
)
def __init__(self) -> None:
self.total_nodes_expanded = 0
self.last_expanded_nodes: list[tuple[int, int, int]] = []
self.nodes_expanded = 0
self.moves_generated = 0
self.moves_added = 0
self.pruned_closed_set = 0
self.pruned_hard_collision = 0
self.pruned_cost = 0
def reset_per_route(self) -> None:
self.nodes_expanded = 0
self.moves_generated = 0
self.moves_added = 0
self.pruned_closed_set = 0
self.pruned_hard_collision = 0
self.pruned_cost = 0
self.last_expanded_nodes = []
class AStarContext:
__slots__ = (
"cost_evaluator",
"config",
"visibility_manager",
"move_cache_rel",
"move_cache_abs",
"hard_collision_set",
"static_safe_cache",
"max_cache_size",
)
def __init__(
self,
cost_evaluator: CostEvaluator,
node_limit: int = 1000000,
max_straight_length: float = 2000.0,
min_straight_length: float = 5.0,
bend_radii: list[float] | None = None,
sbend_radii: list[float] | None = None,
sbend_offsets: list[float] | None = None,
bend_penalty: float = 250.0,
sbend_penalty: float | None = None,
bend_collision_type: Literal["arc", "bbox", "clipped_bbox"] | Any = "arc",
bend_clip_margin: float = 10.0,
visibility_guidance: VisibilityGuidanceMode = "tangent_corner",
max_cache_size: int = 1000000,
) -> None:
actual_sbend_penalty = 2.0 * bend_penalty if sbend_penalty is None else sbend_penalty
self.cost_evaluator = cost_evaluator
self.max_cache_size = max_cache_size
self.config = RouterConfig(
node_limit=node_limit,
max_straight_length=max_straight_length,
min_straight_length=min_straight_length,
bend_radii=bend_radii if bend_radii is not None else [50.0, 100.0],
sbend_radii=sbend_radii if sbend_radii is not None else [5.0, 10.0, 50.0, 100.0],
sbend_offsets=sbend_offsets,
bend_penalty=bend_penalty,
sbend_penalty=actual_sbend_penalty,
bend_collision_type=bend_collision_type,
bend_clip_margin=bend_clip_margin,
visibility_guidance=visibility_guidance,
)
self.cost_evaluator.config = self.config
self.cost_evaluator._refresh_cached_config()
self.visibility_manager = VisibilityManager(self.cost_evaluator.collision_engine)
self.move_cache_rel: dict[tuple, ComponentResult] = {}
self.move_cache_abs: dict[tuple, ComponentResult] = {}
self.hard_collision_set: set[tuple] = set()
self.static_safe_cache: set[tuple] = set()
def clear_static_caches(self) -> None:
self.hard_collision_set.clear()
self.static_safe_cache.clear()
self.visibility_manager.clear_cache()
def check_cache_eviction(self) -> None:
if len(self.move_cache_abs) <= self.max_cache_size * 1.2:
return
num_to_evict = int(len(self.move_cache_abs) * 0.25)
for idx, key in enumerate(list(self.move_cache_abs.keys())):
if idx >= num_to_evict:
break
del self.move_cache_abs[key]
def route_astar(
start: Port,
target: Port,
net_width: float,
context: AStarContext,
metrics: AStarMetrics | None = None,
net_id: str = "default",
bend_collision_type: Literal["arc", "bbox", "clipped_bbox"] | None = None,
return_partial: bool = False,
store_expanded: bool = False,
skip_congestion: bool = False,
max_cost: float | None = None,
self_collision_check: bool = False,
node_limit: int | None = None,
) -> list[ComponentResult] | None:
if metrics is None:
metrics = AStarMetrics()
metrics.reset_per_route()
if bend_collision_type is not None:
context.config.bend_collision_type = bend_collision_type
context.cost_evaluator.set_target(target)
open_set: list[AStarNode] = []
closed_set: dict[tuple[int, int, int], float] = {}
congestion_cache: dict[tuple, int] = {}
start_node = AStarNode(start, 0.0, context.cost_evaluator.h_manhattan(start, target))
heapq.heappush(open_set, start_node)
best_node = start_node
effective_node_limit = node_limit if node_limit is not None else context.config.node_limit
nodes_expanded = 0
while open_set:
if nodes_expanded >= effective_node_limit:
return reconstruct_path(best_node) if return_partial else None
current = heapq.heappop(open_set)
if max_cost is not None and current.fh_cost[0] > max_cost:
metrics.pruned_cost += 1
continue
if current.h_cost < best_node.h_cost:
best_node = current
state = current.port.as_tuple()
if state in closed_set and closed_set[state] <= current.g_cost + TOLERANCE_LINEAR:
continue
closed_set[state] = current.g_cost
if store_expanded:
metrics.last_expanded_nodes.append(state)
nodes_expanded += 1
metrics.total_nodes_expanded += 1
metrics.nodes_expanded += 1
if current.port == target:
return reconstruct_path(current)
expand_moves(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
max_cost=max_cost,
skip_congestion=skip_congestion,
self_collision_check=self_collision_check,
)
return reconstruct_path(best_node) if return_partial else None
def _quantized_lengths(values: list[float], max_reach: float) -> list[int]:
out = {int(round(v)) for v in values if v > 0 and v <= max_reach + 0.01}
return sorted((v for v in out if v > 0), reverse=True)
def _sbend_forward_span(offset: float, radius: float) -> float | None:
abs_offset = abs(offset)
if abs_offset <= TOLERANCE_LINEAR or radius <= 0 or abs_offset >= 2.0 * radius:
return None
theta = __import__("math").acos(1.0 - abs_offset / (2.0 * radius))
return 2.0 * radius * __import__("math").sin(theta)
def _visible_straight_candidates(
current: Port,
context: AStarContext,
max_reach: float,
cos_v: float,
sin_v: float,
net_width: float,
) -> list[float]:
mode = context.config.visibility_guidance
if mode == "off":
return []
if mode == "exact_corner":
max_bend_radius = max(context.config.bend_radii, default=0.0)
visibility_reach = max_reach + max_bend_radius
visible_corners = sorted(
context.visibility_manager.get_corner_visibility(current, max_dist=visibility_reach),
key=lambda corner: corner[2],
)
if not visible_corners:
return []
candidates: set[int] = set()
for cx, cy, _ in visible_corners[:12]:
dx = cx - current.x
dy = cy - current.y
local_x = dx * cos_v + dy * sin_v
if local_x <= context.config.min_straight_length:
continue
candidates.add(int(round(local_x)))
return sorted(candidates, reverse=True)
if mode != "tangent_corner":
return []
visibility_manager = context.visibility_manager
visibility_manager._ensure_current()
max_bend_radius = max(context.config.bend_radii, default=0.0)
if max_bend_radius <= 0 or not visibility_manager.corners:
return []
reach = max_reach + max_bend_radius
bounds = (current.x - reach, current.y - reach, current.x + reach, current.y + reach)
candidate_ids = list(visibility_manager.corner_index.intersection(bounds))
if not candidate_ids:
return []
scored: list[tuple[float, float, float, float, float]] = []
for idx in candidate_ids:
cx, cy = visibility_manager.corners[idx]
dx = cx - current.x
dy = cy - current.y
local_x = dx * cos_v + dy * sin_v
local_y = -dx * sin_v + dy * cos_v
if local_x <= context.config.min_straight_length or local_x > reach + 0.01:
continue
nearest_radius = min(context.config.bend_radii, key=lambda radius: abs(abs(local_y) - radius))
tangent_error = abs(abs(local_y) - nearest_radius)
if tangent_error > 2.0:
continue
length = local_x - nearest_radius
if length <= context.config.min_straight_length or length > max_reach + 0.01:
continue
scored.append((tangent_error, math.hypot(dx, dy), length, dx, dy))
if not scored:
return []
collision_engine = context.cost_evaluator.collision_engine
candidates: set[int] = set()
for _, dist, length, dx, dy in sorted(scored)[:4]:
angle = math.degrees(math.atan2(dy, dx))
corner_reach = collision_engine.ray_cast(current, angle, max_dist=dist + 0.05, net_width=net_width)
if corner_reach < dist - 0.01:
continue
qlen = int(round(length))
if qlen > 0:
candidates.add(qlen)
return sorted(candidates, reverse=True)
def _previous_move_metadata(node: AStarNode) -> tuple[str | None, float | None]:
result = node.component_result
if result is None:
return None, None
move_type = result.move_type
if move_type == "Straight":
return move_type, result.length
return move_type, None
def expand_moves(
current: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
max_cost: float | None = None,
skip_congestion: bool = False,
self_collision_check: bool = False,
) -> None:
cp = current.port
prev_move_type, prev_straight_length = _previous_move_metadata(current)
dx_t = target.x - cp.x
dy_t = target.y - cp.y
dist_sq = dx_t * dx_t + dy_t * dy_t
if cp.r == 0:
cos_v, sin_v = 1.0, 0.0
elif cp.r == 90:
cos_v, sin_v = 0.0, 1.0
elif cp.r == 180:
cos_v, sin_v = -1.0, 0.0
else:
cos_v, sin_v = 0.0, -1.0
proj_t = dx_t * cos_v + dy_t * sin_v
perp_t = -dx_t * sin_v + dy_t * cos_v
dx_local = proj_t
dy_local = perp_t
if proj_t > 0 and abs(perp_t) < 1e-6 and cp.r == target.r:
max_reach = context.cost_evaluator.collision_engine.ray_cast(cp, cp.r, proj_t + 1.0, net_width=net_width)
if max_reach >= proj_t - 0.01 and (
prev_straight_length is None or proj_t < prev_straight_length - TOLERANCE_LINEAR
):
process_move(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
"S",
(int(round(proj_t)),),
skip_congestion,
max_cost=max_cost,
self_collision_check=self_collision_check,
)
max_reach = context.cost_evaluator.collision_engine.ray_cast(cp, cp.r, context.config.max_straight_length, net_width=net_width)
candidate_lengths = [
context.config.min_straight_length,
max_reach,
max_reach / 2.0,
max_reach - 5.0,
]
axis_target_dist = abs(dx_t) if cp.r in (0, 180) else abs(dy_t)
candidate_lengths.append(axis_target_dist)
for radius in context.config.bend_radii:
candidate_lengths.extend((max_reach - radius, axis_target_dist - radius, axis_target_dist - 2.0 * radius))
candidate_lengths.extend(
_visible_straight_candidates(
cp,
context,
max_reach,
cos_v,
sin_v,
net_width,
)
)
if cp.r == target.r and dx_local > 0 and abs(dy_local) > TOLERANCE_LINEAR:
for radius in context.config.sbend_radii:
sbend_span = _sbend_forward_span(dy_local, radius)
if sbend_span is None:
continue
candidate_lengths.extend((dx_local - sbend_span, dx_local - 2.0 * sbend_span))
for length in _quantized_lengths(candidate_lengths, max_reach):
if length < context.config.min_straight_length:
continue
if prev_straight_length is not None and length >= prev_straight_length - TOLERANCE_LINEAR:
continue
process_move(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
"S",
(length,),
skip_congestion,
max_cost=max_cost,
self_collision_check=self_collision_check,
)
angle_to_target = 0.0
if dx_t != 0 or dy_t != 0:
angle_to_target = float((round((180.0 / 3.141592653589793) * __import__("math").atan2(dy_t, dx_t)) + 360.0) % 360.0)
allow_backwards = dist_sq < 150 * 150
for radius in context.config.bend_radii:
for direction in ("CW", "CCW"):
if not allow_backwards:
turn = 90 if direction == "CCW" else -90
new_ori = (cp.r + turn) % 360
new_diff = (angle_to_target - new_ori + 180.0) % 360.0 - 180.0
if abs(new_diff) > 135.0:
continue
process_move(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
"B",
(radius, direction),
skip_congestion,
max_cost=max_cost,
self_collision_check=self_collision_check,
)
max_sbend_r = max(context.config.sbend_radii) if context.config.sbend_radii else 0.0
if max_sbend_r <= 0 or prev_move_type == "SBend":
return
explicit_offsets = context.config.sbend_offsets
offsets: set[int] = set(int(round(v)) for v in explicit_offsets or [])
# S-bends preserve orientation, so the implicit search only makes sense
# when the target is ahead in local coordinates and keeps the same
# orientation. Generating generic speculative offsets on the integer lattice
# explodes the search space without contributing useful moves.
if target.r == cp.r and 0 < dx_local <= 4 * max_sbend_r:
if 0 < abs(dy_local) < 2 * max_sbend_r:
offsets.add(int(round(dy_local)))
if not offsets:
return
for offset in sorted(offsets):
if offset == 0:
continue
for radius in context.config.sbend_radii:
if abs(offset) >= 2 * radius:
continue
process_move(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
"SB",
(offset, radius),
skip_congestion,
max_cost=max_cost,
self_collision_check=self_collision_check,
)
def process_move(
parent: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
move_class: Literal["S", "B", "SB"],
params: tuple,
skip_congestion: bool,
max_cost: float | None = None,
self_collision_check: bool = False,
) -> None:
cp = parent.port
coll_type = context.config.bend_collision_type
coll_key = id(coll_type) if isinstance(coll_type, shapely.geometry.Polygon) else coll_type
self_dilation = context.cost_evaluator.collision_engine.clearance / 2.0
abs_key = (
cp.as_tuple(),
move_class,
params,
net_width,
coll_key,
context.config.bend_clip_margin,
self_dilation,
)
if abs_key in context.move_cache_abs:
res = context.move_cache_abs[abs_key]
else:
context.check_cache_eviction()
base_port = Port(0, 0, cp.r)
rel_key = (
cp.r,
move_class,
params,
net_width,
coll_key,
context.config.bend_clip_margin,
self_dilation,
)
if rel_key in context.move_cache_rel:
res_rel = context.move_cache_rel[rel_key]
else:
try:
if move_class == "S":
res_rel = Straight.generate(base_port, params[0], net_width, dilation=self_dilation)
elif move_class == "B":
res_rel = Bend90.generate(
base_port,
params[0],
net_width,
params[1],
collision_type=context.config.bend_collision_type,
clip_margin=context.config.bend_clip_margin,
dilation=self_dilation,
)
else:
res_rel = SBend.generate(
base_port,
params[0],
params[1],
net_width,
collision_type=context.config.bend_collision_type,
clip_margin=context.config.bend_clip_margin,
dilation=self_dilation,
)
except ValueError:
return
context.move_cache_rel[rel_key] = res_rel
res = res_rel.translate(cp.x, cp.y)
context.move_cache_abs[abs_key] = res
move_radius = params[0] if move_class == "B" else (params[1] if move_class == "SB" else None)
add_node(
parent,
res,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
move_class,
abs_key,
move_radius=move_radius,
skip_congestion=skip_congestion,
max_cost=max_cost,
self_collision_check=self_collision_check,
)
def add_node(
parent: AStarNode,
result: ComponentResult,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
move_type: str,
cache_key: tuple,
move_radius: float | None = None,
skip_congestion: bool = False,
max_cost: float | None = None,
self_collision_check: bool = False,
) -> None:
metrics.moves_generated += 1
state = result.end_port.as_tuple()
new_lower_bound_g = parent.g_cost + result.length
if state in closed_set and closed_set[state] <= new_lower_bound_g + TOLERANCE_LINEAR:
metrics.pruned_closed_set += 1
return
parent_p = parent.port
end_p = result.end_port
if cache_key in context.hard_collision_set:
metrics.pruned_hard_collision += 1
return
is_static_safe = cache_key in context.static_safe_cache
if not is_static_safe:
ce = context.cost_evaluator.collision_engine
if move_type == "S":
collision_found = ce.check_move_straight_static(parent_p, result.length, net_width=net_width)
else:
collision_found = ce.check_move_static(result, start_port=parent_p, end_port=end_p, net_width=net_width)
if collision_found:
context.hard_collision_set.add(cache_key)
metrics.pruned_hard_collision += 1
return
context.static_safe_cache.add(cache_key)
total_overlaps = 0
if not skip_congestion:
if cache_key in congestion_cache:
total_overlaps = congestion_cache[cache_key]
else:
total_overlaps = context.cost_evaluator.collision_engine.check_move_congestion(result, net_id)
congestion_cache[cache_key] = total_overlaps
if self_collision_check:
curr_p = parent
new_tb = result.total_bounds
while curr_p and curr_p.parent:
ancestor_res = curr_p.component_result
if ancestor_res:
anc_tb = ancestor_res.total_bounds
if new_tb[0] < anc_tb[2] and new_tb[2] > anc_tb[0] and new_tb[1] < anc_tb[3] and new_tb[3] > anc_tb[1]:
for p_anc in ancestor_res.geometry:
for p_new in result.geometry:
if p_new.intersects(p_anc) and not p_new.touches(p_anc):
return
curr_p = curr_p.parent
penalty = 0.0
if move_type == "SB":
penalty = context.config.sbend_penalty
elif move_type == "B":
penalty = context.config.bend_penalty
if move_radius is not None and move_radius > TOLERANCE_LINEAR:
penalty *= (10.0 / move_radius) ** 0.5
move_cost = context.cost_evaluator.evaluate_move(
result.geometry,
result.end_port,
net_width,
net_id,
start_port=parent_p,
length=result.length,
dilated_geometry=result.dilated_geometry,
penalty=penalty,
skip_static=True,
skip_congestion=True,
)
move_cost += total_overlaps * context.cost_evaluator.congestion_penalty
if max_cost is not None and parent.g_cost + move_cost > max_cost:
metrics.pruned_cost += 1
return
if move_cost > 1e12:
metrics.pruned_cost += 1
return
g_cost = parent.g_cost + move_cost
if state in closed_set and closed_set[state] <= g_cost + TOLERANCE_LINEAR:
metrics.pruned_closed_set += 1
return
h_cost = context.cost_evaluator.h_manhattan(result.end_port, target)
heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result))
metrics.moves_added += 1
def reconstruct_path(end_node: AStarNode) -> list[ComponentResult]:
path = []
curr: AStarNode | None = end_node
while curr and curr.component_result:
path.append(curr.component_result)
curr = curr.parent
return path[::-1]