"""
SSH Connection Manager for Pentest MCP Server.
Handles SSH connections with auto-reconnection and health checks.
"""
import asyncio
import logging
import time
from typing import Optional, Dict, Any
import asyncssh
from .config import Config
logger = logging.getLogger(__name__)
class SSHConnectionError(Exception):
"""Custom exception for SSH connection issues."""
pass
class SSHManager:
"""Manages SSH connection to your OS with auto-reconnection."""
def __init__(self, config: Config = None):
self.config = config or Config()
self.connection: Optional[asyncssh.SSHClientConnection] = None
self.last_keepalive = time.time()
self.connection_attempts = 0
self._connecting = False
async def connect(self) -> bool:
"""
Establish SSH connection to your OS.
Returns:
bool: True if connection successful, False otherwise
"""
if self._connecting:
logger.info("Connection attempt already in progress")
return False
self._connecting = True
try:
logger.info(f"Connecting to {self.config.TARGET_HOST}:{self.config.TARGET_PORT}")
# Get authentication method
auth_method = self.config.get_ssh_auth_method()
# Establish connection
self.connection = await asyncssh.connect(
self.config.TARGET_HOST,
port=self.config.TARGET_PORT,
username=self.config.TARGET_USER,
known_hosts=None, # Disable known_hosts checking for pentesting
**auth_method
)
logger.info("SSH connection established successfully")
self.connection_attempts = 0
self.last_keepalive = time.time()
return True
except Exception as e:
logger.error(f"Failed to connect: {e}")
self.connection = None
raise SSHConnectionError(f"Connection failed: {e}")
finally:
self._connecting = False
async def disconnect(self):
"""Disconnect from SSH."""
if self.connection and not self.connection.is_closed():
self.connection.close()
await self.connection.wait_closed()
self.connection = None
logger.info("SSH connection closed")
async def is_connected(self) -> bool:
"""
Check if SSH connection is alive.
Returns:
bool: True if connected and responsive
"""
if not self.connection or self.connection.is_closed():
return False
try:
# Test connection with a simple command
result = await self.connection.run("echo ping", timeout=5)
return result.exit_status == 0
except Exception as e:
logger.warning(f"Connection check failed: {e}")
return False
async def ensure_connected(self) -> bool:
"""
Ensure SSH connection is alive, reconnect if needed.
Returns:
bool: True if connected after check/reconnect
"""
if await self.is_connected():
return True
logger.info("Connection lost, attempting to reconnect...")
return await self.reconnect()
async def reconnect(self) -> bool:
"""
Reconnect with exponential backoff.
Returns:
bool: True if reconnection successful
"""
max_attempts = self.config.MAX_RECONNECT_ATTEMPTS
base_delay = self.config.RECONNECT_DELAY_BASE
for attempt in range(max_attempts):
self.connection_attempts = attempt + 1
try:
if self.connection:
await self.disconnect()
# Exponential backoff delay
if attempt > 0:
delay = base_delay ** attempt
logger.info(f"Waiting {delay}s before reconnect attempt {attempt + 1}/{max_attempts}")
await asyncio.sleep(delay)
# Attempt connection
await self.connect()
logger.info(f"Reconnection successful after {attempt + 1} attempts")
return True
except Exception as e:
logger.error(f"Reconnect attempt {attempt + 1}/{max_attempts} failed: {e}")
logger.error(f"Failed to reconnect after {max_attempts} attempts")
return False
async def run_command(self, command: str, timeout: int = None) -> asyncssh.SSHCompletedProcess:
"""
Execute a command via SSH.
Args:
command: Command to execute
timeout: Command timeout in seconds
Returns:
SSHCompletedProcess: Command result
Raises:
SSHConnectionError: If connection fails
"""
if not await self.ensure_connected():
raise SSHConnectionError("Unable to establish SSH connection")
timeout = timeout or self.config.DEFAULT_TIMEOUT
try:
result = await self.connection.run(command, timeout=timeout)
self.last_keepalive = time.time()
return result
except Exception as e:
logger.error(f"Command execution failed: {e}")
raise SSHConnectionError(f"Command failed: {e}")
async def create_process(self, command: str) -> asyncssh.SSHClientProcess:
"""
Create an interactive SSH process.
Args:
command: Command to start
Returns:
SSHClientProcess: Interactive process
"""
if not await self.ensure_connected():
raise SSHConnectionError("Unable to establish SSH connection")
try:
process = await self.connection.create_process(command)
return process
except Exception as e:
logger.error(f"Process creation failed: {e}")
raise SSHConnectionError(f"Process creation failed: {e}")
async def start_sftp_client(self) -> asyncssh.SFTPClient:
"""
Start SFTP client for file operations.
Returns:
SFTPClient: SFTP client instance
"""
if not await self.ensure_connected():
raise SSHConnectionError("Unable to establish SSH connection")
try:
sftp = await self.connection.start_sftp_client()
return sftp
except Exception as e:
logger.error(f"SFTP client creation failed: {e}")
raise SSHConnectionError(f"SFTP failed: {e}")
async def get_system_info(self) -> Dict[str, Any]:
"""
Get system information with structured, machine-readable format.
Returns:
dict: System information with detailed memory and disk data
"""
try:
# Get basic system info
basic_commands = {
"hostname": "hostname",
"uptime": "uptime",
"kernel": "uname -r",
"cpu_info": "nproc",
"disk": "df -h / | awk 'NR==2{print $5,$3,$2}'"
}
info = {}
# Use configurable timeout instead of hardcoded value
timeout = getattr(self.config, 'DEFAULT_TIMEOUT', 30)
for key, cmd in basic_commands.items():
try:
result = await self.run_command(cmd, timeout=timeout)
info[key] = result.stdout.strip()
except Exception as e:
logger.warning(f"Failed to get {key}: {e}")
info[key] = "unknown"
# Get structured memory information
try:
# Get memory values: total, used, available (in MB)
mem_result = await self.run_command(
"free -m | awk 'NR==2{print $2,$3,$7}'", timeout=timeout
)
mem_values = mem_result.stdout.strip().split()
if len(mem_values) != 3:
raise ValueError(f"Expected 3 memory values, got {len(mem_values)}: {mem_values}")
total_mb, used_mb, available_mb = map(int, mem_values)
# Convert MB to GB (1024 MB = 1 GB)
mb_to_gb = 1024
usage_percent = round((used_mb / total_mb) * 100, 1)
used_gb = round(used_mb / mb_to_gb, 1)
total_gb = round(total_mb / mb_to_gb, 1)
available_gb = round(available_mb / mb_to_gb, 1)
memory_info = {
"usage_percent": usage_percent,
"used_mb": used_mb,
"total_mb": total_mb,
"available_mb": available_mb,
"used_gb": used_gb,
"total_gb": total_gb,
"available_gb": available_gb,
"summary": f"{used_gb}GB used of {total_gb}GB total ({usage_percent}% used, {available_gb}GB available)"
}
info["memory"] = memory_info
except Exception as e:
logger.warning(f"Failed to get structured memory info: {e}")
# Fallback to old format if structured parsing fails
try:
result = await self.run_command(
"free -m | awk 'NR==2{printf \"%.1f,%.1f,%.1f\", $3*100/$2, $3, $2}'", timeout=timeout
)
info["memory"] = result.stdout.strip() + " (legacy format)"
except Exception:
info["memory"] = "unknown"
# Improve disk information formatting
if "disk" in info and info["disk"] != "unknown":
try:
# Parse disk info format: "percentage used total" -> structured format
disk_parts = info["disk"].split()
if len(disk_parts) == 3:
usage_percent = disk_parts[0]
used = disk_parts[1]
total = disk_parts[2]
# Calculate available using safer method
try:
def parse_size(size_str):
"""Parse size string like '17G', '1.5T', '512M' to GB."""
if not size_str or len(size_str) < 2:
return 0.0
unit = size_str[-1].upper()
value = float(size_str[:-1])
# Convert to GB
if unit == 'T':
return value * 1024 # TB to GB
elif unit == 'G':
return value
elif unit == 'M':
return value / 1024 # MB to GB
elif unit == 'K':
return value / (1024 * 1024) # KB to GB
else:
# Assume bytes
return value / (1024 * 1024 * 1024)
used_gb = parse_size(used)
total_gb = parse_size(total)
available_gb = total_gb - used_gb
# Format available size in appropriate unit
if available_gb >= 1024:
available_str = f"{available_gb/1024:.1f}T"
elif available_gb >= 1:
available_str = f"{available_gb:.1f}G"
else:
available_str = f"{available_gb*1024:.0f}M"
disk_info = {
"usage_percent": usage_percent,
"used": used,
"total": total,
"available": available_str,
"summary": f"Disk: {used} used of {total} total ({usage_percent} full, {available_str} available)"
}
info["disk"] = disk_info
except (ValueError, IndexError) as calc_error:
logger.warning(f"Failed to calculate disk available space: {calc_error}")
# Fallback: structured format without calculated available
disk_info = {
"usage_percent": usage_percent,
"used": used,
"total": total,
"summary": f"Disk: {used} used of {total} total ({usage_percent} full)"
}
info["disk"] = disk_info
except Exception as disk_error:
logger.warning(f"Failed to structure disk info: {disk_error}")
# Keep original format as fallback
return info
except Exception as e:
logger.error(f"Failed to get system info: {e}")
return {"error": str(e)}
def get_connection_status(self) -> Dict[str, Any]:
"""
Get current connection status.
Returns:
dict: Connection status information
"""
return {
"connected": self.connection is not None and not self.connection.is_closed(),
"host": self.config.TARGET_HOST,
"port": self.config.TARGET_PORT,
"user": self.config.TARGET_USER,
"connection_attempts": self.connection_attempts,
"last_keepalive": self.last_keepalive
}