"""
Advanced security features for MCP system
Includes authentication, authorization, rate limiting, and SQL injection protection
"""
import asyncio
import hashlib
import hmac
import json
import logging
import secrets
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import re
import jwt
from functools import wraps
logger = logging.getLogger(__name__)
class SecurityLevel(Enum):
"""Security levels for different operations"""
PUBLIC = "public"
AUTHENTICATED = "authenticated"
AUTHORIZED = "authorized"
ADMIN = "admin"
class Permission(Enum):
"""Available permissions"""
READ_SCHEMA = "read_schema"
EXECUTE_SQL = "execute_sql"
SEMANTIC_SEARCH = "semantic_search"
ADMIN_ACCESS = "admin_access"
CACHE_MANAGEMENT = "cache_management"
PERFORMANCE_MONITORING = "performance_monitoring"
@dataclass
class User:
"""User representation"""
user_id: str
username: str
email: str
permissions: Set[Permission]
roles: List[str] = field(default_factory=list)
api_key_hash: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
last_login: Optional[datetime] = None
is_active: bool = True
@dataclass
class APIKey:
"""API key representation"""
key_id: str
key_hash: str
user_id: str
name: str
permissions: Set[Permission]
created_at: datetime = field(default_factory=datetime.now)
last_used: Optional[datetime] = None
expires_at: Optional[datetime] = None
is_active: bool = True
@dataclass
class RateLimitRule:
"""Rate limiting rule"""
requests_per_window: int
window_seconds: int
burst_limit: Optional[int] = None
class SQLSecurityScanner:
"""Advanced SQL security scanner"""
def __init__(self):
# Dangerous SQL patterns
self.forbidden_patterns = [
# Data modification
r'\b(DROP|DELETE|INSERT|UPDATE|ALTER|CREATE|TRUNCATE)\b',
# System functions
r'\b(EXEC|EXECUTE|xp_|sp_)\b',
# File operations
r'\b(LOAD_FILE|INTO\s+OUTFILE|INTO\s+DUMPFILE)\b',
# Information schema abuse
r'\bunion\b.*\bselect\b.*\binformation_schema\b',
# Time-based attacks
r'\b(SLEEP|WAITFOR|DELAY)\b',
# Comment-based injection
r'(--|#|\/\*|\*\/)',
# Hex encoding attempts
r'0x[0-9a-fA-F]+',
# Stacked queries
r';\s*\w+',
]
# Allowed safe patterns
self.allowed_patterns = [
r'^\s*SELECT\b',
r'^\s*WITH\b.*SELECT\b',
r'^\s*EXPLAIN\b',
r'^\s*DESCRIBE\b',
r'^\s*SHOW\b',
]
# Compile patterns for performance
self.forbidden_regex = [re.compile(pattern, re.IGNORECASE) for pattern in self.forbidden_patterns]
self.allowed_regex = [re.compile(pattern, re.IGNORECASE) for pattern in self.allowed_patterns]
def scan_sql(self, sql: str) -> Dict[str, Any]:
"""Scan SQL for security issues"""
issues = []
risk_level = "low"
# Normalize SQL
normalized_sql = sql.strip().replace('\n', ' ').replace('\t', ' ')
# Check against forbidden patterns
for i, pattern in enumerate(self.forbidden_regex):
if pattern.search(normalized_sql):
issues.append({
'type': 'forbidden_pattern',
'pattern': self.forbidden_patterns[i],
'description': f'Potentially dangerous SQL pattern detected'
})
risk_level = "high"
# Check if SQL starts with allowed patterns
is_allowed = any(pattern.match(normalized_sql) for pattern in self.allowed_regex)
if not is_allowed:
issues.append({
'type': 'disallowed_statement',
'description': 'SQL statement type not in allowed list'
})
risk_level = "medium"
# Check for suspicious string concatenation
if re.search(r'["\'][\s]*\+[\s]*["\']', normalized_sql):
issues.append({
'type': 'string_concatenation',
'description': 'Potential SQL injection via string concatenation'
})
if risk_level == "low":
risk_level = "medium"
# Check for excessive UNION statements
union_count = len(re.findall(r'\bUNION\b', normalized_sql, re.IGNORECASE))
if union_count > 3:
issues.append({
'type': 'excessive_unions',
'description': f'Excessive UNION statements ({union_count}) detected'
})
if risk_level == "low":
risk_level = "medium"
return {
'is_safe': len(issues) == 0,
'risk_level': risk_level,
'issues': issues,
'normalized_sql': normalized_sql
}
class RateLimiter:
"""Advanced rate limiter with multiple strategies"""
def __init__(self):
self.windows: Dict[str, Dict[str, Any]] = {}
self.rules: Dict[str, RateLimitRule] = {}
def add_rule(self, key: str, rule: RateLimitRule):
"""Add rate limiting rule"""
self.rules[key] = rule
async def check_rate_limit(self, identifier: str, rule_key: str = "default") -> Dict[str, Any]:
"""Check if request is within rate limits"""
if rule_key not in self.rules:
return {'allowed': True, 'reason': 'no_rule'}
rule = self.rules[rule_key]
current_time = time.time()
window_key = f"{identifier}:{rule_key}"
# Initialize window if not exists
if window_key not in self.windows:
self.windows[window_key] = {
'requests': [],
'window_start': current_time
}
window = self.windows[window_key]
# Clean old requests outside window
window_cutoff = current_time - rule.window_seconds
window['requests'] = [req_time for req_time in window['requests'] if req_time > window_cutoff]
# Check rate limit
current_requests = len(window['requests'])
if current_requests >= rule.requests_per_window:
return {
'allowed': False,
'reason': 'rate_limit_exceeded',
'current_requests': current_requests,
'limit': rule.requests_per_window,
'window_seconds': rule.window_seconds,
'reset_time': min(window['requests']) + rule.window_seconds
}
# Record this request
window['requests'].append(current_time)
return {
'allowed': True,
'current_requests': current_requests + 1,
'limit': rule.requests_per_window,
'remaining': rule.requests_per_window - current_requests - 1
}
class SecurityManager:
"""Comprehensive security manager for MCP system"""
def __init__(self, config: Dict[str, Any] = None):
self.config = config or {}
# Components
self.sql_scanner = SQLSecurityScanner()
self.rate_limiter = RateLimiter()
# Storage
self.users: Dict[str, User] = {}
self.api_keys: Dict[str, APIKey] = {}
self.active_sessions: Dict[str, Dict[str, Any]] = {}
# Configuration
self.jwt_secret = self.config.get('jwt_secret', secrets.token_hex(32))
self.jwt_algorithm = self.config.get('jwt_algorithm', 'HS256')
self.jwt_expiry_hours = self.config.get('jwt_expiry_hours', 24)
# Setup default rate limits
self._setup_default_rate_limits()
def _setup_default_rate_limits(self):
"""Setup default rate limiting rules"""
self.rate_limiter.add_rule("api_general", RateLimitRule(
requests_per_window=self.config.get('rate_limit_general', 100),
window_seconds=60
))
self.rate_limiter.add_rule("api_sql", RateLimitRule(
requests_per_window=self.config.get('rate_limit_sql', 50),
window_seconds=60
))
self.rate_limiter.add_rule("api_admin", RateLimitRule(
requests_per_window=self.config.get('rate_limit_admin', 20),
window_seconds=60
))
def generate_api_key(self, user_id: str, name: str, permissions: Set[Permission]) -> str:
"""Generate a new API key"""
# Generate random key
api_key = f"mcp_{secrets.token_urlsafe(32)}"
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
# Create API key record
key_record = APIKey(
key_id=secrets.token_hex(8),
key_hash=key_hash,
user_id=user_id,
name=name,
permissions=permissions
)
# Store API key
self.api_keys[key_hash] = key_record
logger.info(f"Generated API key '{name}' for user {user_id}")
return api_key
def verify_api_key(self, api_key: str) -> Optional[APIKey]:
"""Verify API key and return key info"""
if not api_key.startswith("mcp_"):
return None
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
if key_hash not in self.api_keys:
return None
key_record = self.api_keys[key_hash]
# Check if key is active
if not key_record.is_active:
return None
# Check expiry
if key_record.expires_at and datetime.now() > key_record.expires_at:
return None
# Update last used
key_record.last_used = datetime.now()
return key_record
def create_jwt_token(self, user_id: str, permissions: List[str]) -> str:
"""Create JWT token for user"""
payload = {
'user_id': user_id,
'permissions': permissions,
'exp': datetime.now() + timedelta(hours=self.jwt_expiry_hours),
'iat': datetime.now()
}
return jwt.encode(payload, self.jwt_secret, algorithm=self.jwt_algorithm)
def verify_jwt_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token and return payload"""
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
return payload
except jwt.ExpiredSignatureError:
logger.warning("JWT token expired")
return None
except jwt.InvalidTokenError:
logger.warning("Invalid JWT token")
return None
async def authenticate_request(self, headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
"""Authenticate incoming request"""
# Try API key authentication
api_key = headers.get('X-API-Key') or headers.get('Authorization', '').replace('Bearer ', '')
if api_key:
key_record = self.verify_api_key(api_key)
if key_record:
return {
'method': 'api_key',
'user_id': key_record.user_id,
'permissions': list(key_record.permissions),
'key_id': key_record.key_id
}
# Try JWT authentication
auth_header = headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
token = auth_header[7:] # Remove 'Bearer ' prefix
payload = self.verify_jwt_token(token)
if payload:
return {
'method': 'jwt',
'user_id': payload['user_id'],
'permissions': payload['permissions']
}
return None
async def authorize_action(self, auth_info: Dict[str, Any], required_permission: Permission) -> bool:
"""Check if authenticated user has required permission"""
if not auth_info:
return False
user_permissions = auth_info.get('permissions', [])
# Convert string permissions to enum if needed
if isinstance(user_permissions[0], str):
user_permissions = [Permission(perm) for perm in user_permissions]
return required_permission in user_permissions
async def check_sql_security(self, sql: str) -> Dict[str, Any]:
"""Check SQL query for security issues"""
return self.sql_scanner.scan_sql(sql)
async def check_rate_limit(
self,
identifier: str,
rule_key: str = "api_general"
) -> Dict[str, Any]:
"""Check rate limit for identifier"""
return await self.rate_limiter.check_rate_limit(identifier, rule_key)
def security_middleware(self, required_permission: Permission = None, rate_limit_key: str = "api_general"):
"""Decorator for adding security to endpoints"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract request info (this would need to be adapted based on your framework)
# For FastAPI, you'd extract from request object
# This is a simplified example
request = kwargs.get('request') # Assuming request is passed as kwarg
if not request:
return {'error': 'No request context', 'status': 500}
# Get client identifier
client_ip = getattr(request, 'client', {}).get('host', 'unknown')
headers = getattr(request, 'headers', {})
# Check rate limit
rate_check = await self.check_rate_limit(client_ip, rate_limit_key)
if not rate_check['allowed']:
return {
'error': 'Rate limit exceeded',
'status': 429,
'details': rate_check
}
# Check authentication if required
if required_permission:
auth_info = await self.authenticate_request(headers)
if not auth_info:
return {'error': 'Authentication required', 'status': 401}
# Check authorization
if not await self.authorize_action(auth_info, required_permission):
return {'error': 'Insufficient permissions', 'status': 403}
# Add auth info to kwargs for the function
kwargs['auth_info'] = auth_info
# Call original function
return await func(*args, **kwargs)
return wrapper
return decorator
def create_user(
self,
username: str,
email: str,
permissions: Set[Permission],
roles: List[str] = None
) -> User:
"""Create a new user"""
user_id = secrets.token_hex(16)
user = User(
user_id=user_id,
username=username,
email=email,
permissions=permissions,
roles=roles or []
)
self.users[user_id] = user
logger.info(f"Created user '{username}' with ID {user_id}")
return user
def get_security_report(self) -> Dict[str, Any]:
"""Get comprehensive security report"""
now = datetime.now()
# API key statistics
active_keys = sum(1 for key in self.api_keys.values() if key.is_active)
expired_keys = sum(1 for key in self.api_keys.values()
if key.expires_at and key.expires_at < now)
# Recent activity
recent_activity = sum(1 for key in self.api_keys.values()
if key.last_used and (now - key.last_used).days < 7)
# Rate limit statistics
total_windows = len(self.rate_limiter.windows)
active_windows = sum(1 for window in self.rate_limiter.windows.values()
if window['requests'])
return {
'users': {
'total': len(self.users),
'active': sum(1 for user in self.users.values() if user.is_active)
},
'api_keys': {
'total': len(self.api_keys),
'active': active_keys,
'expired': expired_keys,
'recent_activity': recent_activity
},
'rate_limiting': {
'total_windows': total_windows,
'active_windows': active_windows,
'rules_configured': len(self.rate_limiter.rules)
},
'security_config': {
'jwt_expiry_hours': self.jwt_expiry_hours,
'rate_limits_enabled': len(self.rate_limiter.rules) > 0,
'sql_scanning_enabled': True
}
}
# Global security manager instance
_security_manager: Optional[SecurityManager] = None
def get_security_manager() -> SecurityManager:
"""Get global security manager instance"""
global _security_manager
if _security_manager is None:
_security_manager = SecurityManager()
return _security_manager
def setup_security(config: Dict[str, Any] = None) -> SecurityManager:
"""Setup and configure global security manager"""
global _security_manager
_security_manager = SecurityManager(config)
return _security_manager
# Convenience decorators
def require_permission(permission: Permission, rate_limit_key: str = "api_general"):
"""Decorator requiring specific permission"""
security_manager = get_security_manager()
return security_manager.security_middleware(permission, rate_limit_key)
def require_authentication(rate_limit_key: str = "api_general"):
"""Decorator requiring authentication"""
security_manager = get_security_manager()
return security_manager.security_middleware(None, rate_limit_key)
def sql_security_check(func):
"""Decorator for SQL security checking"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Look for SQL parameter
sql = kwargs.get('sql') or (args[0] if args else None)
if sql and isinstance(sql, str):
security_manager = get_security_manager()
security_check = await security_manager.check_sql_security(sql)
if not security_check['is_safe']:
return {
'success': False,
'error': 'SQL security check failed',
'security_issues': security_check['issues'],
'risk_level': security_check['risk_level']
}
return await func(*args, **kwargs)
return wrapper