inire/inire/router/_astar_admission.py

200 lines
6.4 KiB
Python

from __future__ import annotations
import heapq
from typing import TYPE_CHECKING
from shapely.geometry import Polygon
from inire.constants import TOLERANCE_LINEAR
from inire.geometry.components import Bend90, SBend, Straight, MoveKind
from inire.geometry.primitives import Port
from inire.router.refiner import component_hits_ancestor_chain
from ._astar_types import AStarContext, AStarMetrics, AStarNode, SearchRunConfig
if TYPE_CHECKING:
from inire.geometry.components import ComponentResult
def process_move(
parent: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
config: SearchRunConfig,
move_class: MoveKind,
params: tuple,
) -> None:
cp = parent.port
coll_type = config.bend_collision_type
coll_key = id(coll_type) if isinstance(coll_type, Polygon) else coll_type
physical_type = config.bend_physical_geometry
physical_key = id(physical_type) if isinstance(physical_type, Polygon) else physical_type
self_dilation = context.cost_evaluator.collision_engine.clearance / 2.0
abs_key = (
cp.as_tuple(),
move_class,
params,
net_width,
coll_key,
physical_key,
self_dilation,
)
if abs_key in context.move_cache_abs:
res = context.move_cache_abs[abs_key]
else:
context.check_cache_eviction()
base_port = Port(0, 0, cp.r)
rel_key = (
cp.r,
move_class,
params,
net_width,
coll_key,
physical_key,
self_dilation,
)
if rel_key in context.move_cache_rel:
res_rel = context.move_cache_rel[rel_key]
else:
try:
if move_class == "straight":
res_rel = Straight.generate(base_port, params[0], net_width, dilation=self_dilation)
elif move_class == "bend90":
res_rel = Bend90.generate(
base_port,
params[0],
net_width,
params[1],
collision_type=coll_type,
physical_geometry_type=config.bend_physical_geometry,
clip_margin=config.bend_clip_margin,
dilation=self_dilation,
)
else:
res_rel = SBend.generate(
base_port,
params[0],
params[1],
net_width,
collision_type=coll_type,
physical_geometry_type=config.bend_physical_geometry,
clip_margin=config.bend_clip_margin,
dilation=self_dilation,
)
except ValueError:
return
context.move_cache_rel[rel_key] = res_rel
res = res_rel.translate(cp.x, cp.y)
context.move_cache_abs[abs_key] = res
move_radius = params[0] if move_class == "bend90" else (params[1] if move_class == "sbend" else None)
add_node(
parent,
res,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
config,
move_class,
abs_key,
)
def add_node(
parent: AStarNode,
result: ComponentResult,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
config: SearchRunConfig,
move_type: MoveKind,
cache_key: tuple,
) -> None:
metrics.moves_generated += 1
metrics.total_moves_generated += 1
state = result.end_port.as_tuple()
new_lower_bound_g = parent.g_cost + result.length
if state in closed_set and closed_set[state] <= new_lower_bound_g + TOLERANCE_LINEAR:
metrics.pruned_closed_set += 1
metrics.total_pruned_closed_set += 1
return
parent_p = parent.port
end_p = result.end_port
if cache_key in context.hard_collision_set:
metrics.pruned_hard_collision += 1
metrics.total_pruned_hard_collision += 1
return
is_static_safe = cache_key in context.static_safe_cache
if not is_static_safe:
ce = context.cost_evaluator.collision_engine
if move_type == "straight":
collision_found = ce.check_move_straight_static(parent_p, result.length, net_width=net_width)
else:
collision_found = ce.check_move_static(result, start_port=parent_p, end_port=end_p)
if collision_found:
context.hard_collision_set.add(cache_key)
metrics.pruned_hard_collision += 1
metrics.total_pruned_hard_collision += 1
return
context.static_safe_cache.add(cache_key)
total_overlaps = 0
if not config.skip_congestion:
if cache_key in congestion_cache:
total_overlaps = congestion_cache[cache_key]
else:
total_overlaps = context.cost_evaluator.collision_engine.check_move_congestion(result, net_id)
congestion_cache[cache_key] = total_overlaps
if config.self_collision_check and component_hits_ancestor_chain(result, parent):
return
move_cost = context.cost_evaluator.score_component(
result,
start_port=parent_p,
)
move_cost += total_overlaps * context.congestion_penalty
if config.max_cost is not None and parent.g_cost + move_cost > config.max_cost:
metrics.pruned_cost += 1
metrics.total_pruned_cost += 1
return
if move_cost > 1e12:
metrics.pruned_cost += 1
metrics.total_pruned_cost += 1
return
g_cost = parent.g_cost + move_cost
if state in closed_set and closed_set[state] <= g_cost + TOLERANCE_LINEAR:
metrics.pruned_closed_set += 1
metrics.total_pruned_closed_set += 1
return
h_cost = context.cost_evaluator.h_manhattan(
result.end_port,
target,
min_bend_radius=context.min_bend_radius,
)
heapq.heappush(open_set, AStarNode(result.end_port, g_cost, h_cost, parent, result))
metrics.moves_added += 1
metrics.total_moves_added += 1