Skip to main content
Glama
security_hardening.py21.5 kB
#!/usr/bin/env python3 """ Security hardening tools for MCP server. This module provides comprehensive security hardening including: - Role-Based Access Control (RBAC) - Rate limiting with sliding window - Circuit breaker pattern for resilience - Caching layer with TTL - Observability with metrics and telemetry """ import asyncio import hashlib import json import logging import time from collections import defaultdict, deque from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from urllib.parse import urlparse import aiohttp from aiohttp import web from config import Config from tools.intelligent_base import IntelligentToolBase, IntelligentToolContext class UserRole(Enum): """User roles for RBAC system.""" ADMIN = "admin" DEVELOPER = "developer" READONLY = "readonly" GUEST = "guest" class Permission(Enum): """Permissions for tool access.""" # File operations FILE_READ = "file:read" FILE_WRITE = "file:write" FILE_DELETE = "file:delete" # Git operations GIT_READ = "git:read" GIT_WRITE = "git:write" # API operations API_READ = "api:read" API_WRITE = "api:write" # Build operations BUILD_EXECUTE = "build:execute" # Security operations SECURITY_ENCRYPT = "security:encrypt" SECURITY_DECRYPT = "security:decrypt" # Admin operations ADMIN_CONFIG = "admin:config" ADMIN_AUDIT = "admin:audit" @dataclass class RBACPolicy: """RBAC policy definition.""" role: UserRole permissions: Set[Permission] = field(default_factory=set) resource_patterns: List[str] = field(default_factory=list) conditions: Dict[str, Any] = field(default_factory=dict) @dataclass class RateLimitConfig: """Rate limiting configuration.""" requests_per_window: int window_seconds: int burst_limit: Optional[int] = None @dataclass class CircuitBreakerConfig: """Circuit breaker configuration.""" failure_threshold: int recovery_timeout: float expected_exception: Tuple[Type[Exception], ...] = (Exception,) @dataclass class CacheEntry: """Cache entry with TTL.""" data: Any timestamp: float ttl: float def is_expired(self) -> bool: """Check if cache entry is expired.""" return time.time() - self.timestamp > self.ttl class MetricsCollector: """Collect and export metrics.""" def __init__(self) -> None: self.counters: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) self.histograms: Dict[str, Dict[str, List[Tuple[float, float]]]] = defaultdict( lambda: defaultdict(list) ) self.start_time = time.time() def increment_counter(self, name: str, labels: Optional[Dict[str, str]] = None) -> None: """Increment a counter metric.""" key = json.dumps(labels or {}, sort_keys=True) self.counters[name][key] += 1 def record_histogram( self, name: str, value: float, labels: Optional[Dict[str, str]] = None ) -> None: """Record a histogram value.""" key = json.dumps(labels or {}, sort_keys=True) self.histograms[name][key].append((time.time(), value)) def get_metrics(self) -> Dict[str, Any]: """Get all collected metrics.""" return {"counters": dict(self.counters), "histograms": dict(self.histograms)} class TelemetryExporter: """Export telemetry data.""" def __init__(self, endpoint: Optional[str] = None) -> None: self.endpoint = endpoint or Config.get_str("TELEMETRY_ENDPOINT", "") self.session: Optional[aiohttp.ClientSession] = None async def __aenter__(self) -> "TelemetryExporter": if self.endpoint: self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.session: await self.session.close() async def export(self, metrics: Dict[str, Any]) -> None: """Export metrics to telemetry endpoint.""" if not self.session or not self.endpoint: return try: await self.session.post( self.endpoint, json={ "timestamp": datetime.now().isoformat(), "metrics": metrics, "service": "kotlin-mcp-server", }, headers={"Content-Type": "application/json"}, ) except Exception as e: logging.warning(f"Failed to export telemetry: {e}") class RBACManager: """Role-Based Access Control manager.""" def __init__(self) -> None: self.policies: Dict[UserRole, RBACPolicy] = {} self.user_roles: Dict[str, UserRole] = {} self._load_default_policies() def _load_default_policies(self) -> None: """Load default RBAC policies.""" # Admin policy self.policies[UserRole.ADMIN] = RBACPolicy( role=UserRole.ADMIN, permissions=set(Permission), resource_patterns=["*"] ) # Developer policy self.policies[UserRole.DEVELOPER] = RBACPolicy( role=UserRole.DEVELOPER, permissions={ Permission.FILE_READ, Permission.FILE_WRITE, Permission.GIT_READ, Permission.GIT_WRITE, Permission.API_READ, Permission.API_WRITE, Permission.BUILD_EXECUTE, Permission.SECURITY_ENCRYPT, Permission.SECURITY_DECRYPT, }, resource_patterns=["src/**", "tests/**", "*.kt", "*.java", "*.gradle"], ) # Read-only policy self.policies[UserRole.READONLY] = RBACPolicy( role=UserRole.READONLY, permissions={Permission.FILE_READ, Permission.GIT_READ, Permission.API_READ}, resource_patterns=["**"], ) # Guest policy self.policies[UserRole.GUEST] = RBACPolicy( role=UserRole.GUEST, permissions={Permission.FILE_READ}, resource_patterns=["README.md", "docs/**"], ) def assign_role(self, user_id: str, role: UserRole) -> None: """Assign role to user.""" self.user_roles[user_id] = role def check_permission(self, user_id: str, permission: Permission, resource: str = "") -> bool: """Check if user has permission for resource.""" role = self.user_roles.get(user_id, UserRole.GUEST) policy = self.policies.get(role) if not policy: return False # Check permission if permission not in policy.permissions: return False # Check resource pattern if resource and policy.resource_patterns != ["*"]: if not any( self._matches_pattern(resource, pattern) for pattern in policy.resource_patterns ): return False return True def _matches_pattern(self, resource: str, pattern: str) -> bool: """Check if resource matches pattern (simple glob matching).""" # Simple implementation - could be enhanced with fnmatch if pattern == "*": return True if pattern.endswith("/**"): prefix = pattern[:-3] return resource.startswith(prefix) return resource == pattern class SlidingWindowRateLimiter: """Sliding window rate limiter.""" def __init__(self, config: RateLimitConfig): self.requests_per_window = config.requests_per_window self.window_seconds = config.window_seconds self.burst_limit = config.burst_limit or config.requests_per_window * 2 self.requests: Dict[str, deque] = defaultdict(lambda: deque(maxlen=self.burst_limit)) def is_allowed(self, key: str) -> bool: """Check if request is allowed under rate limit.""" now = time.time() request_times = self.requests[key] # Remove old requests outside the window while request_times and now - request_times[0] > self.window_seconds: request_times.popleft() # Check if under limit if len(request_times) >= self.requests_per_window: return False # Add current request request_times.append(now) return True def get_remaining_requests(self, key: str) -> int: """Get remaining requests in current window.""" now = time.time() request_times = self.requests[key] # Clean old requests while request_times and now - request_times[0] > self.window_seconds: request_times.popleft() return max(0, self.requests_per_window - len(request_times)) class CircuitBreaker: """Circuit breaker implementation.""" class State(Enum): CLOSED = "closed" OPEN = "open" HALF_OPEN = "half_open" def __init__(self, config: CircuitBreakerConfig) -> None: self.failure_threshold = config.failure_threshold self.recovery_timeout = config.recovery_timeout self.expected_exception = config.expected_exception self.state = self.State.CLOSED self.failure_count = 0 self.last_failure_time: Optional[float] = None def __call__(self, func: Any) -> Any: """Decorator to apply circuit breaker to function.""" async def wrapper(*args: Any, **kwargs: Any) -> Any: if self.state == self.State.OPEN: if self._should_attempt_reset(): self.state = self.State.HALF_OPEN else: raise Exception("Circuit breaker is OPEN") try: result = await func(*args, **kwargs) self._on_success() return result except self.expected_exception as e: self._on_failure() raise e return wrapper def _should_attempt_reset(self) -> bool: """Check if we should attempt to reset the circuit.""" if self.last_failure_time is None: return True return time.time() - self.last_failure_time >= self.recovery_timeout def _on_success(self) -> None: """Handle successful call.""" if self.state == self.State.HALF_OPEN: self.state = self.State.CLOSED self.failure_count = 0 def _on_failure(self) -> None: """Handle failed call.""" self.failure_count += 1 self.last_failure_time = time.time() if self.failure_count >= self.failure_threshold: self.state = self.State.OPEN class TTLCache: """Time-To-Live cache implementation.""" def __init__(self, default_ttl: float = 300.0) -> None: # 5 minutes default self.cache: Dict[str, CacheEntry] = {} self.default_ttl = default_ttl def get(self, key: str) -> Optional[Any]: """Get value from cache if not expired.""" entry = self.cache.get(key) if entry and not entry.is_expired(): return entry.data elif entry: del self.cache[key] # Remove expired entry return None def set(self, key: str, value: Any, ttl: Optional[float] = None) -> None: """Set value in cache with TTL.""" ttl = ttl or self.default_ttl self.cache[key] = CacheEntry(data=value, timestamp=time.time(), ttl=ttl) def delete(self, key: str) -> None: """Delete key from cache.""" self.cache.pop(key, None) def clear(self) -> None: """Clear all cache entries.""" self.cache.clear() def cleanup_expired(self) -> None: """Remove expired entries.""" expired_keys = [k for k, v in self.cache.items() if v.is_expired()] for key in expired_keys: del self.cache[key] class SecurityHardeningManager: """Main security hardening manager.""" def __init__(self): self.rbac = RBACManager() self.rate_limiter = SlidingWindowRateLimiter( RateLimitConfig( requests_per_window=Config.get_int("RATE_LIMIT_REQUESTS", 100), window_seconds=Config.get_int("RATE_LIMIT_WINDOW", 60), burst_limit=Config.get_int("RATE_LIMIT_BURST", 200), ) ) self.circuit_breaker = CircuitBreaker( CircuitBreakerConfig( failure_threshold=Config.get_int("CIRCUIT_BREAKER_THRESHOLD", 5), recovery_timeout=Config.get_float("CIRCUIT_BREAKER_TIMEOUT", 60.0), ) ) self.cache = TTLCache(default_ttl=Config.get_float("CACHE_DEFAULT_TTL", 300.0)) self.metrics = MetricsCollector() self.telemetry = TelemetryExporter() async def check_access( self, user_id: str, tool_name: str, permission: Permission, resource: str = "" ) -> bool: """Check if user has access to perform operation.""" # Record access attempt self.metrics.increment_counter( "access_attempts", {"user_id": user_id, "tool": tool_name, "permission": permission.value}, ) # Check rate limit if not self.rate_limiter.is_allowed(user_id): self.metrics.increment_counter("rate_limit_exceeded", {"user_id": user_id}) return False # Check RBAC allowed = self.rbac.check_permission(user_id, permission, resource) if not allowed: self.metrics.increment_counter( "access_denied", {"user_id": user_id, "tool": tool_name, "reason": "rbac"} ) return allowed def get_cache_key(self, tool_name: str, args: Dict[str, Any]) -> str: """Generate cache key from tool name and arguments.""" key_data = {"tool": tool_name, **args} key_str = json.dumps(key_data, sort_keys=True) return hashlib.sha256(key_str.encode()).hexdigest() async def get_cached_result(self, cache_key: str) -> Optional[Any]: """Get cached result if available.""" return self.cache.get(cache_key) async def cache_result(self, cache_key: str, result: Any, ttl: Optional[float] = None): """Cache result with TTL.""" self.cache.set(cache_key, result, ttl) async def execute_with_hardening( self, user_id: str, tool_name: str, permission: Permission, resource: str, operation_func, *args, **kwargs, ) -> Any: """Execute operation with full hardening applied.""" start_time = time.time() try: # Check access if not await self.check_access(user_id, tool_name, permission, resource): raise PermissionError(f"Access denied for user {user_id}") # Check cache for read operations cache_key = None if permission in [Permission.FILE_READ, Permission.GIT_READ, Permission.API_READ]: cache_key = self.get_cache_key(tool_name, kwargs) cached_result = await self.get_cached_result(cache_key) if cached_result is not None: self.metrics.increment_counter("cache_hit", {"tool": tool_name}) return cached_result self.metrics.increment_counter("cache_miss", {"tool": tool_name}) # Execute with circuit breaker @self.circuit_breaker async def execute(): return await operation_func(*args, **kwargs) result = await execute() # Cache result for read operations if cache_key and result is not None: await self.cache_result(cache_key, result) # Record success metrics execution_time = time.time() - start_time self.metrics.record_histogram( "operation_duration", execution_time, {"tool": tool_name, "status": "success"} ) self.metrics.increment_counter("operations_completed", {"tool": tool_name}) return result except Exception as e: # Record failure metrics execution_time = time.time() - start_time self.metrics.record_histogram( "operation_duration", execution_time, {"tool": tool_name, "status": "error", "error_type": type(e).__name__}, ) self.metrics.increment_counter("operations_failed", {"tool": tool_name}) raise e async def export_telemetry(self): """Export collected metrics via telemetry.""" async with self.telemetry: await self.telemetry.export(self.metrics.get_metrics()) def get_metrics_summary(self) -> Dict[str, Any]: """Get metrics summary for monitoring.""" metrics_data = self.metrics.get_metrics() return { "uptime_seconds": time.time() - self.metrics.start_time, "total_operations": sum( sum(counts.values()) for counts in metrics_data["counters"].get("operations_completed", {}).values() ), "total_failures": sum( sum(counts.values()) for counts in metrics_data["counters"].get("operations_failed", {}).values() ), "cache_hit_rate": self._calculate_cache_hit_rate(), "rate_limit_exceeded": sum( sum(counts.values()) for counts in metrics_data["counters"].get("rate_limit_exceeded", {}).values() ), } def _calculate_cache_hit_rate(self) -> float: """Calculate cache hit rate.""" metrics_data = self.metrics.get_metrics() hits = sum( sum(counts.values()) for counts in metrics_data["counters"].get("cache_hit", {}).values() ) misses = sum( sum(counts.values()) for counts in metrics_data["counters"].get("cache_miss", {}).values() ) total = hits + misses return hits / total if total > 0 else 0.0 # Global hardening manager instance hardening_manager = SecurityHardeningManager() class HardeningTool(IntelligentToolBase): """Tool for managing security hardening features.""" def __init__(self, project_path: str): super().__init__(project_path) self.name = "security_hardening" self.description = ( "Manage security hardening features including RBAC, rate limiting, and monitoring" ) async def _execute_core_functionality( self, context: IntelligentToolContext, arguments: Dict[str, Any] ) -> Any: """Execute the core functionality of the hardening tool.""" operation = arguments.get("operation") if operation == "get_metrics": return { "metrics": hardening_manager.get_metrics_summary(), "timestamp": datetime.now().isoformat(), } elif operation == "assign_role": user_id = arguments.get("user_id") role_name = arguments.get("role") if not isinstance(user_id, str) or not isinstance(role_name, str): return {"status": "error", "message": "user_id and role must be strings"} try: role = UserRole(role_name) hardening_manager.rbac.assign_role(user_id, role) return {"status": "success", "user_id": user_id, "role": role_name} except ValueError as e: return {"status": "error", "message": f"Invalid role: {role_name}"} elif operation == "check_permission": user_id = arguments.get("user_id") permission_name = arguments.get("permission") resource = arguments.get("resource", "") if not isinstance(user_id, str) or not isinstance(permission_name, str): return {"status": "error", "message": "user_id and permission must be strings"} try: permission = Permission(permission_name) allowed = hardening_manager.rbac.check_permission(user_id, permission, resource) return { "user_id": user_id, "permission": permission_name, "resource": resource, "allowed": allowed, } except ValueError as e: return {"status": "error", "message": f"Invalid permission: {permission_name}"} elif operation == "clear_cache": hardening_manager.cache.clear() return {"status": "success", "message": "Cache cleared"} elif operation == "export_telemetry": await hardening_manager.export_telemetry() return {"status": "success", "message": "Telemetry exported"} else: return { "status": "error", "message": f"Unknown operation: {operation}", "available_operations": [ "get_metrics", "assign_role", "check_permission", "clear_cache", "export_telemetry", ], } # Export the hardening manager for use by other modules __all__ = ["hardening_manager", "HardeningTool", "UserRole", "Permission"]

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/normaltusker/kotlin-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server