inire/inire/router/pathfinder.py
2026-03-29 20:35:58 -07:00

310 lines
11 KiB
Python

from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from inire.router.astar import AStarMetrics, route_astar
from inire.router.outcomes import RoutingOutcome, infer_routing_outcome, routing_outcome_needs_retry
from inire.router.refiner import PathRefiner
from inire.router.path_state import PathStateManager
from inire.router.session import (
create_routing_session_state,
finalize_routing_session_results,
prepare_routing_session_state,
refine_routing_session_results,
run_routing_iteration,
)
if TYPE_CHECKING:
from collections.abc import Callable
from inire.geometry.components import ComponentResult
from inire.geometry.primitives import Port
from inire.router.astar import AStarContext
from inire.router.cost import CostEvaluator
logger = logging.getLogger(__name__)
@dataclass
class RoutingResult:
net_id: str
path: list[ComponentResult]
is_valid: bool
collisions: int
reached_target: bool = False
outcome: RoutingOutcome = "unroutable"
class PathFinder:
__slots__ = (
"context",
"metrics",
"max_iterations",
"base_congestion_penalty",
"use_tiered_strategy",
"congestion_multiplier",
"accumulated_expanded_nodes",
"warm_start",
"refine_paths",
"refiner",
"path_state",
)
def __init__(
self,
context: AStarContext,
metrics: AStarMetrics | None = None,
max_iterations: int = 10,
base_congestion_penalty: float = 100.0,
congestion_multiplier: float = 1.5,
use_tiered_strategy: bool = True,
warm_start: Literal["shortest", "longest", "user"] | None = "shortest",
refine_paths: bool = True,
) -> None:
self.context = context
self.metrics = metrics if metrics is not None else AStarMetrics()
self.max_iterations = max_iterations
self.base_congestion_penalty = base_congestion_penalty
self.congestion_multiplier = congestion_multiplier
self.use_tiered_strategy = use_tiered_strategy
self.warm_start = warm_start
self.refine_paths = refine_paths
self.refiner = PathRefiner(context)
self.path_state = PathStateManager(context.cost_evaluator.collision_engine)
self.accumulated_expanded_nodes: list[tuple[int, int, int]] = []
@property
def cost_evaluator(self) -> CostEvaluator:
return self.context.cost_evaluator
def _build_greedy_warm_start_paths(
self,
netlist: dict[str, tuple[Port, Port]],
net_widths: dict[str, float],
order: Literal["shortest", "longest", "user"],
) -> dict[str, list[ComponentResult]]:
all_net_ids = list(netlist.keys())
if order != "user":
all_net_ids.sort(
key=lambda nid: abs(netlist[nid][1].x - netlist[nid][0].x) + abs(netlist[nid][1].y - netlist[nid][0].y),
reverse=(order == "longest"),
)
greedy_paths: dict[str, list[ComponentResult]] = {}
temp_obj_ids: list[int] = []
greedy_node_limit = min(self.context.config.node_limit, 2000)
for net_id in all_net_ids:
start, target = netlist[net_id]
width = net_widths.get(net_id, 2.0)
h_start = self.cost_evaluator.h_manhattan(start, target)
max_cost_limit = max(h_start * 3.0, 2000.0)
path = route_astar(
start,
target,
width,
context=self.context,
metrics=self.metrics,
net_id=net_id,
skip_congestion=True,
max_cost=max_cost_limit,
self_collision_check=True,
node_limit=greedy_node_limit,
)
if not path:
continue
greedy_paths[net_id] = path
temp_obj_ids.extend(self.path_state.stage_path_as_static(path))
self.context.clear_static_caches()
self.path_state.remove_static_obstacles(temp_obj_ids)
return greedy_paths
def _path_cost(self, path: list[ComponentResult]) -> float:
return self.refiner.path_cost(path)
def _install_path(self, net_id: str, path: list[ComponentResult]) -> None:
self.path_state.install_path(net_id, path)
def _build_routing_result(
self,
*,
net_id: str,
path: list[ComponentResult],
reached_target: bool,
collisions: int,
outcome: RoutingOutcome | None = None,
) -> RoutingResult:
resolved_outcome = (
infer_routing_outcome(
has_path=bool(path),
reached_target=reached_target,
collision_count=collisions,
)
if outcome is None
else outcome
)
return RoutingResult(
net_id=net_id,
path=path,
is_valid=(resolved_outcome == "completed"),
collisions=collisions,
reached_target=reached_target,
outcome=resolved_outcome,
)
def _refine_path(
self,
net_id: str,
start: Port,
target: Port,
net_width: float,
path: list[ComponentResult],
) -> list[ComponentResult]:
return self.refiner.refine_path(net_id, start, target, net_width, path)
def _route_net_once(
self,
net_id: str,
start: Port,
target: Port,
width: float,
iteration: int,
initial_paths: dict[str, list[ComponentResult]] | None,
store_expanded: bool,
needs_self_collision_check: set[str],
) -> tuple[RoutingResult, RoutingOutcome]:
self.path_state.remove_path(net_id)
path: list[ComponentResult] | None = None
if iteration == 0 and initial_paths and net_id in initial_paths:
path = initial_paths[net_id]
else:
target_coll_model = self.context.config.bend_collision_type
coll_model = target_coll_model
skip_cong = False
if self.use_tiered_strategy and iteration == 0:
skip_cong = True
if target_coll_model == "arc":
coll_model = "clipped_bbox"
path = route_astar(
start,
target,
width,
context=self.context,
metrics=self.metrics,
net_id=net_id,
bend_collision_type=coll_model,
return_partial=True,
store_expanded=store_expanded,
skip_congestion=skip_cong,
self_collision_check=(net_id in needs_self_collision_check),
node_limit=self.context.config.node_limit,
)
if store_expanded and self.metrics.last_expanded_nodes:
self.accumulated_expanded_nodes.extend(self.metrics.last_expanded_nodes)
if not path:
outcome = infer_routing_outcome(has_path=False, reached_target=False, collision_count=0)
return self._build_routing_result(net_id=net_id, path=[], reached_target=False, collisions=0, outcome=outcome), outcome
last_p = path[-1].end_port
reached = last_p == target
collision_count = 0
self._install_path(net_id, path)
if reached:
report = self.path_state.verify_path_report(net_id, path)
collision_count = report.collision_count
if report.self_collision_count > 0:
needs_self_collision_check.add(net_id)
outcome = infer_routing_outcome(
has_path=bool(path),
reached_target=reached,
collision_count=collision_count,
)
return (
self._build_routing_result(
net_id=net_id,
path=path,
reached_target=reached,
collisions=collision_count,
outcome=outcome,
),
outcome,
)
def route_all(
self,
netlist: dict[str, tuple[Port, Port]],
net_widths: dict[str, float],
store_expanded: bool = False,
iteration_callback: Callable[[int, dict[str, RoutingResult]], None] | None = None,
shuffle_nets: bool = False,
sort_nets: Literal["shortest", "longest", "user", None] = None,
initial_paths: dict[str, list[ComponentResult]] | None = None,
seed: int | None = None,
) -> dict[str, RoutingResult]:
self.cost_evaluator.congestion_penalty = self.base_congestion_penalty
self.accumulated_expanded_nodes = []
self.metrics.reset_per_route()
state = create_routing_session_state(
self,
netlist,
net_widths,
store_expanded=store_expanded,
iteration_callback=iteration_callback,
shuffle_nets=shuffle_nets,
sort_nets=sort_nets,
initial_paths=initial_paths,
seed=seed,
)
prepare_routing_session_state(self, state)
for iteration in range(self.max_iterations):
iteration_outcomes = run_routing_iteration(self, state, iteration)
if iteration_outcomes is None:
return self.verify_all_nets(state.results, state.netlist)
if not any(routing_outcome_needs_retry(outcome) for outcome in iteration_outcomes.values()):
break
self.cost_evaluator.congestion_penalty *= self.congestion_multiplier
refine_routing_session_results(self, state)
return finalize_routing_session_results(self, state)
def verify_all_nets(
self,
results: dict[str, RoutingResult],
netlist: dict[str, tuple[Port, Port]],
) -> dict[str, RoutingResult]:
final_results: dict[str, RoutingResult] = {}
for net_id, (_, target_p) in netlist.items():
res = results.get(net_id)
if not res or not res.path:
final_results[net_id] = self._build_routing_result(
net_id=net_id,
path=[],
reached_target=False,
collisions=0,
)
continue
last_p = res.path[-1].end_port
reached = last_p == target_p
report = self.path_state.verify_path_report(net_id, res.path)
final_results[net_id] = RoutingResult(
net_id=net_id,
path=res.path,
is_valid=(reached and report.is_valid),
collisions=report.collision_count,
reached_target=reached,
outcome=infer_routing_outcome(
has_path=True,
reached_target=reached,
collision_count=report.collision_count,
),
)
return final_results