297 lines
12 KiB
Python
297 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import heapq
|
|
from typing import TYPE_CHECKING
|
|
|
|
from shapely.geometry import Polygon
|
|
|
|
from inire.constants import TOLERANCE_LINEAR
|
|
from inire.geometry.components import Bend90, SBend, Straight, MoveKind
|
|
from inire.geometry.primitives import Port
|
|
from inire.router.refiner import component_hits_ancestor_chain
|
|
|
|
from ._astar_types import AStarContext, AStarMetrics, AStarNode, SearchRunConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from inire.geometry.components import ComponentResult
|
|
|
|
|
|
def process_move(
|
|
parent: AStarNode,
|
|
target: Port,
|
|
net_width: float,
|
|
net_id: str,
|
|
open_set: list[AStarNode],
|
|
closed_set: dict[tuple[int, int, int], float],
|
|
context: AStarContext,
|
|
metrics: AStarMetrics,
|
|
congestion_cache: dict[tuple, int],
|
|
congestion_presence_cache: dict[tuple[str, int, int, int, int], bool],
|
|
congestion_candidate_precheck_cache: dict[tuple[str, int, int, int, int], bool],
|
|
congestion_net_envelope_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
|
|
congestion_grid_net_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
|
|
congestion_grid_span_cache: dict[tuple[str, int, int, int, int], dict[str, tuple[int, ...]]],
|
|
config: SearchRunConfig,
|
|
move_class: MoveKind,
|
|
params: tuple,
|
|
) -> None:
|
|
cp = parent.port
|
|
coll_type = config.bend_collision_type
|
|
coll_key = id(coll_type) if isinstance(coll_type, Polygon) else coll_type
|
|
physical_type = config.bend_physical_geometry
|
|
physical_key = id(physical_type) if isinstance(physical_type, Polygon) else physical_type
|
|
self_dilation = context.cost_evaluator.collision_engine.clearance / 2.0
|
|
|
|
abs_key = (
|
|
cp.as_tuple(),
|
|
move_class,
|
|
params,
|
|
net_width,
|
|
coll_key,
|
|
physical_key,
|
|
self_dilation,
|
|
)
|
|
if abs_key in context.move_cache_abs:
|
|
context.metrics.total_move_cache_abs_hits += 1
|
|
res = context.move_cache_abs[abs_key]
|
|
else:
|
|
context.metrics.total_move_cache_abs_misses += 1
|
|
context.check_cache_eviction()
|
|
base_port = Port(0, 0, cp.r)
|
|
rel_key = (
|
|
cp.r,
|
|
move_class,
|
|
params,
|
|
net_width,
|
|
coll_key,
|
|
physical_key,
|
|
self_dilation,
|
|
)
|
|
if rel_key in context.move_cache_rel:
|
|
context.metrics.total_move_cache_rel_hits += 1
|
|
res_rel = context.move_cache_rel[rel_key]
|
|
else:
|
|
context.metrics.total_move_cache_rel_misses += 1
|
|
try:
|
|
if move_class == "straight":
|
|
res_rel = Straight.generate(base_port, params[0], net_width, dilation=self_dilation)
|
|
elif move_class == "bend90":
|
|
res_rel = Bend90.generate(
|
|
base_port,
|
|
params[0],
|
|
net_width,
|
|
params[1],
|
|
collision_type=coll_type,
|
|
physical_geometry_type=config.bend_physical_geometry,
|
|
clip_margin=config.bend_clip_margin,
|
|
dilation=self_dilation,
|
|
)
|
|
else:
|
|
res_rel = SBend.generate(
|
|
base_port,
|
|
params[0],
|
|
params[1],
|
|
net_width,
|
|
collision_type=coll_type,
|
|
physical_geometry_type=config.bend_physical_geometry,
|
|
clip_margin=config.bend_clip_margin,
|
|
dilation=self_dilation,
|
|
)
|
|
except ValueError:
|
|
return
|
|
context.move_cache_rel[rel_key] = res_rel
|
|
res = res_rel.translate(cp.x, cp.y)
|
|
context.move_cache_abs[abs_key] = res
|
|
|
|
add_node(
|
|
parent,
|
|
res,
|
|
target,
|
|
net_width,
|
|
net_id,
|
|
open_set,
|
|
closed_set,
|
|
context,
|
|
metrics,
|
|
congestion_cache,
|
|
congestion_presence_cache,
|
|
congestion_candidate_precheck_cache,
|
|
congestion_net_envelope_cache,
|
|
congestion_grid_net_cache,
|
|
congestion_grid_span_cache,
|
|
config,
|
|
move_class,
|
|
abs_key,
|
|
)
|
|
|
|
|
|
def add_node(
|
|
parent: AStarNode,
|
|
result: ComponentResult,
|
|
target: Port,
|
|
net_width: float,
|
|
net_id: str,
|
|
open_set: list[AStarNode],
|
|
closed_set: dict[tuple[int, int, int], float],
|
|
context: AStarContext,
|
|
metrics: AStarMetrics,
|
|
congestion_cache: dict[tuple, int],
|
|
congestion_presence_cache: dict[tuple[str, int, int, int, int], bool],
|
|
congestion_candidate_precheck_cache: dict[tuple[str, int, int, int, int], bool],
|
|
congestion_net_envelope_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
|
|
congestion_grid_net_cache: dict[tuple[str, int, int, int, int], tuple[str, ...]],
|
|
congestion_grid_span_cache: dict[tuple[str, int, int, int, int], dict[str, tuple[int, ...]]],
|
|
config: SearchRunConfig,
|
|
move_type: MoveKind,
|
|
cache_key: tuple,
|
|
) -> None:
|
|
frontier_trace = config.frontier_trace
|
|
metrics.moves_generated += 1
|
|
metrics.total_moves_generated += 1
|
|
state = result.end_port.as_tuple()
|
|
new_lower_bound_g = parent.g_cost + result.length
|
|
if state in closed_set and closed_set[state] <= new_lower_bound_g + TOLERANCE_LINEAR:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_closed_set += 1
|
|
metrics.total_pruned_closed_set += 1
|
|
return
|
|
|
|
parent_p = parent.port
|
|
end_p = result.end_port
|
|
|
|
if cache_key in context.hard_collision_set:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("hard_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
context.metrics.total_hard_collision_cache_hits += 1
|
|
metrics.pruned_hard_collision += 1
|
|
metrics.total_pruned_hard_collision += 1
|
|
return
|
|
|
|
is_static_safe = cache_key in context.static_safe_cache
|
|
if is_static_safe:
|
|
context.metrics.total_static_safe_cache_hits += 1
|
|
if not is_static_safe:
|
|
ce = context.cost_evaluator.collision_engine
|
|
if move_type == "straight":
|
|
collision_found = ce.check_move_straight_static(parent_p, result.length, net_width=net_width)
|
|
else:
|
|
collision_found = ce.check_move_static(result, start_port=parent_p, end_port=end_p)
|
|
if collision_found:
|
|
context.hard_collision_set.add(cache_key)
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("hard_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_hard_collision += 1
|
|
metrics.total_pruned_hard_collision += 1
|
|
return
|
|
context.static_safe_cache.add(cache_key)
|
|
|
|
if config.self_collision_check and component_hits_ancestor_chain(result, parent):
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("self_collision", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
return
|
|
|
|
move_cost = context.cost_evaluator.score_component(
|
|
result,
|
|
start_port=parent_p,
|
|
)
|
|
next_seed_index = None
|
|
if (
|
|
config.guidance_seed is not None
|
|
and parent.seed_index is not None
|
|
and parent.seed_index < len(config.guidance_seed)
|
|
and result.move_spec == config.guidance_seed[parent.seed_index]
|
|
):
|
|
context.metrics.total_guidance_match_moves += 1
|
|
if result.move_type == "straight":
|
|
context.metrics.total_guidance_match_moves_straight += 1
|
|
applied_bonus = config.guidance_bonus
|
|
context.metrics.total_guidance_bonus_applied_straight += applied_bonus
|
|
elif result.move_type == "bend90":
|
|
context.metrics.total_guidance_match_moves_bend90 += 1
|
|
applied_bonus = config.guidance_bonus
|
|
context.metrics.total_guidance_bonus_applied_bend90 += applied_bonus
|
|
else:
|
|
context.metrics.total_guidance_match_moves_sbend += 1
|
|
applied_bonus = config.guidance_bonus
|
|
context.metrics.total_guidance_bonus_applied_sbend += applied_bonus
|
|
context.metrics.total_guidance_bonus_applied += applied_bonus
|
|
move_cost = max(0.001, move_cost - applied_bonus)
|
|
next_seed_index = parent.seed_index + 1
|
|
|
|
if config.max_cost is not None and parent.g_cost + move_cost > config.max_cost:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("cost", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_cost += 1
|
|
metrics.total_pruned_cost += 1
|
|
return
|
|
if move_cost > 1e12:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("cost", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_cost += 1
|
|
metrics.total_pruned_cost += 1
|
|
return
|
|
|
|
if state in closed_set and closed_set[state] <= parent.g_cost + move_cost + TOLERANCE_LINEAR:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_closed_set += 1
|
|
metrics.total_pruned_closed_set += 1
|
|
return
|
|
|
|
total_overlaps = 0
|
|
if not config.skip_congestion and context.cost_evaluator.collision_engine.has_dynamic_paths():
|
|
ce = context.cost_evaluator.collision_engine
|
|
if ce.has_possible_move_congestion(result, net_id, congestion_presence_cache):
|
|
if ce.has_candidate_move_congestion(
|
|
result,
|
|
net_id,
|
|
congestion_candidate_precheck_cache,
|
|
congestion_net_envelope_cache,
|
|
congestion_grid_net_cache,
|
|
):
|
|
if cache_key in congestion_cache:
|
|
context.metrics.total_congestion_cache_hits += 1
|
|
total_overlaps = congestion_cache[cache_key]
|
|
else:
|
|
context.metrics.total_congestion_cache_misses += 1
|
|
total_overlaps = ce.check_move_congestion(
|
|
result,
|
|
net_id,
|
|
net_envelope_cache=congestion_net_envelope_cache,
|
|
grid_net_cache=congestion_grid_net_cache,
|
|
broad_phase_cache=congestion_grid_span_cache,
|
|
)
|
|
congestion_cache[cache_key] = total_overlaps
|
|
else:
|
|
context.metrics.total_congestion_candidate_precheck_skips += 1
|
|
else:
|
|
context.metrics.total_congestion_presence_skips += 1
|
|
move_cost += total_overlaps * context.congestion_penalty
|
|
|
|
g_cost = parent.g_cost + move_cost
|
|
if state in closed_set and closed_set[state] <= g_cost + TOLERANCE_LINEAR:
|
|
if frontier_trace is not None:
|
|
frontier_trace.record("closed_set", move_type, parent.port.as_tuple(), state, result.total_dilated_bounds)
|
|
metrics.pruned_closed_set += 1
|
|
metrics.total_pruned_closed_set += 1
|
|
return
|
|
|
|
h_cost = context.cost_evaluator.h_manhattan(
|
|
result.end_port,
|
|
target,
|
|
min_bend_radius=context.min_bend_radius,
|
|
)
|
|
heapq.heappush(
|
|
open_set,
|
|
AStarNode(
|
|
result.end_port,
|
|
g_cost,
|
|
h_cost,
|
|
parent,
|
|
result,
|
|
seed_index=next_seed_index,
|
|
),
|
|
)
|
|
metrics.moves_added += 1
|
|
metrics.total_moves_added += 1
|