"""Unit tests for TokenValidator."""
import time
from unittest.mock import AsyncMock, MagicMock
import jwt
import pytest
from cryptography.hazmat.primitives.asymmetric import rsa
from sso_mcp_server.auth.cloud.claims import TokenClaims
from sso_mcp_server.auth.cloud.validator import TokenValidator
from sso_mcp_server.auth.exceptions import (
CloudTokenExpiredError,
InvalidAudienceError,
InvalidIssuerError,
InvalidTokenError,
TokenSignatureError,
)
@pytest.fixture
def rsa_key_pair():
"""Generate RSA key pair for testing."""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
public_key = private_key.public_key()
return private_key, public_key
@pytest.fixture
def create_test_token(rsa_key_pair):
"""Factory fixture to create test JWT tokens."""
private_key, _ = rsa_key_pair
def _create_token(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=None,
iat=None,
scopes=None,
kid="test-key-id",
**extra_claims,
):
now = int(time.time())
payload = {
"sub": sub,
"iss": iss,
"aud": aud,
"exp": exp or now + 3600,
"iat": iat or now,
**extra_claims,
}
if scopes:
payload["scp"] = " ".join(scopes)
return jwt.encode(
payload,
private_key,
algorithm="RS256",
headers={"kid": kid},
)
return _create_token
@pytest.fixture
def mock_jwks_client(rsa_key_pair):
"""Create a mock JWKS client that returns our test key."""
_, public_key = rsa_key_pair
client = MagicMock()
client.get_signing_key_with_retry = AsyncMock(return_value=public_key)
return client
@pytest.fixture
def validator(mock_jwks_client):
"""Create TokenValidator with mock JWKS client."""
return TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=["https://issuer.example.com"],
jwks_client=mock_jwks_client,
)
class TestTokenValidatorValidate:
"""Tests for TokenValidator.validate method."""
@pytest.mark.asyncio
async def test_validates_valid_token(self, validator, create_test_token):
"""Test validating a valid token returns claims."""
token = create_test_token()
claims = await validator.validate(token)
assert isinstance(claims, TokenClaims)
assert claims.sub == "user123"
assert claims.iss == "https://issuer.example.com"
assert claims.aud == "https://resource.example.com"
@pytest.mark.asyncio
async def test_extracts_scopes_from_token(self, validator, create_test_token):
"""Test scopes are extracted from token."""
token = create_test_token(scopes=["read", "write"])
claims = await validator.validate(token)
assert claims.scopes == ["read", "write"]
@pytest.mark.asyncio
async def test_rejects_expired_token(self, validator, create_test_token):
"""Test expired token raises CloudTokenExpiredError."""
expired_token = create_test_token(exp=int(time.time()) - 100)
with pytest.raises(CloudTokenExpiredError):
await validator.validate(expired_token)
@pytest.mark.asyncio
async def test_rejects_wrong_audience(self, validator, create_test_token):
"""Test wrong audience raises InvalidAudienceError."""
token = create_test_token(aud="https://wrong-resource.example.com")
with pytest.raises(InvalidAudienceError):
await validator.validate(token)
@pytest.mark.asyncio
async def test_rejects_wrong_issuer(self, validator, create_test_token):
"""Test wrong issuer raises InvalidIssuerError."""
token = create_test_token(iss="https://untrusted.example.com")
with pytest.raises(InvalidIssuerError):
await validator.validate(token)
@pytest.mark.asyncio
async def test_rejects_invalid_signature(self, validator):
"""Test invalid signature raises TokenSignatureError."""
# Create token with different key
other_private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
token = jwt.encode(
{
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
},
other_private_key,
algorithm="RS256",
headers={"kid": "other-key"},
)
with pytest.raises(TokenSignatureError):
await validator.validate(token)
class TestTokenValidatorAudienceValidation:
"""Tests for audience validation."""
@pytest.mark.asyncio
async def test_accepts_matching_string_audience(self, mock_jwks_client, create_test_token):
"""Test accepts token with matching string audience."""
validator = TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=["https://issuer.example.com"],
jwks_client=mock_jwks_client,
)
token = create_test_token(aud="https://resource.example.com")
claims = await validator.validate(token)
assert claims is not None
@pytest.mark.asyncio
async def test_accepts_matching_audience_in_list(self, mock_jwks_client, create_test_token):
"""Test accepts token with matching audience in list."""
validator = TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=["https://issuer.example.com"],
jwks_client=mock_jwks_client,
)
token = create_test_token(
aud=["https://other.example.com", "https://resource.example.com"]
)
claims = await validator.validate(token)
assert claims is not None
@pytest.mark.asyncio
async def test_normalizes_trailing_slashes(self, mock_jwks_client, create_test_token):
"""Test audience comparison normalizes trailing slashes."""
validator = TokenValidator(
resource_identifier="https://resource.example.com/",
allowed_issuers=["https://issuer.example.com"],
jwks_client=mock_jwks_client,
)
token = create_test_token(aud="https://resource.example.com")
claims = await validator.validate(token)
assert claims is not None
class TestTokenValidatorIssuerValidation:
"""Tests for issuer validation."""
@pytest.mark.asyncio
async def test_accepts_issuer_in_allowlist(self, mock_jwks_client, create_test_token):
"""Test accepts issuer that's in the allowlist."""
validator = TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=[
"https://issuer1.example.com",
"https://issuer.example.com",
],
jwks_client=mock_jwks_client,
)
token = create_test_token(iss="https://issuer.example.com")
claims = await validator.validate(token)
assert claims.iss == "https://issuer.example.com"
@pytest.mark.asyncio
async def test_normalizes_issuer_trailing_slashes(self, mock_jwks_client, create_test_token):
"""Test issuer comparison normalizes trailing slashes."""
validator = TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=["https://issuer.example.com/"],
jwks_client=mock_jwks_client,
)
token = create_test_token(iss="https://issuer.example.com")
claims = await validator.validate(token)
assert claims is not None
class TestTokenValidatorMalformedTokens:
"""Tests for handling malformed tokens."""
@pytest.mark.asyncio
async def test_rejects_completely_invalid_token(self, validator):
"""Test rejects completely invalid token string."""
with pytest.raises(InvalidTokenError):
await validator.validate("not-a-jwt-token")
@pytest.mark.asyncio
async def test_rejects_empty_token(self, validator):
"""Test rejects empty token string."""
with pytest.raises(InvalidTokenError):
await validator.validate("")
@pytest.mark.asyncio
async def test_handles_jwks_client_error(self, mock_jwks_client, create_test_token):
"""Test handles error from JWKS client."""
mock_jwks_client.get_signing_key_with_retry = AsyncMock(
side_effect=InvalidTokenError("Key not found")
)
validator = TokenValidator(
resource_identifier="https://resource.example.com",
allowed_issuers=["https://issuer.example.com"],
jwks_client=mock_jwks_client,
)
token = create_test_token()
with pytest.raises(InvalidTokenError):
await validator.validate(token)