"""Token validator for OAuth 2.1 Bearer tokens.
Validates JWT access tokens including signature verification,
standard claims validation, audience validation (RFC 8707),
and issuer validation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import jwt
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError, PyJWTError
from sso_mcp_server import get_logger
from sso_mcp_server.auth.cloud.claims import TokenClaims
from sso_mcp_server.auth.exceptions import (
CloudTokenExpiredError,
InvalidAudienceError,
InvalidIssuerError,
InvalidTokenError,
TokenSignatureError,
)
if TYPE_CHECKING:
from sso_mcp_server.auth.cloud.jwks_client import JWKSClient
_logger = get_logger("token_validator")
class TokenValidator:
"""Validates OAuth 2.1 Bearer tokens.
Performs comprehensive token validation including:
- Signature verification using JWKS
- Standard claims validation (exp, nbf, iat)
- Audience validation (RFC 8707 Resource Indicators)
- Issuer validation against allowlist
Attributes:
resource_identifier: This server's resource URL for audience validation.
allowed_issuers: List of trusted token issuers.
"""
def __init__(
self,
resource_identifier: str,
allowed_issuers: list[str],
jwks_client: JWKSClient,
) -> None:
"""Initialize token validator.
Args:
resource_identifier: This server's resource URL (for audience validation).
allowed_issuers: List of allowed token issuer URLs.
jwks_client: JWKS client for fetching signing keys.
"""
self._resource_identifier = resource_identifier
self._allowed_issuers = allowed_issuers
self._jwks_client = jwks_client
_logger.debug(
"token_validator_initialized",
resource_identifier=resource_identifier,
allowed_issuers=allowed_issuers,
)
async def validate(self, token: str) -> TokenClaims:
"""Validate a Bearer token and return claims.
Performs the following validations:
1. Signature verification using JWKS
2. Standard claims (exp, nbf, iat)
3. Audience matches resource_identifier
4. Issuer is in allowed_issuers list
Args:
token: The JWT access token string.
Returns:
TokenClaims with the validated token's claims.
Raises:
InvalidTokenError: If the token is malformed.
TokenSignatureError: If signature verification fails.
CloudTokenExpiredError: If the token has expired.
InvalidAudienceError: If the audience doesn't match.
InvalidIssuerError: If the issuer is not trusted.
"""
_logger.debug("validating_token")
# Get signing key (with retry for key rotation)
try:
signing_key = await self._jwks_client.get_signing_key_with_retry(token)
except InvalidTokenError:
raise
except Exception as e:
_logger.error("failed_to_get_signing_key", error=str(e))
raise InvalidTokenError(f"Failed to get signing key: {e}") from e
# Decode and verify the token
# We disable PyJWT's audience validation to handle it ourselves
# (we need to normalize trailing slashes and handle list audiences)
try:
payload = jwt.decode(
token,
key=signing_key,
algorithms=["RS256", "RS384", "RS512"],
options={
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": True,
"verify_aud": False, # We validate audience ourselves
"require": ["exp", "iat", "iss", "aud"],
},
)
except ExpiredSignatureError as e:
_logger.warning("token_expired")
raise CloudTokenExpiredError() from e
except InvalidSignatureError as e:
_logger.warning("invalid_token_signature")
raise TokenSignatureError() from e
except PyJWTError as e:
_logger.warning("token_validation_failed", error=str(e))
raise InvalidTokenError(str(e)) from e
# Validate issuer
issuer = payload.get("iss", "")
self._validate_issuer(issuer)
# Validate audience
audience = payload.get("aud", "")
self._validate_audience(audience)
# Build claims
claims = TokenClaims.from_jwt_payload(payload)
_logger.info(
"token_validated",
sub=claims.sub,
iss=claims.iss,
scopes=claims.scopes,
)
return claims
def _validate_issuer(self, issuer: str) -> None:
"""Validate the token issuer against the allowlist.
Args:
issuer: The issuer claim from the token.
Raises:
InvalidIssuerError: If the issuer is not in the allowlist.
"""
# Normalize issuers for comparison (remove trailing slashes)
normalized_issuer = issuer.rstrip("/")
normalized_allowed = [iss.rstrip("/") for iss in self._allowed_issuers]
if normalized_issuer not in normalized_allowed:
_logger.warning(
"invalid_issuer",
issuer=issuer,
allowed=self._allowed_issuers,
)
raise InvalidIssuerError(issuer)
def _validate_audience(self, audience: str | list[str]) -> None:
"""Validate the token audience matches this server's resource identifier.
Per RFC 8707 (Resource Indicators), tokens must be bound to the
specific resource they're intended for.
Args:
audience: The audience claim from the token (string or list).
Raises:
InvalidAudienceError: If the audience doesn't match.
"""
# Normalize resource identifier
expected = self._resource_identifier.rstrip("/")
# Handle both string and list audience
audiences = [audience] if isinstance(audience, str) else audience
# Normalize and check
normalized_audiences = [aud.rstrip("/") for aud in audiences]
if expected not in normalized_audiences:
_logger.warning(
"invalid_audience",
expected=expected,
actual=audience,
)
raise InvalidAudienceError(self._resource_identifier, audience)
_logger.debug("audience_validated", audience=audience)