Skip to main content
Glama
token_manager.py16.3 kB
""" OAuth 2.1 Token Management Implementation Handles token lifecycle with OAuth 2.1 requirements: - Short-lived access tokens (≤ 1 hour) - Refresh token rotation - Resource indicators for audience validation - Secure token storage and validation """ import json import logging import secrets from dataclasses import dataclass, field from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, List, Set import jwt from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import base64 logger = logging.getLogger(__name__) @dataclass class TokenContext: """Token context with OAuth 2.1 compliance""" access_token: str token_type: str = "Bearer" expires_in: int = 3600 # 1 hour max per OAuth 2.1 refresh_token: Optional[str] = None scope: Optional[str] = None resource: Optional[str] = None # Resource indicator client_id: str = "" user_id: str = "" issued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) metadata: Dict[str, Any] = field(default_factory=dict) class TokenError(Exception): """Token-related errors""" pass class TokenManager: """ OAuth 2.1 compliant token manager with resource indicators support Features: - Short-lived access tokens (max 1 hour) - Refresh token rotation for enhanced security - Resource indicators for audience validation - Secure token generation and validation - Token introspection capabilities """ def __init__(self, secret_key: str, default_token_expiry: int = 3600, # 1 hour refresh_token_expiry: int = 7200, # 2 hours issuer: str = "mcp-personal-assistant"): """ Initialize token manager Args: secret_key: Secret key for token signing default_token_expiry: Default access token expiry in seconds refresh_token_expiry: Refresh token expiry in seconds issuer: Token issuer identifier """ self.secret_key = secret_key self.default_token_expiry = min(default_token_expiry, 3600) # Enforce 1 hour max self.refresh_token_expiry = refresh_token_expiry self.issuer = issuer # In-memory token store (use Redis/database in production) self.active_tokens: Dict[str, TokenContext] = {} self.refresh_tokens: Dict[str, Dict[str, Any]] = {} logger.info(f"TokenManager initialized with {default_token_expiry}s token expiry") def create_access_token(self, client_id: str, user_id: str, scope: Optional[str] = None, resource: Optional[str] = None, expires_in: Optional[int] = None) -> TokenContext: """ Create OAuth 2.1 compliant access token Args: client_id: OAuth client identifier user_id: User identifier scope: Token scope resource: Resource indicator (RFC 8707) expires_in: Token expiry in seconds (max 3600) Returns: TokenContext with access token details """ # Enforce OAuth 2.1 token expiry limits if expires_in is None: expires_in = self.default_token_expiry expires_in = min(expires_in, 3600) # Max 1 hour now = datetime.now(timezone.utc) expiry = now + timedelta(seconds=expires_in) # Create JWT payload with resource indicator support payload = { "iss": self.issuer, "sub": user_id, "aud": resource or self.issuer, # Use resource as audience "client_id": client_id, "exp": int(expiry.timestamp()), "iat": int(now.timestamp()), "jti": self._generate_token_id(), "token_type": "access_token" } # Add scope if provided if scope: payload["scope"] = scope # Add resource indicator if provided if resource: payload["resource"] = resource # Generate JWT access token access_token = jwt.encode(payload, self.secret_key, algorithm="HS256") # Create token context token_context = TokenContext( access_token=access_token, expires_in=expires_in, scope=scope, resource=resource, client_id=client_id, user_id=user_id, issued_at=now, metadata={ "jti": payload["jti"], "algorithm": "HS256" } ) # Store token for tracking self.active_tokens[payload["jti"]] = token_context logger.info(f"Created access token for client {client_id}, user {user_id}") return token_context def create_refresh_token(self, client_id: str, user_id: str, access_token_jti: str) -> str: """ Create refresh token with rotation support Args: client_id: OAuth client identifier user_id: User identifier access_token_jti: Associated access token JTI Returns: Refresh token string """ refresh_token = self._generate_secure_token() expires_at = datetime.now(timezone.utc) + timedelta(seconds=self.refresh_token_expiry) # Store refresh token metadata self.refresh_tokens[refresh_token] = { "client_id": client_id, "user_id": user_id, "access_token_jti": access_token_jti, "expires_at": expires_at, "created_at": datetime.now(timezone.utc), "used": False } logger.info(f"Created refresh token for client {client_id}") return refresh_token def validate_access_token(self, token: str, resource: Optional[str] = None) -> TokenContext: """ Validate access token with resource indicator support Args: token: JWT access token resource: Expected resource indicator Returns: TokenContext if valid Raises: TokenError: If token is invalid """ try: # Decode and verify JWT payload = jwt.decode( token, self.secret_key, algorithms=["HS256"], options={ "require": ["exp", "iat", "iss", "sub", "jti"], "verify_exp": True, "verify_iat": True, } ) # Validate issuer if payload.get("iss") != self.issuer: raise TokenError("Invalid token issuer") # Validate token type if payload.get("token_type") != "access_token": raise TokenError("Invalid token type") # Validate resource indicator if provided if resource: token_resource = payload.get("resource") token_audience = payload.get("aud") if token_resource and token_resource != resource: raise TokenError("Token not valid for requested resource") if token_audience and token_audience != resource: raise TokenError("Token audience mismatch") # Check if token is still active jti = payload.get("jti") if jti not in self.active_tokens: raise TokenError("Token not found or revoked") token_context = self.active_tokens[jti] # Verify token hasn't expired (additional check) if datetime.now(timezone.utc) >= token_context.issued_at + timedelta(seconds=token_context.expires_in): self._revoke_token(jti) raise TokenError("Token has expired") logger.debug(f"Token validation successful for jti: {jti}") return token_context except jwt.InvalidTokenError as e: logger.warning(f"JWT validation failed: {e}") raise TokenError(f"Invalid token: {e}") except Exception as e: logger.error(f"Token validation error: {e}") raise TokenError(f"Token validation failed: {e}") def refresh_access_token(self, refresh_token: str, resource: Optional[str] = None) -> TokenContext: """ Refresh access token using refresh token with rotation Args: refresh_token: Valid refresh token resource: Optional resource indicator Returns: New TokenContext with fresh access token Raises: TokenError: If refresh token is invalid """ # Validate refresh token refresh_data = self.refresh_tokens.get(refresh_token) if not refresh_data: raise TokenError("Invalid refresh token") if refresh_data["used"]: # Refresh token reuse detected - revoke all tokens logger.warning(f"Refresh token reuse detected for client {refresh_data['client_id']}") self._revoke_all_client_tokens(refresh_data["client_id"], refresh_data["user_id"]) raise TokenError("Refresh token already used") if datetime.now(timezone.utc) >= refresh_data["expires_at"]: raise TokenError("Refresh token has expired") # Mark refresh token as used refresh_data["used"] = True # Revoke old access token old_jti = refresh_data["access_token_jti"] if old_jti in self.active_tokens: del self.active_tokens[old_jti] # Create new access token new_token_context = self.create_access_token( client_id=refresh_data["client_id"], user_id=refresh_data["user_id"], resource=resource ) # Create new refresh token (refresh token rotation) new_refresh_token = self.create_refresh_token( client_id=refresh_data["client_id"], user_id=refresh_data["user_id"], access_token_jti=new_token_context.metadata["jti"] ) new_token_context.refresh_token = new_refresh_token # Remove old refresh token del self.refresh_tokens[refresh_token] logger.info(f"Access token refreshed for client {refresh_data['client_id']}") return new_token_context def revoke_token(self, token: str) -> bool: """ Revoke access or refresh token Args: token: Token to revoke Returns: True if token was revoked, False if not found """ # Try to decode as access token try: payload = jwt.decode( token, self.secret_key, algorithms=["HS256"], options={"verify_signature": False} ) jti = payload.get("jti") if jti and jti in self.active_tokens: return self._revoke_token(jti) except jwt.InvalidTokenError: pass # Not a JWT, might be refresh token # Try as refresh token if token in self.refresh_tokens: del self.refresh_tokens[token] logger.info("Refresh token revoked") return True return False def introspect_token(self, token: str) -> Dict[str, Any]: """ Introspect token (RFC 7662 style) Args: token: Token to introspect Returns: Token introspection response """ try: token_context = self.validate_access_token(token) # Decode token for additional info payload = jwt.decode( token, self.secret_key, algorithms=["HS256"] ) return { "active": True, "client_id": token_context.client_id, "username": token_context.user_id, "scope": token_context.scope, "resource": token_context.resource, "exp": payload.get("exp"), "iat": payload.get("iat"), "iss": payload.get("iss"), "sub": payload.get("sub"), "aud": payload.get("aud"), "token_type": "access_token" } except TokenError: return {"active": False} def cleanup_expired_tokens(self) -> int: """ Clean up expired tokens from memory Returns: Number of tokens cleaned up """ now = datetime.now(timezone.utc) cleaned_count = 0 # Clean up expired access tokens expired_jtis = [] for jti, token_context in self.active_tokens.items(): if now >= token_context.issued_at + timedelta(seconds=token_context.expires_in): expired_jtis.append(jti) for jti in expired_jtis: del self.active_tokens[jti] cleaned_count += 1 # Clean up expired refresh tokens expired_refresh_tokens = [] for refresh_token, data in self.refresh_tokens.items(): if now >= data["expires_at"]: expired_refresh_tokens.append(refresh_token) for refresh_token in expired_refresh_tokens: del self.refresh_tokens[refresh_token] cleaned_count += 1 if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired tokens") return cleaned_count def _generate_token_id(self) -> str: """Generate unique token ID (JTI)""" return secrets.token_urlsafe(32) def _generate_secure_token(self) -> str: """Generate cryptographically secure token""" return secrets.token_urlsafe(64) def _revoke_token(self, jti: str) -> bool: """Revoke access token by JTI""" if jti in self.active_tokens: del self.active_tokens[jti] logger.info(f"Access token revoked: {jti}") return True return False def _revoke_all_client_tokens(self, client_id: str, user_id: str) -> None: """Revoke all tokens for a client/user combination""" # Revoke access tokens jtis_to_remove = [] for jti, token_context in self.active_tokens.items(): if token_context.client_id == client_id and token_context.user_id == user_id: jtis_to_remove.append(jti) for jti in jtis_to_remove: del self.active_tokens[jti] # Revoke refresh tokens refresh_tokens_to_remove = [] for refresh_token, data in self.refresh_tokens.items(): if data["client_id"] == client_id and data["user_id"] == user_id: refresh_tokens_to_remove.append(refresh_token) for refresh_token in refresh_tokens_to_remove: del self.refresh_tokens[refresh_token] logger.warning(f"Revoked all tokens for client {client_id}, user {user_id}") # Convenience functions def create_token_manager(secret_key: str, **kwargs) -> TokenManager: """Create token manager with secure defaults""" return TokenManager( secret_key=secret_key, default_token_expiry=kwargs.get('token_expiry', 3600), # 1 hour refresh_token_expiry=kwargs.get('refresh_expiry', 7200), # 2 hours issuer=kwargs.get('issuer', 'mcp-personal-assistant') )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/swapnilsurdi/mcp-pa'

If you have feedback or need assistance with the MCP directory API, please join our Discord server