"""
Main MCP Server for Pentest MCP Server.
Implements MCP tools for SSH + Tmux session management.
"""
import asyncio
import logging
import json
import re
import time
from typing import Any, Dict, List, Optional, Sequence
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.server.stdio import stdio_server
from mcp.types import (
Resource,
Tool,
TextContent,
ImageContent,
EmbeddedResource,
)
from .config import Config
from .ssh_manager import SSHManager, SSHConnectionError
from .tmux_manager import TmuxManager, TmuxSessionError
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PentestMCPServer:
"""Main MCP Server for autonomous pentesting operations."""
def __init__(self):
self.config = Config()
self.ssh_manager: Optional[SSHManager] = None
self.tmux_manager: Optional[TmuxManager] = None
self.server = Server("pentest-mcp")
self._setup_resources()
self._setup_tools()
def _setup_resources(self):
"""Setup MCP resources."""
@self.server.list_resources()
async def handle_list_resources() -> List[Resource]:
return [
Resource(
uri="pentest://sessions",
name="Active Sessions",
description="List of all active tmux sessions",
mimeType="application/json"
),
Resource(
uri="pentest://system-status",
name="System Status",
description="Current system resource usage and connection status",
mimeType="application/json"
)
]
@self.server.read_resource()
async def handle_read_resource(uri: str) -> List[TextContent]:
if uri == "pentest://sessions":
sessions = await self.tmux_manager.list_sessions()
return [TextContent(type="text", text=json.dumps(sessions))]
elif uri == "pentest://system-status":
status = await self._handle_get_system_status({})
return [TextContent(type="text", text=json.dumps(status))]
else:
raise ValueError(f"Unknown resource: {uri}")
def _setup_tools(self):
"""Setup MCP tools."""
@self.server.list_tools()
async def handle_list_tools() -> List[Tool]:
return [
Tool(
name="create_session",
description="Create a new persistent tmux session",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Unique identifier for the session"
},
"shell": {
"type": "string",
"description": "Shell to use (default: bash)",
"default": "bash"
}
},
"required": ["session_id"]
}
),
Tool(
name="list_sessions",
description="List all active tmux sessions",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="kill_session",
description="Terminate a tmux session",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Session to terminate"
}
},
"required": ["session_id"]
}
),
Tool(
name="execute",
description="Execute a command in a persistent tmux session with trigger-based monitoring",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Target session ID"
},
"command": {
"type": "string",
"description": "Command to execute"
},
"background": {
"type": "boolean",
"description": "Run in background (default: false)",
"default": False
},
"triggers": {
"type": "array",
"description": "Triggers for AI suspension/resumption",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "prompt | regex | timeout | file_exists",
"enum": ["prompt", "regex", "timeout", "file_exists"]
},
"pattern": {
"type": "string",
"description": "Regex pattern (for regex type)"
},
"timeout_seconds": {
"type": "integer",
"description": "Timeout in seconds (for timeout type)"
},
"path": {
"type": "string",
"description": "File path (for file_exists type)"
},
"name": {
"type": "string",
"description": "Trigger identifier"
}
},
"required": ["type"]
}
},
"max_timeout": {
"type": "integer",
"description": "Maximum wait time in seconds (default: 300)",
"default": 300
}
},
"required": ["session_id", "command"]
}
),
Tool(
name="read_output",
description="Get current output from a session",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Target session ID"
},
"lines": {
"type": "integer",
"description": "Number of lines to return (default: all)"
}
},
"required": ["session_id"]
}
),
Tool(
name="send_input",
description="Send input to an interactive session",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Target session ID"
},
"input": {
"type": "string",
"description": "Text to send"
},
"press_enter": {
"type": "boolean",
"description": "Press Enter after input (default: true)",
"default": True
}
},
"required": ["session_id", "input"]
}
),
Tool(
name="get_system_status",
description="Get your OS system resource usage and status",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="recover_sessions",
description="Find and reconnect to orphaned tmux sessions",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="upload_file",
description="Upload a file to the your OS system",
inputSchema={
"type": "object",
"properties": {
"local_path": {
"type": "string",
"description": "Local file path to upload"
},
"remote_path": {
"type": "string",
"description": "Remote destination path"
}
},
"required": ["local_path", "remote_path"]
}
),
Tool(
name="download_file",
description="Download a file from the your OS system",
inputSchema={
"type": "object",
"properties": {
"remote_path": {
"type": "string",
"description": "Remote file path to download"
},
"local_path": {
"type": "string",
"description": "Local destination path"
}
},
"required": ["remote_path", "local_path"]
}
),
Tool(
name="get_session_history",
description="Get command history for a specific tmux session",
inputSchema={
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "Session to get command history for"
}
},
"required": ["session_id"]
}
),
Tool(
name="parse_tool_output",
description="Parse common pentest tool outputs (nmap XML, JSON, etc.)",
inputSchema={
"type": "object",
"properties": {
"tool": {
"type": "string",
"description": "nmap | masscan | sqlmap | custom",
"enum": ["nmap", "masscan", "sqlmap", "custom"]
},
"file_path": {
"type": "string",
"description": "Path to output file on your OS system"
},
"format": {
"type": "string",
"description": "xml | json | txt",
"enum": ["xml", "json", "txt"]
}
},
"required": ["tool", "file_path"]
}
)
]
@self.server.call_tool()
async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Handle tool calls."""
try:
# Ensure connection is established
if not await self._ensure_initialized():
return [TextContent(
type="text",
text=json.dumps({
"status": "error",
"message": "Failed to initialize SSH/Tmux connection"
})
)]
# Route to appropriate handler
if name == "create_session":
result = await self._handle_create_session(arguments)
elif name == "list_sessions":
result = await self._handle_list_sessions(arguments)
elif name == "kill_session":
result = await self._handle_kill_session(arguments)
elif name == "execute":
result = await self._handle_execute(arguments)
elif name == "read_output":
result = await self._handle_read_output(arguments)
elif name == "send_input":
result = await self._handle_send_input(arguments)
elif name == "get_system_status":
result = await self._handle_get_system_status(arguments)
elif name == "recover_sessions":
result = await self._handle_recover_sessions(arguments)
elif name == "upload_file":
result = await self._handle_upload_file(arguments)
elif name == "download_file":
result = await self._handle_download_file(arguments)
elif name == "get_session_history":
result = await self._handle_get_session_history(arguments)
elif name == "parse_tool_output":
result = await self._handle_parse_tool_output(arguments)
elif name == "list_files":
result = await self._handle_list_files(arguments)
else:
result = {
"status": "error",
"message": f"Unknown tool: {name}"
}
return [TextContent(type="text", text=json.dumps(result))]
except Exception as e:
logger.error(f"Error handling tool {name}: {e}")
return [TextContent(
type="text",
text=json.dumps({
"status": "error",
"message": f"Tool execution failed: {str(e)}"
})
)]
async def _ensure_initialized(self) -> bool:
"""Ensure SSH and Tmux managers are initialized."""
max_attempts = 3
for attempt in range(max_attempts):
try:
if not self.ssh_manager:
self.ssh_manager = SSHManager(self.config)
connected = await self.ssh_manager.connect()
if not connected:
continue
if not self.tmux_manager:
self.tmux_manager = TmuxManager(self.ssh_manager)
initialized = await self.tmux_manager.initialize()
if not initialized:
continue
# Ensure connection is alive
if await self.ssh_manager.ensure_connected():
return True
except Exception as e:
logger.error(f"Initialization attempt {attempt + 1} failed: {e}")
if attempt == max_attempts - 1:
return False
await asyncio.sleep(2 ** attempt) # Exponential backoff
return False
# Tool Handlers
async def _handle_create_session(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle create_session tool call."""
session_id = args.get("session_id")
shell = args.get("shell", "bash")
if not session_id:
return {"status": "error", "message": "session_id is required"}
return await self.tmux_manager.create_session(session_id, shell)
async def _handle_list_sessions(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle list_sessions tool call."""
sessions = await self.tmux_manager.list_sessions()
return {"status": "success", "sessions": sessions}
async def _handle_kill_session(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle kill_session tool call."""
session_id = args.get("session_id")
if not session_id:
return {"status": "error", "message": "session_id is required"}
return await self.tmux_manager.kill_session(session_id)
async def _handle_execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Enhanced execute tool call with comprehensive trigger support."""
session_id = args.get("session_id")
command = args.get("command")
background = args.get("background", False)
triggers = args.get("triggers", [{"type": "prompt"}]) # Default to prompt trigger
max_timeout = args.get("max_timeout", 300)
poll_interval = args.get("poll_interval", 1.0)
if not session_id or not command:
return {"status": "error", "message": "session_id and command are required"}
# Validate command (optional security check)
if not await self._validate_command(command):
return {"status": "error", "message": f"Command blocked for security: {command}"}
# Ensure we're initialized
if not await self._ensure_initialized():
return {"status": "error", "message": "SSH/Tmux connection not available"}
# Execute command
execute_result = await self.tmux_manager.execute_command(session_id, command)
if execute_result.get("status") != "sent":
return {
"status": "error",
"message": f"Failed to execute command: {execute_result.get('message', 'Unknown error')}"
}
# Background execution: return immediately
if background:
return {
"status": "background",
"message": "Command sent, running in background",
"session_id": session_id,
"command": command
}
# Foreground execution: monitor with enhanced trigger system
try:
from .trigger_system import TriggerSystem
trigger_system = TriggerSystem(self.ssh_manager, self.tmux_manager)
monitor_result = await trigger_system.monitor_session_with_triggers(
session_id=session_id,
triggers=triggers,
max_timeout=max_timeout,
poll_interval=poll_interval
)
# Add command context to result
monitor_result["command"] = command
monitor_result["session_id"] = session_id
return monitor_result
except Exception as e:
logger.error(f"Error during trigger monitoring: {e}")
# Fallback: wait briefly and return output
await asyncio.sleep(2)
output_result = await self.tmux_manager.capture_pane(session_id)
return {
"status": "completed",
"command": command,
"output": output_result.get("output", ""),
"execution_time": 2.0,
"note": "Trigger monitoring failed, using fallback"
}
async def _handle_read_output(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle read_output tool call."""
session_id = args.get("session_id")
lines = args.get("lines")
if not session_id:
return {"status": "error", "message": "session_id is required"}
result = await self.tmux_manager.capture_pane(session_id, lines)
return result
async def _handle_send_input(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle send_input tool call."""
session_id = args.get("session_id")
input_text = args.get("input")
press_enter = args.get("press_enter", True)
if not session_id or input_text is None:
return {"status": "error", "message": "session_id and input are required"}
return await self.tmux_manager.send_input(session_id, input_text, press_enter)
async def _handle_get_system_status(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle get_system_status tool call."""
try:
# Get system info
system_info = await self.ssh_manager.get_system_info()
# Get session count
sessions = await self.tmux_manager.list_sessions()
# Get connection status
connection_status = self.ssh_manager.get_connection_status()
return {
"status": "success",
"system_info": system_info,
"active_sessions": len(sessions),
"connection_status": connection_status,
"sessions": sessions
}
except Exception as e:
return {"status": "error", "message": f"Failed to get system status: {e}"}
async def _handle_recover_sessions(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle recover_sessions tool call."""
try:
recovered = await self.tmux_manager.recover_sessions()
recovered_info = []
for session_id in recovered:
session_info = self.tmux_manager.get_session_info(session_id)
if session_info:
# Get last output
output_result = await self.tmux_manager.capture_pane(session_id, 20)
last_output = output_result.get("output", "")[-500:] # Last 500 chars
recovered_info.append({
"session_id": session_id,
"age_seconds": int(time.time() - session_info.created_at),
"last_output": last_output
})
return {
"status": "success",
"recovered_sessions": recovered_info
}
except Exception as e:
return {"status": "error", "message": f"Recovery failed: {e}"}
async def _handle_get_session_history(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle get_session_history tool call."""
session_id = args.get("session_id")
if not session_id:
return {"status": "error", "message": "session_id is required"}
try:
result = await self.tmux_manager.get_session_history(session_id)
return result
except Exception as e:
return {"status": "error", "message": f"Failed to get session history: {e}"}
async def _handle_upload_file(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle upload_file tool call."""
try:
import os
local_path = args.get("local_path")
remote_path = args.get("remote_path")
if not await self._ensure_initialized():
return {"status": "error", "message": "SSH connection not available"}
# Validate local path exists on the MCP host (likely Windows)
if not local_path or not os.path.isfile(local_path):
return {"status": "error", "message": f"Local file not found: {local_path}"}
# Ensure remote directory exists (use SSH command via tmux/exec)
remote_dir = os.path.dirname(remote_path) or "/"
try:
await self.ssh_manager.run_command(f"mkdir -p '{remote_dir}'", timeout=10)
except Exception:
# Non-fatal: continue and let SFTP report precise errors
pass
# Use async context manager for SFTP lifecycle
sftp = await self.ssh_manager.start_sftp_client()
try:
await sftp.put(local_path, remote_path)
finally:
try:
await sftp.exit()
except Exception:
# Some asyncssh versions use .close(); attempt fallback
try:
await sftp.close() # type: ignore[attr-defined]
except Exception:
pass
return {
"status": "success",
"message": f"File uploaded to {remote_path}"
}
except Exception as e:
return {"status": "error", "message": f"Upload failed: {e}"}
async def _handle_download_file(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle download_file tool call."""
try:
import os
remote_path = args.get("remote_path")
local_path = args.get("local_path")
if not await self._ensure_initialized():
return {"status": "error", "message": "SSH connection not available"}
# Ensure local destination directory exists on MCP host
local_dir = os.path.dirname(local_path) or "."
try:
os.makedirs(local_dir, exist_ok=True)
except Exception as mkerr:
return {"status": "error", "message": f"Failed to create local directory '{local_dir}': {mkerr}"}
sftp = await self.ssh_manager.start_sftp_client()
try:
await sftp.get(remote_path, local_path)
finally:
try:
await sftp.exit()
except Exception:
try:
await sftp.close() # type: ignore[attr-defined]
except Exception:
pass
return {
"status": "success",
"message": f"File downloaded to {local_path}"
}
except Exception as e:
return {"status": "error", "message": f"Download failed: {e}"}
async def _handle_parse_tool_output(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle parse_tool_output tool call."""
try:
tool = args.get("tool")
file_path = args.get("file_path")
format_type = args.get("format", "xml")
if not await self._ensure_initialized():
return {"status": "error", "message": "SSH connection not available"}
# Download the file first
import tempfile
import os
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
try:
# Download the file
sftp = await self.ssh_manager.start_sftp_client()
try:
await sftp.get(file_path, temp_path)
finally:
try:
await sftp.exit()
except Exception:
try:
await sftp.close() # type: ignore[attr-defined]
except Exception:
pass
# Parse based on tool and format
parsed_data = await self._parse_output_file(tool, format_type, temp_path)
return {
"status": "success",
"parsed_data": parsed_data,
"summary": f"Successfully parsed {tool} output"
}
finally:
# Cleanup temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
return {"status": "error", "message": f"Output parsing failed: {e}"}
async def _handle_list_files(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Handle list_files tool call."""
try:
path = args.get("path")
recursive = args.get("recursive", False)
if not await self._ensure_initialized():
return {"status": "error", "message": "SSH connection not available"}
if recursive:
cmd = f"find {path} -type f -printf '%p %s %T@\\n' 2>/dev/null || echo 'NOT_FOUND'"
else:
cmd = f"ls -la {path} 2>/dev/null || echo 'NOT_FOUND'"
result = await self.ssh_manager.run_command(cmd)
if "NOT_FOUND" in result.stdout:
return {"status": "error", "message": f"Path not found: {path}"}
return {
"status": "success",
"output": result.stdout,
"path": path,
"recursive": recursive
}
except Exception as e:
return {"status": "error", "message": f"File listing failed: {e}"}
async def _parse_output_file(self, tool: str, format_type: str, file_path: str) -> Dict[str, Any]:
"""Parse tool output files."""
try:
if tool == "nmap" and format_type == "xml":
return await self._parse_nmap_xml(file_path)
elif tool == "masscan" and format_type == "xml":
return await self._parse_masscan_xml(file_path)
else:
# For unsupported combinations, return basic file info
import os
return {
"file_size": os.path.getsize(file_path),
"tool": tool,
"format": format_type,
"note": "Automatic parsing not implemented for this tool/format"
}
except Exception as e:
return {"error": f"Parsing failed: {e}"}
async def _parse_nmap_xml(self, file_path: str) -> Dict[str, Any]:
"""Parse nmap XML output."""
try:
import xml.etree.ElementTree as ET
tree = ET.parse(file_path)
root = tree.getroot()
hosts = []
for host in root.findall('host'):
host_info = {}
# Get address
address_elem = host.find('address')
if address_elem is not None:
host_info['address'] = address_elem.get('addr')
host_info['type'] = address_elem.get('addrtype')
# Get ports
ports_info = []
ports = host.find('ports')
if ports is not None:
for port in ports.findall('port'):
port_info = {
'port': port.get('portid'),
'protocol': port.get('protocol'),
'state': port.find('state').get('state') if port.find('state') is not None else 'unknown',
'service': port.find('service').get('name') if port.find('service') is not None else 'unknown'
}
ports_info.append(port_info)
host_info['ports'] = ports_info
hosts.append(host_info)
return {
"hosts": hosts,
"scan_info": {
"scanner": "nmap",
"hosts_found": len(hosts)
}
}
except Exception as e:
return {"error": f"Nmap XML parsing failed: {e}"}
async def _parse_masscan_xml(self, file_path: str) -> Dict[str, Any]:
"""Parse masscan XML output."""
try:
import xml.etree.ElementTree as ET
tree = ET.parse(file_path)
root = tree.getroot()
hosts = []
for host in root.findall('host'):
host_info = {}
# Get address
address_elem = host.find('address')
if address_elem is not None:
host_info['address'] = address_elem.get('addr')
host_info['type'] = address_elem.get('addrtype')
# Get ports
ports_info = []
for port in host.findall('ports/port'):
port_info = {
'port': port.get('portid'),
'protocol': port.get('protocol'),
'state': 'open' # masscan typically only shows open ports
}
ports_info.append(port_info)
host_info['ports'] = ports_info
hosts.append(host_info)
return {
"hosts": hosts,
"scan_info": {
"scanner": "masscan",
"hosts_found": len(hosts)
}
}
except Exception as e:
return {"error": f"Masscan XML parsing failed: {e}"}
async def _validate_command(self, command: str) -> bool:
"""
Validate command for security (optional).
Args:
command: Command to validate
Returns:
bool: True if command is allowed
"""
# Check against blocked patterns
for pattern in self.config.BLOCKED_COMMANDS:
if re.search(pattern, command, re.IGNORECASE):
logger.warning(f"Blocked dangerous command: {command}")
return False
return True
async def shutdown(self):
"""Cleanup resources on shutdown."""
try:
if self.ssh_manager:
await self.ssh_manager.disconnect()
logger.info("MCP Server shutdown completed")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
async def main():
"""Main entry point for the MCP server."""
try:
# Validate configuration
Config.validate_config()
# Create and run server
server_instance = PentestMCPServer()
# Setup signal handlers for clean shutdown
import signal
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, shutting down...")
asyncio.create_task(server_instance.shutdown())
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the server
async with stdio_server() as (read_stream, write_stream):
await server_instance.server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="pentest-mcp",
server_version="1.0.0",
capabilities={
"tools": {},
"resources": {}
}
)
)
except Exception as e:
logger.error(f"Server failed to start: {e}")
raise
if __name__ == "__main__":
asyncio.run(main())