#!/usr/bin/env python3 from __future__ import annotations import argparse import csv import re from collections import deque from dataclasses import dataclass from pathlib import Path ADDRESS_RE = re.compile(r"0x[0-9a-fA-F]{8}") @dataclass(frozen=True) class Row: address: int size: str name: str subsystem: str calling_convention: str prototype_status: str source_tool: str confidence: str notes: str verified_against: str def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Export a bounded function subgraph from function-map.csv notes." ) parser.add_argument("function_map", type=Path) parser.add_argument("output_prefix", type=Path) parser.add_argument( "--seed", action="append", default=[], help="Seed function address in hex. May be repeated.", ) parser.add_argument( "--depth", type=int, default=2, help="Traversal depth over note-address references.", ) parser.add_argument( "--title", default="Function Subgraph", help="Title used in the markdown summary.", ) parser.add_argument( "--include-backrefs", action="store_true", help="Also traverse rows that reference currently included nodes.", ) args = parser.parse_args() if not args.seed: parser.error("at least one --seed is required") return args def parse_hex(text: str) -> int: value = text.strip().lower() if value.startswith("0x"): value = value[2:] return int(value, 16) def fmt_addr(value: int) -> str: return f"0x{value:08x}" def load_rows(path: Path) -> dict[int, Row]: with path.open(newline="", encoding="utf-8") as handle: reader = csv.DictReader(handle) rows: dict[int, Row] = {} for raw in reader: address = parse_hex(raw["address"]) rows[address] = Row( address=address, size=raw["size"], name=raw["name"], subsystem=raw["subsystem"], calling_convention=raw["calling_convention"], prototype_status=raw["prototype_status"], source_tool=raw["source_tool"], confidence=raw["confidence"], notes=raw["notes"], verified_against=raw["verified_against"], ) return rows def extract_note_refs(rows: dict[int, Row]) -> dict[int, set[int]]: refs: dict[int, set[int]] = {} known = set(rows) for address, row in rows.items(): hits = {parse_hex(match.group(0)) for match in ADDRESS_RE.finditer(row.notes)} refs[address] = {hit for hit in hits if hit in known and hit != address} return refs def build_backrefs(refs: dict[int, set[int]]) -> dict[int, set[int]]: backrefs: dict[int, set[int]] = {address: set() for address in refs} for src, dsts in refs.items(): for dst in dsts: backrefs.setdefault(dst, set()).add(src) return backrefs def walk_subgraph( rows: dict[int, Row], refs: dict[int, set[int]], seeds: list[int], depth: int, include_backrefs: bool, ) -> set[int]: backrefs = build_backrefs(refs) seen: set[int] = set() queue: deque[tuple[int, int]] = deque((seed, 0) for seed in seeds if seed in rows) while queue: address, level = queue.popleft() if address in seen: continue seen.add(address) if level >= depth: continue for dst in sorted(refs.get(address, ())): if dst not in seen: queue.append((dst, level + 1)) if include_backrefs: for src in sorted(backrefs.get(address, ())): if src not in seen: queue.append((src, level + 1)) return seen def quote_dot(text: str) -> str: return text.replace("\\", "\\\\").replace('"', '\\"') def emit_dot( rows: dict[int, Row], refs: dict[int, set[int]], included: set[int], seeds: set[int], output_path: Path, title: str, ) -> None: subsystems: dict[str, list[Row]] = {} for address in sorted(included): row = rows[address] subsystems.setdefault(row.subsystem, []).append(row) lines: list[str] = [ "digraph shell_load {", ' graph [rankdir=LR, labelloc="t", labeljust="l"];', f' label="{quote_dot(title)}";', ' node [shape=box, style="rounded,filled", fillcolor="#f8f8f8", color="#555555", fontname="Helvetica"];', ' edge [color="#666666", fontname="Helvetica"];', ] for subsystem in sorted(subsystems): cluster_id = subsystem.replace("-", "_") lines.append(f' subgraph cluster_{cluster_id} {{') lines.append(f' label="{quote_dot(subsystem)}";') lines.append(' color="#cccccc";') for row in subsystems[subsystem]: seed_mark = " [seed]" if row.address in seeds else "" label = f"{fmt_addr(row.address)}\\n{row.name}{seed_mark}" fill = "#ffe9a8" if row.address in seeds else "#f8f8f8" lines.append( f' "{fmt_addr(row.address)}" [label="{quote_dot(label)}", fillcolor="{fill}"];' ) lines.append(" }") for src in sorted(included): for dst in sorted(refs.get(src, ())): if dst not in included: continue lines.append(f' "{fmt_addr(src)}" -> "{fmt_addr(dst)}";') lines.append("}") output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") def emit_markdown( rows: dict[int, Row], refs: dict[int, set[int]], included: set[int], seeds: set[int], output_path: Path, title: str, dot_path: Path, ) -> None: included_rows = [rows[address] for address in sorted(included)] edge_count = sum( 1 for src in included for dst in refs.get(src, ()) if dst in included ) lines = [ f"# {title}", "", f"- Nodes: `{len(included_rows)}`", f"- Edges: `{edge_count}`", f"- Seeds: {', '.join(f'`{fmt_addr(seed)}`' for seed in sorted(seeds))}", f"- Graphviz: `{dot_path.name}`", "", "## Nodes", "", "| Address | Name | Subsystem | Confidence |", "| --- | --- | --- | --- |", ] for row in included_rows: lines.append( f"| `{fmt_addr(row.address)}` | `{row.name}` | `{row.subsystem}` | `{row.confidence}` |" ) lines.extend(["", "## Edges", ""]) for src in sorted(included): dsts = [dst for dst in sorted(refs.get(src, ())) if dst in included] if not dsts: continue src_row = rows[src] lines.append(f"- `{fmt_addr(src)}` `{src_row.name}`") for dst in dsts: dst_row = rows[dst] lines.append(f" -> `{fmt_addr(dst)}` `{dst_row.name}`") output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") def main() -> int: args = parse_args() rows = load_rows(args.function_map) refs = extract_note_refs(rows) seeds = [parse_hex(seed) for seed in args.seed] included = walk_subgraph(rows, refs, seeds, args.depth, args.include_backrefs) output_prefix = args.output_prefix.resolve() output_prefix.parent.mkdir(parents=True, exist_ok=True) dot_path = output_prefix.with_suffix(".dot") md_path = output_prefix.with_suffix(".md") emit_dot(rows, refs, included, set(seeds), dot_path, args.title) emit_markdown(rows, refs, included, set(seeds), md_path, args.title, dot_path) return 0 if __name__ == "__main__": raise SystemExit(main())