Skip to main content
Glama
validator.py6.69 kB
"""Token validator for OAuth 2.1 Bearer tokens. Validates JWT access tokens including signature verification, standard claims validation, audience validation (RFC 8707), and issuer validation. """ from __future__ import annotations from typing import TYPE_CHECKING import jwt from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError, PyJWTError from sso_mcp_server import get_logger from sso_mcp_server.auth.cloud.claims import TokenClaims from sso_mcp_server.auth.exceptions import ( CloudTokenExpiredError, InvalidAudienceError, InvalidIssuerError, InvalidTokenError, TokenSignatureError, ) if TYPE_CHECKING: from sso_mcp_server.auth.cloud.jwks_client import JWKSClient _logger = get_logger("token_validator") class TokenValidator: """Validates OAuth 2.1 Bearer tokens. Performs comprehensive token validation including: - Signature verification using JWKS - Standard claims validation (exp, nbf, iat) - Audience validation (RFC 8707 Resource Indicators) - Issuer validation against allowlist Attributes: resource_identifier: This server's resource URL for audience validation. allowed_issuers: List of trusted token issuers. """ def __init__( self, resource_identifier: str, allowed_issuers: list[str], jwks_client: JWKSClient, ) -> None: """Initialize token validator. Args: resource_identifier: This server's resource URL (for audience validation). allowed_issuers: List of allowed token issuer URLs. jwks_client: JWKS client for fetching signing keys. """ self._resource_identifier = resource_identifier self._allowed_issuers = allowed_issuers self._jwks_client = jwks_client _logger.debug( "token_validator_initialized", resource_identifier=resource_identifier, allowed_issuers=allowed_issuers, ) async def validate(self, token: str) -> TokenClaims: """Validate a Bearer token and return claims. Performs the following validations: 1. Signature verification using JWKS 2. Standard claims (exp, nbf, iat) 3. Audience matches resource_identifier 4. Issuer is in allowed_issuers list Args: token: The JWT access token string. Returns: TokenClaims with the validated token's claims. Raises: InvalidTokenError: If the token is malformed. TokenSignatureError: If signature verification fails. CloudTokenExpiredError: If the token has expired. InvalidAudienceError: If the audience doesn't match. InvalidIssuerError: If the issuer is not trusted. """ _logger.debug("validating_token") # Get signing key (with retry for key rotation) try: signing_key = await self._jwks_client.get_signing_key_with_retry(token) except InvalidTokenError: raise except Exception as e: _logger.error("failed_to_get_signing_key", error=str(e)) raise InvalidTokenError(f"Failed to get signing key: {e}") from e # Decode and verify the token # We disable PyJWT's audience validation to handle it ourselves # (we need to normalize trailing slashes and handle list audiences) try: payload = jwt.decode( token, key=signing_key, algorithms=["RS256", "RS384", "RS512"], options={ "verify_signature": True, "verify_exp": True, "verify_nbf": True, "verify_iat": True, "verify_aud": False, # We validate audience ourselves "require": ["exp", "iat", "iss", "aud"], }, ) except ExpiredSignatureError as e: _logger.warning("token_expired") raise CloudTokenExpiredError() from e except InvalidSignatureError as e: _logger.warning("invalid_token_signature") raise TokenSignatureError() from e except PyJWTError as e: _logger.warning("token_validation_failed", error=str(e)) raise InvalidTokenError(str(e)) from e # Validate issuer issuer = payload.get("iss", "") self._validate_issuer(issuer) # Validate audience audience = payload.get("aud", "") self._validate_audience(audience) # Build claims claims = TokenClaims.from_jwt_payload(payload) _logger.info( "token_validated", sub=claims.sub, iss=claims.iss, scopes=claims.scopes, ) return claims def _validate_issuer(self, issuer: str) -> None: """Validate the token issuer against the allowlist. Args: issuer: The issuer claim from the token. Raises: InvalidIssuerError: If the issuer is not in the allowlist. """ # Normalize issuers for comparison (remove trailing slashes) normalized_issuer = issuer.rstrip("/") normalized_allowed = [iss.rstrip("/") for iss in self._allowed_issuers] if normalized_issuer not in normalized_allowed: _logger.warning( "invalid_issuer", issuer=issuer, allowed=self._allowed_issuers, ) raise InvalidIssuerError(issuer) def _validate_audience(self, audience: str | list[str]) -> None: """Validate the token audience matches this server's resource identifier. Per RFC 8707 (Resource Indicators), tokens must be bound to the specific resource they're intended for. Args: audience: The audience claim from the token (string or list). Raises: InvalidAudienceError: If the audience doesn't match. """ # Normalize resource identifier expected = self._resource_identifier.rstrip("/") # Handle both string and list audience audiences = [audience] if isinstance(audience, str) else audience # Normalize and check normalized_audiences = [aud.rstrip("/") for aud in audiences] if expected not in normalized_audiences: _logger.warning( "invalid_audience", expected=expected, actual=audience, ) raise InvalidAudienceError(self._resource_identifier, audience) _logger.debug("audience_validated", audience=audience)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/DauQuangThanh/sso-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server