"""JWKS client for fetching and caching signing keys.
Fetches JSON Web Key Sets (JWKS) from OAuth authorization servers
and caches them for efficient token validation.
"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import httpx
import jwt
from jwt import PyJWK, PyJWKSet
from sso_mcp_server import get_logger
from sso_mcp_server.auth.exceptions import InvalidTokenError, JWKSFetchError
if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
_logger = get_logger("jwks_client")
# Default timeout for HTTP requests
HTTP_TIMEOUT = 10.0
class JWKSClient:
"""Client for fetching and caching JWKS from OAuth issuers.
Implements caching with configurable TTL and automatic key rotation
handling when a key ID is not found in the cached JWKS.
Attributes:
cache_ttl: Time-to-live for cached JWKS in seconds.
"""
def __init__(self, cache_ttl: int = 3600) -> None:
"""Initialize JWKS client.
Args:
cache_ttl: Cache time-to-live in seconds (default: 1 hour).
"""
self._cache_ttl = cache_ttl
self._cache: dict[str, _CachedJWKS] = {}
self._http_client: httpx.AsyncClient | None = None
_logger.debug("jwks_client_initialized", cache_ttl=cache_ttl)
async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client (lazy initialization)."""
if self._http_client is None:
self._http_client = httpx.AsyncClient(timeout=HTTP_TIMEOUT)
return self._http_client
async def close(self) -> None:
"""Close the HTTP client."""
if self._http_client is not None:
await self._http_client.aclose()
self._http_client = None
async def get_signing_key(self, token: str) -> RSAPublicKey:
"""Get the signing key for a JWT token.
Extracts the key ID (kid) and issuer (iss) from the token header,
fetches the JWKS from the issuer (with caching), and returns the
matching public key.
Args:
token: The JWT token string.
Returns:
The RSA public key for verifying the token signature.
Raises:
InvalidTokenError: If the token is malformed or missing required claims.
JWKSFetchError: If the JWKS cannot be fetched from the issuer.
"""
# Decode token header without verification to get kid and iss
try:
unverified_header = jwt.get_unverified_header(token)
unverified_payload = jwt.decode(token, options={"verify_signature": False})
except jwt.exceptions.DecodeError as e:
_logger.warning("failed_to_decode_token_header", error=str(e))
raise InvalidTokenError("Malformed JWT token") from e
kid = unverified_header.get("kid")
issuer = unverified_payload.get("iss")
if not kid:
raise InvalidTokenError("Token missing 'kid' (key ID) in header")
if not issuer:
raise InvalidTokenError("Token missing 'iss' (issuer) claim")
_logger.debug("extracting_signing_key", kid=kid, issuer=issuer)
# Get JWKS for this issuer
jwks = await self._get_jwks(issuer)
# Find the key matching the kid
key = self._find_key_by_kid(jwks, kid, issuer)
return key.key
async def _get_jwks(self, issuer: str, force_refresh: bool = False) -> PyJWKSet:
"""Get JWKS for an issuer (with caching).
Args:
issuer: The token issuer URL.
force_refresh: Force fetching fresh JWKS ignoring cache.
Returns:
The PyJWKSet for the issuer.
Raises:
JWKSFetchError: If the JWKS cannot be fetched.
"""
# Check cache
if not force_refresh and issuer in self._cache:
cached = self._cache[issuer]
if not cached.is_expired():
_logger.debug("jwks_cache_hit", issuer=issuer)
return cached.jwks
_logger.debug("jwks_cache_miss", issuer=issuer, force_refresh=force_refresh)
# Fetch JWKS
jwks = await self._fetch_jwks(issuer)
# Cache it
self._cache[issuer] = _CachedJWKS(jwks=jwks, ttl=self._cache_ttl)
return jwks
async def _fetch_jwks(self, issuer: str) -> PyJWKSet:
"""Fetch JWKS from an issuer's well-known endpoint.
First fetches the OpenID Connect discovery document to get the jwks_uri,
then fetches the JWKS from that URI.
Args:
issuer: The token issuer URL.
Returns:
The PyJWKSet containing the signing keys.
Raises:
JWKSFetchError: If the JWKS cannot be fetched.
"""
client = await self._get_http_client()
# Construct OpenID Connect discovery URL
discovery_url = self._get_discovery_url(issuer)
_logger.debug("fetching_oidc_discovery", url=discovery_url)
try:
# Fetch discovery document
response = await client.get(discovery_url)
response.raise_for_status()
discovery = response.json()
jwks_uri = discovery.get("jwks_uri")
if not jwks_uri:
raise JWKSFetchError(issuer, "Discovery document missing 'jwks_uri'")
_logger.debug("fetching_jwks", jwks_uri=jwks_uri)
# Fetch JWKS
response = await client.get(jwks_uri)
response.raise_for_status()
jwks_data = response.json()
return PyJWKSet.from_dict(jwks_data)
except httpx.HTTPStatusError as e:
_logger.error("jwks_http_error", issuer=issuer, status=e.response.status_code)
raise JWKSFetchError(issuer, f"HTTP {e.response.status_code}") from e
except httpx.RequestError as e:
_logger.error("jwks_request_error", issuer=issuer, error=str(e))
raise JWKSFetchError(issuer, str(e)) from e
except (KeyError, ValueError) as e:
_logger.error("jwks_parse_error", issuer=issuer, error=str(e))
raise JWKSFetchError(issuer, f"Invalid JWKS format: {e}") from e
def _get_discovery_url(self, issuer: str) -> str:
"""Construct the OpenID Connect discovery URL for an issuer.
Handles both issuers with and without path components per RFC 8414.
Args:
issuer: The token issuer URL.
Returns:
The discovery document URL.
"""
# Remove trailing slash
issuer = issuer.rstrip("/")
# For Azure AD v2.0, the issuer typically ends with /v2.0
# The discovery URL is at /.well-known/openid-configuration
return f"{issuer}/.well-known/openid-configuration"
def _find_key_by_kid(self, jwks: PyJWKSet, kid: str, issuer: str) -> PyJWK:
"""Find a key in the JWKS by key ID.
If the key is not found, attempts to refresh the JWKS once
to handle key rotation.
Args:
jwks: The JWKS to search.
kid: The key ID to find.
issuer: The issuer (for error messages and refresh).
Returns:
The matching PyJWK.
Raises:
InvalidTokenError: If no matching key is found.
"""
for key in jwks.keys:
if key.key_id == kid:
return key
_logger.warning("key_not_found_in_jwks", kid=kid, issuer=issuer)
raise InvalidTokenError(f"No key found for kid '{kid}'")
async def get_signing_key_with_retry(self, token: str) -> RSAPublicKey:
"""Get signing key with automatic retry on cache miss.
If the key is not found in the cached JWKS, this method will
force a refresh of the JWKS and try again once. This handles
key rotation scenarios.
Args:
token: The JWT token string.
Returns:
The RSA public key for verifying the token signature.
"""
try:
return await self.get_signing_key(token)
except InvalidTokenError as e:
if "No key found for kid" in str(e):
# Key not found - might be key rotation, try refreshing JWKS
_logger.info("retrying_jwks_fetch_after_key_not_found")
# Get issuer from token
unverified_payload = jwt.decode(token, options={"verify_signature": False})
issuer = unverified_payload.get("iss")
if issuer:
# Force refresh JWKS
await self._get_jwks(issuer, force_refresh=True)
# Retry
return await self.get_signing_key(token)
raise
def clear_cache(self) -> None:
"""Clear the JWKS cache."""
self._cache.clear()
_logger.debug("jwks_cache_cleared")
class _CachedJWKS:
"""Internal class for caching JWKS with expiry."""
def __init__(self, jwks: PyJWKSet, ttl: int) -> None:
"""Initialize cached JWKS.
Args:
jwks: The JWKS to cache.
ttl: Time-to-live in seconds.
"""
self.jwks = jwks
self.expires_at = time.time() + ttl
def is_expired(self) -> bool:
"""Check if the cache entry has expired."""
return time.time() >= self.expires_at