inire/scripts/characterize_pair_local_search.py

177 lines
6.1 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from time import perf_counter
from inire.tests.example_scenarios import _run_example_07_variant
def _parse_csv_ints(raw: str) -> tuple[int, ...]:
return tuple(int(part) for part in raw.split(",") if part.strip())
def _run_case(num_nets: int, seed: int) -> dict[str, object]:
t0 = perf_counter()
run = _run_example_07_variant(
num_nets=num_nets,
seed=seed,
warm_start_enabled=False,
)
duration_s = perf_counter() - t0
return {
"duration_s": duration_s,
"summary": {
"total_results": len(run.results_by_net),
"valid_results": sum(1 for result in run.results_by_net.values() if result.is_valid),
"reached_targets": sum(1 for result in run.results_by_net.values() if result.reached_target),
},
"metrics": asdict(run.metrics),
}
def _is_smoke_candidate(entry: dict[str, object]) -> bool:
summary = entry["summary"]
metrics = entry["metrics"]
return (
summary["valid_results"] == summary["total_results"]
and metrics["pair_local_search_accepts"] >= 1
and entry["duration_s"] <= 1.0
)
def _select_smoke_case(cases: list[dict[str, object]]) -> dict[str, object] | None:
grouped: dict[tuple[int, int], list[dict[str, object]]] = {}
for case in cases:
key = (case["num_nets"], case["seed"])
grouped.setdefault(key, []).append(case)
candidates = []
for (num_nets, seed), repeats in grouped.items():
if repeats and all(_is_smoke_candidate(repeat) for repeat in repeats):
candidates.append({"num_nets": num_nets, "seed": seed})
if not candidates:
return None
candidates.sort(key=lambda item: (item["num_nets"], item["seed"]))
return candidates[0]
def _render_markdown(payload: dict[str, object]) -> str:
lines = [
"# Pair-Local Search Characterization",
"",
f"Generated at {payload['generated_at']} by `{payload['generator']}`.",
"",
f"Grid: `num_nets={payload['grid']['num_nets']}`, `seed={payload['grid']['seeds']}`, repeats={payload['grid']['repeats']}.",
"",
"| Nets | Seed | Repeat | Duration (s) | Valid | Reached | Pair Pairs | Pair Accepts | Pair Nodes | Nodes | Checks |",
"| :-- | :-- | :-- | --: | --: | --: | --: | --: | --: | --: | --: |",
]
for case in payload["cases"]:
summary = case["summary"]
metrics = case["metrics"]
lines.append(
"| "
f"{case['num_nets']} | "
f"{case['seed']} | "
f"{case['repeat']} | "
f"{case['duration_s']:.4f} | "
f"{summary['valid_results']} | "
f"{summary['reached_targets']} | "
f"{metrics['pair_local_search_pairs_considered']} | "
f"{metrics['pair_local_search_accepts']} | "
f"{metrics['pair_local_search_nodes_expanded']} | "
f"{metrics['nodes_expanded']} | "
f"{metrics['congestion_check_calls']} |"
)
lines.extend(["", "## Recommendation", ""])
recommended = payload["recommended_smoke_scenario"]
if recommended is None:
lines.append(
"No smaller stable pair-local smoke scenario satisfied the rule "
"`valid_results == total_results`, `pair_local_search_accepts >= 1`, and `duration_s <= 1.0` across all repeats."
)
else:
lines.append(
f"Recommended smoke scenario: `num_nets={recommended['num_nets']}`, `seed={recommended['seed']}`."
)
return "\n".join(lines)
def main() -> None:
parser = argparse.ArgumentParser(description="Characterize pair-local search across example_07-style no-warm runs.")
parser.add_argument(
"--num-nets",
default="6,8,10",
help="Comma-separated num_nets values to sweep. Default: 6,8,10.",
)
parser.add_argument(
"--seeds",
default="41,42,43",
help="Comma-separated seed values to sweep. Default: 41,42,43.",
)
parser.add_argument(
"--repeats",
type=int,
default=2,
help="Number of repeated runs per (num_nets, seed). Default: 2.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory to write pair_local_characterization.json and .md into. Defaults to <repo>/docs.",
)
args = parser.parse_args()
repo_root = Path(__file__).resolve().parents[1]
output_dir = repo_root / "docs" if args.output_dir is None else args.output_dir.resolve()
output_dir.mkdir(exist_ok=True)
num_nets_values = _parse_csv_ints(args.num_nets)
seed_values = _parse_csv_ints(args.seeds)
cases: list[dict[str, object]] = []
for num_nets in num_nets_values:
for seed in seed_values:
for repeat in range(args.repeats):
case = _run_case(num_nets, seed)
case["num_nets"] = num_nets
case["seed"] = seed
case["repeat"] = repeat
cases.append(case)
payload = {
"generated_at": datetime.now().astimezone().isoformat(timespec="seconds"),
"generator": "scripts/characterize_pair_local_search.py",
"grid": {
"num_nets": list(num_nets_values),
"seeds": list(seed_values),
"repeats": args.repeats,
},
"cases": cases,
"recommended_smoke_scenario": _select_smoke_case(cases),
}
json_path = output_dir / "pair_local_characterization.json"
markdown_path = output_dir / "pair_local_characterization.md"
json_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
markdown_path.write_text(_render_markdown(payload) + "\n")
if json_path.is_relative_to(repo_root):
print(f"Wrote {json_path.relative_to(repo_root)}")
else:
print(f"Wrote {json_path}")
if markdown_path.is_relative_to(repo_root):
print(f"Wrote {markdown_path.relative_to(repo_root)}")
else:
print(f"Wrote {markdown_path}")
if __name__ == "__main__":
main()