Skip to main content
Glama
middleware.py8.65 kB
"""Authentication middleware for MCP tool calls. Provides decorators and utilities to enforce authentication before tool execution. Supports dual-mode authentication: - LOCAL mode: OAuth client flow with browser SSO - CLOUD mode: Resource Server with Bearer token validation """ from __future__ import annotations import contextvars from collections.abc import Callable from functools import wraps from typing import TYPE_CHECKING, Any, TypeVar from mcp.shared.exceptions import McpError from mcp.types import ErrorData from sso_mcp_server import get_logger from sso_mcp_server.auth.exceptions import ( AuthNotConfiguredError, CloudAuthError, MissingAuthorizationError, NotAuthenticatedError, ) from sso_mcp_server.config import AuthMode, get_settings if TYPE_CHECKING: from sso_mcp_server.auth.cloud import TokenClaims, TokenValidator from sso_mcp_server.auth.manager import AuthManager _logger = get_logger("auth_middleware") # Type variable for generic function signatures F = TypeVar("F", bound=Callable[..., Any]) # Global auth manager reference (set during server initialization) - LOCAL mode _auth_manager: AuthManager | None = None # Global token validator reference (set during server initialization) - CLOUD mode _token_validator: TokenValidator | None = None # Context variable for storing current request's token claims (CLOUD mode) _current_claims: contextvars.ContextVar[TokenClaims | None] = contextvars.ContextVar( "current_claims", default=None ) # Context variable for storing current request's authorization header _current_auth_header: contextvars.ContextVar[str | None] = contextvars.ContextVar( "current_auth_header", default=None ) def set_auth_manager(manager: AuthManager) -> None: """Set the global auth manager instance (LOCAL mode). Called during server initialization to provide the auth manager to the middleware. Args: manager: The AuthManager instance to use. """ global _auth_manager # noqa: PLW0603 _auth_manager = manager _logger.debug("auth_manager_set") def get_auth_manager() -> AuthManager | None: """Get the global auth manager instance. Returns: The AuthManager instance, or None if not set. """ return _auth_manager def set_token_validator(validator: TokenValidator) -> None: """Set the global token validator instance (CLOUD mode). Called during server initialization to provide the token validator to the middleware. Args: validator: The TokenValidator instance to use. """ global _token_validator # noqa: PLW0603 _token_validator = validator _logger.debug("token_validator_set") def get_token_validator() -> TokenValidator | None: """Get the global token validator instance. Returns: The TokenValidator instance, or None if not set. """ return _token_validator def set_authorization_header(auth_header: str | None) -> None: """Set the current request's authorization header. Called by the server when processing a request to make the Authorization header available to the middleware. Args: auth_header: The Authorization header value. """ _current_auth_header.set(auth_header) def get_current_claims() -> TokenClaims | None: """Get the current request's validated token claims. After successful authentication in CLOUD mode, the validated claims are stored in context and can be accessed by tools. Returns: TokenClaims if authenticated in CLOUD mode, None otherwise. """ return _current_claims.get() def _extract_bearer_token(auth_header: str | None) -> str | None: """Extract Bearer token from Authorization header. Args: auth_header: The Authorization header value. Returns: The token string, or None if not a valid Bearer token. """ if not auth_header: return None parts = auth_header.split(" ", 1) if len(parts) != 2: return None scheme, token = parts if scheme.lower() != "bearer": return None return token.strip() if token.strip() else None async def _local_auth_flow(func: Callable, *args: Any, **kwargs: Any) -> Any: """Execute local authentication flow. Uses the AuthManager to ensure authentication via browser SSO. """ if _auth_manager is None: error = AuthNotConfiguredError() _logger.error( "auth_manager_not_configured", error_code=error.code, action=error.action, ) raise McpError(ErrorData(code=-32001, message=str(error))) # Ensure authenticated (will trigger auth flow on first call if needed) if not _auth_manager.ensure_authenticated(): error = NotAuthenticatedError() _logger.warning( "authentication_failed", error_code=error.code, action=error.action, ) raise McpError(ErrorData(code=-32002, message=str(error))) _logger.debug("local_auth_passed", function=func.__name__) return await func(*args, **kwargs) async def _cloud_auth_flow(func: Callable, *args: Any, **kwargs: Any) -> Any: """Execute cloud authentication flow. Validates incoming Bearer token using the TokenValidator. """ if _token_validator is None: error = AuthNotConfiguredError() _logger.error( "token_validator_not_configured", error_code=error.code, action=error.action, ) raise McpError(ErrorData(code=-32001, message=str(error))) # Get Authorization header from context auth_header = _current_auth_header.get() # Extract Bearer token token = _extract_bearer_token(auth_header) if not token: error = MissingAuthorizationError() _logger.warning( "missing_bearer_token", error_code=error.code, ) raise McpError( ErrorData( code=-32002, message=str(error), ) ) # Validate token try: claims = await _token_validator.validate(token) _current_claims.set(claims) _logger.debug( "cloud_auth_passed", function=func.__name__, sub=claims.sub, ) return await func(*args, **kwargs) except CloudAuthError as e: _logger.warning( "cloud_auth_failed", error_code=e.code, error=e.error, ) raise McpError( ErrorData( code=-32002, message=str(e), ) ) from e finally: # Clear claims after request _current_claims.set(None) def require_auth(func: F) -> F: """Decorator to require authentication for a tool function. Routes to the appropriate authentication flow based on AUTH_MODE: - LOCAL: Uses AuthManager for browser SSO - CLOUD: Validates incoming Bearer token - AUTO: Attempts CLOUD if Authorization header present, else LOCAL Args: func: The tool function to wrap. Returns: Wrapped function that enforces authentication. Raises: McpError: If authentication fails or is not configured. """ @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: settings = get_settings() if settings.auth_mode == AuthMode.LOCAL: return await _local_auth_flow(func, *args, **kwargs) elif settings.auth_mode == AuthMode.CLOUD: return await _cloud_auth_flow(func, *args, **kwargs) else: # AUTO mode # Check if Authorization header is present auth_header = _current_auth_header.get() if auth_header and auth_header.lower().startswith("bearer "): return await _cloud_auth_flow(func, *args, **kwargs) else: return await _local_auth_flow(func, *args, **kwargs) return wrapper # type: ignore[return-value] def check_auth() -> bool: """Check if currently authenticated (LOCAL mode only). Utility function for manual auth checks. Returns: True if authenticated, False otherwise. """ if _auth_manager is None: return False return _auth_manager.is_authenticated() def ensure_authenticated() -> bool: """Ensure authentication, triggering flow if needed (LOCAL mode only). Returns: True if authenticated after the check. """ if _auth_manager is None: _logger.error("auth_manager_not_configured") return False return _auth_manager.ensure_authenticated()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/DauQuangThanh/sso-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server