We provide all the information about MCP servers via our MCP API.
curl -X GET 'https://glama.ai/api/mcp/v1/servers/DauQuangThanh/sso-mcp-server'
If you have feedback or need assistance with the MCP directory API, please join our Discord server
"""JWKS client for fetching and caching signing keys.
Fetches JSON Web Key Sets (JWKS) from OAuth authorization servers
and caches them for efficient token validation.
"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import httpx
import jwt
from jwt import PyJWK, PyJWKSet
from sso_mcp_server import get_logger
from sso_mcp_server.auth.exceptions import InvalidTokenError, JWKSFetchError
if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
_logger = get_logger("jwks_client")
# Default timeout for HTTP requests
HTTP_TIMEOUT = 10.0
class JWKSClient:
"""Client for fetching and caching JWKS from OAuth issuers.
Implements caching with configurable TTL and automatic key rotation
handling when a key ID is not found in the cached JWKS.
Attributes:
cache_ttl: Time-to-live for cached JWKS in seconds.
"""
def __init__(self, cache_ttl: int = 3600) -> None:
"""Initialize JWKS client.
Args:
cache_ttl: Cache time-to-live in seconds (default: 1 hour).
"""
self._cache_ttl = cache_ttl
self._cache: dict[str, _CachedJWKS] = {}
self._http_client: httpx.AsyncClient | None = None
_logger.debug("jwks_client_initialized", cache_ttl=cache_ttl)
async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client (lazy initialization)."""
if self._http_client is None:
self._http_client = httpx.AsyncClient(timeout=HTTP_TIMEOUT)
return self._http_client
async def close(self) -> None:
"""Close the HTTP client."""
if self._http_client is not None:
await self._http_client.aclose()
self._http_client = None
async def get_signing_key(self, token: str) -> RSAPublicKey:
"""Get the signing key for a JWT token.
Extracts the key ID (kid) and issuer (iss) from the token header,
fetches the JWKS from the issuer (with caching), and returns the
matching public key.
Args:
token: The JWT token string.
Returns:
The RSA public key for verifying the token signature.
Raises:
InvalidTokenError: If the token is malformed or missing required claims.
JWKSFetchError: If the JWKS cannot be fetched from the issuer.
"""
# Decode token header without verification to get kid and iss
try:
unverified_header = jwt.get_unverified_header(token)
unverified_payload = jwt.decode(token, options={"verify_signature": False})
except jwt.exceptions.DecodeError as e:
_logger.warning("failed_to_decode_token_header", error=str(e))
raise InvalidTokenError("Malformed JWT token") from e
kid = unverified_header.get("kid")
issuer = unverified_payload.get("iss")
if not kid:
raise InvalidTokenError("Token missing 'kid' (key ID) in header")
if not issuer:
raise InvalidTokenError("Token missing 'iss' (issuer) claim")
_logger.debug("extracting_signing_key", kid=kid, issuer=issuer)
# Get JWKS for this issuer
jwks = await self._get_jwks(issuer)
# Find the key matching the kid
key = self._find_key_by_kid(jwks, kid, issuer)
return key.key
async def _get_jwks(self, issuer: str, force_refresh: bool = False) -> PyJWKSet:
"""Get JWKS for an issuer (with caching).
Args:
issuer: The token issuer URL.
force_refresh: Force fetching fresh JWKS ignoring cache.
Returns:
The PyJWKSet for the issuer.
Raises:
JWKSFetchError: If the JWKS cannot be fetched.
"""
# Check cache
if not force_refresh and issuer in self._cache:
cached = self._cache[issuer]
if not cached.is_expired():
_logger.debug("jwks_cache_hit", issuer=issuer)
return cached.jwks
_logger.debug("jwks_cache_miss", issuer=issuer, force_refresh=force_refresh)
# Fetch JWKS
jwks = await self._fetch_jwks(issuer)
# Cache it
self._cache[issuer] = _CachedJWKS(jwks=jwks, ttl=self._cache_ttl)
return jwks
async def _fetch_jwks(self, issuer: str) -> PyJWKSet:
"""Fetch JWKS from an issuer's well-known endpoint.
First fetches the OpenID Connect discovery document to get the jwks_uri,
then fetches the JWKS from that URI.
Args:
issuer: The token issuer URL.
Returns:
The PyJWKSet containing the signing keys.
Raises:
JWKSFetchError: If the JWKS cannot be fetched.
"""
client = await self._get_http_client()
# Construct OpenID Connect discovery URL
discovery_url = self._get_discovery_url(issuer)
_logger.debug("fetching_oidc_discovery", url=discovery_url)
try:
# Fetch discovery document
response = await client.get(discovery_url)
response.raise_for_status()
discovery = response.json()
jwks_uri = discovery.get("jwks_uri")
if not jwks_uri:
raise JWKSFetchError(issuer, "Discovery document missing 'jwks_uri'")
_logger.debug("fetching_jwks", jwks_uri=jwks_uri)
# Fetch JWKS
response = await client.get(jwks_uri)
response.raise_for_status()
jwks_data = response.json()
return PyJWKSet.from_dict(jwks_data)
except httpx.HTTPStatusError as e:
_logger.error("jwks_http_error", issuer=issuer, status=e.response.status_code)
raise JWKSFetchError(issuer, f"HTTP {e.response.status_code}") from e
except httpx.RequestError as e:
_logger.error("jwks_request_error", issuer=issuer, error=str(e))
raise JWKSFetchError(issuer, str(e)) from e
except (KeyError, ValueError) as e:
_logger.error("jwks_parse_error", issuer=issuer, error=str(e))
raise JWKSFetchError(issuer, f"Invalid JWKS format: {e}") from e
def _get_discovery_url(self, issuer: str) -> str:
"""Construct the OpenID Connect discovery URL for an issuer.
Handles both issuers with and without path components per RFC 8414.
Args:
issuer: The token issuer URL.
Returns:
The discovery document URL.
"""
# Remove trailing slash
issuer = issuer.rstrip("/")
# For Azure AD v2.0, the issuer typically ends with /v2.0
# The discovery URL is at /.well-known/openid-configuration
return f"{issuer}/.well-known/openid-configuration"
def _find_key_by_kid(self, jwks: PyJWKSet, kid: str, issuer: str) -> PyJWK:
"""Find a key in the JWKS by key ID.
If the key is not found, attempts to refresh the JWKS once
to handle key rotation.
Args:
jwks: The JWKS to search.
kid: The key ID to find.
issuer: The issuer (for error messages and refresh).
Returns:
The matching PyJWK.
Raises:
InvalidTokenError: If no matching key is found.
"""
for key in jwks.keys:
if key.key_id == kid:
return key
_logger.warning("key_not_found_in_jwks", kid=kid, issuer=issuer)
raise InvalidTokenError(f"No key found for kid '{kid}'")
async def get_signing_key_with_retry(self, token: str) -> RSAPublicKey:
"""Get signing key with automatic retry on cache miss.
If the key is not found in the cached JWKS, this method will
force a refresh of the JWKS and try again once. This handles
key rotation scenarios.
Args:
token: The JWT token string.
Returns:
The RSA public key for verifying the token signature.
"""
try:
return await self.get_signing_key(token)
except InvalidTokenError as e:
if "No key found for kid" in str(e):
# Key not found - might be key rotation, try refreshing JWKS
_logger.info("retrying_jwks_fetch_after_key_not_found")
# Get issuer from token
unverified_payload = jwt.decode(token, options={"verify_signature": False})
issuer = unverified_payload.get("iss")
if issuer:
# Force refresh JWKS
await self._get_jwks(issuer, force_refresh=True)
# Retry
return await self.get_signing_key(token)
raise
def clear_cache(self) -> None:
"""Clear the JWKS cache."""
self._cache.clear()
_logger.debug("jwks_cache_cleared")
class _CachedJWKS:
"""Internal class for caching JWKS with expiry."""
def __init__(self, jwks: PyJWKSet, ttl: int) -> None:
"""Initialize cached JWKS.
Args:
jwks: The JWKS to cache.
ttl: Time-to-live in seconds.
"""
self.jwks = jwks
self.expires_at = time.time() + ttl
def is_expired(self) -> bool:
"""Check if the cache entry has expired."""
return time.time() >= self.expires_at