wireshark-mcp-server.py•20.1 kB
#!/usr/bin/env python3
"""
Wireshark MCP Server
A Model Context Protocol server that provides AI assistants with access to
Wireshark network analysis capabilities for network troubleshooting and analysis.
Author: AI Assistant
Date: 2025-06-20
"""
import asyncio
import json
import logging
import os
import re
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, Any, List, Optional, Union
try:
from fastmcp import FastMCP
except ImportError:
print("FastMCP not installed. Install with: pip install fastmcp")
exit(1)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SecurityValidator:
"""Security validation utilities for network analysis operations."""
# Allowed interface patterns (Windows and Linux)
INTERFACE_PATTERNS = [
r'^(eth|wlan|lo|en|enp|wlp|docker|br-)[a-zA-Z0-9]{1,15}$', # Linux
r'^Ethernet \d+$', # Windows Ethernet
r'^Wi-Fi \d*$', # Windows WiFi
r'^Local Area Connection \d*$', # Windows LAN
r'^\d+$' # Interface number
]
# Security limits
MAX_CAPTURE_DURATION = 300 # 5 minutes
MAX_PACKET_COUNT = 10000 # Maximum packets per capture
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@staticmethod
def validate_interface(interface: str) -> bool:
"""Validate network interface name against allowed patterns."""
if not interface or len(interface) > 50:
return False
return any(
re.match(pattern, interface)
for pattern in SecurityValidator.INTERFACE_PATTERNS
)
@staticmethod
def validate_capture_filter(filter_expr: str) -> bool:
"""Validate BPF capture filter expression."""
if not filter_expr:
return True
# Check for dangerous patterns
dangerous_patterns = [';', '|', '&', '$(', '`', '\n', '\r', '..']
if any(pattern in filter_expr for pattern in dangerous_patterns):
return False
# Basic length check
return len(filter_expr) < 500
@staticmethod
def sanitize_filepath(filepath: str) -> Optional[str]:
"""Sanitize and validate file paths."""
try:
resolved_path = Path(filepath).resolve()
# Check if file exists
if not resolved_path.exists():
return None
# Check file extension
if resolved_path.suffix.lower() not in ['.pcap', '.pcapng']:
return None
# Check file size
if resolved_path.stat().st_size > SecurityValidator.MAX_FILE_SIZE:
return None
return str(resolved_path)
except Exception:
return None
class WiresharkInterface:
"""Interface to Wireshark CLI tools with security controls."""
def __init__(self):
self.tshark_path = self._find_tshark()
self.dumpcap_path = self._find_dumpcap()
self.capinfos_path = self._find_capinfos()
if not self.tshark_path:
raise RuntimeError("TShark not found. Please install Wireshark.")
def _find_tshark(self) -> Optional[str]:
"""Find TShark executable."""
common_paths = [
"tshark",
"tshark.exe",
r"C:\Program Files\Wireshark\tshark.exe",
"/usr/bin/tshark",
"/usr/local/bin/tshark"
]
for path in common_paths:
try:
result = subprocess.run([path, "--version"],
capture_output=True, timeout=5)
if result.returncode == 0:
return path
except (subprocess.TimeoutExpired, FileNotFoundError):
continue
return None
def _find_dumpcap(self) -> Optional[str]:
"""Find dumpcap executable."""
common_paths = [
"dumpcap",
"dumpcap.exe",
r"C:\Program Files\Wireshark\dumpcap.exe",
"/usr/bin/dumpcap",
"/usr/local/bin/dumpcap"
]
for path in common_paths:
try:
result = subprocess.run([path, "--version"],
capture_output=True, timeout=5)
if result.returncode == 0:
return path
except (subprocess.TimeoutExpired, FileNotFoundError):
continue
return None
def _find_capinfos(self) -> Optional[str]:
"""Find capinfos executable."""
common_paths = [
"capinfos",
"capinfos.exe",
r"C:\Program Files\Wireshark\capinfos.exe",
"/usr/bin/capinfos",
"/usr/local/bin/capinfos"
]
for path in common_paths:
try:
result = subprocess.run([path, "--version"],
capture_output=True, timeout=5)
if result.returncode == 0:
return path
except (subprocess.TimeoutExpired, FileNotFoundError):
continue
return None
def get_interfaces(self) -> Dict[str, Any]:
"""Get list of available network interfaces."""
try:
if self.tshark_path:
result = subprocess.run(
[self.tshark_path, "-D"],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
interfaces = []
for line in result.stdout.strip().split('\n'):
if line:
interfaces.append(line)
return {
"status": "success",
"interfaces": interfaces,
"count": len(interfaces)
}
return {"status": "error", "message": "Unable to list interfaces"}
except Exception as e:
logger.error(f"Error getting interfaces: {e}")
return {"status": "error", "message": str(e)}
def capture_packets(self, interface: str, count: int = 100,
filter_expr: str = "", timeout: int = 30) -> Dict[str, Any]:
"""Capture packets from network interface using TShark."""
try:
# Build command
cmd = [self.tshark_path, "-i", interface, "-c", str(count), "-T", "json"]
if filter_expr:
cmd.extend(["-f", filter_expr])
# Execute capture
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout
)
if result.returncode != 0:
return {
"status": "error",
"message": f"Capture failed: {result.stderr}"
}
# Parse JSON output
packets = []
if result.stdout.strip():
try:
# TShark outputs one JSON object per line
for line in result.stdout.strip().split('\n'):
if line:
packet = json.loads(line)
packets.append(packet)
except json.JSONDecodeError:
# Fallback to raw output
packets = [{"raw_output": result.stdout}]
return {
"status": "success",
"interface": interface,
"packet_count": len(packets),
"packets": packets[:20], # Limit output for display
"total_captured": len(packets)
}
except subprocess.TimeoutExpired:
return {"status": "error", "message": "Capture timeout"}
except Exception as e:
logger.error(f"Capture error: {e}")
return {"status": "error", "message": str(e)}
def analyze_pcap_file(self, filepath: str, filter_expr: str = "",
max_packets: int = 1000) -> Dict[str, Any]:
"""Analyze PCAP file using TShark."""
try:
cmd = [self.tshark_path, "-r", filepath, "-T", "json"]
if filter_expr:
cmd.extend(["-Y", filter_expr])
if max_packets > 0:
cmd.extend(["-c", str(max_packets)])
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60
)
if result.returncode != 0:
return {
"status": "error",
"message": f"Analysis failed: {result.stderr}"
}
# Parse packets
packets = []
if result.stdout.strip():
try:
for line in result.stdout.strip().split('\n'):
if line:
packet = json.loads(line)
packets.append(packet)
except json.JSONDecodeError:
packets = [{"raw_output": result.stdout}]
return {
"status": "success",
"file": filepath,
"packet_count": len(packets),
"packets": packets[:10], # Sample for display
"total_analyzed": len(packets)
}
except Exception as e:
logger.error(f"Analysis error: {e}")
return {"status": "error", "message": str(e)}
def get_protocol_statistics(self, filepath: str) -> Dict[str, Any]:
"""Generate protocol statistics from PCAP file."""
try:
# Protocol hierarchy
cmd_hierarchy = [self.tshark_path, "-r", filepath, "-q", "-z", "io,phs"]
hierarchy_result = subprocess.run(
cmd_hierarchy,
capture_output=True,
text=True,
timeout=30
)
# Conversation statistics
cmd_conv = [self.tshark_path, "-r", filepath, "-q", "-z", "conv,ip"]
conv_result = subprocess.run(
cmd_conv,
capture_output=True,
text=True,
timeout=30
)
return {
"status": "success",
"file": filepath,
"protocol_hierarchy": hierarchy_result.stdout if hierarchy_result.returncode == 0 else "Error generating hierarchy",
"ip_conversations": conv_result.stdout if conv_result.returncode == 0 else "Error generating conversations"
}
except Exception as e:
logger.error(f"Statistics error: {e}")
return {"status": "error", "message": str(e)}
def get_file_info(self, filepath: str) -> Dict[str, Any]:
"""Get information about a capture file using capinfos."""
if not self.capinfos_path:
return {"status": "error", "message": "capinfos not available"}
try:
result = subprocess.run(
[self.capinfos_path, filepath],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
return {
"status": "success",
"file": filepath,
"info": result.stdout
}
else:
return {
"status": "error",
"message": f"capinfos failed: {result.stderr}"
}
except Exception as e:
return {"status": "error", "message": str(e)}
class WiresharkMCPServer:
"""Main MCP server class providing Wireshark capabilities."""
def __init__(self):
self.mcp = FastMCP("wireshark-analyzer")
self.wireshark = WiresharkInterface()
self.executor = ThreadPoolExecutor(max_workers=2)
self.register_tools()
logger.info("Wireshark MCP Server initialized")
def register_tools(self):
"""Register all MCP tools."""
@self.mcp.tool
def get_network_interfaces() -> Dict[str, Any]:
"""Get list of available network interfaces for packet capture."""
return self.wireshark.get_interfaces()
@self.mcp.tool
async def capture_live_packets(
interface: str,
count: int = 50,
capture_filter: str = "",
timeout: int = 30
) -> Dict[str, Any]:
"""
Capture live network packets from a specified interface.
Args:
interface: Network interface name or number (e.g., "eth0", "1")
count: Number of packets to capture (max 1000)
capture_filter: BPF capture filter (e.g., "tcp port 80")
timeout: Capture timeout in seconds (max 60)
"""
# Input validation
if not SecurityValidator.validate_interface(interface):
return {"status": "error", "message": "Invalid interface name"}
if not SecurityValidator.validate_capture_filter(capture_filter):
return {"status": "error", "message": "Invalid capture filter"}
# Apply limits
count = min(count, SecurityValidator.MAX_PACKET_COUNT)
timeout = min(timeout, 60)
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self.executor,
self.wireshark.capture_packets,
interface, count, capture_filter, timeout
)
return result
except Exception as e:
logger.error(f"Live capture error: {e}")
return {"status": "error", "message": str(e)}
@self.mcp.tool
async def analyze_pcap_file(
filepath: str,
display_filter: str = "",
max_packets: int = 100
) -> Dict[str, Any]:
"""
Analyze an existing PCAP/PCAPNG file.
Args:
filepath: Path to the PCAP/PCAPNG file
display_filter: Wireshark display filter (e.g., "http.request")
max_packets: Maximum number of packets to analyze
"""
# Validate and sanitize file path
sanitized_path = SecurityValidator.sanitize_filepath(filepath)
if not sanitized_path:
return {"status": "error", "message": "Invalid or inaccessible file path"}
# Apply limits
max_packets = min(max_packets, 1000)
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self.executor,
self.wireshark.analyze_pcap_file,
sanitized_path, display_filter, max_packets
)
return result
except Exception as e:
logger.error(f"File analysis error: {e}")
return {"status": "error", "message": str(e)}
@self.mcp.tool
def get_protocol_statistics(filepath: str) -> Dict[str, Any]:
"""
Generate protocol hierarchy and conversation statistics from a PCAP file.
Args:
filepath: Path to the PCAP/PCAPNG file
"""
# Validate file path
sanitized_path = SecurityValidator.sanitize_filepath(filepath)
if not sanitized_path:
return {"status": "error", "message": "Invalid or inaccessible file path"}
return self.wireshark.get_protocol_statistics(sanitized_path)
@self.mcp.tool
def get_capture_file_info(filepath: str) -> Dict[str, Any]:
"""
Get detailed information about a capture file.
Args:
filepath: Path to the PCAP/PCAPNG file
"""
# Validate file path
sanitized_path = SecurityValidator.sanitize_filepath(filepath)
if not sanitized_path:
return {"status": "error", "message": "Invalid or inaccessible file path"}
return self.wireshark.get_file_info(sanitized_path)
@self.mcp.resource("network://help")
def get_help_documentation() -> str:
"""Comprehensive help documentation for Wireshark MCP tools."""
return """
# Wireshark MCP Server Help
## Available Tools
### get_network_interfaces()
- Lists all available network interfaces for packet capture
- No parameters required
- Returns interface names and numbers
### capture_live_packets(interface, count, capture_filter, timeout)
- Captures live network packets from specified interface
- Parameters:
- interface: Interface name (e.g., "eth0") or number (e.g., "1")
- count: Number of packets to capture (default: 50, max: 1000)
- capture_filter: BPF filter expression (optional)
- timeout: Capture timeout in seconds (default: 30, max: 60)
### analyze_pcap_file(filepath, display_filter, max_packets)
- Analyzes existing PCAP/PCAPNG files
- Parameters:
- filepath: Path to capture file
- display_filter: Wireshark display filter (optional)
- max_packets: Maximum packets to analyze (default: 100, max: 1000)
### get_protocol_statistics(filepath)
- Generates protocol hierarchy and conversation statistics
- Parameters:
- filepath: Path to capture file
### get_capture_file_info(filepath)
- Gets detailed information about capture file
- Parameters:
- filepath: Path to capture file
## Common Filters
### Capture Filters (BPF syntax):
- "tcp port 80" - HTTP traffic
- "host 192.168.1.1" - Traffic to/from specific host
- "net 10.0.0.0/8" - Traffic on specific network
### Display Filters (Wireshark syntax):
- "http.request" - HTTP requests
- "tcp.flags.syn == 1" - TCP SYN packets
- "dns.flags.response == 1" - DNS responses
## Security Notes
- All inputs are validated for security
- File paths are sanitized and checked
- Capture limits are enforced
- Only PCAP/PCAPNG files are accepted
"""
def run(self):
"""Run the MCP server."""
logger.info("Starting Wireshark MCP Server...")
self.mcp.run()
def main():
"""Main entry point."""
try:
server = WiresharkMCPServer()
server.run()
except KeyboardInterrupt:
logger.info("Server stopped by user")
except Exception as e:
logger.error(f"Server error: {e}")
raise
if __name__ == "__main__":
main()