main.py•12.7 kB
"""
SSH MCP Server with Advanced Session Management
"""
import asyncio
import asyncssh
import uuid
import time
from typing import Dict, Optional, Any
from mcp.server.fastmcp import FastMCP
import logging
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SSHSessionManager:
"""SSH 세션 관리자"""
def __init__(self, max_sessions: int = 10, session_timeout: int = 300):
self.connections: Dict[str, asyncssh.SSHClientConnection] = {}
self.session_metadata: Dict[str, Dict[str, Any]] = {}
self.connection_tasks: Dict[str, asyncio.Task] = {}
self.max_sessions = max_sessions
self.session_timeout = session_timeout
async def create_session(self, host: str, username: str, password: str,
port: int = 22, session_name: Optional[str] = None) -> str:
"""새 SSH 세션 생성"""
# 세션 수 제한 확인
if len(self.connections) >= self.max_sessions:
await self._cleanup_oldest_session()
# 세션 ID 생성
session_id = session_name or f"ssh_{uuid.uuid4().hex[:8]}"
try:
# SSH 연결 생성
conn = await asyncssh.connect(
host=host,
port=port,
username=username,
password=password,
known_hosts=None, # 개발용 - 운영에서는 적절한 호스트 키 검증 필요
client_keys=None,
passphrase=None
)
# 세션 저장
self.connections[session_id] = conn
self.session_metadata[session_id] = {
'host': host,
'port': port,
'username': username,
'created_at': time.time(),
'last_used': time.time(),
'command_count': 0
}
# 연결 모니터링 태스크 시작
self.connection_tasks[session_id] = asyncio.create_task(
self._monitor_session(session_id)
)
logger.info(f"SSH session created: {session_id} -> {username}@{host}:{port}")
return session_id
except Exception as e:
logger.error(f"Failed to create SSH session: {e}")
raise Exception(f"SSH connection failed: {str(e)}")
async def execute_command(self, session_id: str, command: str,
timeout: int = 30) -> Dict[str, Any]:
"""세션에서 명령 실행"""
if session_id not in self.connections:
raise Exception(f"Session '{session_id}' not found")
conn = self.connections[session_id]
metadata = self.session_metadata[session_id]
try:
# 명령 실행
result = await asyncio.wait_for(
conn.run(command, check=False),
timeout=timeout
)
# 메타데이터 업데이트
metadata['last_used'] = time.time()
metadata['command_count'] += 1
return {
'session_id': session_id,
'command': command,
'stdout': result.stdout,
'stderr': result.stderr,
'exit_status': result.exit_status,
'execution_time': time.time()
}
except asyncio.TimeoutError:
raise Exception(f"Command timeout after {timeout} seconds")
except Exception as e:
logger.error(f"Command execution failed in session {session_id}: {e}")
raise Exception(f"Command execution failed: {str(e)}")
async def get_session_info(self, session_id: str) -> Dict[str, Any]:
"""세션 정보 조회"""
if session_id not in self.connections:
raise Exception(f"Session '{session_id}' not found")
metadata = self.session_metadata[session_id]
conn = self.connections[session_id]
return {
'session_id': session_id,
'host': metadata['host'],
'port': metadata['port'],
'username': metadata['username'],
'created_at': metadata['created_at'],
'last_used': metadata['last_used'],
'command_count': metadata['command_count'],
'is_active': hasattr(conn, '_conn') and conn._conn is not None,
'uptime': time.time() - metadata['created_at']
}
async def list_sessions(self) -> Dict[str, Dict[str, Any]]:
"""모든 세션 목록 조회"""
sessions = {}
for session_id in list(self.connections.keys()):
try:
sessions[session_id] = await self.get_session_info(session_id)
except Exception as e:
logger.warning(f"Failed to get info for session {session_id}: {e}")
# 문제가 있는 세션은 정리
await self.close_session(session_id)
return sessions
async def close_session(self, session_id: str) -> bool:
"""특정 세션 종료"""
if session_id not in self.connections:
return False
try:
# 연결 종료
conn = self.connections[session_id]
conn.close()
# 모니터링 태스크 종료
if session_id in self.connection_tasks:
self.connection_tasks[session_id].cancel()
del self.connection_tasks[session_id]
# 메타데이터 정리
del self.connections[session_id]
del self.session_metadata[session_id]
logger.info(f"SSH session closed: {session_id}")
return True
except Exception as e:
logger.error(f"Error closing session {session_id}: {e}")
return False
async def close_all_sessions(self) -> int:
"""모든 세션 종료"""
session_ids = list(self.connections.keys())
closed_count = 0
for session_id in session_ids:
if await self.close_session(session_id):
closed_count += 1
return closed_count
async def _monitor_session(self, session_id: str):
"""세션 상태 모니터링"""
while session_id in self.connections:
try:
conn = self.connections[session_id]
metadata = self.session_metadata[session_id]
# 연결 상태 확인
if hasattr(conn, '_conn') and conn._conn is None:
logger.info(f"Session {session_id} connection closed")
await self.close_session(session_id)
break
# 타임아웃 확인
if time.time() - metadata['last_used'] > self.session_timeout:
logger.info(f"Session {session_id} timed out")
await self.close_session(session_id)
break
await asyncio.sleep(10) # 10초마다 확인
except Exception as e:
logger.error(f"Error monitoring session {session_id}: {e}")
await self.close_session(session_id)
break
async def _cleanup_oldest_session(self):
"""가장 오래된 세션 정리"""
if not self.session_metadata:
return
# 가장 오래 사용되지 않은 세션 찾기
oldest_session = min(
self.session_metadata.items(),
key=lambda x: x[1]['last_used']
)[0]
await self.close_session(oldest_session)
logger.info(f"Cleaned up oldest session: {oldest_session}")
# MCP 서버 생성
mcp = FastMCP("SSH Session Manager")
ssh_manager = SSHSessionManager(max_sessions=5, session_timeout=600) # 10분 타임아웃
@mcp.tool()
async def ssh_connect(host: str, username: str, password: str,
port: int = 22, session_name: str = None) -> str:
"""SSH 서버에 연결하여 새 세션 생성"""
try:
session_id = await ssh_manager.create_session(
host=host,
username=username,
password=password,
port=port,
session_name=session_name
)
return f"SSH session created: {session_id}"
except Exception as e:
return f"Connection failed: {str(e)}"
@mcp.tool()
async def ssh_execute(session_id: str, command: str, timeout: int = 30) -> str:
"""SSH 세션에서 명령 실행"""
try:
result = await ssh_manager.execute_command(session_id, command, timeout)
output = []
output.append(f"Session: {session_id}")
output.append(f"Command: {command}")
output.append(f"Exit Status: {result['exit_status']}")
if result['stdout']:
output.append("STDOUT:")
output.append(result['stdout'])
if result['stderr']:
output.append("STDERR:")
output.append(result['stderr'])
return "\n".join(output)
except Exception as e:
return f"Command execution failed: {str(e)}"
@mcp.tool()
async def ssh_list_sessions() -> str:
"""모든 활성 SSH 세션 목록 조회"""
try:
sessions = await ssh_manager.list_sessions()
if not sessions:
return "No active SSH sessions"
output = ["Active SSH Sessions:"]
output.append("-" * 50)
for session_id, info in sessions.items():
uptime_min = int(info['uptime'] / 60)
output.append(f"- {session_id}")
output.append(f" Host: {info['username']}@{info['host']}:{info['port']}")
output.append(f" Status: {'Active' if info['is_active'] else 'Inactive'}")
output.append(f" Uptime: {uptime_min} minutes")
output.append(f" Commands: {info['command_count']}")
output.append("")
return "\n".join(output)
except Exception as e:
return f"Failed to list sessions: {str(e)}"
@mcp.tool()
async def ssh_close_session(session_id: str) -> str:
"""특정 SSH 세션 종료"""
try:
success = await ssh_manager.close_session(session_id)
if success:
return f"Session '{session_id}' closed successfully"
else:
return f"Session '{session_id}' not found"
except Exception as e:
return f"Failed to close session: {str(e)}"
@mcp.tool()
async def ssh_close_all_sessions() -> str:
"""모든 SSH 세션 종료"""
try:
closed_count = await ssh_manager.close_all_sessions()
return f"Closed {closed_count} SSH sessions"
except Exception as e:
return f"Failed to close sessions: {str(e)}"
@mcp.tool()
async def ssh_session_info(session_id: str) -> str:
"""특정 SSH 세션의 상세 정보 조회"""
try:
info = await ssh_manager.get_session_info(session_id)
uptime_min = int(info['uptime'] / 60)
last_used_min = int((time.time() - info['last_used']) / 60)
output = [
f"SSH Session Info: {session_id}",
"-" * 40,
f"Host: {info['username']}@{info['host']}:{info['port']}",
f"Status: {'Active' if info['is_active'] else 'Inactive'}",
f"Created: {time.ctime(info['created_at'])}",
f"Last Used: {last_used_min} minutes ago",
f"Uptime: {uptime_min} minutes",
f"Commands Executed: {info['command_count']}"
]
return "\n".join(output)
except Exception as e:
return f"Failed to get session info: {str(e)}"
# 서버 시작 시 정리 작업 (FastMCP에서는 on_startup 지원하지 않음)
# @mcp.on_startup()
# async def startup_cleanup():
# """서버 시작 시 기존 세션 정리"""
# logger.info("SSH MCP Server starting up...")
# 서버 종료 시 정리 작업 (FastMCP에서는 on_shutdown 지원하지 않음)
# @mcp.on_shutdown()
# async def shutdown_cleanup():
# """서버 종료 시 모든 세션 정리"""
# logger.info("Shutting down SSH MCP Server...")
# await ssh_manager.close_all_sessions()
# logger.info("All SSH sessions closed")
if __name__ == "__main__":
mcp.run()