Skip to main content
Glama
vitalune

Personal Knowledge Assistant

by vitalune
auth.py22 kB
""" Secure Authentication Management System This module provides comprehensive authentication handling with security best practices: - OAuth2 token management with encryption - API key secure storage and rotation - Multi-service authentication state management - Session management for MCP connections - Token refresh and validation """ import asyncio import json import secrets import time from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Dict, List, Optional, Any, Tuple from dataclasses import dataclass, asdict from enum import Enum import structlog from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.ciphers.aead import AESGCM from pydantic import BaseModel, Field, SecretStr from .settings import get_settings logger = structlog.get_logger(__name__) class AuthProvider(str, Enum): """Supported authentication providers""" GOOGLE = "google" MICROSOFT = "microsoft" TWITTER = "twitter" LINKEDIN = "linkedin" CUSTOM = "custom" class TokenType(str, Enum): """Types of authentication tokens""" ACCESS_TOKEN = "access_token" REFRESH_TOKEN = "refresh_token" ID_TOKEN = "id_token" API_KEY = "api_key" SESSION_TOKEN = "session_token" class AuthStatus(str, Enum): """Authentication status""" VALID = "valid" EXPIRED = "expired" REVOKED = "revoked" INVALID = "invalid" PENDING_REFRESH = "pending_refresh" @dataclass class TokenMetadata: """Metadata for authentication tokens""" created_at: datetime expires_at: Optional[datetime] last_used: Optional[datetime] usage_count: int scopes: List[str] issuer: str subject: Optional[str] client_id: Optional[str] class SecureToken(BaseModel): """Secure token container with encryption""" token_type: TokenType provider: AuthProvider encrypted_value: bytes metadata: TokenMetadata checksum: str class Config: arbitrary_types_allowed = True def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for storage""" return { "token_type": self.token_type.value, "provider": self.provider.value, "encrypted_value": self.encrypted_value.hex(), "metadata": { "created_at": self.metadata.created_at.isoformat(), "expires_at": self.metadata.expires_at.isoformat() if self.metadata.expires_at else None, "last_used": self.metadata.last_used.isoformat() if self.metadata.last_used else None, "usage_count": self.metadata.usage_count, "scopes": self.metadata.scopes, "issuer": self.metadata.issuer, "subject": self.metadata.subject, "client_id": self.metadata.client_id, }, "checksum": self.checksum, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SecureToken": """Create from dictionary""" metadata = TokenMetadata( created_at=datetime.fromisoformat(data["metadata"]["created_at"]), expires_at=datetime.fromisoformat(data["metadata"]["expires_at"]) if data["metadata"]["expires_at"] else None, last_used=datetime.fromisoformat(data["metadata"]["last_used"]) if data["metadata"]["last_used"] else None, usage_count=data["metadata"]["usage_count"], scopes=data["metadata"]["scopes"], issuer=data["metadata"]["issuer"], subject=data["metadata"]["subject"], client_id=data["metadata"]["client_id"], ) return cls( token_type=TokenType(data["token_type"]), provider=AuthProvider(data["provider"]), encrypted_value=bytes.fromhex(data["encrypted_value"]), metadata=metadata, checksum=data["checksum"], ) class AuthenticationManager: """Secure authentication manager for multiple providers and token types""" def __init__(self): self.settings = get_settings() self._encryption_key = self._derive_encryption_key() self._token_storage: Dict[str, SecureToken] = {} self._session_states: Dict[str, Dict[str, Any]] = {} self._load_tokens() # Audit logger for authentication events self._audit_logger = structlog.get_logger("auth_audit") def _derive_encryption_key(self) -> bytes: """Derive encryption key from master key and salt""" master_key = self.settings.encryption.master_key.get_secret_value().encode() salt = self.settings.encryption.key_salt.get_secret_value().encode() kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=self.settings.encryption.pbkdf2_iterations, ) return kdf.derive(master_key) def _encrypt_token(self, token_value: str) -> Tuple[bytes, str]: """Encrypt token value and return encrypted data with checksum""" aesgcm = AESGCM(self._encryption_key) nonce = secrets.token_bytes(12) # 12 bytes for GCM encrypted_data = aesgcm.encrypt(nonce, token_value.encode(), None) full_encrypted = nonce + encrypted_data # Create checksum for integrity validation checksum = hashes.Hash(hashes.SHA256()) checksum.update(full_encrypted) checksum_value = checksum.finalize().hex()[:16] # First 16 chars of SHA256 return full_encrypted, checksum_value def _decrypt_token(self, encrypted_data: bytes, checksum: str) -> str: """Decrypt token value and validate checksum""" # Validate checksum calculated_checksum = hashes.Hash(hashes.SHA256()) calculated_checksum.update(encrypted_data) calculated_checksum_value = calculated_checksum.finalize().hex()[:16] if calculated_checksum_value != checksum: raise ValueError("Token integrity check failed - possible tampering detected") aesgcm = AESGCM(self._encryption_key) nonce = encrypted_data[:12] ciphertext = encrypted_data[12:] decrypted_data = aesgcm.decrypt(nonce, ciphertext, None) return decrypted_data.decode() def _get_storage_path(self) -> Path: """Get secure storage path for tokens""" return self.settings.get_storage_path("auth_tokens.encrypted") def _load_tokens(self): """Load encrypted tokens from secure storage""" storage_path = self._get_storage_path() if not storage_path.exists(): logger.info("No existing token storage found, starting fresh") return try: with open(storage_path, 'r') as f: stored_data = json.load(f) for token_id, token_data in stored_data.items(): try: secure_token = SecureToken.from_dict(token_data) self._token_storage[token_id] = secure_token except Exception as e: logger.error(f"Failed to load token {token_id}", error=str(e)) logger.info(f"Loaded {len(self._token_storage)} tokens from secure storage") except Exception as e: logger.error("Failed to load token storage", error=str(e)) # Don't fail completely, just start with empty storage self._token_storage = {} def _save_tokens(self): """Save encrypted tokens to secure storage""" storage_path = self._get_storage_path() # Ensure parent directory exists storage_path.parent.mkdir(parents=True, exist_ok=True) try: # Convert all tokens to dict format storage_data = { token_id: token.to_dict() for token_id, token in self._token_storage.items() } # Write to temporary file first, then rename for atomic operation temp_path = storage_path.with_suffix('.tmp') with open(temp_path, 'w') as f: json.dump(storage_data, f, indent=2) # Set restrictive permissions temp_path.chmod(0o600) # Atomic rename temp_path.rename(storage_path) logger.info("Tokens saved to secure storage") except Exception as e: logger.error("Failed to save tokens to storage", error=str(e)) raise def _generate_token_id(self, provider: AuthProvider, token_type: TokenType, subject: Optional[str] = None) -> str: """Generate unique token identifier""" base_id = f"{provider.value}_{token_type.value}" if subject: base_id += f"_{subject}" # Add timestamp and random component for uniqueness timestamp = int(time.time()) random_suffix = secrets.token_hex(4) return f"{base_id}_{timestamp}_{random_suffix}" async def store_token( self, provider: AuthProvider, token_type: TokenType, token_value: str, expires_in: Optional[int] = None, scopes: Optional[List[str]] = None, subject: Optional[str] = None, client_id: Optional[str] = None, issuer: Optional[str] = None, ) -> str: """Store a token securely with encryption""" # Encrypt the token value encrypted_value, checksum = self._encrypt_token(token_value) # Calculate expiration time expires_at = None if expires_in: expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) # Create token metadata metadata = TokenMetadata( created_at=datetime.now(timezone.utc), expires_at=expires_at, last_used=None, usage_count=0, scopes=scopes or [], issuer=issuer or provider.value, subject=subject, client_id=client_id, ) # Create secure token secure_token = SecureToken( token_type=token_type, provider=provider, encrypted_value=encrypted_value, metadata=metadata, checksum=checksum, ) # Generate unique token ID token_id = self._generate_token_id(provider, token_type, subject) # Store token self._token_storage[token_id] = secure_token # Save to persistent storage self._save_tokens() # Audit log await self._audit_log("token_stored", { "token_id": token_id, "provider": provider.value, "token_type": token_type.value, "scopes": scopes or [], "expires_at": expires_at.isoformat() if expires_at else None, }) logger.info(f"Token stored securely", token_id=token_id, provider=provider.value, token_type=token_type.value) return token_id async def retrieve_token(self, token_id: str) -> Optional[str]: """Retrieve and decrypt a token""" if token_id not in self._token_storage: logger.warning(f"Token not found", token_id=token_id) return None secure_token = self._token_storage[token_id] # Check if token is expired if secure_token.metadata.expires_at and secure_token.metadata.expires_at < datetime.now(timezone.utc): logger.warning(f"Token expired", token_id=token_id) await self._audit_log("token_access_expired", {"token_id": token_id}) return None try: # Decrypt token token_value = self._decrypt_token(secure_token.encrypted_value, secure_token.checksum) # Update usage metadata secure_token.metadata.last_used = datetime.now(timezone.utc) secure_token.metadata.usage_count += 1 # Save updated metadata self._save_tokens() # Audit log await self._audit_log("token_accessed", { "token_id": token_id, "provider": secure_token.provider.value, "token_type": secure_token.token_type.value, "usage_count": secure_token.metadata.usage_count, }) return token_value except Exception as e: logger.error(f"Failed to decrypt token", token_id=token_id, error=str(e)) await self._audit_log("token_access_failed", { "token_id": token_id, "error": str(e), }) return None async def revoke_token(self, token_id: str) -> bool: """Revoke a token and remove from storage""" if token_id not in self._token_storage: logger.warning(f"Token not found for revocation", token_id=token_id) return False secure_token = self._token_storage[token_id] # Remove token from storage del self._token_storage[token_id] # Save changes self._save_tokens() # Audit log await self._audit_log("token_revoked", { "token_id": token_id, "provider": secure_token.provider.value, "token_type": secure_token.token_type.value, }) logger.info(f"Token revoked", token_id=token_id) return True async def list_tokens(self, provider: Optional[AuthProvider] = None, token_type: Optional[TokenType] = None) -> List[Dict[str, Any]]: """List stored tokens with metadata (without revealing token values)""" tokens = [] for token_id, secure_token in self._token_storage.items(): # Filter by provider if specified if provider and secure_token.provider != provider: continue # Filter by token type if specified if token_type and secure_token.token_type != token_type: continue # Check token status status = AuthStatus.VALID if secure_token.metadata.expires_at: if secure_token.metadata.expires_at < datetime.now(timezone.utc): status = AuthStatus.EXPIRED token_info = { "token_id": token_id, "provider": secure_token.provider.value, "token_type": secure_token.token_type.value, "status": status.value, "created_at": secure_token.metadata.created_at.isoformat(), "expires_at": secure_token.metadata.expires_at.isoformat() if secure_token.metadata.expires_at else None, "last_used": secure_token.metadata.last_used.isoformat() if secure_token.metadata.last_used else None, "usage_count": secure_token.metadata.usage_count, "scopes": secure_token.metadata.scopes, "subject": secure_token.metadata.subject, } tokens.append(token_info) return tokens async def cleanup_expired_tokens(self) -> int: """Remove expired tokens from storage""" expired_token_ids = [] current_time = datetime.now(timezone.utc) for token_id, secure_token in self._token_storage.items(): if secure_token.metadata.expires_at and secure_token.metadata.expires_at < current_time: expired_token_ids.append(token_id) # Remove expired tokens for token_id in expired_token_ids: del self._token_storage[token_id] logger.info(f"Removed expired token", token_id=token_id) if expired_token_ids: self._save_tokens() # Audit log await self._audit_log("tokens_cleanup", { "expired_count": len(expired_token_ids), "token_ids": expired_token_ids, }) return len(expired_token_ids) def create_session_state(self, provider: AuthProvider, redirect_uri: str, scopes: List[str]) -> str: """Create secure session state for OAuth flow""" state = secrets.token_urlsafe(32) session_data = { "provider": provider.value, "redirect_uri": redirect_uri, "scopes": scopes, "created_at": datetime.now(timezone.utc).isoformat(), "expires_at": (datetime.now(timezone.utc) + timedelta(minutes=self.settings.api_credentials.oauth_state_expiry_minutes)).isoformat(), } self._session_states[state] = session_data logger.info(f"Created session state", state=state[:8] + "...", provider=provider.value) return state def validate_session_state(self, state: str) -> Optional[Dict[str, Any]]: """Validate session state and return session data""" if state not in self._session_states: logger.warning(f"Invalid session state", state=state[:8] + "...") return None session_data = self._session_states[state] # Check expiration expires_at = datetime.fromisoformat(session_data["expires_at"]) if expires_at < datetime.now(timezone.utc): logger.warning(f"Expired session state", state=state[:8] + "...") del self._session_states[state] return None return session_data def consume_session_state(self, state: str) -> Optional[Dict[str, Any]]: """Consume (validate and remove) session state""" session_data = self.validate_session_state(state) if session_data: del self._session_states[state] logger.info(f"Consumed session state", state=state[:8] + "...") return session_data async def rotate_tokens(self) -> int: """Rotate tokens that support rotation""" rotated_count = 0 for token_id, secure_token in list(self._token_storage.items()): # Only rotate tokens that are close to expiry (within 1 hour) if secure_token.metadata.expires_at: time_to_expiry = secure_token.metadata.expires_at - datetime.now(timezone.utc) if time_to_expiry < timedelta(hours=1) and secure_token.token_type == TokenType.ACCESS_TOKEN: # Look for corresponding refresh token refresh_token_id = self._find_refresh_token(secure_token.provider, secure_token.metadata.subject) if refresh_token_id: logger.info(f"Token rotation needed", token_id=token_id) # This would typically involve calling the OAuth provider's refresh endpoint # For now, just log the need for rotation rotated_count += 1 if rotated_count > 0: await self._audit_log("tokens_rotation_check", { "tokens_needing_rotation": rotated_count, }) return rotated_count def _find_refresh_token(self, provider: AuthProvider, subject: Optional[str]) -> Optional[str]: """Find corresponding refresh token for a provider and subject""" for token_id, secure_token in self._token_storage.items(): if (secure_token.provider == provider and secure_token.token_type == TokenType.REFRESH_TOKEN and secure_token.metadata.subject == subject): return token_id return None async def _audit_log(self, event_type: str, event_data: Dict[str, Any]): """Log authentication events for audit trail""" if not self.settings.audit.audit_enabled: return audit_entry = { "timestamp": datetime.now(timezone.utc).isoformat(), "event_type": event_type, "component": "authentication_manager", "data": event_data, } self._audit_logger.info("Auth audit event", **audit_entry) async def get_health_status(self) -> Dict[str, Any]: """Get authentication system health status""" total_tokens = len(self._token_storage) expired_tokens = 0 expiring_soon = 0 current_time = datetime.now(timezone.utc) for secure_token in self._token_storage.values(): if secure_token.metadata.expires_at: if secure_token.metadata.expires_at < current_time: expired_tokens += 1 elif secure_token.metadata.expires_at < current_time + timedelta(hours=24): expiring_soon += 1 return { "total_tokens": total_tokens, "expired_tokens": expired_tokens, "expiring_soon": expiring_soon, "active_sessions": len(self._session_states), "storage_path": str(self._get_storage_path()), "encryption_enabled": True, } # Global authentication manager instance auth_manager: Optional[AuthenticationManager] = None def get_auth_manager() -> AuthenticationManager: """Get authentication manager singleton""" global auth_manager if auth_manager is None: auth_manager = AuthenticationManager() return auth_manager

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/vitalune/Nexus-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server