"""Base authentication framework."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Callable
from enum import Enum
class Permission(Enum):
"""Permission types."""
READ_TOOLS = "read_tools"
CALL_TOOLS = "call_tools"
READ_RESOURCES = "read_resources"
ACCESS_RESOURCES = "access_resources"
READ_PROMPTS = "read_prompts"
GET_PROMPTS = "get_prompts"
ADMIN = "admin"
@dataclass
class AuthContext:
"""Authentication context."""
user_id: Optional[str] = field(default=None)
client_id: Optional[str] = None
permissions: Set[Permission] = field(default_factory=set)
metadata: Dict[str, Any] = field(default_factory=dict)
authenticated: bool = False
def __post_init__(self) -> None:
if self.permissions is None:
self.permissions = set()
if self.metadata is None:
self.metadata = {}
def has_permission(self, permission: Permission) -> bool:
"""Check if context has a specific permission.
Args:
permission: Permission to check
Returns:
True if permission is granted
"""
return permission in self.permissions or Permission.ADMIN in self.permissions
def has_any_permission(self, permissions: List[Permission]) -> bool:
"""Check if context has any of the specified permissions.
Args:
permissions: List of permissions to check
Returns:
True if any permission is granted
"""
return any(self.has_permission(perm) for perm in permissions)
def has_all_permissions(self, permissions: List[Permission]) -> bool:
"""Check if context has all of the specified permissions.
Args:
permissions: List of permissions to check
Returns:
True if all permissions are granted
"""
return all(self.has_permission(perm) for perm in permissions)
class BaseAuthProvider(ABC):
"""Base class for authentication providers."""
def __init__(self, name: str) -> None:
"""Initialize auth provider.
Args:
name: Provider name
"""
self.name = name
self.enabled = True
@abstractmethod
async def authenticate(self, credentials: Dict[str, Any]) -> AuthContext:
"""Authenticate user credentials.
Args:
credentials: Authentication credentials
Returns:
Authentication context
Raises:
AuthenticationError: If authentication fails
"""
pass
@abstractmethod
async def authorize(self, context: AuthContext, resource: str, action: str) -> bool:
"""Authorize access to a resource.
Args:
context: Authentication context
resource: Resource name
action: Action being performed
Returns:
True if authorized
"""
pass
def extract_credentials(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Extract credentials from request data.
Args:
request_data: Request data
Returns:
Extracted credentials
"""
return {}
async def validate_token(self, token: str) -> AuthContext:
"""Validate an authentication token.
Args:
token: Token to validate
Returns:
Authentication context
"""
return AuthContext()
class AuthenticationError(Exception):
"""Authentication error."""
pass
class AuthorizationError(Exception):
"""Authorization error."""
pass
class AuthManager:
"""Manager for authentication providers."""
def __init__(self) -> None:
"""Initialize auth manager."""
self.providers: Dict[str, BaseAuthProvider] = {}
self.default_provider: Optional[str] = None
self._anonymous_context = AuthContext(
user_id="anonymous",
authenticated=False,
permissions={Permission.READ_TOOLS, Permission.READ_RESOURCES, Permission.READ_PROMPTS},
)
def add_provider(self, provider: BaseAuthProvider, is_default: bool = False) -> None:
"""Add authentication provider.
Args:
provider: Auth provider to add
is_default: Whether this is the default provider
"""
self.providers[provider.name] = provider
if is_default or not self.default_provider:
self.default_provider = provider.name
def remove_provider(self, provider_name: str) -> None:
"""Remove authentication provider.
Args:
provider_name: Name of provider to remove
"""
if provider_name in self.providers:
del self.providers[provider_name]
if self.default_provider == provider_name:
self.default_provider = next(iter(self.providers.keys())) if self.providers else None
def get_provider(self, provider_name: Optional[str] = None) -> Optional[BaseAuthProvider]:
"""Get authentication provider.
Args:
provider_name: Provider name (uses default if None)
Returns:
Auth provider or None
"""
if provider_name:
return self.providers.get(provider_name)
elif self.default_provider:
return self.providers.get(self.default_provider)
return None
async def authenticate(self, credentials: Dict[str, Any], provider_name: Optional[str] = None) -> AuthContext:
"""Authenticate using credentials.
Args:
credentials: Authentication credentials
provider_name: Specific provider to use
Returns:
Authentication context
Raises:
AuthenticationError: If authentication fails
"""
provider = self.get_provider(provider_name)
if not provider:
raise AuthenticationError("No authentication provider available")
if not provider.enabled:
raise AuthenticationError(f"Authentication provider {provider.name} is disabled")
return await provider.authenticate(credentials)
async def authorize(
self, context: AuthContext, resource: str, action: str, provider_name: Optional[str] = None
) -> bool:
"""Authorize access to resource.
Args:
context: Authentication context
resource: Resource name
action: Action being performed
provider_name: Specific provider to use
Returns:
True if authorized
"""
if not context.authenticated:
return False
provider = self.get_provider(provider_name)
if not provider:
return False
return await provider.authorize(context, resource, action)
def get_anonymous_context(self) -> AuthContext:
"""Get context for anonymous access.
Returns:
Anonymous authentication context
"""
return AuthContext(
user_id="anonymous", authenticated=False, permissions=set(self._anonymous_context.permissions)
)
def require_permissions(self, *permissions: Permission) -> Callable:
"""Decorator to require specific permissions.
Args:
permissions: Required permissions
Returns:
Decorator function
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# Extract auth context from args/kwargs
context = None
for arg in args:
if isinstance(arg, AuthContext):
context = arg
break
if not context:
context = kwargs.get("auth_context", self.get_anonymous_context())
# Check permissions
if not context.has_all_permissions(list(permissions)):
missing = [p for p in permissions if not context.has_permission(p)]
raise AuthorizationError(f"Missing permissions: {[p.value for p in missing]}")
return await func(*args, **kwargs)
return wrapper
return decorator