inire/inire/router/_astar_admission.py

297 lines
12 KiB
Python

from __future__ import annotations
import heapq
from typing import TYPE_CHECKING
from shapely.geometry import Polygon
from inire.constants import TOLERANCE_LINEAR
from inire.geometry.components import Bend90, SBend, Straight, MoveKind
from inire.geometry.primitives import Port
from inire.router.refiner import component_hits_ancestor_chain
from ._astar_types import AStarContext, AStarMetrics, AStarNode, SearchRunConfig
if TYPE_CHECKING:
from inire.geometry.components import ComponentResult
def process_move(
parent: AStarNode,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
congestion_presence_cache: dict[tuple[str, int, int, int, int], bool],
congestion_candidate_precheck_cache: dict[tuple[str, int, int, int, int], bool],
congestion_net_envelope_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
congestion_grid_net_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
congestion_grid_span_cache: dict[tuple[str, int, int, int, int], dict[str, tuple[int, ...]]],
config: SearchRunConfig,
move_class: MoveKind,
params: tuple,
) -> None:
cp = parent.port
coll_type = config.bend_collision_type
coll_key = id(coll_type) if isinstance(coll_type, Polygon) else coll_type
physical_type = config.bend_physical_geometry
physical_key = id(physical_type) if isinstance(physical_type, Polygon) else physical_type
self_dilation = context.cost_evaluator.collision_engine.clearance / 2.0
abs_key = (
cp.as_tuple(),
move_class,
params,
net_width,
coll_key,
physical_key,
self_dilation,
)
if abs_key in context.move_cache_abs:
context.metrics.total_move_cache_abs_hits += 1
res = context.move_cache_abs[abs_key]
else:
context.metrics.total_move_cache_abs_misses += 1
context.check_cache_eviction()
base_port = Port(0, 0, cp.r)
rel_key = (
cp.r,
move_class,
params,
net_width,
coll_key,
physical_key,
self_dilation,
)
if rel_key in context.move_cache_rel:
context.metrics.total_move_cache_rel_hits += 1
res_rel = context.move_cache_rel[rel_key]
else:
context.metrics.total_move_cache_rel_misses += 1
try:
if move_class == "straight":
res_rel = Straight.generate(base_port, params[0], net_width, dilation=self_dilation)
elif move_class == "bend90":
res_rel = Bend90.generate(
base_port,
params[0],
net_width,
params[1],
collision_type=coll_type,
physical_geometry_type=config.bend_physical_geometry,
clip_margin=config.bend_clip_margin,
dilation=self_dilation,
)
else:
res_rel = SBend.generate(
base_port,
params[0],
params[1],
net_width,
collision_type=coll_type,
physical_geometry_type=config.bend_physical_geometry,
clip_margin=config.bend_clip_margin,
dilation=self_dilation,
)
except ValueError:
return
context.move_cache_rel[rel_key] = res_rel
res = res_rel.translate(cp.x, cp.y)
context.move_cache_abs[abs_key] = res
add_node(
parent,
res,
target,
net_width,
net_id,
open_set,
closed_set,
context,
metrics,
congestion_cache,
congestion_presence_cache,
congestion_candidate_precheck_cache,
congestion_net_envelope_cache,
congestion_grid_net_cache,
congestion_grid_span_cache,
config,
move_class,
abs_key,
)
def add_node(
parent: AStarNode,
result: ComponentResult,
target: Port,
net_width: float,
net_id: str,
open_set: list[AStarNode],
closed_set: dict[tuple[int, int, int], float],
context: AStarContext,
metrics: AStarMetrics,
congestion_cache: dict[tuple, int],
congestion_presence_cache: dict[tuple[str, int, int, int, int], bool],
congestion_candidate_precheck_cache: dict[tuple[str, int, int, int, int], bool],
congestion_net_envelope_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
congestion_grid_net_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
congestion_grid_span_cache: dict[tuple[str, int, int, int, int], dict[str, tuple[int, ...]]],
config: SearchRunConfig,
move_type: MoveKind,
cache_key: tuple,
) -> None:
frontier_trace = config.frontier_trace
metrics.moves_generated += 1
metrics.total_moves_generated += 1
state = result.end_port.as_tuple()
new_lower_bound_g = parent.g_cost + result.length
if state in closed_set and closed_set[state] <= new_lower_bound_g + TOLERANCE_LINEAR:
if frontier_trace is not None:
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_closed_set += 1
metrics.total_pruned_closed_set += 1
return
parent_p = parent.port
end_p = result.end_port
if cache_key in context.hard_collision_set:
if frontier_trace is not None:
frontier_trace.record("hard_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
context.metrics.total_hard_collision_cache_hits += 1
metrics.pruned_hard_collision += 1
metrics.total_pruned_hard_collision += 1
return
is_static_safe = cache_key in context.static_safe_cache
if is_static_safe:
context.metrics.total_static_safe_cache_hits += 1
if not is_static_safe:
ce = context.cost_evaluator.collision_engine
if move_type == "straight":
collision_found = ce.check_move_straight_static(parent_p, result.length, net_width=net_width)
else:
collision_found = ce.check_move_static(result, start_port=parent_p, end_port=end_p)
if collision_found:
context.hard_collision_set.add(cache_key)
if frontier_trace is not None:
frontier_trace.record("hard_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_hard_collision += 1
metrics.total_pruned_hard_collision += 1
return
context.static_safe_cache.add(cache_key)
if config.self_collision_check and component_hits_ancestor_chain(result, parent):
if frontier_trace is not None:
frontier_trace.record("self_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
return
move_cost = context.cost_evaluator.score_component(
result,
start_port=parent_p,
)
next_seed_index = None
if (
config.guidance_seed is not None
and parent.seed_index is not None
and parent.seed_index < len(config.guidance_seed)
and result.move_spec == config.guidance_seed[parent.seed_index]
):
context.metrics.total_guidance_match_moves += 1
if result.move_type == "straight":
context.metrics.total_guidance_match_moves_straight += 1
applied_bonus = config.guidance_bonus
context.metrics.total_guidance_bonus_applied_straight += applied_bonus
elif result.move_type == "bend90":
context.metrics.total_guidance_match_moves_bend90 += 1
applied_bonus = config.guidance_bonus
context.metrics.total_guidance_bonus_applied_bend90 += applied_bonus
else:
context.metrics.total_guidance_match_moves_sbend += 1
applied_bonus = config.guidance_bonus
context.metrics.total_guidance_bonus_applied_sbend += applied_bonus
context.metrics.total_guidance_bonus_applied += applied_bonus
move_cost = max(0.001, move_cost - applied_bonus)
next_seed_index = parent.seed_index + 1
if config.max_cost is not None and parent.g_cost + move_cost > config.max_cost:
if frontier_trace is not None:
frontier_trace.record("cost", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_cost += 1
metrics.total_pruned_cost += 1
return
if move_cost > 1e12:
if frontier_trace is not None:
frontier_trace.record("cost", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_cost += 1
metrics.total_pruned_cost += 1
return
if state in closed_set and closed_set[state] <= parent.g_cost + move_cost + TOLERANCE_LINEAR:
if frontier_trace is not None:
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_closed_set += 1
metrics.total_pruned_closed_set += 1
return
total_overlaps = 0
if not config.skip_congestion and context.cost_evaluator.collision_engine.has_dynamic_paths():
ce = context.cost_evaluator.collision_engine
if ce.has_possible_move_congestion(result, net_id, congestion_presence_cache):
if ce.has_candidate_move_congestion(
result,
net_id,
congestion_candidate_precheck_cache,
congestion_net_envelope_cache,
congestion_grid_net_cache,
):
if cache_key in congestion_cache:
context.metrics.total_congestion_cache_hits += 1
total_overlaps = congestion_cache[cache_key]
else:
context.metrics.total_congestion_cache_misses += 1
total_overlaps = ce.check_move_congestion(
result,
net_id,
net_envelope_cache=congestion_net_envelope_cache,
grid_net_cache=congestion_grid_net_cache,
broad_phase_cache=congestion_grid_span_cache,
)
congestion_cache[cache_key] = total_overlaps
else:
context.metrics.total_congestion_candidate_precheck_skips += 1
else:
context.metrics.total_congestion_presence_skips += 1
move_cost += total_overlaps * context.congestion_penalty
g_cost = parent.g_cost + move_cost
if state in closed_set and closed_set[state] <= g_cost + TOLERANCE_LINEAR:
if frontier_trace is not None:
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
metrics.pruned_closed_set += 1
metrics.total_pruned_closed_set += 1
return
h_cost = context.cost_evaluator.h_manhattan(
result.end_port,
target,
min_bend_radius=context.min_bend_radius,
)
heapq.heappush(
open_set,
AStarNode(
result.end_port,
g_cost,
h_cost,
parent,
result,
seed_index=next_seed_index,
),
)
metrics.moves_added += 1
metrics.total_moves_added += 1