inire/inire/tests/test_astar.py

669 lines
24 KiB
Python

import math
import pytest
from shapely.geometry import Polygon
from inire import CongestionOptions, NetSpec, RoutingProblem, RoutingOptions, RoutingResult, SearchOptions
from inire.geometry.components import Bend90, Straight
from inire.geometry.collision import RoutingWorld
from inire.geometry.primitives import Port
from inire.router._astar_types import AStarContext, AStarNode, SearchRunConfig
from inire.router._astar_admission import add_node
from inire.router._astar_moves import (
_distance_to_bounds_in_heading,
_should_cap_straights_to_bounds,
)
from inire.router._router import PathFinder, _RoutingState
from inire.router._search import route_astar
from inire.router.cost import CostEvaluator
from inire.router.danger_map import DangerMap
from inire.seeds import StraightSeed
BOUNDS = (0, -50, 150, 150)
@pytest.fixture
def basic_evaluator() -> CostEvaluator:
engine = RoutingWorld(clearance=2.0)
danger_map = DangerMap(bounds=BOUNDS)
danger_map.precompute([])
return CostEvaluator(engine, danger_map, bend_penalty=50.0, sbend_penalty=150.0)
def _build_options(**search_overrides: object) -> RoutingOptions:
return RoutingOptions(search=SearchOptions(**search_overrides))
def _build_context(
evaluator: CostEvaluator,
*,
bounds: tuple[float, float, float, float],
**search_overrides: object,
) -> AStarContext:
return AStarContext(
evaluator,
RoutingProblem(bounds=bounds),
_build_options(**search_overrides),
)
def _route(context: AStarContext, start: Port, target: Port, **config_overrides: object):
return route_astar(
start,
target,
net_width=2.0,
context=context,
config=SearchRunConfig.from_options(context.options, **config_overrides),
)
def _validate_routing_result(
result: RoutingResult,
static_obstacles: list[Polygon],
clearance: float,
expected_start: Port | None = None,
expected_end: Port | None = None,
) -> dict[str, object]:
if not result.path:
return {"is_valid": False, "reason": "No path found"}
connectivity_errors: list[str] = []
if expected_start:
first_port = result.path[0].start_port
dist_to_start = math.hypot(first_port.x - expected_start.x, first_port.y - expected_start.y)
if dist_to_start > 0.005:
connectivity_errors.append(f"Initial port position mismatch: {dist_to_start*1000:.2f}nm")
if abs(first_port.r - expected_start.r) > 0.1:
connectivity_errors.append(f"Initial port orientation mismatch: {first_port.r} vs {expected_start.r}")
if expected_end:
last_port = result.path[-1].end_port
dist_to_end = math.hypot(last_port.x - expected_end.x, last_port.y - expected_end.y)
if dist_to_end > 0.005:
connectivity_errors.append(f"Final port position mismatch: {dist_to_end*1000:.2f}nm")
if abs(last_port.r - expected_end.r) > 0.1:
connectivity_errors.append(f"Final port orientation mismatch: {last_port.r} vs {expected_end.r}")
engine = RoutingWorld(clearance=clearance)
for obstacle in static_obstacles:
engine.add_static_obstacle(obstacle)
report = engine.verify_path_report("validation", result.path)
is_valid = report.is_valid and not connectivity_errors
reasons = []
if report.static_collision_count:
reasons.append(f"Found {report.static_collision_count} obstacle collisions.")
if report.dynamic_collision_count:
reasons.append(f"Found {report.dynamic_collision_count} dynamic-net collisions.")
if report.self_collision_count:
reasons.append(f"Found {report.self_collision_count} self-intersections.")
reasons.extend(connectivity_errors)
return {
"is_valid": is_valid,
"reason": " ".join(reasons),
"obstacle_collisions": report.static_collision_count,
"dynamic_collisions": report.dynamic_collision_count,
"self_intersections": report.self_collision_count,
"total_length": report.total_length,
"connectivity_ok": not connectivity_errors,
}
def test_astar_straight(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
start = Port(0, 0, 0)
target = Port(50, 0, 0)
path = _route(context, start, target)
assert path is not None
result = RoutingResult(net_id="test", path=path, reached_target=True)
validation = _validate_routing_result(result, [], clearance=2.0, expected_start=start, expected_end=target)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
assert validation["connectivity_ok"]
# Path should be exactly 50um (or slightly more if it did weird things, but here it's straight)
assert abs(validation["total_length"] - 50.0) < 1e-6
def test_astar_bend(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS, bend_radii=(10.0,))
start = Port(0, 0, 0)
# 20um right, 20um up. Needs a 10um bend and a 10um bend.
target = Port(20, 20, 0)
path = _route(context, start, target)
assert path is not None
result = RoutingResult(net_id="test", path=path, reached_target=True)
validation = _validate_routing_result(result, [], clearance=2.0, expected_start=start, expected_end=target)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
assert validation["connectivity_ok"]
def test_astar_obstacle(basic_evaluator: CostEvaluator) -> None:
# Add an obstacle in the middle of a straight path
# Obstacle from x=20 to 40, y=-20 to 20
obstacle = Polygon([(20, -20), (40, -20), (40, 20), (20, 20)])
basic_evaluator.collision_engine.add_static_obstacle(obstacle)
basic_evaluator.danger_map.precompute([obstacle])
context = _build_context(basic_evaluator, bounds=BOUNDS, bend_radii=(10.0,), node_limit=1000000)
start = Port(0, 0, 0)
target = Port(60, 0, 0)
path = _route(context, start, target)
assert path is not None
result = RoutingResult(net_id="test", path=path, reached_target=True)
validation = _validate_routing_result(result, [obstacle], clearance=2.0, expected_start=start, expected_end=target)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
# Path should have detoured, so length > 50
assert validation["total_length"] > 50.0
def test_astar_uses_integerized_ports(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
start = Port(0, 0, 0)
target = Port(10.1, 0, 0)
path = _route(context, start, target)
assert path is not None
result = RoutingResult(net_id="test", path=path, reached_target=True)
assert target.x == 10
validation = _validate_routing_result(result, [], clearance=2.0, expected_start=start, expected_end=target)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
def test_validate_routing_result_checks_expected_start() -> None:
path = [Straight.generate(Port(100, 0, 0), 10.0, width=2.0, dilation=1.0)]
result = RoutingResult(net_id="test", path=path, reached_target=True)
validation = _validate_routing_result(
result,
[],
clearance=2.0,
expected_start=Port(0, 0, 0),
expected_end=Port(110, 0, 0),
)
assert not validation["is_valid"]
assert "Initial port position mismatch" in validation["reason"]
def test_validate_routing_result_uses_exact_component_geometry() -> None:
bend = Bend90.generate(Port(0, 0, 0), 10.0, 2.0, direction="CCW", collision_type="bbox", dilation=1.0)
result = RoutingResult(net_id="test", path=[bend], reached_target=True)
obstacle = Polygon([(2.0, 7.0), (4.0, 7.0), (4.0, 9.0), (2.0, 9.0)])
validation = _validate_routing_result(
result,
[obstacle],
clearance=2.0,
expected_start=Port(0, 0, 0),
expected_end=bend.end_port,
)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
def test_astar_context_keeps_evaluator_weights_separate(basic_evaluator: CostEvaluator) -> None:
basic_evaluator = CostEvaluator(
basic_evaluator.collision_engine,
basic_evaluator.danger_map,
bend_penalty=120.0,
sbend_penalty=240.0,
)
context = _build_context(basic_evaluator, bounds=BOUNDS, bend_radii=(5.0,))
assert context.options.search.bend_radii == (5.0,)
assert basic_evaluator.h_manhattan(Port(0, 0, 0), Port(10, 10, 0)) > 0.0
def test_distance_to_bounds_in_heading_is_directional() -> None:
bounds = (0, 0, 100, 200)
assert _distance_to_bounds_in_heading(Port(20, 30, 0), bounds) == pytest.approx(80.0)
assert _distance_to_bounds_in_heading(Port(20, 30, 90), bounds) == pytest.approx(170.0)
assert _distance_to_bounds_in_heading(Port(20, 30, 180), bounds) == pytest.approx(20.0)
assert _distance_to_bounds_in_heading(Port(20, 30, 270), bounds) == pytest.approx(30.0)
def test_should_cap_straights_to_bounds_only_for_large_no_warm_runs(basic_evaluator: CostEvaluator) -> None:
large_context = AStarContext(
basic_evaluator,
RoutingProblem(
bounds=(0, 0, 1000, 1000),
nets=tuple(
NetSpec(f"net{i}", Port(0, i * 10, 0), Port(10, i * 10, 0), width=2.0)
for i in range(8)
),
),
RoutingOptions(
congestion=CongestionOptions(warm_start_enabled=False),
),
)
small_context = _build_context(basic_evaluator, bounds=BOUNDS)
assert _should_cap_straights_to_bounds(large_context)
assert not _should_cap_straights_to_bounds(small_context)
def test_pair_local_context_clones_live_static_obstacles() -> None:
obstacle = Polygon([(20, -20), (40, -20), (40, 20), (20, 20)])
engine = RoutingWorld(clearance=2.0)
engine.add_static_obstacle(obstacle)
danger_map = DangerMap(bounds=BOUNDS)
danger_map.precompute([obstacle])
evaluator = CostEvaluator(engine, danger_map, bend_penalty=50.0, sbend_penalty=150.0)
finder = PathFinder(
AStarContext(
evaluator,
RoutingProblem(
bounds=BOUNDS,
nets=(
NetSpec("pair_a", Port(0, 0, 0), Port(60, 0, 0), width=2.0),
NetSpec("pair_b", Port(0, 10, 0), Port(60, 10, 0), width=2.0),
),
),
RoutingOptions(),
)
)
state = _RoutingState(
net_specs={
"pair_a": NetSpec("pair_a", Port(0, 0, 0), Port(60, 0, 0), width=2.0),
"pair_b": NetSpec("pair_b", Port(0, 10, 0), Port(60, 10, 0), width=2.0),
},
ordered_net_ids=["pair_a", "pair_b"],
results={},
needs_self_collision_check=set(),
start_time=0.0,
timeout_s=1.0,
initial_paths=None,
accumulated_expanded_nodes=[],
best_results={},
best_completed_nets=-1,
best_conflict_edges=10**9,
best_dynamic_collisions=10**9,
last_conflict_signature=(),
last_conflict_edge_count=0,
repeated_conflict_count=0,
pair_local_plateau_count=0,
recent_attempt_work={},
pre_pair_candidate=None,
)
local_context = finder._build_pair_local_context(state, {}, ("pair_a", "pair_b"))
assert finder.context.problem.static_obstacles == ()
assert len(local_context.problem.static_obstacles) == 1
assert len(local_context.cost_evaluator.collision_engine._static_obstacles.geometries) == 1
assert next(iter(local_context.problem.static_obstacles)).equals(obstacle)
def test_route_astar_bend_collision_override_does_not_persist(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS, bend_radii=(10.0,), bend_collision_type="arc")
route_astar(
Port(0, 0, 0),
Port(30, 10, 0),
net_width=2.0,
context=context,
config=SearchRunConfig.from_options(
context.options,
bend_collision_type="clipped_bbox",
return_partial=True,
),
)
assert context.options.search.bend_collision_type == "arc"
def test_route_astar_returns_partial_path_when_node_limited(basic_evaluator: CostEvaluator) -> None:
obstacle = Polygon([(20, -20), (40, -20), (40, 20), (20, 20)])
basic_evaluator.collision_engine.add_static_obstacle(obstacle)
basic_evaluator.danger_map.precompute([obstacle])
context = _build_context(basic_evaluator, bounds=BOUNDS, bend_radii=(10.0,), node_limit=2)
start = Port(0, 0, 0)
target = Port(60, 0, 0)
partial_path = _route(context, start, target, return_partial=True)
no_partial_path = _route(context, start, target, return_partial=False)
assert partial_path is not None
assert partial_path
assert partial_path[-1].end_port != target
assert no_partial_path is None
def test_route_astar_uses_single_sbend_for_same_orientation_offset(basic_evaluator: CostEvaluator) -> None:
context = _build_context(
basic_evaluator,
bounds=BOUNDS,
bend_radii=(10.0,),
sbend_radii=(10.0,),
sbend_offsets=(10.0,),
max_straight_length=150.0,
)
start = Port(0, 0, 0)
target = Port(100, 10, 0)
path = _route(context, start, target)
assert path is not None
assert path[-1].end_port == target
assert sum(1 for component in path if component.move_type == "sbend") == 1
assert not any(
first.move_type == second.move_type == "sbend"
for first, second in zip(path, path[1:], strict=False)
)
@pytest.mark.parametrize("visibility_guidance", ["off", "exact_corner", "tangent_corner"])
def test_route_astar_supports_all_visibility_guidance_modes(
basic_evaluator: CostEvaluator,
visibility_guidance: str,
) -> None:
obstacle = Polygon([(30, 10), (50, 10), (50, 40), (30, 40)])
basic_evaluator.collision_engine.add_static_obstacle(obstacle)
basic_evaluator.danger_map.precompute([obstacle])
context = _build_context(
basic_evaluator,
bounds=BOUNDS,
bend_radii=(10.0,),
sbend_radii=(),
max_straight_length=150.0,
visibility_guidance=visibility_guidance,
)
start = Port(0, 0, 0)
target = Port(80, 50, 0)
path = _route(context, start, target)
assert path is not None
result = RoutingResult(net_id="test", path=path, reached_target=True)
validation = _validate_routing_result(result, [obstacle], clearance=2.0, expected_start=start, expected_end=target)
assert validation["is_valid"], f"Validation failed: {validation.get('reason')}"
assert validation["connectivity_ok"]
def test_tangent_corner_mode_avoids_exact_visibility_graph_builds(basic_evaluator: CostEvaluator) -> None:
obstacle = Polygon([(30, 10), (50, 10), (50, 40), (30, 40)])
basic_evaluator.collision_engine.add_static_obstacle(obstacle)
basic_evaluator.danger_map.precompute([obstacle])
context = _build_context(
basic_evaluator,
bounds=BOUNDS,
bend_radii=(10.0,),
sbend_radii=(),
max_straight_length=150.0,
visibility_guidance="tangent_corner",
)
path = _route(context, Port(0, 0, 0), Port(80, 50, 0))
assert path is not None
assert context.metrics.total_visibility_builds == 0
assert context.metrics.total_visibility_corner_pairs_checked == 0
assert context.metrics.total_ray_cast_calls_visibility_build == 0
def test_route_astar_repeated_searches_succeed_with_small_cache_limit(basic_evaluator: CostEvaluator) -> None:
context = AStarContext(
basic_evaluator,
RoutingProblem(bounds=BOUNDS),
_build_options(
min_straight_length=1.0,
max_straight_length=100.0,
),
max_cache_size=2,
)
start = Port(0, 0, 0)
targets = [Port(length, 0, 0) for length in range(10, 70, 10)]
for target in targets:
path = _route(context, start, target)
assert path is not None
assert path[-1].end_port == target
def test_self_collision_prunes_before_congestion_check(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
root = AStarNode(Port(0, 0, 0), 0.0, 0.0)
parent_result = Straight.generate(Port(0, 0, 0), 10.0, width=2.0, dilation=1.0)
parent = AStarNode(parent_result.end_port, g_cost=10.0, h_cost=0.0, parent=root, component_result=parent_result)
open_set: list[AStarNode] = []
closed_set: dict[tuple[int, int, int], float] = {}
add_node(
parent,
parent_result,
target=Port(20, 0, 0),
net_width=2.0,
net_id="netA",
open_set=open_set,
closed_set=closed_set,
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(context.options, self_collision_check=True),
move_type="straight",
cache_key=("overlap",),
)
assert not open_set
assert context.metrics.total_congestion_check_calls == 0
assert context.metrics.total_congestion_cache_misses == 0
def test_closed_set_dominance_prunes_before_congestion_check(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
root = AStarNode(Port(0, 0, 0), 0.0, 0.0)
result = Straight.generate(Port(0, 0, 0), 10.0, width=2.0, dilation=1.0)
open_set: list[AStarNode] = []
closed_set = {result.end_port.as_tuple(): context.cost_evaluator.score_component(result, start_port=root.port)}
add_node(
root,
result,
target=Port(20, 0, 0),
net_width=2.0,
net_id="netA",
open_set=open_set,
closed_set=closed_set,
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(context.options),
move_type="straight",
cache_key=("dominated",),
)
assert not open_set
assert context.metrics.total_congestion_check_calls == 0
assert context.metrics.total_congestion_cache_misses == 0
def test_no_dynamic_paths_skips_congestion_check(basic_evaluator: CostEvaluator) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
root = AStarNode(Port(0, 0, 0), 0.0, 0.0)
result = Straight.generate(Port(0, 0, 0), 10.0, width=2.0, dilation=1.0)
open_set: list[AStarNode] = []
closed_set: dict[tuple[int, int, int], float] = {}
add_node(
root,
result,
target=Port(20, 0, 0),
net_width=2.0,
net_id="netA",
open_set=open_set,
closed_set=closed_set,
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(context.options),
move_type="straight",
cache_key=("no-dynamic",),
)
assert open_set
assert context.metrics.total_congestion_check_calls == 0
assert context.metrics.total_congestion_cache_misses == 0
def test_guidance_seed_matching_move_reduces_cost_and_advances_seed_index(
basic_evaluator: CostEvaluator,
) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
root = AStarNode(Port(0, 0, 0), 0.0, 0.0, seed_index=0)
result = Straight.generate(Port(0, 0, 0), 10.0, width=2.0, dilation=1.0)
open_set: list[AStarNode] = []
unguided_open_set: list[AStarNode] = []
closed_set: dict[tuple[int, int, int], float] = {}
add_node(
root,
result,
target=Port(20, 0, 0),
net_width=2.0,
net_id="netA",
open_set=open_set,
closed_set=closed_set,
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(
context.options,
guidance_seed=(StraightSeed(length=10.0),),
guidance_bonus=5.0,
),
move_type="straight",
cache_key=("guided",),
)
add_node(
AStarNode(Port(0, 0, 0), 0.0, 0.0),
result,
target=Port(20, 0, 0),
net_width=2.0,
net_id="netA",
open_set=unguided_open_set,
closed_set={},
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(context.options),
move_type="straight",
cache_key=("unguided",),
)
assert open_set
assert unguided_open_set
guided_node = open_set[0]
unguided_node = unguided_open_set[0]
assert guided_node.seed_index == 1
assert guided_node.g_cost < unguided_node.g_cost
assert context.metrics.total_guidance_match_moves == 1
assert context.metrics.total_guidance_match_moves_straight == 1
assert context.metrics.total_guidance_match_moves_bend90 == 0
assert context.metrics.total_guidance_match_moves_sbend == 0
assert context.metrics.total_guidance_bonus_applied == pytest.approx(5.0)
assert context.metrics.total_guidance_bonus_applied_straight == pytest.approx(5.0)
assert context.metrics.total_guidance_bonus_applied_bend90 == pytest.approx(0.0)
assert context.metrics.total_guidance_bonus_applied_sbend == pytest.approx(0.0)
def test_guidance_seed_bend90_keeps_full_bonus(
basic_evaluator: CostEvaluator,
) -> None:
context = _build_context(basic_evaluator, bounds=BOUNDS)
root = AStarNode(Port(0, 0, 0), 0.0, 0.0, seed_index=0)
result = Bend90.generate(Port(0, 0, 0), 10.0, width=2.0, direction="CCW", dilation=1.0)
open_set: list[AStarNode] = []
unguided_open_set: list[AStarNode] = []
add_node(
root,
result,
target=Port(10, 10, 90),
net_width=2.0,
net_id="netA",
open_set=open_set,
closed_set={},
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(
context.options,
guidance_seed=(result.move_spec,),
guidance_bonus=5.0,
),
move_type="bend90",
cache_key=("guided-bend90",),
)
add_node(
AStarNode(Port(0, 0, 0), 0.0, 0.0),
result,
target=Port(10, 10, 90),
net_width=2.0,
net_id="netA",
open_set=unguided_open_set,
closed_set={},
context=context,
metrics=context.metrics,
congestion_cache={},
congestion_presence_cache={},
congestion_candidate_precheck_cache={},
congestion_net_envelope_cache={},
congestion_grid_net_cache={},
congestion_grid_span_cache={},
config=SearchRunConfig.from_options(context.options),
move_type="bend90",
cache_key=("unguided-bend90",),
)
assert open_set
assert unguided_open_set
guided_node = open_set[0]
unguided_node = unguided_open_set[0]
assert guided_node.seed_index == 1
assert unguided_node.g_cost - guided_node.g_cost == pytest.approx(5.0)
assert context.metrics.total_guidance_match_moves == 1
assert context.metrics.total_guidance_match_moves_straight == 0
assert context.metrics.total_guidance_match_moves_bend90 == 1
assert context.metrics.total_guidance_match_moves_sbend == 0
assert context.metrics.total_guidance_bonus_applied == pytest.approx(5.0)
assert context.metrics.total_guidance_bonus_applied_straight == pytest.approx(0.0)
assert context.metrics.total_guidance_bonus_applied_bend90 == pytest.approx(5.0)
assert context.metrics.total_guidance_bonus_applied_sbend == pytest.approx(0.0)