Skip to main content
Glama
jwks_client.py9.38 kB
"""JWKS client for fetching and caching signing keys. Fetches JSON Web Key Sets (JWKS) from OAuth authorization servers and caches them for efficient token validation. """ from __future__ import annotations import time from typing import TYPE_CHECKING import httpx import jwt from jwt import PyJWK, PyJWKSet from sso_mcp_server import get_logger from sso_mcp_server.auth.exceptions import InvalidTokenError, JWKSFetchError if TYPE_CHECKING: from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey _logger = get_logger("jwks_client") # Default timeout for HTTP requests HTTP_TIMEOUT = 10.0 class JWKSClient: """Client for fetching and caching JWKS from OAuth issuers. Implements caching with configurable TTL and automatic key rotation handling when a key ID is not found in the cached JWKS. Attributes: cache_ttl: Time-to-live for cached JWKS in seconds. """ def __init__(self, cache_ttl: int = 3600) -> None: """Initialize JWKS client. Args: cache_ttl: Cache time-to-live in seconds (default: 1 hour). """ self._cache_ttl = cache_ttl self._cache: dict[str, _CachedJWKS] = {} self._http_client: httpx.AsyncClient | None = None _logger.debug("jwks_client_initialized", cache_ttl=cache_ttl) async def _get_http_client(self) -> httpx.AsyncClient: """Get or create HTTP client (lazy initialization).""" if self._http_client is None: self._http_client = httpx.AsyncClient(timeout=HTTP_TIMEOUT) return self._http_client async def close(self) -> None: """Close the HTTP client.""" if self._http_client is not None: await self._http_client.aclose() self._http_client = None async def get_signing_key(self, token: str) -> RSAPublicKey: """Get the signing key for a JWT token. Extracts the key ID (kid) and issuer (iss) from the token header, fetches the JWKS from the issuer (with caching), and returns the matching public key. Args: token: The JWT token string. Returns: The RSA public key for verifying the token signature. Raises: InvalidTokenError: If the token is malformed or missing required claims. JWKSFetchError: If the JWKS cannot be fetched from the issuer. """ # Decode token header without verification to get kid and iss try: unverified_header = jwt.get_unverified_header(token) unverified_payload = jwt.decode(token, options={"verify_signature": False}) except jwt.exceptions.DecodeError as e: _logger.warning("failed_to_decode_token_header", error=str(e)) raise InvalidTokenError("Malformed JWT token") from e kid = unverified_header.get("kid") issuer = unverified_payload.get("iss") if not kid: raise InvalidTokenError("Token missing 'kid' (key ID) in header") if not issuer: raise InvalidTokenError("Token missing 'iss' (issuer) claim") _logger.debug("extracting_signing_key", kid=kid, issuer=issuer) # Get JWKS for this issuer jwks = await self._get_jwks(issuer) # Find the key matching the kid key = self._find_key_by_kid(jwks, kid, issuer) return key.key async def _get_jwks(self, issuer: str, force_refresh: bool = False) -> PyJWKSet: """Get JWKS for an issuer (with caching). Args: issuer: The token issuer URL. force_refresh: Force fetching fresh JWKS ignoring cache. Returns: The PyJWKSet for the issuer. Raises: JWKSFetchError: If the JWKS cannot be fetched. """ # Check cache if not force_refresh and issuer in self._cache: cached = self._cache[issuer] if not cached.is_expired(): _logger.debug("jwks_cache_hit", issuer=issuer) return cached.jwks _logger.debug("jwks_cache_miss", issuer=issuer, force_refresh=force_refresh) # Fetch JWKS jwks = await self._fetch_jwks(issuer) # Cache it self._cache[issuer] = _CachedJWKS(jwks=jwks, ttl=self._cache_ttl) return jwks async def _fetch_jwks(self, issuer: str) -> PyJWKSet: """Fetch JWKS from an issuer's well-known endpoint. First fetches the OpenID Connect discovery document to get the jwks_uri, then fetches the JWKS from that URI. Args: issuer: The token issuer URL. Returns: The PyJWKSet containing the signing keys. Raises: JWKSFetchError: If the JWKS cannot be fetched. """ client = await self._get_http_client() # Construct OpenID Connect discovery URL discovery_url = self._get_discovery_url(issuer) _logger.debug("fetching_oidc_discovery", url=discovery_url) try: # Fetch discovery document response = await client.get(discovery_url) response.raise_for_status() discovery = response.json() jwks_uri = discovery.get("jwks_uri") if not jwks_uri: raise JWKSFetchError(issuer, "Discovery document missing 'jwks_uri'") _logger.debug("fetching_jwks", jwks_uri=jwks_uri) # Fetch JWKS response = await client.get(jwks_uri) response.raise_for_status() jwks_data = response.json() return PyJWKSet.from_dict(jwks_data) except httpx.HTTPStatusError as e: _logger.error("jwks_http_error", issuer=issuer, status=e.response.status_code) raise JWKSFetchError(issuer, f"HTTP {e.response.status_code}") from e except httpx.RequestError as e: _logger.error("jwks_request_error", issuer=issuer, error=str(e)) raise JWKSFetchError(issuer, str(e)) from e except (KeyError, ValueError) as e: _logger.error("jwks_parse_error", issuer=issuer, error=str(e)) raise JWKSFetchError(issuer, f"Invalid JWKS format: {e}") from e def _get_discovery_url(self, issuer: str) -> str: """Construct the OpenID Connect discovery URL for an issuer. Handles both issuers with and without path components per RFC 8414. Args: issuer: The token issuer URL. Returns: The discovery document URL. """ # Remove trailing slash issuer = issuer.rstrip("/") # For Azure AD v2.0, the issuer typically ends with /v2.0 # The discovery URL is at /.well-known/openid-configuration return f"{issuer}/.well-known/openid-configuration" def _find_key_by_kid(self, jwks: PyJWKSet, kid: str, issuer: str) -> PyJWK: """Find a key in the JWKS by key ID. If the key is not found, attempts to refresh the JWKS once to handle key rotation. Args: jwks: The JWKS to search. kid: The key ID to find. issuer: The issuer (for error messages and refresh). Returns: The matching PyJWK. Raises: InvalidTokenError: If no matching key is found. """ for key in jwks.keys: if key.key_id == kid: return key _logger.warning("key_not_found_in_jwks", kid=kid, issuer=issuer) raise InvalidTokenError(f"No key found for kid '{kid}'") async def get_signing_key_with_retry(self, token: str) -> RSAPublicKey: """Get signing key with automatic retry on cache miss. If the key is not found in the cached JWKS, this method will force a refresh of the JWKS and try again once. This handles key rotation scenarios. Args: token: The JWT token string. Returns: The RSA public key for verifying the token signature. """ try: return await self.get_signing_key(token) except InvalidTokenError as e: if "No key found for kid" in str(e): # Key not found - might be key rotation, try refreshing JWKS _logger.info("retrying_jwks_fetch_after_key_not_found") # Get issuer from token unverified_payload = jwt.decode(token, options={"verify_signature": False}) issuer = unverified_payload.get("iss") if issuer: # Force refresh JWKS await self._get_jwks(issuer, force_refresh=True) # Retry return await self.get_signing_key(token) raise def clear_cache(self) -> None: """Clear the JWKS cache.""" self._cache.clear() _logger.debug("jwks_cache_cleared") class _CachedJWKS: """Internal class for caching JWKS with expiry.""" def __init__(self, jwks: PyJWKSet, ttl: int) -> None: """Initialize cached JWKS. Args: jwks: The JWKS to cache. ttl: Time-to-live in seconds. """ self.jwks = jwks self.expires_at = time.time() + ttl def is_expired(self) -> bool: """Check if the cache entry has expired.""" return time.time() >= self.expires_at

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/DauQuangThanh/sso-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server