bridge_mcp_ghidra.py•19.4 kB
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "requests>=2,<3",
# "mcp>=1.2.0,<2",
# ]
# ///
import sys
import requests
import argparse
import logging
from urllib.parse import urljoin
from mcp.server.fastmcp import FastMCP
DEFAULT_GHIDRA_SERVER = "http://127.0.0.1:8080/"
DEFAULT_REQUEST_TIMEOUT = 5
logger = logging.getLogger(__name__)
mcp = FastMCP("ghidra-mcp")
# Initialize ghidra_server_url with default value
ghidra_server_url = DEFAULT_GHIDRA_SERVER
# Initialize ghidra_request_timeout with default value
ghidra_request_timeout = DEFAULT_REQUEST_TIMEOUT
def safe_get(endpoint: str, params: dict = None) -> list:
"""
Perform a GET request with optional query parameters.
"""
if params is None:
params = {}
url = urljoin(ghidra_server_url, endpoint)
try:
response = requests.get(url, params=params, timeout=ghidra_request_timeout)
response.encoding = 'utf-8'
if response.ok:
return response.text.splitlines()
else:
return [f"Error {response.status_code}: {response.text.strip()}"]
except Exception as e:
return [f"Request failed: {str(e)}"]
def safe_post(endpoint: str, data: dict | str) -> str:
try:
url = urljoin(ghidra_server_url, endpoint)
if isinstance(data, dict):
# BSim queries might be a bit slower, using configurable timeout
response = requests.post(url, data=data, timeout=ghidra_request_timeout)
else:
response = requests.post(url, data=data.encode("utf-8"), timeout=ghidra_request_timeout)
response.encoding = 'utf-8'
if response.ok:
return response.text.strip()
else:
return f"Error {response.status_code}: {response.text.strip()}"
except Exception as e:
return f"Request failed: {str(e)}"
@mcp.tool()
def list_methods(offset: int = 0, limit: int = 100) -> list:
"""
List all function names in the program with pagination.
"""
return safe_get("methods", {"offset": offset, "limit": limit})
@mcp.tool()
def list_classes(offset: int = 0, limit: int = 100) -> list:
"""
List all namespace/class names in the program with pagination.
"""
return safe_get("classes", {"offset": offset, "limit": limit})
@mcp.tool()
def decompile_function(name: str) -> str:
"""
Decompile a specific function by name and return the decompiled C code.
"""
return safe_post("decompile", name)
@mcp.tool()
def rename_function(old_name: str, new_name: str) -> str:
"""
Rename a function by its current name to a new user-defined name.
"""
return safe_post("renameFunction", {"oldName": old_name, "newName": new_name})
@mcp.tool()
def rename_data(address: str, new_name: str) -> str:
"""
Rename a data label at the specified address.
"""
return safe_post("renameData", {"address": address, "newName": new_name})
@mcp.tool()
def list_segments(offset: int = 0, limit: int = 100) -> list:
"""
List all memory segments in the program with pagination.
"""
return safe_get("segments", {"offset": offset, "limit": limit})
@mcp.tool()
def list_imports(offset: int = 0, limit: int = 100) -> list:
"""
List imported symbols in the program with pagination.
"""
return safe_get("imports", {"offset": offset, "limit": limit})
@mcp.tool()
def list_exports(offset: int = 0, limit: int = 100) -> list:
"""
List exported functions/symbols with pagination.
"""
return safe_get("exports", {"offset": offset, "limit": limit})
@mcp.tool()
def list_namespaces(offset: int = 0, limit: int = 100) -> list:
"""
List all non-global namespaces in the program with pagination.
"""
return safe_get("namespaces", {"offset": offset, "limit": limit})
@mcp.tool()
def list_data_items(offset: int = 0, limit: int = 100) -> list:
"""
List defined data labels and their values with pagination.
"""
return safe_get("data", {"offset": offset, "limit": limit})
@mcp.tool()
def search_functions_by_name(query: str, offset: int = 0, limit: int = 100) -> list:
"""
Search for functions whose name contains the given substring.
"""
if not query:
return ["Error: query string is required"]
return safe_get("searchFunctions", {"query": query, "offset": offset, "limit": limit})
@mcp.tool()
def rename_variable(function_name: str, old_name: str, new_name: str) -> str:
"""
Rename a local variable within a function.
"""
return safe_post("renameVariable", {
"functionName": function_name,
"oldName": old_name,
"newName": new_name
})
@mcp.tool()
def get_function_by_address(address: str) -> str:
"""
Get a function by its address.
"""
return "\n".join(safe_get("get_function_by_address", {"address": address}))
@mcp.tool()
def get_current_address() -> str:
"""
Get the address currently selected by the user.
"""
return "\n".join(safe_get("get_current_address"))
@mcp.tool()
def get_current_function() -> str:
"""
Get the function currently selected by the user.
"""
return "\n".join(safe_get("get_current_function"))
@mcp.tool()
def list_functions() -> list:
"""
List all functions in the database.
"""
return safe_get("list_functions")
@mcp.tool()
def decompile_function_by_address(address: str) -> str:
"""
Decompile a function at the given address.
"""
return "\n".join(safe_get("decompile_function", {"address": address}))
@mcp.tool()
def disassemble_function(address: str) -> list:
"""
Get assembly code (address: instruction; comment) for a function.
"""
return safe_get("disassemble_function", {"address": address})
@mcp.tool()
def set_decompiler_comment(address: str, comment: str) -> str:
"""
Set a comment for a given address in the function pseudocode.
"""
return safe_post("set_decompiler_comment", {"address": address, "comment": comment})
@mcp.tool()
def set_disassembly_comment(address: str, comment: str) -> str:
"""
Set a comment for a given address in the function disassembly.
"""
return safe_post("set_disassembly_comment", {"address": address, "comment": comment})
@mcp.tool()
def set_plate_comment(address: str, comment: str) -> str:
"""
Set a plate comment for a given address. Plate comments are multi-line bordered
comments typically displayed above functions or code sections in Ghidra's listing view.
"""
return safe_post("set_plate_comment", {"address": address, "comment": comment})
@mcp.tool()
def rename_function_by_address(function_address: str, new_name: str) -> str:
"""
Rename a function by its address.
"""
return safe_post("rename_function_by_address", {"function_address": function_address, "new_name": new_name})
@mcp.tool()
def set_function_prototype(function_address: str, prototype: str) -> str:
"""
Set a function's prototype.
"""
return safe_post("set_function_prototype", {"function_address": function_address, "prototype": prototype})
@mcp.tool()
def set_local_variable_type(function_address: str, variable_name: str, new_type: str) -> str:
"""
Set a local variable's type.
"""
return safe_post("set_local_variable_type", {"function_address": function_address, "variable_name": variable_name, "new_type": new_type})
@mcp.tool()
def set_data_type(address: str, type_name: str) -> str:
"""
Set the data type at a specific address in the Ghidra program.
Args:
address: Memory address in hex format (e.g. "0x1400010a0")
type_name: Name of the data type to set (e.g. "int", "dword", "byte[20]", "PCHAR")
Returns:
Success or error message
"""
return safe_post("set_data_type", {"address": address, "type_name": type_name})
@mcp.tool()
def get_xrefs_to(address: str, offset: int = 0, limit: int = 100) -> list:
"""
Get all references to the specified address (xref to).
Args:
address: Target address in hex format (e.g. "0x1400010a0")
offset: Pagination offset (default: 0)
limit: Maximum number of references to return (default: 100)
Returns:
List of references to the specified address
"""
return safe_get("xrefs_to", {"address": address, "offset": offset, "limit": limit})
@mcp.tool()
def get_xrefs_from(address: str, offset: int = 0, limit: int = 100) -> list:
"""
Get all references from the specified address (xref from).
Args:
address: Source address in hex format (e.g. "0x1400010a0")
offset: Pagination offset (default: 0)
limit: Maximum number of references to return (default: 100)
Returns:
List of references from the specified address
"""
return safe_get("xrefs_from", {"address": address, "offset": offset, "limit": limit})
@mcp.tool()
def get_function_xrefs(name: str, offset: int = 0, limit: int = 100) -> list:
"""
Get all references to the specified function by name.
Args:
name: Function name to search for
offset: Pagination offset (default: 0)
limit: Maximum number of references to return (default: 100)
Returns:
List of references to the specified function
"""
return safe_get("function_xrefs", {"name": name, "offset": offset, "limit": limit})
@mcp.tool()
def list_strings(offset: int = 0, limit: int = 2000, filter: str = None) -> list:
"""
List all defined strings in the program with their addresses.
Args:
offset: Pagination offset (default: 0)
limit: Maximum number of strings to return (default: 2000)
filter: Optional filter to match within string content
Returns:
List of strings with their addresses
"""
params = {"offset": offset, "limit": limit}
if filter:
params["filter"] = filter
return safe_get("strings", params)
@mcp.tool()
def bsim_select_database(database_path: str) -> str:
"""
Select and connect to a BSim database for function similarity matching.
Args:
database_path: Path to BSim database file (e.g., "/path/to/database.bsim")
or URL (e.g., "postgresql://host:port/dbname")
Returns:
Connection status and database information
"""
return safe_post("bsim/select_database", {"database_path": database_path})
@mcp.tool()
def bsim_query_function(
function_address: str,
max_matches: int = 10,
similarity_threshold: float = 0.7,
confidence_threshold: float = 0.0,
max_similarity: float | None = None,
max_confidence: float | None = None,
offset: int = 0,
limit: int = 100,
) -> str:
"""
Query a single function against the BSim database to find similar functions.
Args:
function_address: Address of the function to query (e.g., "0x401000")
max_matches: Maximum number of matches to return (default: 10)
similarity_threshold: Minimum similarity score (inclusive, 0.0-1.0, default: 0.7)
confidence_threshold: Minimum confidence score (inclusive, 0.0-1.0, default: 0.0)
max_similarity: Maximum similarity score (exclusive, 0.0-1.0, default: unbounded)
max_confidence: Maximum confidence score (exclusive, 0.0-1.0, default: unbounded)
offset: Pagination offset (default: 0)
limit: Maximum number of results to return (default: 100)
Returns:
List of matching functions with similarity scores and metadata
"""
data = {
"function_address": function_address,
"max_matches": str(max_matches),
"similarity_threshold": str(similarity_threshold),
"confidence_threshold": str(confidence_threshold),
"offset": str(offset),
"limit": str(limit),
}
if max_similarity is not None:
data["max_similarity"] = str(max_similarity)
if max_confidence is not None:
data["max_confidence"] = str(max_confidence)
return safe_post("bsim/query_function", data)
@mcp.tool()
def bsim_query_all_functions(
max_matches_per_function: int = 5,
similarity_threshold: float = 0.7,
confidence_threshold: float = 0.0,
max_similarity: float | None = None,
max_confidence: float | None = None,
offset: int = 0,
limit: int = 100,
) -> str:
"""
Query all functions in the current program against the BSim database.
Returns an overview of matches for all functions.
Args:
max_matches_per_function: Max matches per function (default: 5)
similarity_threshold: Minimum similarity score (inclusive, 0.0-1.0, default: 0.7)
confidence_threshold: Minimum confidence score (inclusive, 0.0-1.0, default: 0.0)
max_similarity: Maximum similarity score (exclusive, 0.0-1.0, default: unbounded)
max_confidence: Maximum confidence score (exclusive, 0.0-1.0, default: unbounded)
offset: Pagination offset (default: 0)
limit: Maximum number of results to return (default: 100)
Returns:
Summary and detailed results for all matching functions
"""
data = {
"max_matches_per_function": str(max_matches_per_function),
"similarity_threshold": str(similarity_threshold),
"confidence_threshold": str(confidence_threshold),
"offset": str(offset),
"limit": str(limit),
}
if max_similarity is not None:
data["max_similarity"] = str(max_similarity)
if max_confidence is not None:
data["max_confidence"] = str(max_confidence)
return safe_post("bsim/query_all_functions", data)
@mcp.tool()
def bsim_disconnect() -> str:
"""
Disconnect from the current BSim database.
Returns:
Disconnection status message
"""
return safe_post("bsim/disconnect", {})
@mcp.tool()
def bsim_status() -> str:
"""
Get the current BSim database connection status.
Returns:
Current connection status and database path if connected
"""
return "\n".join(safe_get("bsim/status"))
@mcp.tool()
def bsim_get_match_disassembly(
executable_path: str,
function_name: str,
function_address: str,
) -> str:
"""
Get the disassembly of a specific BSim match. This requires the matched
executable to be available in the Ghidra project.
Args:
executable_path: Path to the matched executable (from BSim match result)
function_name: Name of the matched function
function_address: Address of the matched function (e.g., "0x401000")
Returns:
Function prototype and assembly code for the matched function.
Returns an error message if the program is not found in the project.
"""
return safe_post("bsim/get_match_disassembly", {
"executable_path": executable_path,
"function_name": function_name,
"function_address": function_address,
})
@mcp.tool()
def bsim_get_match_decompile(
executable_path: str,
function_name: str,
function_address: str,
) -> str:
"""
Get the decompilation of a specific BSim match. This requires the matched
executable to be available in the Ghidra project.
Args:
executable_path: Path to the matched executable (from BSim match result)
function_name: Name of the matched function
function_address: Address of the matched function (e.g., "0x401000")
Returns:
Function prototype and decompiled C code for the matched function.
Returns an error message if the program is not found in the project.
"""
return safe_post("bsim/get_match_decompile", {
"executable_path": executable_path,
"function_name": function_name,
"function_address": function_address,
})
@mcp.tool()
def bulk_operations(operations: list[dict]) -> str:
"""
Execute multiple operations in a single request. This is more efficient than
making multiple individual requests.
Args:
operations: List of operations to execute. Each operation is a dict with:
- endpoint: The API endpoint path (e.g., "/methods", "/decompile")
- params: Dict of parameters for that endpoint (e.g., {"name": "main"})
Example:
operations = [
{"endpoint": "/methods", "params": {"offset": 0, "limit": 10}},
{"endpoint": "/decompile", "params": {"name": "main"}},
{"endpoint": "/rename_function_by_address", "params": {"function_address": "0x401000", "new_name": "initialize"}}
]
Returns:
JSON string containing results array with the response for each operation.
"""
import json
try:
# Build JSON payload
payload = {
"operations": operations
}
url = urljoin(ghidra_server_url, "bulk")
response = requests.post(url, json=payload, timeout=ghidra_request_timeout)
response.encoding = 'utf-8'
if response.ok:
return response.text
else:
return f"Error {response.status_code}: {response.text}"
except Exception as e:
return f"Request failed: {str(e)}"
def main():
parser = argparse.ArgumentParser(description="MCP server for Ghidra")
parser.add_argument("--ghidra-server", type=str, default=DEFAULT_GHIDRA_SERVER,
help=f"Ghidra server URL, default: {DEFAULT_GHIDRA_SERVER}")
parser.add_argument("--mcp-host", type=str, default="127.0.0.1",
help="Host to run MCP server on (only used for sse), default: 127.0.0.1")
parser.add_argument("--mcp-port", type=int,
help="Port to run MCP server on (only used for sse), default: 8081")
parser.add_argument("--transport", type=str, default="stdio", choices=["stdio", "sse"],
help="Transport protocol for MCP, default: stdio")
parser.add_argument("--ghidra-timeout", type=int, default=DEFAULT_REQUEST_TIMEOUT,
help=f"MCP requests timeout, default: {DEFAULT_REQUEST_TIMEOUT}")
args = parser.parse_args()
# Use the global variable to ensure it's properly updated
global ghidra_server_url
if args.ghidra_server:
ghidra_server_url = args.ghidra_server
global ghidra_request_timeout
if args.ghidra_timeout:
ghidra_request_timeout = args.ghidra_timeout
if args.transport == "sse":
try:
# Set up logging
log_level = logging.INFO
logging.basicConfig(level=log_level)
logging.getLogger().setLevel(log_level)
# Configure MCP settings
mcp.settings.log_level = "INFO"
if args.mcp_host:
mcp.settings.host = args.mcp_host
else:
mcp.settings.host = "127.0.0.1"
if args.mcp_port:
mcp.settings.port = args.mcp_port
else:
mcp.settings.port = 8081
logger.info(f"Connecting to Ghidra server at {ghidra_server_url}")
logger.info(f"Starting MCP server on http://{mcp.settings.host}:{mcp.settings.port}/sse")
logger.info(f"Using transport: {args.transport}")
mcp.run(transport="sse")
except KeyboardInterrupt:
logger.info("Server stopped by user")
else:
mcp.run()
if __name__ == "__main__":
main()