ida-mcp-server
- ida-mcp-server
- src
- mcp_server_ida
import logging
import socket
import json
import time
import struct
import uuid
from typing import Dict, Any, List, Union, Optional, Tuple, Callable, TypeVar, Set, Awaitable, Type, cast
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import (
TextContent,
Tool,
)
from enum import Enum
from pydantic import BaseModel
# Modify request models
class GetFunctionAssemblyByName(BaseModel):
function_name: str
class GetFunctionAssemblyByAddress(BaseModel):
address: str # Hexadecimal address as string
class GetFunctionDecompiledByName(BaseModel):
function_name: str
class GetFunctionDecompiledByAddress(BaseModel):
address: str # Hexadecimal address as string
class GetGlobalVariableByName(BaseModel):
variable_name: str
class GetGlobalVariableByAddress(BaseModel):
address: str # Hexadecimal address as string
class GetCurrentFunctionAssembly(BaseModel):
pass
class GetCurrentFunctionDecompiled(BaseModel):
pass
class RenameLocalVariable(BaseModel):
function_name: str
old_name: str
new_name: str
class RenameGlobalVariable(BaseModel):
old_name: str
new_name: str
class RenameFunction(BaseModel):
old_name: str
new_name: str
class RenameMultiLocalVariables(BaseModel):
function_name: str
rename_pairs_old2new: List[Dict[str, str]] # List of dictionaries with "old_name" and "new_name" keys
class RenameMultiGlobalVariables(BaseModel):
rename_pairs_old2new: List[Dict[str, str]]
class RenameMultiFunctions(BaseModel):
rename_pairs_old2new: List[Dict[str, str]]
class AddAssemblyComment(BaseModel):
address: str # Can be a hexadecimal address string
comment: str
is_repeatable: bool = False # Whether the comment should be repeatable
class AddFunctionComment(BaseModel):
function_name: str
comment: str
is_repeatable: bool = False # Whether the comment should be repeatable
class AddPseudocodeComment(BaseModel):
function_name: str
address: str # Address in the pseudocode
comment: str
is_repeatable: bool = False # Whether comment should be repeated at all occurrences
class ExecuteScript(BaseModel):
script: str
class ExecuteScriptFromFile(BaseModel):
file_path: str
class IDATools(str, Enum):
GET_FUNCTION_ASSEMBLY_BY_NAME = "ida_get_function_assembly_by_name"
GET_FUNCTION_ASSEMBLY_BY_ADDRESS = "ida_get_function_assembly_by_address"
GET_FUNCTION_DECOMPILED_BY_NAME = "ida_get_function_decompiled_by_name"
GET_FUNCTION_DECOMPILED_BY_ADDRESS = "ida_get_function_decompiled_by_address"
GET_GLOBAL_VARIABLE_BY_NAME = "ida_get_global_variable_by_name"
GET_GLOBAL_VARIABLE_BY_ADDRESS = "ida_get_global_variable_by_address"
GET_CURRENT_FUNCTION_ASSEMBLY = "ida_get_current_function_assembly"
GET_CURRENT_FUNCTION_DECOMPILED = "ida_get_current_function_decompiled"
RENAME_LOCAL_VARIABLE = "ida_rename_local_variable"
RENAME_GLOBAL_VARIABLE = "ida_rename_global_variable"
RENAME_FUNCTION = "ida_rename_function"
RENAME_MULTI_LOCAL_VARIABLES = "ida_rename_multi_local_variables"
RENAME_MULTI_GLOBAL_VARIABLES = "ida_rename_multi_global_variables"
RENAME_MULTI_FUNCTIONS = "ida_rename_multi_functions"
ADD_ASSEMBLY_COMMENT = "ida_add_assembly_comment"
ADD_FUNCTION_COMMENT = "ida_add_function_comment"
ADD_PSEUDOCODE_COMMENT = "ida_add_pseudocode_comment"
EXECUTE_SCRIPT = "ida_execute_script"
EXECUTE_SCRIPT_FROM_FILE = "ida_execute_script_from_file"
# IDA Pro通信处理器
class IDAProCommunicator:
def __init__(self, host: str = 'localhost', port: int = 5000):
self.host: str = host
self.port: int = port
self.sock: Optional[socket.socket] = None
self.logger: logging.Logger = logging.getLogger(__name__)
self.connected: bool = False
self.reconnect_attempts: int = 0
self.max_reconnect_attempts: int = 5
self.last_reconnect_time: float = 0
self.reconnect_cooldown: int = 5 # seconds
self.request_count: int = 0
self.default_timeout: int = 10
self.batch_timeout: int = 60 # it may take more time for batch operations
def connect(self) -> bool:
"""Connect to IDA plugin"""
# Check if cooldown is needed
current_time: float = time.time()
if current_time - self.last_reconnect_time < self.reconnect_cooldown and self.reconnect_attempts > 0:
self.logger.debug("In reconnection cooldown, skipping")
return False
# If already connected, disconnect first
if self.connected:
self.disconnect()
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(self.default_timeout)
self.sock.connect((self.host, self.port))
self.connected = True
self.reconnect_attempts = 0
self.logger.info(f"Connected to IDA Pro ({self.host}:{self.port})")
return True
except Exception as e:
self.last_reconnect_time = current_time
self.reconnect_attempts += 1
if self.reconnect_attempts <= self.max_reconnect_attempts:
self.logger.warning(f"Failed to connect to IDA Pro: {str(e)}. Attempt {self.reconnect_attempts}/{self.max_reconnect_attempts}")
else:
self.logger.error(f"Failed to connect to IDA Pro after {self.max_reconnect_attempts} attempts: {str(e)}")
return False
def disconnect(self) -> None:
"""Disconnect from IDA Pro"""
if self.sock:
try:
self.sock.close()
except:
pass
self.sock = None
self.connected = False
def ensure_connection(self) -> bool:
"""Ensure connection is established"""
if not self.connected:
return self.connect()
return True
def send_message(self, data: bytes) -> None:
"""Send message with length prefix"""
if self.sock is None:
raise ConnectionError("Socket is not connected")
length: int = len(data)
length_bytes: bytes = struct.pack('!I', length) # 4-byte length prefix
self.sock.sendall(length_bytes + data)
def receive_message(self) -> Optional[bytes]:
"""Receive message with length prefix"""
try:
# Receive 4-byte length prefix
length_bytes: Optional[bytes] = self.receive_exactly(4)
if not length_bytes:
return None
length: int = struct.unpack('!I', length_bytes)[0]
# Receive message body
data: Optional[bytes] = self.receive_exactly(length)
return data
except Exception as e:
self.logger.error(f"Error receiving message: {str(e)}")
return None
def receive_exactly(self, n: int) -> Optional[bytes]:
"""Receive exactly n bytes of data"""
if self.sock is None:
raise ConnectionError("Socket is not connected")
data: bytes = b''
while len(data) < n:
chunk: bytes = self.sock.recv(min(n - len(data), 4096))
if not chunk: # Connection closed
return None
data += chunk
return data
def send_request(self, request_type: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Send request to IDA plugin"""
# Ensure connection is established
if not self.ensure_connection():
return {"error": "Cannot connect to IDA Pro"}
try:
if request_type in ["rename_multi_local_variables",
"rename_multi_global_variables",
"rename_multi_functions"]:
if self.sock:
self.sock.settimeout(self.batch_timeout)
self.logger.debug(f"Set timeout to {self.batch_timeout}s for batch operation")
else:
if self.sock:
self.sock.settimeout(self.default_timeout)
self.logger.debug(f"Set timeout to {self.default_timeout}s for normal operation")
# Add request ID
request_id: str = str(uuid.uuid4())
self.request_count += 1
request_count: int = self.request_count
request: Dict[str, Any] = {
"id": request_id,
"count": request_count,
"type": request_type,
"data": data
}
self.logger.debug(f"Sending request: {request_id}, type: {request_type}, count: {request_count}")
try:
# Send request
request_json: bytes = json.dumps(request).encode('utf-8')
self.send_message(request_json)
# Receive response
response_data: Optional[bytes] = self.receive_message()
# If no data received, assume connection is closed
if not response_data:
self.logger.warning("No data received, connection may be closed")
self.disconnect()
return {"error": "No response received from IDA Pro"}
# Parse response
try:
self.logger.debug(f"Received raw data length: {len(response_data)}")
response: Dict[str, Any] = json.loads(response_data.decode('utf-8'))
# Verify response ID matches
response_id: str = response.get("id")
if response_id != request_id:
self.logger.warning(f"Response ID mismatch! Request ID: {request_id}, Response ID: {response_id}")
self.logger.debug(f"Received response: ID={response.get('id')}, count={response.get('count')}")
# Additional type verification
if not isinstance(response, dict):
self.logger.error(f"Received response is not a dictionary: {type(response)}")
return {"error": f"Response format error: expected dictionary, got {type(response).__name__}"}
return response
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse JSON response: {str(e)}")
return {"error": f"Invalid JSON response: {str(e)}"}
except Exception as e:
self.logger.error(f"Error communicating with IDA Pro: {str(e)}")
self.disconnect() # Disconnect after error
return {"error": str(e)}
finally:
# restore timeout
if self.sock:
self.sock.settimeout(self.default_timeout)
def ping(self) -> bool:
"""Check if connection is valid"""
response: Dict[str, Any] = self.send_request("ping", {})
return response.get("status") == "pong"
# Actual IDA Pro functionality implementation
class IDAProFunctions:
def __init__(self, communicator: IDAProCommunicator):
self.communicator: IDAProCommunicator = communicator
self.logger: logging.Logger = logging.getLogger(__name__)
def get_function_assembly(self, function_name: str) -> str:
"""Get assembly code for a function by name (legacy method)"""
return self.get_function_assembly_by_name(function_name)
def get_function_assembly_by_name(self, function_name: str) -> str:
"""Get assembly code for a function by its name"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"get_function_assembly_by_name",
{"function_name": function_name}
)
if "error" in response:
return f"Error retrieving assembly for function '{function_name}': {response['error']}"
assembly: Any = response.get("assembly")
# Verify assembly is string type
if assembly is None:
return f"Error: No assembly data returned for function '{function_name}'"
if not isinstance(assembly, str):
self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion")
assembly = str(assembly)
return f"Assembly code for function '{function_name}':\n{assembly}"
except Exception as e:
self.logger.error(f"Error getting function assembly: {str(e)}", exc_info=True)
return f"Error retrieving assembly for function '{function_name}': {str(e)}"
def get_function_decompiled(self, function_name: str) -> str:
"""Get decompiled code for a function by name (legacy method)"""
return self.get_function_decompiled_by_name(function_name)
def get_function_decompiled_by_name(self, function_name: str) -> str:
"""Get decompiled pseudocode for a function by its name"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"get_function_decompiled_by_name",
{"function_name": function_name}
)
# Log complete response for debugging
self.logger.debug(f"Decompilation response: {response}")
if "error" in response:
return f"Error retrieving decompiled code for function '{function_name}': {response['error']}"
decompiled_code: Any = response.get("decompiled_code")
# Detailed type checking and conversion
if decompiled_code is None:
return f"Error: No decompiled code returned for function '{function_name}'"
# Log actual type
actual_type: str = type(decompiled_code).__name__
self.logger.debug(f"Decompiled code type is: {actual_type}")
# Ensure result is string
if not isinstance(decompiled_code, str):
self.logger.warning(f"Decompiled code type is not string but {actual_type}, attempting conversion")
try:
decompiled_code = str(decompiled_code)
except Exception as e:
return f"Error: Failed to convert decompiled code from {actual_type} to string: {str(e)}"
return f"Decompiled code for function '{function_name}':\n{decompiled_code}"
except Exception as e:
self.logger.error(f"Error getting function decompiled code: {str(e)}", exc_info=True)
return f"Error retrieving decompiled code for function '{function_name}': {str(e)}"
def get_global_variable(self, variable_name: str) -> str:
"""Get global variable information by name (legacy method)"""
return self.get_global_variable_by_name(variable_name)
def get_global_variable_by_name(self, variable_name: str) -> str:
"""Get global variable information by its name"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"get_global_variable_by_name",
{"variable_name": variable_name}
)
if "error" in response:
return f"Error retrieving global variable '{variable_name}': {response['error']}"
variable_info: Any = response.get("variable_info")
# Verify variable_info is string type
if variable_info is None:
return f"Error: No variable info returned for '{variable_name}'"
if not isinstance(variable_info, str):
self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion")
try:
# If it's a dictionary, convert to JSON string first
if isinstance(variable_info, dict):
variable_info = json.dumps(variable_info, indent=2)
else:
variable_info = str(variable_info)
except Exception as e:
return f"Error: Failed to convert variable info to string: {str(e)}"
return f"Global variable '{variable_name}':\n{variable_info}"
except Exception as e:
self.logger.error(f"Error getting global variable: {str(e)}", exc_info=True)
return f"Error retrieving global variable '{variable_name}': {str(e)}"
def get_global_variable_by_address(self, address: str) -> str:
"""Get global variable information by its address"""
try:
# Convert string address to int
try:
addr_int = int(address, 16) if address.startswith("0x") else int(address)
except ValueError:
return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal"
response: Dict[str, Any] = self.communicator.send_request(
"get_global_variable_by_address",
{"address": addr_int}
)
if "error" in response:
return f"Error retrieving global variable at address '{address}': {response['error']}"
variable_info: Any = response.get("variable_info")
# Verify variable_info is string type
if variable_info is None:
return f"Error: No variable info returned for address '{address}'"
if not isinstance(variable_info, str):
self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion")
try:
# If it's a dictionary, convert to JSON string first
if isinstance(variable_info, dict):
variable_info = json.dumps(variable_info, indent=2)
else:
variable_info = str(variable_info)
except Exception as e:
return f"Error: Failed to convert variable info to string: {str(e)}"
# Try to extract the variable name from the JSON for a better message
var_name = "Unknown"
try:
var_info_dict = json.loads(variable_info)
if isinstance(var_info_dict, dict) and "name" in var_info_dict:
var_name = var_info_dict["name"]
except:
pass
return f"Global variable '{var_name}' at address {address}:\n{variable_info}"
except Exception as e:
self.logger.error(f"Error getting global variable by address: {str(e)}", exc_info=True)
return f"Error retrieving global variable at address '{address}': {str(e)}"
def get_current_function_assembly(self) -> str:
"""Get assembly code for the function at current cursor position"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"get_current_function_assembly",
{}
)
if "error" in response:
return f"Error retrieving assembly for current function: {response['error']}"
assembly: Any = response.get("assembly")
function_name: str = response.get("function_name", "Current function")
# Verify assembly is string type
if assembly is None:
return f"Error: No assembly data returned for current function"
if not isinstance(assembly, str):
self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion")
assembly = str(assembly)
return f"Assembly code for function '{function_name}':\n{assembly}"
except Exception as e:
self.logger.error(f"Error getting current function assembly: {str(e)}", exc_info=True)
return f"Error retrieving assembly for current function: {str(e)}"
def get_current_function_decompiled(self) -> str:
"""Get decompiled code for the function at current cursor position"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"get_current_function_decompiled",
{}
)
if "error" in response:
return f"Error retrieving decompiled code for current function: {response['error']}"
decompiled_code: Any = response.get("decompiled_code")
function_name: str = response.get("function_name", "Current function")
# Detailed type checking and conversion
if decompiled_code is None:
return f"Error: No decompiled code returned for current function"
# Ensure result is string
if not isinstance(decompiled_code, str):
self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion")
try:
decompiled_code = str(decompiled_code)
except Exception as e:
return f"Error: Failed to convert decompiled code: {str(e)}"
return f"Decompiled code for function '{function_name}':\n{decompiled_code}"
except Exception as e:
self.logger.error(f"Error getting current function decompiled code: {str(e)}", exc_info=True)
return f"Error retrieving decompiled code for current function: {str(e)}"
def rename_local_variable(self, function_name: str, old_name: str, new_name: str) -> str:
"""Rename a local variable within a function"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_local_variable",
{"function_name": function_name, "old_name": old_name, "new_name": new_name}
)
if "error" in response:
return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
return f"Successfully renamed local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}"
else:
return f"Failed to rename local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}"
except Exception as e:
self.logger.error(f"Error renaming local variable: {str(e)}", exc_info=True)
return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {str(e)}"
def rename_global_variable(self, old_name: str, new_name: str) -> str:
"""Rename a global variable"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_global_variable",
{"old_name": old_name, "new_name": new_name}
)
if "error" in response:
return f"Error renaming global variable from '{old_name}' to '{new_name}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
return f"Successfully renamed global variable from '{old_name}' to '{new_name}': {message}"
else:
return f"Failed to rename global variable from '{old_name}' to '{new_name}': {message}"
except Exception as e:
self.logger.error(f"Error renaming global variable: {str(e)}", exc_info=True)
return f"Error renaming global variable from '{old_name}' to '{new_name}': {str(e)}"
def rename_function(self, old_name: str, new_name: str) -> str:
"""Rename a function"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_function",
{"old_name": old_name, "new_name": new_name}
)
if "error" in response:
return f"Error renaming function from '{old_name}' to '{new_name}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
return f"Successfully renamed function from '{old_name}' to '{new_name}': {message}"
else:
return f"Failed to rename function from '{old_name}' to '{new_name}': {message}"
except Exception as e:
self.logger.error(f"Error renaming function: {str(e)}", exc_info=True)
return f"Error renaming function from '{old_name}' to '{new_name}': {str(e)}"
def rename_multi_local_variables(self, function_name: str, rename_pairs_old2new: List[Dict[str, str]]) -> str:
"""Rename multiple local variables within a function at once"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_multi_local_variables",
{
"function_name": function_name,
"rename_pairs_old2new": rename_pairs_old2new
}
)
if "error" in response:
return f"Error renaming multiple local variables in function '{function_name}': {response['error']}"
success_count: int = response.get("success_count", 0)
failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", [])
result_parts: List[str] = [
f"Successfully renamed {success_count} local variables in function '{function_name}'"
]
if failed_pairs:
result_parts.append("\nFailed renamings:")
for pair in failed_pairs:
result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}")
return "\n".join(result_parts)
except Exception as e:
self.logger.error(f"Error renaming multiple local variables: {str(e)}", exc_info=True)
return f"Error renaming multiple local variables in function '{function_name}': {str(e)}"
def rename_multi_global_variables(self, rename_pairs_old2new: List[Dict[str, str]]) -> str:
"""Rename multiple global variables at once"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_multi_global_variables",
{"rename_pairs_old2new": rename_pairs_old2new}
)
if "error" in response:
return f"Error renaming multiple global variables: {response['error']}"
success_count: int = response.get("success_count", 0)
failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", [])
result_parts: List[str] = [
f"Successfully renamed {success_count} global variables"
]
if failed_pairs:
result_parts.append("\nFailed renamings:")
for pair in failed_pairs:
result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}")
return "\n".join(result_parts)
except Exception as e:
self.logger.error(f"Error renaming multiple global variables: {str(e)}", exc_info=True)
return f"Error renaming multiple global variables: {str(e)}"
def rename_multi_functions(self, rename_pairs_old2new: List[Dict[str, str]]) -> str:
"""Rename multiple functions at once"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"rename_multi_functions",
{"rename_pairs_old2new": rename_pairs_old2new}
)
if "error" in response:
return f"Error renaming multiple functions: {response['error']}"
success_count: int = response.get("success_count", 0)
failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", [])
result_parts: List[str] = [
f"Successfully renamed {success_count} functions"
]
if failed_pairs:
result_parts.append("\nFailed renamings:")
for pair in failed_pairs:
result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}")
return "\n".join(result_parts)
except Exception as e:
self.logger.error(f"Error renaming multiple functions: {str(e)}", exc_info=True)
return f"Error renaming multiple functions: {str(e)}"
def add_assembly_comment(self, address: str, comment: str, is_repeatable: bool = False) -> str:
"""Add an assembly comment"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"add_assembly_comment",
{"address": address, "comment": comment, "is_repeatable": is_repeatable}
)
if "error" in response:
return f"Error adding assembly comment at address '{address}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
comment_type: str = "repeatable" if is_repeatable else "regular"
return f"Successfully added {comment_type} assembly comment at address '{address}': {message}"
else:
return f"Failed to add assembly comment at address '{address}': {message}"
except Exception as e:
self.logger.error(f"Error adding assembly comment: {str(e)}", exc_info=True)
return f"Error adding assembly comment at address '{address}': {str(e)}"
def add_function_comment(self, function_name: str, comment: str, is_repeatable: bool = False) -> str:
"""Add a comment to a function"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"add_function_comment",
{"function_name": function_name, "comment": comment, "is_repeatable": is_repeatable}
)
if "error" in response:
return f"Error adding comment to function '{function_name}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
comment_type: str = "repeatable" if is_repeatable else "regular"
return f"Successfully added {comment_type} comment to function '{function_name}': {message}"
else:
return f"Failed to add comment to function '{function_name}': {message}"
except Exception as e:
self.logger.error(f"Error adding function comment: {str(e)}", exc_info=True)
return f"Error adding comment to function '{function_name}': {str(e)}"
def add_pseudocode_comment(self, function_name: str, address: str, comment: str, is_repeatable: bool = False) -> str:
"""Add a comment to a specific address in the function's decompiled pseudocode"""
try:
response: Dict[str, Any] = self.communicator.send_request(
"add_pseudocode_comment",
{
"function_name": function_name,
"address": address,
"comment": comment,
"is_repeatable": is_repeatable
}
)
if "error" in response:
return f"Error adding comment at address {address} in function '{function_name}': {response['error']}"
success: bool = response.get("success", False)
message: str = response.get("message", "")
if success:
comment_type: str = "repeatable" if is_repeatable else "regular"
return f"Successfully added {comment_type} comment at address {address} in function '{function_name}': {message}"
else:
return f"Failed to add comment at address {address} in function '{function_name}': {message}"
except Exception as e:
self.logger.error(f"Error adding pseudocode comment: {str(e)}", exc_info=True)
return f"Error adding comment at address {address} in function '{function_name}': {str(e)}"
def execute_script(self, script: str) -> str:
"""Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules."""
try:
response: Dict[str, Any] = self.communicator.send_request(
"execute_script",
{"script": script}
)
# Handle case where response is None
if response is None:
self.logger.error("Received None response from IDA when executing script")
return "Error executing script: Received empty response from IDA"
# Handle case where response contains error
if "error" in response:
return f"Error executing script: {response['error']}"
# Handle successful execution
success: bool = response.get("success", False)
if not success:
error_msg: str = response.get("error", "Unknown error")
traceback: str = response.get("traceback", "")
return f"Script execution failed: {error_msg}\n\nTraceback:\n{traceback}"
# Get output - ensure all values are strings to avoid None errors
stdout: str = str(response.get("stdout", ""))
stderr: str = str(response.get("stderr", ""))
return_value: str = str(response.get("return_value", ""))
result_text: List[str] = []
result_text.append("Script executed successfully")
if return_value and return_value != "None":
result_text.append(f"\nReturn value:\n{return_value}")
if stdout:
result_text.append(f"\nStandard output:\n{stdout}")
if stderr:
result_text.append(f"\nStandard error:\n{stderr}")
return "\n".join(result_text)
except Exception as e:
self.logger.error(f"Error executing script: {str(e)}", exc_info=True)
return f"Error executing script: {str(e)}"
def execute_script_from_file(self, file_path: str) -> str:
"""Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process."""
try:
response: Dict[str, Any] = self.communicator.send_request(
"execute_script_from_file",
{"file_path": file_path}
)
# Handle case where response is None
if response is None:
self.logger.error("Received None response from IDA when executing script from file")
return f"Error executing script from file '{file_path}': Received empty response from IDA"
# Handle case where response contains error
if "error" in response:
return f"Error executing script from file '{file_path}': {response['error']}"
# Handle successful execution
success: bool = response.get("success", False)
if not success:
error_msg: str = response.get("error", "Unknown error")
traceback: str = response.get("traceback", "")
return f"Script execution from file '{file_path}' failed: {error_msg}\n\nTraceback:\n{traceback}"
# Get output - ensure all values are strings to avoid None errors
stdout: str = str(response.get("stdout", ""))
stderr: str = str(response.get("stderr", ""))
return_value: str = str(response.get("return_value", ""))
result_text: List[str] = []
result_text.append(f"Script from file '{file_path}' executed successfully")
if return_value and return_value != "None":
result_text.append(f"\nReturn value:\n{return_value}")
if stdout:
result_text.append(f"\nStandard output:\n{stdout}")
if stderr:
result_text.append(f"\nStandard error:\n{stderr}")
return "\n".join(result_text)
except Exception as e:
self.logger.error(f"Error executing script from file: {str(e)}", exc_info=True)
return f"Error executing script from file '{file_path}': {str(e)}"
def get_function_assembly_by_address(self, address: str) -> str:
"""Get assembly code for a function by its address"""
try:
# Convert string address to int
try:
addr_int = int(address, 16) if address.startswith("0x") else int(address)
except ValueError:
return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal"
response: Dict[str, Any] = self.communicator.send_request(
"get_function_assembly_by_address",
{"address": addr_int}
)
if "error" in response:
return f"Error retrieving assembly for address '{address}': {response['error']}"
assembly: Any = response.get("assembly")
function_name: str = response.get("function_name", "Unknown function")
# Verify assembly is string type
if assembly is None:
return f"Error: No assembly data returned for address '{address}'"
if not isinstance(assembly, str):
self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion")
assembly = str(assembly)
return f"Assembly code for function '{function_name}' at address {address}:\n{assembly}"
except Exception as e:
self.logger.error(f"Error getting function assembly by address: {str(e)}", exc_info=True)
return f"Error retrieving assembly for address '{address}': {str(e)}"
def get_function_decompiled_by_address(self, address: str) -> str:
"""Get decompiled pseudocode for a function by its address"""
try:
# Convert string address to int
try:
addr_int = int(address, 16) if address.startswith("0x") else int(address)
except ValueError:
return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal"
response: Dict[str, Any] = self.communicator.send_request(
"get_function_decompiled_by_address",
{"address": addr_int}
)
if "error" in response:
return f"Error retrieving decompiled code for address '{address}': {response['error']}"
decompiled_code: Any = response.get("decompiled_code")
function_name: str = response.get("function_name", "Unknown function")
# Detailed type checking and conversion
if decompiled_code is None:
return f"Error: No decompiled code returned for address '{address}'"
# Ensure result is string
if not isinstance(decompiled_code, str):
self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion")
try:
decompiled_code = str(decompiled_code)
except Exception as e:
return f"Error: Failed to convert decompiled code: {str(e)}"
return f"Decompiled code for function '{function_name}' at address {address}:\n{decompiled_code}"
except Exception as e:
self.logger.error(f"Error getting function decompiled code by address: {str(e)}", exc_info=True)
return f"Error retrieving decompiled code for address '{address}': {str(e)}"
async def serve() -> None:
"""MCP server main entry point"""
logger: logging.Logger = logging.getLogger(__name__)
# Set log level to DEBUG for detailed information
logger.setLevel(logging.DEBUG)
server: Server = Server("mcp-ida")
# Create communicator and attempt connection
ida_communicator: IDAProCommunicator = IDAProCommunicator()
logger.info("Attempting to connect to IDA Pro plugin...")
if ida_communicator.connect():
logger.info("Successfully connected to IDA Pro plugin")
else:
logger.warning("Initial connection to IDA Pro plugin failed, will retry on request")
# Create IDA functions class with persistent connection
ida_functions: IDAProFunctions = IDAProFunctions(ida_communicator)
@server.list_tools()
async def list_tools() -> List[Tool]:
"""List supported tools"""
return [
Tool(
name=IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME,
description="Get assembly code for a function by name",
inputSchema=GetFunctionAssemblyByName.schema(),
),
Tool(
name=IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS,
description="Get assembly code for a function by address",
inputSchema=GetFunctionAssemblyByAddress.schema(),
),
Tool(
name=IDATools.GET_FUNCTION_DECOMPILED_BY_NAME,
description="Get decompiled pseudocode for a function by name",
inputSchema=GetFunctionDecompiledByName.schema(),
),
Tool(
name=IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS,
description="Get decompiled pseudocode for a function by address",
inputSchema=GetFunctionDecompiledByAddress.schema(),
),
Tool(
name=IDATools.GET_GLOBAL_VARIABLE_BY_NAME,
description="Get information about a global variable by name",
inputSchema=GetGlobalVariableByName.schema(),
),
Tool(
name=IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS,
description="Get information about a global variable by address",
inputSchema=GetGlobalVariableByAddress.schema(),
),
Tool(
name=IDATools.GET_CURRENT_FUNCTION_ASSEMBLY,
description="Get assembly code for the function at the current cursor position",
inputSchema=GetCurrentFunctionAssembly.schema(),
),
Tool(
name=IDATools.GET_CURRENT_FUNCTION_DECOMPILED,
description="Get decompiled pseudocode for the function at the current cursor position",
inputSchema=GetCurrentFunctionDecompiled.schema(),
),
Tool(
name=IDATools.RENAME_LOCAL_VARIABLE,
description="Rename a local variable within a function in the IDA database",
inputSchema=RenameLocalVariable.schema(),
),
Tool(
name=IDATools.RENAME_GLOBAL_VARIABLE,
description="Rename a global variable in the IDA database",
inputSchema=RenameGlobalVariable.schema(),
),
Tool(
name=IDATools.RENAME_FUNCTION,
description="Rename a function in the IDA database",
inputSchema=RenameFunction.schema(),
),
Tool(
name=IDATools.RENAME_MULTI_LOCAL_VARIABLES,
description="Rename multiple local variables within a function at once in the IDA database",
inputSchema=RenameMultiLocalVariables.schema(),
),
Tool(
name=IDATools.RENAME_MULTI_GLOBAL_VARIABLES,
description="Rename multiple global variables at once in the IDA database",
inputSchema=RenameMultiGlobalVariables.schema(),
),
Tool(
name=IDATools.RENAME_MULTI_FUNCTIONS,
description="Rename multiple functions at once in the IDA database",
inputSchema=RenameMultiFunctions.schema(),
),
Tool(
name=IDATools.ADD_ASSEMBLY_COMMENT,
description="Add a comment at a specific address in the assembly view of the IDA database",
inputSchema=AddAssemblyComment.schema(),
),
Tool(
name=IDATools.ADD_FUNCTION_COMMENT,
description="Add a comment to a function in the IDA database",
inputSchema=AddFunctionComment.schema(),
),
Tool(
name=IDATools.ADD_PSEUDOCODE_COMMENT,
description="Add a comment to a specific address in the function's decompiled pseudocode",
inputSchema=AddPseudocodeComment.schema(),
),
Tool(
name=IDATools.EXECUTE_SCRIPT,
description="Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules.",
inputSchema=ExecuteScript.schema(),
),
Tool(
name=IDATools.EXECUTE_SCRIPT_FROM_FILE,
description="Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process.",
inputSchema=ExecuteScriptFromFile.schema(),
),
]
@server.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Call tool and handle results"""
# Ensure connection exists
if not ida_communicator.connected and not ida_communicator.ensure_connection():
return [TextContent(
type="text",
text=f"Error: Cannot connect to IDA Pro plugin. Please ensure the plugin is running."
)]
try:
match name:
case IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME:
assembly: str = ida_functions.get_function_assembly_by_name(arguments["function_name"])
return [TextContent(
type="text",
text=assembly
)]
case IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS:
assembly: str = ida_functions.get_function_assembly_by_address(arguments["address"])
return [TextContent(
type="text",
text=assembly
)]
case IDATools.GET_FUNCTION_DECOMPILED_BY_NAME:
decompiled: str = ida_functions.get_function_decompiled_by_name(arguments["function_name"])
return [TextContent(
type="text",
text=decompiled
)]
case IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS:
decompiled: str = ida_functions.get_function_decompiled_by_address(arguments["address"])
return [TextContent(
type="text",
text=decompiled
)]
case IDATools.GET_GLOBAL_VARIABLE_BY_NAME:
variable_info: str = ida_functions.get_global_variable_by_name(arguments["variable_name"])
return [TextContent(
type="text",
text=variable_info
)]
case IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS:
variable_info: str = ida_functions.get_global_variable_by_address(arguments["address"])
return [TextContent(
type="text",
text=variable_info
)]
case IDATools.GET_CURRENT_FUNCTION_ASSEMBLY:
assembly: str = ida_functions.get_current_function_assembly()
return [TextContent(
type="text",
text=assembly
)]
case IDATools.GET_CURRENT_FUNCTION_DECOMPILED:
decompiled: str = ida_functions.get_current_function_decompiled()
return [TextContent(
type="text",
text=decompiled
)]
case IDATools.RENAME_LOCAL_VARIABLE:
result: str = ida_functions.rename_local_variable(
arguments["function_name"],
arguments["old_name"],
arguments["new_name"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.RENAME_GLOBAL_VARIABLE:
result: str = ida_functions.rename_global_variable(
arguments["old_name"],
arguments["new_name"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.RENAME_FUNCTION:
result: str = ida_functions.rename_function(
arguments["old_name"],
arguments["new_name"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.RENAME_MULTI_LOCAL_VARIABLES:
result: str = ida_functions.rename_multi_local_variables(
arguments["function_name"],
arguments["rename_pairs_old2new"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.RENAME_MULTI_GLOBAL_VARIABLES:
result: str = ida_functions.rename_multi_global_variables(
arguments["rename_pairs_old2new"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.RENAME_MULTI_FUNCTIONS:
result: str = ida_functions.rename_multi_functions(
arguments["rename_pairs_old2new"]
)
return [TextContent(
type="text",
text=result
)]
case IDATools.ADD_ASSEMBLY_COMMENT:
result: str = ida_functions.add_assembly_comment(
arguments["address"],
arguments["comment"],
arguments.get("is_repeatable", False)
)
return [TextContent(
type="text",
text=result
)]
case IDATools.ADD_FUNCTION_COMMENT:
result: str = ida_functions.add_function_comment(
arguments["function_name"],
arguments["comment"],
arguments.get("is_repeatable", False)
)
return [TextContent(
type="text",
text=result
)]
case IDATools.ADD_PSEUDOCODE_COMMENT:
result: str = ida_functions.add_pseudocode_comment(
arguments["function_name"],
arguments["address"],
arguments["comment"],
arguments.get("is_repeatable", False)
)
return [TextContent(
type="text",
text=result
)]
case IDATools.EXECUTE_SCRIPT:
try:
if "script" not in arguments or not arguments["script"]:
return [TextContent(
type="text",
text="Error: No script content provided"
)]
result: str = ida_functions.execute_script(arguments["script"])
return [TextContent(
type="text",
text=result
)]
except Exception as e:
logger.error(f"Error executing script: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error executing script: {str(e)}"
)]
case IDATools.EXECUTE_SCRIPT_FROM_FILE:
try:
if "file_path" not in arguments or not arguments["file_path"]:
return [TextContent(
type="text",
text="Error: No file path provided"
)]
result: str = ida_functions.execute_script_from_file(arguments["file_path"])
return [TextContent(
type="text",
text=result
)]
except Exception as e:
logger.error(f"Error executing script from file: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error executing script from file: {str(e)}"
)]
case _:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
logger.error(f"Error calling tool: {str(e)}", exc_info=True)
return [TextContent(
type="text",
text=f"Error executing {name}: {str(e)}"
)]
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)