"""Authentication middleware for MCP tool calls.
Provides decorators and utilities to enforce authentication
before tool execution. Supports dual-mode authentication:
- LOCAL mode: OAuth client flow with browser SSO
- CLOUD mode: Resource Server with Bearer token validation
"""
from __future__ import annotations
import contextvars
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeVar
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData
from sso_mcp_server import get_logger
from sso_mcp_server.auth.exceptions import (
AuthNotConfiguredError,
CloudAuthError,
MissingAuthorizationError,
NotAuthenticatedError,
)
from sso_mcp_server.config import AuthMode, get_settings
if TYPE_CHECKING:
from sso_mcp_server.auth.cloud import TokenClaims, TokenValidator
from sso_mcp_server.auth.manager import AuthManager
_logger = get_logger("auth_middleware")
# Type variable for generic function signatures
F = TypeVar("F", bound=Callable[..., Any])
# Global auth manager reference (set during server initialization) - LOCAL mode
_auth_manager: AuthManager | None = None
# Global token validator reference (set during server initialization) - CLOUD mode
_token_validator: TokenValidator | None = None
# Context variable for storing current request's token claims (CLOUD mode)
_current_claims: contextvars.ContextVar[TokenClaims | None] = contextvars.ContextVar(
"current_claims", default=None
)
# Context variable for storing current request's authorization header
_current_auth_header: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"current_auth_header", default=None
)
def set_auth_manager(manager: AuthManager) -> None:
"""Set the global auth manager instance (LOCAL mode).
Called during server initialization to provide the auth manager
to the middleware.
Args:
manager: The AuthManager instance to use.
"""
global _auth_manager # noqa: PLW0603
_auth_manager = manager
_logger.debug("auth_manager_set")
def get_auth_manager() -> AuthManager | None:
"""Get the global auth manager instance.
Returns:
The AuthManager instance, or None if not set.
"""
return _auth_manager
def set_token_validator(validator: TokenValidator) -> None:
"""Set the global token validator instance (CLOUD mode).
Called during server initialization to provide the token validator
to the middleware.
Args:
validator: The TokenValidator instance to use.
"""
global _token_validator # noqa: PLW0603
_token_validator = validator
_logger.debug("token_validator_set")
def get_token_validator() -> TokenValidator | None:
"""Get the global token validator instance.
Returns:
The TokenValidator instance, or None if not set.
"""
return _token_validator
def set_authorization_header(auth_header: str | None) -> None:
"""Set the current request's authorization header.
Called by the server when processing a request to make the
Authorization header available to the middleware.
Args:
auth_header: The Authorization header value.
"""
_current_auth_header.set(auth_header)
def get_current_claims() -> TokenClaims | None:
"""Get the current request's validated token claims.
After successful authentication in CLOUD mode, the validated
claims are stored in context and can be accessed by tools.
Returns:
TokenClaims if authenticated in CLOUD mode, None otherwise.
"""
return _current_claims.get()
def _extract_bearer_token(auth_header: str | None) -> str | None:
"""Extract Bearer token from Authorization header.
Args:
auth_header: The Authorization header value.
Returns:
The token string, or None if not a valid Bearer token.
"""
if not auth_header:
return None
parts = auth_header.split(" ", 1)
if len(parts) != 2:
return None
scheme, token = parts
if scheme.lower() != "bearer":
return None
return token.strip() if token.strip() else None
async def _local_auth_flow(func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Execute local authentication flow.
Uses the AuthManager to ensure authentication via browser SSO.
"""
if _auth_manager is None:
error = AuthNotConfiguredError()
_logger.error(
"auth_manager_not_configured",
error_code=error.code,
action=error.action,
)
raise McpError(ErrorData(code=-32001, message=str(error)))
# Ensure authenticated (will trigger auth flow on first call if needed)
if not _auth_manager.ensure_authenticated():
error = NotAuthenticatedError()
_logger.warning(
"authentication_failed",
error_code=error.code,
action=error.action,
)
raise McpError(ErrorData(code=-32002, message=str(error)))
_logger.debug("local_auth_passed", function=func.__name__)
return await func(*args, **kwargs)
async def _cloud_auth_flow(func: Callable, *args: Any, **kwargs: Any) -> Any:
"""Execute cloud authentication flow.
Validates incoming Bearer token using the TokenValidator.
"""
if _token_validator is None:
error = AuthNotConfiguredError()
_logger.error(
"token_validator_not_configured",
error_code=error.code,
action=error.action,
)
raise McpError(ErrorData(code=-32001, message=str(error)))
# Get Authorization header from context
auth_header = _current_auth_header.get()
# Extract Bearer token
token = _extract_bearer_token(auth_header)
if not token:
error = MissingAuthorizationError()
_logger.warning(
"missing_bearer_token",
error_code=error.code,
)
raise McpError(
ErrorData(
code=-32002,
message=str(error),
)
)
# Validate token
try:
claims = await _token_validator.validate(token)
_current_claims.set(claims)
_logger.debug(
"cloud_auth_passed",
function=func.__name__,
sub=claims.sub,
)
return await func(*args, **kwargs)
except CloudAuthError as e:
_logger.warning(
"cloud_auth_failed",
error_code=e.code,
error=e.error,
)
raise McpError(
ErrorData(
code=-32002,
message=str(e),
)
) from e
finally:
# Clear claims after request
_current_claims.set(None)
def require_auth(func: F) -> F:
"""Decorator to require authentication for a tool function.
Routes to the appropriate authentication flow based on AUTH_MODE:
- LOCAL: Uses AuthManager for browser SSO
- CLOUD: Validates incoming Bearer token
- AUTO: Attempts CLOUD if Authorization header present, else LOCAL
Args:
func: The tool function to wrap.
Returns:
Wrapped function that enforces authentication.
Raises:
McpError: If authentication fails or is not configured.
"""
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
settings = get_settings()
if settings.auth_mode == AuthMode.LOCAL:
return await _local_auth_flow(func, *args, **kwargs)
elif settings.auth_mode == AuthMode.CLOUD:
return await _cloud_auth_flow(func, *args, **kwargs)
else: # AUTO mode
# Check if Authorization header is present
auth_header = _current_auth_header.get()
if auth_header and auth_header.lower().startswith("bearer "):
return await _cloud_auth_flow(func, *args, **kwargs)
else:
return await _local_auth_flow(func, *args, **kwargs)
return wrapper # type: ignore[return-value]
def check_auth() -> bool:
"""Check if currently authenticated (LOCAL mode only).
Utility function for manual auth checks.
Returns:
True if authenticated, False otherwise.
"""
if _auth_manager is None:
return False
return _auth_manager.is_authenticated()
def ensure_authenticated() -> bool:
"""Ensure authentication, triggering flow if needed (LOCAL mode only).
Returns:
True if authenticated after the check.
"""
if _auth_manager is None:
_logger.error("auth_manager_not_configured")
return False
return _auth_manager.ensure_authenticated()