inire/inire/router/_search.py

106 lines
3.4 KiB
Python

from __future__ import annotations
import heapq
from typing import TYPE_CHECKING
from inire.constants import TOLERANCE_LINEAR
from ._astar_moves import expand_moves as _expand_moves
from ._astar_types import AStarContext, AStarMetrics, AStarNode as _AStarNode, SearchRunConfig
if TYPE_CHECKING:
from inire.geometry.components import ComponentResult
from inire.geometry.primitives import Port
def _reconstruct_path(end_node: _AStarNode) -> list[ComponentResult]:
path = []
curr: _AStarNode | None = end_node
while curr and curr.component_result:
path.append(curr.component_result)
curr = curr.parent
return path[::-1]
def route_astar(
start: Port,
target: Port,
net_width: float,
context: AStarContext,
*,
metrics: AStarMetrics | None = None,
net_id: str = "default",
config: SearchRunConfig,
) -> list[ComponentResult] | None:
if metrics is None:
metrics = AStarMetrics()
metrics.reset_per_route()
context.ensure_static_caches_current()
context.cost_evaluator.set_target(target)
open_set: list[_AStarNode] = []
closed_set: dict[tuple[int, int, int], float] = {}
congestion_cache: dict[tuple, int] = {}
congestion_presence_cache: dict[tuple[str, int, int, int, int], bool] = {}
congestion_candidate_precheck_cache: dict[tuple[str, int, int, int, int], bool] = {}
congestion_net_envelope_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]] = {}
congestion_grid_net_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]] = {}
congestion_grid_span_cache: dict[tuple[str, int, int, int, int], dict[str, tuple[int, ...]]] = {}
start_node = _AStarNode(
start,
0.0,
context.cost_evaluator.h_manhattan(start, target, min_bend_radius=context.min_bend_radius),
seed_index=0 if config.guidance_seed else None,
)
heapq.heappush(open_set, start_node)
best_node = start_node
nodes_expanded = 0
while open_set:
if nodes_expanded >= config.node_limit:
return _reconstruct_path(best_node) if config.return_partial else None
current = heapq.heappop(open_set)
if config.max_cost is not None and current.fh_cost[0] > config.max_cost:
metrics.pruned_cost += 1
metrics.total_pruned_cost += 1
continue
if current.h_cost < best_node.h_cost:
best_node = current
state = current.port.as_tuple()
if state in closed_set and closed_set[state] <= current.g_cost + TOLERANCE_LINEAR:
continue
closed_set[state] = current.g_cost
if config.store_expanded:
metrics.last_expanded_nodes.append(state)
nodes_expanded += 1
metrics.total_nodes_expanded += 1
metrics.nodes_expanded += 1
if current.port == target:
return _reconstruct_path(current)
_expand_moves(
current,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
congestion_presence_cache,
congestion_candidate_precheck_cache,
congestion_net_envelope_cache,
congestion_grid_net_cache,
congestion_grid_span_cache,
config=config,
)
return _reconstruct_path(best_node) if config.return_partial else None