rrt/tools/ghidra/scripts/ExportStartupFunctions.java

386 lines
11 KiB
Java
Raw Normal View History

// Export startup-oriented function metadata from the current Ghidra program.
//@category RT3
import ghidra.app.script.GhidraScript;
import ghidra.program.model.address.Address;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Instruction;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class ExportStartupFunctions extends GhidraScript {
private static class QueueEntry {
final Function function;
final int depth;
final String parentAddress;
final String callSite;
QueueEntry(Function function, int depth, String parentAddress, String callSite) {
this.function = function;
this.depth = depth;
this.parentAddress = parentAddress;
this.callSite = callSite;
}
}
private static class RootSpec {
final String name;
final String address;
RootSpec(String name, String address) {
this.name = name;
this.address = address;
}
}
private static class FunctionRow {
String rootName;
String rootAddress;
String address;
int depth;
String name;
long sizeBytes;
String callingConvention;
String signatureSource;
String signature;
String parentAddress;
String callSite;
}
private static class EdgeRow {
String rootName;
String rootAddress;
String parentAddress;
String childAddress;
String callSite;
}
private static class CallTarget {
final String callSite;
final Function callee;
CallTarget(String callSite, Function callee) {
this.callSite = callSite;
this.callee = callee;
}
}
@Override
protected void run() throws Exception {
String[] args = getScriptArgs();
if (args.length < 1) {
throw new RuntimeException(
"usage: ExportStartupFunctions.java <output-dir> [max-depth] [root-name:address ...]");
}
String outputDir = args[0];
int maxDepth = args.length > 1 ? Integer.parseInt(args[1]) : 2;
List<RootSpec> roots = parseRoots(args);
File output = new File(outputDir);
output.mkdirs();
List<FunctionRow> rows = new ArrayList<FunctionRow>();
List<EdgeRow> edges = new ArrayList<EdgeRow>();
for (RootSpec root : roots) {
collectRoot(rows, edges, root, maxDepth);
}
Collections.sort(rows, Comparator
.comparing((FunctionRow row) -> row.rootName)
.thenComparing(row -> row.rootAddress)
.thenComparingInt(row -> row.depth)
.thenComparing(row -> row.address));
Collections.sort(edges, Comparator
.comparing((EdgeRow row) -> row.rootName)
.thenComparing(row -> row.rootAddress)
.thenComparing(row -> row.parentAddress)
.thenComparing(row -> row.callSite)
.thenComparing(row -> row.childAddress));
writeCsv(new File(output, "ghidra-startup-functions.csv"), rows);
writeMarkdown(new File(output, "startup-call-chain.md"), rows, edges, roots, maxDepth);
}
private void collectRoot(
List<FunctionRow> rows,
List<EdgeRow> edges,
RootSpec root,
int maxDepth) {
Function rootFunction = getInternalFunction(toAddr(root.address));
if (rootFunction == null) {
throw new IllegalArgumentException("no function found at " + root.address);
}
ArrayDeque<QueueEntry> queue = new ArrayDeque<QueueEntry>();
Set<String> seen = new HashSet<String>();
queue.add(new QueueEntry(rootFunction, 0, "", ""));
while (!queue.isEmpty()) {
QueueEntry item = queue.removeFirst();
Function function = item.function;
String address = asAddress(function.getEntryPoint());
if (seen.contains(address)) {
continue;
}
seen.add(address);
FunctionRow row = new FunctionRow();
row.rootName = root.name;
row.rootAddress = root.address;
row.address = address;
row.depth = item.depth;
row.name = function.getName();
row.sizeBytes = function.getBody().getNumAddresses();
row.callingConvention = safeString(function.getCallingConventionName(), "unknown");
row.signatureSource = safeSignatureSource(function);
row.signature = safeSignature(function);
row.parentAddress = item.parentAddress;
row.callSite = item.callSite;
rows.add(row);
if (!item.parentAddress.isEmpty()) {
EdgeRow edge = new EdgeRow();
edge.rootName = root.name;
edge.rootAddress = root.address;
edge.parentAddress = item.parentAddress;
edge.childAddress = address;
edge.callSite = item.callSite;
edges.add(edge);
}
if (item.depth >= maxDepth) {
continue;
}
for (CallTarget callTarget : iterInternalCalls(function)) {
queue.add(new QueueEntry(
callTarget.callee,
item.depth + 1,
address,
callTarget.callSite));
}
}
}
private List<RootSpec> parseRoots(String[] args) {
List<RootSpec> roots = new ArrayList<RootSpec>();
if (args.length <= 2) {
roots.add(new RootSpec("entry", "0x005a313b"));
return roots;
}
for (int index = 2; index < args.length; index++) {
String token = args[index];
int separator = token.indexOf(':');
if (separator <= 0 || separator == token.length() - 1) {
throw new IllegalArgumentException("root spec must be name:address: " + token);
}
roots.add(new RootSpec(token.substring(0, separator), token.substring(separator + 1)));
}
return roots;
}
private Function getInternalFunction(Address address) {
Function function = getFunctionAt(address);
if (function == null) {
function = getFunctionContaining(address);
}
if (function != null && function.isExternal()) {
return null;
}
return function;
}
private List<CallTarget> iterInternalCalls(Function function) {
List<CallTarget> result = new ArrayList<CallTarget>();
Set<String> seenCalls = new HashSet<String>();
var instructions = currentProgram.getListing().getInstructions(function.getBody(), true);
while (instructions.hasNext() && !monitor.isCancelled()) {
Instruction instruction = instructions.next();
if (!instruction.getFlowType().isCall()) {
continue;
}
Address[] flows = instruction.getFlows();
for (Address flow : flows) {
Function callee = getInternalFunction(flow);
if (callee == null) {
continue;
}
String key = asAddress(instruction.getAddress()) + "->" + asAddress(callee.getEntryPoint());
if (seenCalls.contains(key)) {
continue;
}
seenCalls.add(key);
result.add(new CallTarget(asAddress(instruction.getAddress()), callee));
}
}
return result;
}
private void writeCsv(File path, List<FunctionRow> rows) throws IOException {
BufferedWriter writer = new BufferedWriter(new FileWriter(path));
try {
writer.write(
"root_name,root_address,address,depth,name,size_bytes,calling_convention,signature_source,signature,parent_address,call_site\n");
for (FunctionRow row : rows) {
writer.write(csv(row.rootName));
writer.write(",");
writer.write(csv(row.rootAddress));
writer.write(",");
writer.write(csv(row.address));
writer.write(",");
writer.write(csv(Integer.toString(row.depth)));
writer.write(",");
writer.write(csv(row.name));
writer.write(",");
writer.write(csv(Long.toString(row.sizeBytes)));
writer.write(",");
writer.write(csv(row.callingConvention));
writer.write(",");
writer.write(csv(row.signatureSource));
writer.write(",");
writer.write(csv(row.signature));
writer.write(",");
writer.write(csv(row.parentAddress));
writer.write(",");
writer.write(csv(row.callSite));
writer.write("\n");
}
}
finally {
writer.close();
}
}
private void writeMarkdown(
File path,
List<FunctionRow> rows,
List<EdgeRow> edges,
List<RootSpec> roots,
int maxDepth) throws IOException {
Map<String, FunctionRow> byAddress = new HashMap<String, FunctionRow>();
Map<String, List<EdgeRow>> children = new HashMap<String, List<EdgeRow>>();
for (FunctionRow row : rows) {
byAddress.put(keyedAddress(row.rootName, row.rootAddress, row.address), row);
}
for (EdgeRow edge : edges) {
children.computeIfAbsent(
keyedAddress(edge.rootName, edge.rootAddress, edge.parentAddress),
ignored -> new ArrayList<EdgeRow>()).add(edge);
}
for (List<EdgeRow> childList : children.values()) {
Collections.sort(childList, Comparator.comparing(edge -> edge.childAddress));
}
BufferedWriter writer = new BufferedWriter(new FileWriter(path));
try {
writer.write("# Startup Call Chain\n\n");
writer.write("- Depth limit: `" + maxDepth + "`\n");
writer.write("- Internal call targets only; imported APIs are intentionally excluded.\n\n");
for (RootSpec root : roots) {
FunctionRow row = byAddress.get(keyedAddress(root.name, root.address, root.address));
if (row == null) {
continue;
}
writer.write("## `" + root.name + "` root `" + row.address + "` `" + row.name + "`\n\n");
writer.write("- `" + row.address + "` `" + row.name + "`\n");
HashSet<String> visited = new HashSet<String>();
visited.add(keyedAddress(root.name, root.address, row.address));
emitChildren(writer, root.name, root.address, row.address, 1, children, byAddress, visited);
writer.write("\n");
}
}
finally {
writer.close();
}
}
private void emitChildren(
BufferedWriter writer,
String rootName,
String rootAddress,
String address,
int indent,
Map<String, List<EdgeRow>> children,
Map<String, FunctionRow> byAddress,
Set<String> visited) throws IOException {
List<EdgeRow> childRows = children.get(keyedAddress(rootName, rootAddress, address));
if (childRows == null) {
return;
}
for (EdgeRow edge : childRows) {
String childKey = keyedAddress(edge.rootName, edge.rootAddress, edge.childAddress);
FunctionRow child = byAddress.get(childKey);
if (child == null || visited.contains(childKey)) {
continue;
}
visited.add(childKey);
writer.write(" ".repeat(indent));
writer.write("- `" + child.address + "` `" + child.name + "` via `" + edge.callSite + "`\n");
emitChildren(
writer,
edge.rootName,
edge.rootAddress,
child.address,
indent + 1,
children,
byAddress,
visited);
}
}
private String keyedAddress(String rootName, String rootAddress, String address) {
return rootName + "|" + rootAddress + "|" + address;
}
private String safeSignature(Function function) {
try {
return function.getPrototypeString(true, false);
}
catch (Exception ignored) {
return "";
}
}
private String safeSignatureSource(Function function) {
try {
return function.getSignatureSource().toString();
}
catch (Exception ignored) {
return "";
}
}
private String asAddress(Address address) {
return "0x" + address.toString();
}
private String safeString(String value, String fallback) {
if (value == null || value.isEmpty()) {
return fallback;
}
return value;
}
private String csv(String value) {
if (value == null) {
value = "";
}
return "\"" + value.replace("\"", "\"\"") + "\"";
}
}