"""Unit tests for TokenClaims dataclass."""
import time
from datetime import datetime
from sso_mcp_server.auth.cloud.claims import TokenClaims
class TestTokenClaims:
"""Tests for TokenClaims dataclass."""
def test_has_scope_returns_true_for_existing_scope(self):
"""Test has_scope returns True when scope exists."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read", "write", "admin"],
)
assert claims.has_scope("read") is True
assert claims.has_scope("write") is True
assert claims.has_scope("admin") is True
def test_has_scope_returns_false_for_missing_scope(self):
"""Test has_scope returns False when scope doesn't exist."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read"],
)
assert claims.has_scope("write") is False
assert claims.has_scope("admin") is False
def test_has_any_scope_returns_true_when_any_match(self):
"""Test has_any_scope returns True when at least one scope matches."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read", "write"],
)
assert claims.has_any_scope(["read", "admin"]) is True
assert claims.has_any_scope(["write", "delete"]) is True
def test_has_any_scope_returns_false_when_none_match(self):
"""Test has_any_scope returns False when no scopes match."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read"],
)
assert claims.has_any_scope(["write", "admin"]) is False
def test_has_all_scopes_returns_true_when_all_match(self):
"""Test has_all_scopes returns True when all scopes exist."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read", "write", "admin"],
)
assert claims.has_all_scopes(["read", "write"]) is True
assert claims.has_all_scopes(["read"]) is True
def test_has_all_scopes_returns_false_when_any_missing(self):
"""Test has_all_scopes returns False when any scope is missing."""
claims = TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read", "write"],
)
assert claims.has_all_scopes(["read", "admin"]) is False
class TestTokenClaimsFromJwtPayload:
"""Tests for TokenClaims.from_jwt_payload factory method."""
def test_parses_standard_claims(self):
"""Test parsing standard JWT claims."""
now = int(time.time())
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": now + 3600,
"iat": now,
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.sub == "user123"
assert claims.iss == "https://issuer.example.com"
assert claims.aud == "https://resource.example.com"
assert claims.exp.timestamp() == now + 3600
assert claims.iat.timestamp() == now
def test_parses_space_separated_scopes(self):
"""Test parsing space-separated scope claim."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"scope": "read write admin",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.scopes == ["read", "write", "admin"]
def test_parses_azure_scp_claim(self):
"""Test parsing Azure AD 'scp' claim."""
payload = {
"sub": "user123",
"iss": "https://login.microsoftonline.com/tenant/v2.0",
"aud": "api://resource-id",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"scp": "User.Read Files.ReadWrite",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.scopes == ["User.Read", "Files.ReadWrite"]
def test_parses_array_scopes(self):
"""Test parsing array-format scope claim."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"scope": ["read", "write"],
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.scopes == ["read", "write"]
def test_parses_empty_scopes(self):
"""Test parsing when no scope claim present."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.scopes == []
def test_parses_email_from_email_claim(self):
"""Test parsing email from 'email' claim."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"email": "user@example.com",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.email == "user@example.com"
def test_parses_email_from_preferred_username(self):
"""Test parsing email from 'preferred_username' claim when 'email' not present."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"preferred_username": "user@example.com",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.email == "user@example.com"
def test_parses_name_claim(self):
"""Test parsing 'name' claim."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"name": "John Doe",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.name == "John Doe"
def test_stores_raw_claims(self):
"""Test that raw claims are stored for additional access."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": "https://resource.example.com",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
"custom_claim": "custom_value",
"tenant_id": "tenant-123",
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.raw_claims["custom_claim"] == "custom_value"
assert claims.raw_claims["tenant_id"] == "tenant-123"
def test_handles_list_audience(self):
"""Test handling of list audience claim."""
payload = {
"sub": "user123",
"iss": "https://issuer.example.com",
"aud": ["https://resource1.example.com", "https://resource2.example.com"],
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
}
claims = TokenClaims.from_jwt_payload(payload)
assert claims.aud == ["https://resource1.example.com", "https://resource2.example.com"]