Skip to main content
Glama
enkryptai

Enkrypt AI Secure MCP Gateway

Official
by enkryptai
config_manager.py21.6 kB
"""Authentication configuration manager.""" import time from typing import Any, Dict, List, Optional, Tuple from mcp.server.fastmcp import Context from secure_mcp_gateway.plugins.auth.base import ( AuthCredentials, AuthProvider, AuthProviderRegistry, AuthResult, SessionData, ) from secure_mcp_gateway.utils import logger class AuthConfigManager: """ Manages authentication configuration and provider instantiation. """ def __init__(self): """Initialize the auth config manager.""" self.registry = AuthProviderRegistry() self.sessions: Dict[str, SessionData] = {} self.default_provider = "enkrypt" # Import cache service from secure_mcp_gateway.services.cache.cache_service import cache_service self.cache_service = cache_service def register_provider(self, provider: AuthProvider) -> None: """ Register an authentication provider. Args: provider: Provider to register """ self.registry.register(provider) logger.info(f"Registered auth provider: {provider.get_name()}") def unregister_provider(self, name: str) -> None: """ Unregister a provider. Args: name: Provider name """ self.registry.unregister(name) logger.info(f"Unregistered auth provider: {name}") def get_provider(self, name: Optional[str] = None) -> Optional[AuthProvider]: """ Get the registered provider. Args: name: Provider name (for compatibility, but ignored since only one provider) Returns: Optional[AuthProvider]: Provider if found """ return self.registry.get_provider(name) def list_providers(self) -> List[str]: """ List all registered providers. Returns: List[str]: Provider names """ return self.registry.list_providers() def extract_credentials(self, ctx: Context) -> AuthCredentials: """ Extract credentials from MCP context. Args: ctx: MCP context Returns: AuthCredentials: Extracted credentials """ credentials = AuthCredentials() # Extract from request headers (for streamable-http) if ctx and ctx.request_context and ctx.request_context.request: headers = ctx.request_context.request.headers credentials.api_key = headers.get("apikey") credentials.gateway_key = headers.get("ENKRYPT_GATEWAY_KEY") or headers.get( "apikey" ) credentials.project_id = headers.get("project_id") credentials.user_id = headers.get("user_id") credentials.access_token = headers.get("Authorization", "").replace( "Bearer ", "" ) credentials.username = headers.get("username") credentials.password = headers.get("password") # Mask sensitive headers before storing from secure_mcp_gateway.utils import mask_sensitive_headers credentials.headers = mask_sensitive_headers(dict(headers)) # Fallback to environment variables import os if not credentials.gateway_key: credentials.gateway_key = os.environ.get("ENKRYPT_GATEWAY_KEY") if not credentials.project_id: credentials.project_id = os.environ.get("ENKRYPT_PROJECT_ID") if not credentials.user_id: credentials.user_id = os.environ.get("ENKRYPT_USER_ID") return credentials async def authenticate( self, ctx: Context, provider_name: Optional[str] = None ) -> AuthResult: """ Authenticate a request using the specified provider with cache integration. Args: ctx: MCP context provider_name: Provider to use (None for default) Returns: AuthResult: Authentication result """ logger.info("[AuthConfigManager] Starting authentication") # Extract credentials credentials = self.extract_credentials(ctx) gateway_key = credentials.gateway_key or credentials.api_key project_id = credentials.project_id user_id = credentials.user_id # Validate credentials if not gateway_key: return AuthResult( status="error", authenticated=False, message="Gateway key is required", error="Missing gateway_key", ) # Get local config to find mcp_config_id local_config = await self.get_local_mcp_config(gateway_key, project_id, user_id) if not local_config: return AuthResult( status="error", authenticated=False, message="No configuration found", error="Configuration not found", ) mcp_config_id = local_config.get("mcp_config_id") if not mcp_config_id: return AuthResult( status="error", authenticated=False, message="No MCP config ID found", error="Missing mcp_config_id", ) # Create session key session_key = self.create_session_key( gateway_key, project_id, user_id, mcp_config_id ) # Check if already authenticated in session if self.is_session_authenticated(session_key): logger.info("[AuthConfigManager] Already authenticated in session") session = self.sessions[session_key] return AuthResult( status="success", authenticated=True, message="Already authenticated (session)", user_id=session.user_id, project_id=session.project_id, session_id=session_key, gateway_config=session.gateway_config, mcp_config=session.gateway_config.get("mcp_config", []), metadata={"source": "session"}, ) # Check cache id = local_config.get("id") if id: cached_config = self.cache_service.get_cached_gateway_config(id) if cached_config: logger.info(f"[AuthConfigManager] Found cached config for ID: {id}") self.create_session(session_key, cached_config) return AuthResult( status="success", authenticated=True, message="Authentication successful (cache)", user_id=cached_config.get("user_id"), project_id=cached_config.get("project_id"), session_id=session_key, gateway_config=cached_config, mcp_config=cached_config.get("mcp_config", []), metadata={"source": "cache"}, ) # Get provider and authenticate provider = self.get_provider(provider_name) if not provider: return AuthResult( status="error", authenticated=False, message=f"Provider '{provider_name or self.default_provider}' not found", error="Provider not registered", ) # Authenticate with provider result = await provider.authenticate(credentials) # Cache and create session if successful if result.is_success: # Cache gateway config if id and result.gateway_config: self.cache_service.cache_gateway_config(id, result.gateway_config) # Create session self.create_session(session_key, result.gateway_config) # Update result with session info result.session_id = session_key result.metadata = result.metadata or {} result.metadata["session_created"] = True return result def _create_session(self, auth_result: AuthResult) -> SessionData: """ Create a session from authentication result. Args: auth_result: Authentication result Returns: SessionData: Created session """ session_id = auth_result.session_id or self._generate_session_id(auth_result) return SessionData( session_id=session_id, user_id=auth_result.user_id, project_id=auth_result.project_id, authenticated=True, created_at=time.time(), last_accessed=time.time(), gateway_config=auth_result.gateway_config, metadata=auth_result.metadata, ) def _generate_session_id(self, auth_result: AuthResult) -> str: """Generate a unique session ID.""" import hashlib data = f"{auth_result.user_id}_{auth_result.project_id}_{time.time()}" return hashlib.sha256(data.encode()).hexdigest() def get_session(self, session_id: str) -> Optional[SessionData]: """ Get session data. Args: session_id: Session ID Returns: Optional[SessionData]: Session data if exists """ session = self.sessions.get(session_id) if session: session.last_accessed = time.time() return session def delete_session(self, session_id: str) -> bool: """ Delete a session. Args: session_id: Session ID Returns: bool: True if deleted """ if session_id in self.sessions: del self.sessions[session_id] return True return False def cleanup_expired_sessions(self, max_age_hours: int = 24) -> int: """ Clean up expired sessions. Args: max_age_hours: Maximum session age in hours Returns: int: Number of sessions cleaned up """ current_time = time.time() max_age_seconds = max_age_hours * 3600 expired_keys = [] for session_id, session in self.sessions.items(): age = current_time - session.created_at if age > max_age_seconds: expired_keys.append(session_id) for key in expired_keys: del self.sessions[key] return len(expired_keys) def get_session_stats(self) -> Dict[str, Any]: """ Get session statistics. Returns: Dict[str, Any]: Session stats """ total = len(self.sessions) authenticated = sum(1 for s in self.sessions.values() if s.authenticated) return { "total_sessions": total, "authenticated_sessions": authenticated, "unauthenticated_sessions": total - authenticated, "providers": self.list_providers(), } # ======================================================================== # BACKWARD-COMPATIBLE METHODS (matching auth_service API) # ======================================================================== def get_gateway_credentials(self, ctx: Context) -> Dict[str, str]: """ Backward-compatible method matching auth_service.get_gateway_credentials() Returns dict with keys: gateway_key, project_id, user_id """ creds = self.extract_credentials(ctx) return { "gateway_key": creds.gateway_key or creds.api_key, "project_id": creds.project_id, "user_id": creds.user_id, } async def get_local_mcp_config( self, gateway_key: str, project_id: str = None, user_id: str = None ) -> Dict[str, Any]: """ Backward-compatible method matching auth_service.get_local_mcp_config() Delegates to EnkryptAuthProvider._get_local_config() """ provider = self.get_provider("enkrypt") if not provider or not hasattr(provider, "_get_local_config"): return {} return await provider._get_local_config(gateway_key, project_id, user_id) def create_session_key( self, gateway_key: str, project_id: str, user_id: str, mcp_config_id: str ) -> str: """ Backward-compatible method for creating session keys. """ return f"{gateway_key}_{project_id}_{user_id}_{mcp_config_id}" def is_session_authenticated(self, session_key: str) -> bool: """ Backward-compatible method for checking session authentication. """ session = self.sessions.get(session_key) return session is not None and session.authenticated def create_session(self, session_key: str, gateway_config: Dict[str, Any]) -> None: """ Backward-compatible session creation. """ if session_key not in self.sessions: self.sessions[session_key] = SessionData( session_id=session_key, user_id=gateway_config.get("user_id", ""), project_id=gateway_config.get("project_id"), authenticated=True, created_at=time.time(), last_accessed=time.time(), gateway_config=gateway_config, metadata={}, ) else: # Update existing session self.sessions[session_key].authenticated = True self.sessions[session_key].gateway_config = gateway_config self.sessions[session_key].last_accessed = time.time() async def is_authenticated(self, ctx: Context) -> bool: """ Backward-compatible authentication check. """ credentials = self.extract_credentials(ctx) gateway_key = credentials.gateway_key or credentials.api_key project_id = credentials.project_id user_id = credentials.user_id if not all([gateway_key, project_id, user_id]): return False # Get MCP config to get mcp_config_id local_config = await self.get_local_mcp_config(gateway_key, project_id, user_id) if not local_config: return False mcp_config_id = local_config.get("mcp_config_id") if not mcp_config_id: return False session_key = self.create_session_key( gateway_key, project_id, user_id, mcp_config_id ) return self.is_session_authenticated(session_key) def require_authentication(self, ctx: Context) -> Tuple[bool, Dict[str, Any]]: """ Backward-compatible authentication requirement check. Returns: Tuple[bool, Dict]: (is_authenticated, auth_result) """ if self.is_authenticated(ctx): return True, {"status": "success", "message": "Already authenticated"} # Use async authenticate in sync context import asyncio auth_result = asyncio.run(self.authenticate(ctx)) return auth_result.is_success, { "status": auth_result.status.value, "message": auth_result.message, "error": auth_result.error, } async def get_authenticated_session(self, ctx: Context) -> Optional[SessionData]: """ Backward-compatible authenticated session retrieval. """ credentials = self.extract_credentials(ctx) gateway_key = credentials.gateway_key or credentials.api_key project_id = credentials.project_id user_id = credentials.user_id if not all([gateway_key, project_id, user_id]): return None local_config = await self.get_local_mcp_config(gateway_key, project_id, user_id) if not local_config: return None mcp_config_id = local_config.get("mcp_config_id") if not mcp_config_id: return None session_key = self.create_session_key( gateway_key, project_id, user_id, mcp_config_id ) return self.get_session(session_key) async def clear_session(self, ctx: Context) -> bool: """ Backward-compatible session clearing. """ credentials = self.extract_credentials(ctx) gateway_key = credentials.gateway_key or credentials.api_key project_id = credentials.project_id user_id = credentials.user_id if not all([gateway_key, project_id, user_id]): return False local_config = await self.get_local_mcp_config(gateway_key, project_id, user_id) if not local_config: return False mcp_config_id = local_config.get("mcp_config_id") if not mcp_config_id: return False session_key = self.create_session_key( gateway_key, project_id, user_id, mcp_config_id ) return self.delete_session(session_key) async def get_session_gateway_config_key_suffix( self, credentials: Dict[str, Any] ) -> str: """ Backward-compatible config key suffix extraction. """ try: gateway_key = credentials.get("gateway_key") project_id = credentials.get("project_id") user_id = credentials.get("user_id") local_cfg = await self.get_local_mcp_config( gateway_key, project_id, user_id ) if not local_cfg: return "not_provided" return local_cfg.get("mcp_config_id", "not_provided") except Exception: return "not_provided" def get_session_gateway_config(self, session_key: str) -> Dict[str, Any]: """ Backward-compatible gateway config retrieval from session. """ session = self.get_session(session_key) if not session: raise ValueError(f"Session {session_key} not found") if not session.authenticated: raise ValueError(f"Session {session_key} not authenticated") if not session.gateway_config: raise ValueError(f"Session {session_key} has no gateway configuration") return session.gateway_config # ============================================================================ # Response Format Conversion Utilities # ============================================================================ def convert_auth_result_to_legacy_format(auth_result: AuthResult) -> Dict[str, Any]: """ Convert new AuthResult to legacy dict format for backward compatibility. Args: auth_result: New format AuthResult Returns: Dict in legacy format """ if auth_result.is_success: return { "status": "success", "message": auth_result.message, "id": auth_result.session_id, "mcp_config": auth_result.mcp_config or [], "available_servers": { s["server_name"]: s for s in (auth_result.mcp_config or []) }, "gateway_config": auth_result.gateway_config, } else: return { "status": "error", "message": auth_result.message, "error": auth_result.error or auth_result.message, } def convert_legacy_format_to_auth_result(legacy_result: Dict[str, Any]) -> AuthResult: """ Convert legacy dict format to new AuthResult format. Args: legacy_result: Legacy format dict Returns: AuthResult object """ from secure_mcp_gateway.plugins.auth.base import AuthStatus if legacy_result.get("status") == "success": return AuthResult( status=AuthStatus.SUCCESS, authenticated=True, message=legacy_result.get("message", "Authentication successful"), user_id=legacy_result.get("gateway_config", {}).get("user_id"), project_id=legacy_result.get("gateway_config", {}).get("project_id"), session_id=legacy_result.get("id"), gateway_config=legacy_result.get("gateway_config", {}), mcp_config=legacy_result.get("mcp_config", []), metadata={"source": "legacy_conversion"}, ) else: return AuthResult( status=AuthStatus.ERROR, authenticated=False, message=legacy_result.get("message", "Authentication failed"), error=legacy_result.get("error"), ) # ============================================================================ # Global Instance # ============================================================================ _auth_config_manager: Optional[AuthConfigManager] = None def get_auth_config_manager() -> AuthConfigManager: """ Get or create the global AuthConfigManager instance. Returns: AuthConfigManager: Global instance """ global _auth_config_manager if _auth_config_manager is None: _auth_config_manager = AuthConfigManager() return _auth_config_manager def initialize_auth_system(config: Dict[str, Any] = None) -> AuthConfigManager: """ Initialize the authentication system with providers. Args: config: Configuration dict containing auth settings Returns: AuthConfigManager: Initialized manager """ manager = get_auth_config_manager() if config is None: return manager # Use the new centralized plugin loader with fallback mechanism from secure_mcp_gateway.plugins.plugin_loader import PluginLoader PluginLoader.load_plugin_providers(config, "auth", manager) return manager

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/enkryptai/secure-mcp-gateway'

If you have feedback or need assistance with the MCP directory API, please join our Discord server