Skip to main content
Glama
auth.pyโ€ข12.7 kB
#!/usr/bin/env python3 """ Authentication module for MCP SSE Server Implements token-based authentication with API key support """ import os import secrets import hashlib import hmac from typing import Optional, Dict, Any, List from datetime import datetime, timezone, timedelta from dataclasses import dataclass import json import base64 from jose import jwt from jose.exceptions import JWTError, ExpiredSignatureError from pydantic import BaseModel, Field import httpx from fastapi import HTTPException, Header @dataclass class AuthToken: """Authentication token data.""" token: str api_key_id: str created_at: datetime expires_at: Optional[datetime] = None scopes: List[str] = None metadata: Dict[str, Any] = None def is_valid(self) -> bool: """Check if token is still valid.""" if self.expires_at: return datetime.now(timezone.utc) < self.expires_at return True def has_scope(self, scope: str) -> bool: """Check if token has specific scope.""" if not self.scopes: return True # No scopes means full access return scope in self.scopes class APIKey(BaseModel): """API Key model.""" id: str = Field(description="Unique key identifier") name: str = Field(description="Human-readable name") key_hash: str = Field(description="Hashed API key") created_at: datetime = Field(description="Creation timestamp") expires_at: Optional[datetime] = Field(default=None, description="Expiration timestamp") scopes: List[str] = Field(default_factory=list, description="Allowed scopes") active: bool = Field(default=True, description="Whether key is active") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") class AuthManager: """Manage authentication for MCP server.""" def __init__(self): self.secret_key = os.getenv("AUTH_SECRET_KEY", secrets.token_urlsafe(32)) self.token_lifetime = int(os.getenv("TOKEN_LIFETIME_HOURS", "24")) self.api_keys: Dict[str, APIKey] = {} self.tokens: Dict[str, AuthToken] = {} # Load API keys from environment or config self._load_api_keys() def _load_api_keys(self): """Load API keys from configuration.""" # Load from environment variable (JSON format) api_keys_json = os.getenv("API_KEYS") if api_keys_json: try: keys_data = json.loads(api_keys_json) for key_data in keys_data: api_key = APIKey(**key_data) self.api_keys[api_key.id] = api_key except Exception as e: print(f"Error loading API keys: {e}") # Create default API key if none exist if not self.api_keys: default_key = self.create_api_key( name="Default API Key", scopes=["wazuh:read", "wazuh:write"] ) # Only log that a key was created, not the actual key value print("Created default API key - save this securely for client authentication") def hash_api_key(self, api_key: str) -> str: """Hash an API key using HMAC-SHA256.""" return hmac.new( self.secret_key.encode(), api_key.encode(), hashlib.sha256 ).hexdigest() def create_api_key( self, name: str, scopes: Optional[List[str]] = None, expires_at: Optional[datetime] = None, metadata: Optional[Dict[str, Any]] = None ) -> str: """Create new API key.""" # Generate secure random key api_key = f"wazuh_{secrets.token_urlsafe(32)}" key_id = secrets.token_urlsafe(16) # Create key object key_obj = APIKey( id=key_id, name=name, key_hash=self.hash_api_key(api_key), created_at=datetime.now(timezone.utc), expires_at=expires_at, scopes=scopes or [], metadata=metadata or {} ) self.api_keys[key_id] = key_obj return api_key def validate_api_key(self, api_key: str) -> Optional[APIKey]: """Validate API key and return key object if valid with proper cryptographic verification.""" # Input validation if not api_key or not isinstance(api_key, str): return None # Format validation - must start with wazuh_ and have proper length # secrets.token_urlsafe(32) generates 43 chars, so total = 6 + 43 = 49 if not api_key.startswith("wazuh_") or len(api_key) != 49: return None # Sanitize input to prevent injection attacks if not api_key.replace("_", "").replace("-", "").isalnum(): return None # Use constant-time comparison to prevent timing attacks key_hash = self.hash_api_key(api_key) for key_obj in self.api_keys.values(): # Use hmac.compare_digest for constant-time comparison if hmac.compare_digest(key_obj.key_hash, key_hash) and key_obj.active: # Check expiration with timezone awareness if key_obj.expires_at and datetime.now(timezone.utc) > key_obj.expires_at: # Log expiration attempt (without exposing key) print(f"Attempted use of expired API key (ID: {key_obj.id[:8]}...)") return None return key_obj return None def create_token(self, api_key: str) -> Optional[str]: """Create authentication token from API key.""" key_obj = self.validate_api_key(api_key) if not key_obj: return None # Generate token token = f"wst_{secrets.token_urlsafe(48)}" # wst = Wazuh Session Token # Create token object token_obj = AuthToken( token=token, api_key_id=key_obj.id, created_at=datetime.now(timezone.utc), expires_at=datetime.now(timezone.utc) + timedelta(hours=self.token_lifetime), scopes=key_obj.scopes, metadata={ "api_key_name": key_obj.name, **key_obj.metadata } ) self.tokens[token] = token_obj return token def validate_token(self, token: str) -> Optional[AuthToken]: """Validate token and return token object if valid.""" if not token.startswith("wst_"): return None token_obj = self.tokens.get(token) if token_obj and token_obj.is_valid(): return token_obj # Clean up expired token if token_obj: del self.tokens[token] return None def revoke_token(self, token: str) -> bool: """Revoke a token.""" if token in self.tokens: del self.tokens[token] return True return False def revoke_api_key(self, key_id: str) -> bool: """Revoke an API key and all associated tokens.""" if key_id in self.api_keys: self.api_keys[key_id].active = False # Revoke all tokens for this key tokens_to_revoke = [ token for token, token_obj in self.tokens.items() if token_obj.api_key_id == key_id ] for token in tokens_to_revoke: del self.tokens[token] return True return False def cleanup_expired(self): """Clean up expired tokens.""" expired_tokens = [ token for token, token_obj in self.tokens.items() if not token_obj.is_valid() ] for token in expired_tokens: del self.tokens[token] def get_stats(self) -> Dict[str, Any]: """Get authentication statistics.""" self.cleanup_expired() return { "api_keys": { "total": len(self.api_keys), "active": sum(1 for k in self.api_keys.values() if k.active) }, "tokens": { "total": len(self.tokens), "active": sum(1 for t in self.tokens.values() if t.is_valid()) } } # Global auth manager instance auth_manager = AuthManager() class TokenRequest(BaseModel): """Token request model.""" api_key: str = Field(description="API key to exchange for token") class TokenResponse(BaseModel): """Token response model.""" token: str = Field(description="Authentication token") expires_in: int = Field(description="Token lifetime in seconds") token_type: str = Field(default="Bearer") async def verify_bearer_token(authorization: str) -> AuthToken: """Verify bearer token from Authorization header.""" if not authorization.startswith("Bearer "): raise ValueError("Invalid authorization header format") token = authorization[7:] # Remove "Bearer " prefix token_obj = auth_manager.validate_token(token) if not token_obj: raise ValueError("Invalid or expired token") return token_obj def create_access_token(data: Dict[str, Any], secret_key: str, expires_delta: Optional[timedelta] = None) -> str: """Create JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(hours=24) to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)}) try: encoded_jwt = jwt.encode(to_encode, secret_key, algorithm="HS256") return encoded_jwt except Exception as e: raise ValueError(f"Failed to create access token: {str(e)}") def verify_token(token: str, secret_key: str) -> Dict[str, Any]: """Verify and decode JWT token.""" try: payload = jwt.decode(token, secret_key, algorithms=["HS256"]) return payload except ExpiredSignatureError: raise ValueError("Token has expired") except JWTError: raise ValueError("Invalid token") async def create_auth_endpoints(app): """Add authentication endpoints to FastAPI app.""" @app.post("/auth/token", response_model=TokenResponse) async def create_token(request: TokenRequest): """Exchange API key for authentication token.""" token = auth_manager.create_token(request.api_key) if not token: raise HTTPException(status_code=401, detail="Invalid API key") token_obj = auth_manager.tokens[token] expires_in = int((token_obj.expires_at - datetime.now(timezone.utc)).total_seconds()) return TokenResponse( token=token, expires_in=expires_in ) @app.get("/auth/validate") async def validate_token( authorization: str = Header(description="Bearer token") ): """Validate authentication token.""" try: token_obj = await verify_bearer_token(authorization) return { "valid": True, "api_key_id": token_obj.api_key_id, "scopes": token_obj.scopes, "expires_at": token_obj.expires_at.isoformat() if token_obj.expires_at else None } except ValueError as e: raise HTTPException(status_code=401, detail=str(e)) @app.post("/auth/revoke") async def revoke_token( authorization: str = Header(description="Bearer token") ): """Revoke authentication token.""" if not authorization.startswith("Bearer "): raise HTTPException(status_code=400, detail="Invalid authorization header") token = authorization[7:] if auth_manager.revoke_token(token): return {"revoked": True} else: raise HTTPException(status_code=404, detail="Token not found") @app.get("/auth/stats") async def auth_stats( authorization: str = Header(description="Bearer token") ): """Get authentication statistics (requires admin scope).""" try: token_obj = await verify_bearer_token(authorization) if not token_obj.has_scope("admin"): raise HTTPException(status_code=403, detail="Admin scope required") return auth_manager.get_stats() except ValueError as e: raise HTTPException(status_code=401, detail=str(e))

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/gensecaihq/Wazuh-MCP-Server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server