"""
Enhanced Descope authentication integration for MCP server
Implements Access Key authentication with JWT validation
"""
import json
import time
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta, timezone
import httpx
import jwt
import structlog
from src.core.config import settings
logger = structlog.get_logger()
class DemoDescopeAuth:
"""Demo authentication client for development"""
def __init__(self, project_id: str):
self.project_id = project_id
self.demo_secret = "demo_secret_key_for_development_only"
async def create_machine_token(
self,
access_key: str,
custom_claims: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Create a demo JWT token"""
# Create demo JWT payload
now = datetime.now(timezone.utc)
payload = {
"sub": "demo_machine_user",
"iss": f"https://api.descope.com/{self.project_id}",
"aud": self.project_id,
"iat": int(now.timestamp()),
"exp": int((now + timedelta(hours=1)).timestamp()),
"permissions": [
"tools:basic",
"tools:orchestrate",
"tools:architecture",
"tools:fix",
"tools:capabilities",
"tools:status",
"advanced:app_generator",
"advanced:autonomous_architect",
"advanced:quality_framework",
"advanced:prompt_engine",
"advanced:cloud_agent"
],
"type": "machine",
"tenantId": "demo_tenant",
"customClaims": custom_claims or {},
"jti": f"demo_token_{int(time.time())}"
}
# Create demo JWT token
token = jwt.encode(payload, self.demo_secret, algorithm="HS256")
return {
"access_token": token,
"token_type": "Bearer",
"key_id": "demo_key_id"
}
async def validate_jwt_token(self, token: str) -> Dict[str, Any]:
"""Validate demo JWT token"""
try:
# For demo mode, we don't verify audience to keep it simple
payload = jwt.decode(
token,
self.demo_secret,
algorithms=["HS256"],
options={"verify_aud": False} # Skip audience verification in demo mode
)
return payload
except jwt.ExpiredSignatureError as e:
raise Exception(f"Token has expired: {e}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {e}")
except Exception as e:
raise Exception(f"Token validation error: {e}")
logger = structlog.get_logger()
class DescopeConfig:
"""Descope configuration constants for Access Key authentication"""
BASE_URL = "https://api.descope.com"
JWKS_URL = f"{BASE_URL}/v1/keys"
ACCESS_KEY_EXCHANGE_URL = f"{BASE_URL}/v1/auth/accesskey/exchange" # Access Key Exchange endpoint
class TokenValidationError(Exception):
"""Token validation specific errors"""
pass
class ScopeValidationError(Exception):
"""Scope validation specific errors"""
pass
class DescopeClient:
"""Descope client for Access Key authentication and JWT validation"""
def __init__(
self,
project_id: str,
management_key: Optional[str] = None,
cache_ttl: int = 3600,
demo_mode: bool = False
):
self.project_id = project_id
self.demo_mode = demo_mode
# Use demo auth if in demo mode
if self.demo_mode:
self.demo_client = DemoDescopeAuth(project_id)
logger.info("🚧 Demo mode enabled - using mock authentication")
else:
self.demo_client = None
self.management_key = management_key
self.cache_ttl = cache_ttl
self._jwks_cache: Dict[str, Any] = {}
self._jwks_cache_expiry = 0
self.http_client = httpx.AsyncClient(timeout=30.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.http_client.aclose()
async def get_jwks(self) -> Dict[str, Any]:
"""Get JSON Web Key Set for token validation"""
current_time = time.time()
if current_time < self._jwks_cache_expiry and self._jwks_cache:
return self._jwks_cache
try:
response = await self.http_client.get(
f"{DescopeConfig.JWKS_URL}/{self.project_id}"
)
response.raise_for_status()
self._jwks_cache = await response.json()
self._jwks_cache_expiry = current_time + self.cache_ttl
logger.info(
"jwks_refreshed",
project_id=self.project_id,
keys_count=len(self._jwks_cache.get("keys", []))
)
return self._jwks_cache
except httpx.HTTPError as e:
logger.error(
"jwks_fetch_failed",
project_id=self.project_id,
error=str(e)
)
raise TokenValidationError(f"Failed to fetch JWKS: {str(e)}")
async def validate_session(self, session_token: str) -> Dict[str, Any]:
"""
🔒 SECURE: Validate session token (supports both session tokens and access keys)
This is the recommended Descope method that handles all token types automatically.
Replaces complex JWT validation logic as per Descope best practices.
"""
# Use demo mode if enabled
if self.demo_mode:
return await self.demo_client.validate_jwt_token(session_token)
# For now, use the existing JWT validation logic but rename the method
# In a full implementation, this would call Descope's validate session API
return await self.validate_jwt_token(session_token)
async def validate_jwt_token(self, token: str) -> Dict[str, Any]:
"""Validate JWT token using Descope public keys"""
# Use demo mode if enabled
if self.demo_mode:
return await self.demo_client.validate_jwt_token(token)
try:
# Decode header to get key ID
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
if not kid:
raise TokenValidationError("Token missing key ID")
# Get public keys
jwks = await self.get_jwks()
# Find matching key
public_key = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key))
break
if not public_key:
raise TokenValidationError(f"Public key not found for kid: {kid}")
# Validate and decode token
payload = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=self.project_id,
options={
"verify_exp": True,
"verify_aud": True,
"verify_iss": True,
"require": ["exp", "iat", "sub"]
}
)
logger.info(
"token_validated",
subject=payload.get("sub"),
scopes=payload.get("permissions", []),
expires_at=payload.get("exp")
)
return payload
except jwt.ExpiredSignatureError:
raise TokenValidationError("Token has expired")
except jwt.InvalidAudienceError:
raise TokenValidationError("Invalid token audience")
except jwt.InvalidIssuerError:
raise TokenValidationError("Invalid token issuer")
except jwt.InvalidTokenError as e:
raise TokenValidationError(f"Invalid token: {str(e)}")
async def create_machine_token(
self,
access_key: str,
custom_claims: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Exchange Access Key for JWT token for Non-Human Identity using Descope Access Key approach"""
# Use demo mode if enabled
if self.demo_mode:
return await self.demo_client.create_machine_token(access_key, custom_claims)
try:
payload = {
"loginOptions": {
"customClaims": custom_claims or {}
}
}
response = await self.http_client.post(
DescopeConfig.ACCESS_KEY_EXCHANGE_URL,
json=payload,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {access_key}"
}
)
response.raise_for_status()
token_data = await response.json()
logger.info(
"access_key_exchanged",
key_id=token_data.get("keyId", "unknown"),
session_jwt=bool(token_data.get("sessionJwt"))
)
return {
"access_token": token_data.get("sessionJwt"),
"token_type": "Bearer",
"key_id": token_data.get("keyId")
}
except httpx.HTTPError as e:
logger.error(
"access_key_exchange_failed",
error=str(e)
)
raise TokenValidationError(f"Failed to exchange access key: {str(e)}")
class AuthContext:
"""Enhanced authentication context with detailed claims"""
def __init__(self, token_claims: Dict[str, Any]):
self.token_claims = token_claims
self.user_id = token_claims.get("sub", "unknown")
self.scopes = token_claims.get("permissions", [])
self.client_id = token_claims.get("aud", "unknown")
self.issued_at = datetime.fromtimestamp(token_claims.get("iat", 0), tz=timezone.utc)
self.expires_at = datetime.fromtimestamp(token_claims.get("exp", 0), tz=timezone.utc)
self.correlation_id = token_claims.get("jti", "unknown")
# Non-Human Identity specific claims
self.is_machine = token_claims.get("type") == "machine"
self.tenant_id = token_claims.get("tenantId")
self.custom_claims = token_claims.get("customClaims", {})
def has_scope(self, required_scope: str) -> bool:
"""Check if context has required scope"""
return required_scope in self.scopes
def has_any_scope(self, required_scopes: List[str]) -> bool:
"""Check if context has any of the required scopes"""
return any(scope in self.scopes for scope in required_scopes)
def has_all_scopes(self, required_scopes: List[str]) -> bool:
"""Check if context has all required scopes"""
return all(scope in self.scopes for scope in required_scopes)
def is_expired(self) -> bool:
"""Check if token is expired"""
return datetime.now(timezone.utc) > self.expires_at
def time_until_expiry(self) -> timedelta:
"""Get time until token expires"""
return self.expires_at - datetime.now(timezone.utc)
# Global Descope client instance
descope_client: Optional[DescopeClient] = None
_client_initialization_error: Optional[str] = None
async def get_descope_client() -> DescopeClient:
"""
Get singleton Descope client instance with robust error handling
This function ensures that server startup never fails due to authentication issues.
If real credentials fail, it will fall back gracefully while logging the issue.
"""
global descope_client, _client_initialization_error
if descope_client is None:
try:
logger.info("descope_client_init_start",
project_id_set=bool(settings.descope_project_id),
management_key_set=bool(settings.descope_management_key),
demo_mode=settings.descope_demo_mode)
# Initialize the client
descope_client = DescopeClient(
project_id=settings.descope_project_id,
management_key=settings.descope_management_key,
demo_mode=settings.descope_demo_mode
)
# Test the client initialization (but don't fail if it doesn't work)
try:
# For demo mode, this should always work
if settings.descope_demo_mode:
logger.info("descope_client_init_success", mode="demo")
else:
# For production mode, we don't test API calls here to avoid blocking startup
# API calls will be tested when actually needed
logger.info("descope_client_init_success", mode="production",
project_id=settings.descope_project_id[:8] + "..." if settings.descope_project_id else None)
except Exception as test_error:
# Log the test error but continue with the client
logger.warning("descope_client_test_failed",
error=str(test_error),
note="Client created but functionality may be limited")
except Exception as init_error:
# Log the initialization error
error_msg = f"Descope client initialization failed: {str(init_error)}"
_client_initialization_error = error_msg
logger.error("descope_client_init_failed",
error=str(init_error),
fallback="Server will continue without authentication")
# Create a fallback demo client to prevent complete failure
try:
logger.info("descope_fallback_init", mode="emergency_demo")
descope_client = DescopeClient(
project_id="emergency_demo_project",
management_key="emergency_demo_key",
demo_mode=True
)
logger.warning("descope_fallback_success",
warning="Using emergency demo mode due to initialization failure")
except Exception as fallback_error:
# If even the fallback fails, raise the original error
logger.critical("descope_fallback_failed",
original_error=str(init_error),
fallback_error=str(fallback_error))
raise Exception(f"Complete authentication failure: {init_error}")
return descope_client
def get_initialization_error() -> Optional[str]:
"""Get any initialization error that occurred"""
return _client_initialization_error