Skip to main content
Glama
test_token_validator.py9.14 kB
"""Unit tests for TokenValidator.""" import time from unittest.mock import AsyncMock, MagicMock import jwt import pytest from cryptography.hazmat.primitives.asymmetric import rsa from sso_mcp_server.auth.cloud.claims import TokenClaims from sso_mcp_server.auth.cloud.validator import TokenValidator from sso_mcp_server.auth.exceptions import ( CloudTokenExpiredError, InvalidAudienceError, InvalidIssuerError, InvalidTokenError, TokenSignatureError, ) @pytest.fixture def rsa_key_pair(): """Generate RSA key pair for testing.""" private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) public_key = private_key.public_key() return private_key, public_key @pytest.fixture def create_test_token(rsa_key_pair): """Factory fixture to create test JWT tokens.""" private_key, _ = rsa_key_pair def _create_token( sub="user123", iss="https://issuer.example.com", aud="https://resource.example.com", exp=None, iat=None, scopes=None, kid="test-key-id", **extra_claims, ): now = int(time.time()) payload = { "sub": sub, "iss": iss, "aud": aud, "exp": exp or now + 3600, "iat": iat or now, **extra_claims, } if scopes: payload["scp"] = " ".join(scopes) return jwt.encode( payload, private_key, algorithm="RS256", headers={"kid": kid}, ) return _create_token @pytest.fixture def mock_jwks_client(rsa_key_pair): """Create a mock JWKS client that returns our test key.""" _, public_key = rsa_key_pair client = MagicMock() client.get_signing_key_with_retry = AsyncMock(return_value=public_key) return client @pytest.fixture def validator(mock_jwks_client): """Create TokenValidator with mock JWKS client.""" return TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=["https://issuer.example.com"], jwks_client=mock_jwks_client, ) class TestTokenValidatorValidate: """Tests for TokenValidator.validate method.""" @pytest.mark.asyncio async def test_validates_valid_token(self, validator, create_test_token): """Test validating a valid token returns claims.""" token = create_test_token() claims = await validator.validate(token) assert isinstance(claims, TokenClaims) assert claims.sub == "user123" assert claims.iss == "https://issuer.example.com" assert claims.aud == "https://resource.example.com" @pytest.mark.asyncio async def test_extracts_scopes_from_token(self, validator, create_test_token): """Test scopes are extracted from token.""" token = create_test_token(scopes=["read", "write"]) claims = await validator.validate(token) assert claims.scopes == ["read", "write"] @pytest.mark.asyncio async def test_rejects_expired_token(self, validator, create_test_token): """Test expired token raises CloudTokenExpiredError.""" expired_token = create_test_token(exp=int(time.time()) - 100) with pytest.raises(CloudTokenExpiredError): await validator.validate(expired_token) @pytest.mark.asyncio async def test_rejects_wrong_audience(self, validator, create_test_token): """Test wrong audience raises InvalidAudienceError.""" token = create_test_token(aud="https://wrong-resource.example.com") with pytest.raises(InvalidAudienceError): await validator.validate(token) @pytest.mark.asyncio async def test_rejects_wrong_issuer(self, validator, create_test_token): """Test wrong issuer raises InvalidIssuerError.""" token = create_test_token(iss="https://untrusted.example.com") with pytest.raises(InvalidIssuerError): await validator.validate(token) @pytest.mark.asyncio async def test_rejects_invalid_signature(self, validator): """Test invalid signature raises TokenSignatureError.""" # Create token with different key other_private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) token = jwt.encode( { "sub": "user123", "iss": "https://issuer.example.com", "aud": "https://resource.example.com", "exp": int(time.time()) + 3600, "iat": int(time.time()), }, other_private_key, algorithm="RS256", headers={"kid": "other-key"}, ) with pytest.raises(TokenSignatureError): await validator.validate(token) class TestTokenValidatorAudienceValidation: """Tests for audience validation.""" @pytest.mark.asyncio async def test_accepts_matching_string_audience(self, mock_jwks_client, create_test_token): """Test accepts token with matching string audience.""" validator = TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=["https://issuer.example.com"], jwks_client=mock_jwks_client, ) token = create_test_token(aud="https://resource.example.com") claims = await validator.validate(token) assert claims is not None @pytest.mark.asyncio async def test_accepts_matching_audience_in_list(self, mock_jwks_client, create_test_token): """Test accepts token with matching audience in list.""" validator = TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=["https://issuer.example.com"], jwks_client=mock_jwks_client, ) token = create_test_token( aud=["https://other.example.com", "https://resource.example.com"] ) claims = await validator.validate(token) assert claims is not None @pytest.mark.asyncio async def test_normalizes_trailing_slashes(self, mock_jwks_client, create_test_token): """Test audience comparison normalizes trailing slashes.""" validator = TokenValidator( resource_identifier="https://resource.example.com/", allowed_issuers=["https://issuer.example.com"], jwks_client=mock_jwks_client, ) token = create_test_token(aud="https://resource.example.com") claims = await validator.validate(token) assert claims is not None class TestTokenValidatorIssuerValidation: """Tests for issuer validation.""" @pytest.mark.asyncio async def test_accepts_issuer_in_allowlist(self, mock_jwks_client, create_test_token): """Test accepts issuer that's in the allowlist.""" validator = TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=[ "https://issuer1.example.com", "https://issuer.example.com", ], jwks_client=mock_jwks_client, ) token = create_test_token(iss="https://issuer.example.com") claims = await validator.validate(token) assert claims.iss == "https://issuer.example.com" @pytest.mark.asyncio async def test_normalizes_issuer_trailing_slashes(self, mock_jwks_client, create_test_token): """Test issuer comparison normalizes trailing slashes.""" validator = TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=["https://issuer.example.com/"], jwks_client=mock_jwks_client, ) token = create_test_token(iss="https://issuer.example.com") claims = await validator.validate(token) assert claims is not None class TestTokenValidatorMalformedTokens: """Tests for handling malformed tokens.""" @pytest.mark.asyncio async def test_rejects_completely_invalid_token(self, validator): """Test rejects completely invalid token string.""" with pytest.raises(InvalidTokenError): await validator.validate("not-a-jwt-token") @pytest.mark.asyncio async def test_rejects_empty_token(self, validator): """Test rejects empty token string.""" with pytest.raises(InvalidTokenError): await validator.validate("") @pytest.mark.asyncio async def test_handles_jwks_client_error(self, mock_jwks_client, create_test_token): """Test handles error from JWKS client.""" mock_jwks_client.get_signing_key_with_retry = AsyncMock( side_effect=InvalidTokenError("Key not found") ) validator = TokenValidator( resource_identifier="https://resource.example.com", allowed_issuers=["https://issuer.example.com"], jwks_client=mock_jwks_client, ) token = create_test_token() with pytest.raises(InvalidTokenError): await validator.validate(token)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/DauQuangThanh/sso-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server