Skip to main content
Glama
test_token_manager_extended.py9.47 kB
"""Additional unit tests for TokenManager covering edge cases. These tests supplement the base tests in test_token_manager_base.py, focusing on JWT parsing, edge cases around expiration, and the _request_token method. """ # pyright: reportPrivateUsage=false import base64 import json from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock, MagicMock, patch import pytest from snc_cribl_mcp.client.token_manager import TokenManager from snc_cribl_mcp.config import CriblConfig def _config_with_credentials() -> CriblConfig: """Create a config with username/password authentication.""" return CriblConfig( server_url="https://cribl.example.com", base_url="https://cribl.example.com/api/v1", username="user", password="pass", ) def _config_with_token() -> CriblConfig: """Create a config with bearer token authentication.""" return CriblConfig( server_url="https://cribl.example.com", base_url="https://cribl.example.com/api/v1", bearer_token="preexisting-token", ) def _make_jwt(exp: datetime | None = None, *, missing_exp: bool = False) -> str: """Create a test JWT token with the given expiration.""" header = base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}').decode().rstrip("=") payload_dict: dict[str, int] = {} if exp and not missing_exp: payload_dict["exp"] = int(exp.timestamp()) payload = base64.urlsafe_b64encode(json.dumps(payload_dict).encode()).decode().rstrip("=") signature = base64.urlsafe_b64encode(b"fake-signature").decode().rstrip("=") return f"{header}.{payload}.{signature}" class TestTokenManagerJWTParsing: """Tests for JWT expiration parsing.""" def test_get_jwt_exp_valid_token(self) -> None: """Valid JWT should have its exp claim parsed correctly.""" manager = TokenManager(_config_with_credentials()) future = datetime.now(UTC) + timedelta(hours=2) token = _make_jwt(exp=future) result = manager._get_jwt_exp(token) # Allow a small tolerance for test execution time assert abs((result - future).total_seconds()) < 2 def test_get_jwt_exp_missing_exp_raises(self) -> None: """JWT without exp claim should raise ValueError.""" manager = TokenManager(_config_with_credentials()) token = _make_jwt(missing_exp=True) with pytest.raises(ValueError, match="missing 'exp' field"): manager._get_jwt_exp(token) def test_get_jwt_exp_invalid_format_raises(self) -> None: """Non-JWT string should raise ValueError.""" manager = TokenManager(_config_with_credentials()) with pytest.raises(ValueError, match="Invalid JWT format"): manager._get_jwt_exp("not.a.valid.jwt.token") def test_get_jwt_exp_two_parts_raises(self) -> None: """JWT with only two parts should raise ValueError.""" manager = TokenManager(_config_with_credentials()) with pytest.raises(ValueError, match="Invalid JWT format"): manager._get_jwt_exp("header.payload") class TestTokenManagerCaching: """Tests for token caching behavior.""" @pytest.mark.asyncio async def test_cached_token_returned_when_not_expired(self) -> None: """Cached token should be returned without network call when still valid.""" manager = TokenManager(_config_with_token()) # Set expiration to 1 hour from now manager._token_expires_at = datetime.now(UTC) + timedelta(hours=1) security = await manager.get_security() assert security.bearer_auth == "preexisting-token" @pytest.mark.asyncio async def test_cached_token_near_expiration_still_used(self) -> None: """Token near expiration (> 3 seconds) should still be used.""" manager = TokenManager(_config_with_token()) # Set expiration to 5 seconds from now (> 3 second buffer) manager._token_expires_at = datetime.now(UTC) + timedelta(seconds=5) security = await manager.get_security() assert security.bearer_auth == "preexisting-token" @pytest.mark.asyncio async def test_expired_token_logs_warning_no_credentials(self) -> None: """Expired token without credentials should log warning and return cached token.""" manager = TokenManager(_config_with_token()) # Set expiration to past manager._token_expires_at = datetime.now(UTC) - timedelta(hours=1) with patch("snc_cribl_mcp.client.token_manager.logger") as mock_logger: security = await manager.get_security() mock_logger.warning.assert_called() assert security.bearer_auth == "preexisting-token" class TestTokenManagerLocking: """Tests for async lock behavior.""" @pytest.mark.asyncio async def test_ensure_lock_creates_new_lock_on_new_loop(self) -> None: """A new lock should be created when called from a different event loop.""" manager = TokenManager(_config_with_credentials()) # First call to _ensure_lock lock1 = manager._ensure_lock() assert lock1 is not None # Simulate same loop - should return same lock lock2 = manager._ensure_lock() assert lock1 is lock2 class TestTokenManagerRequestToken: """Tests for the _request_token method.""" @pytest.mark.asyncio async def test_request_token_success(self) -> None: """Successful token request should return the token.""" manager = TokenManager(_config_with_credentials()) with patch("snc_cribl_mcp.client.token_manager.httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client_class.return_value.__aenter__.return_value = mock_client mock_control_plane = MagicMock() mock_control_plane.__aenter__ = AsyncMock(return_value=mock_control_plane) mock_control_plane.__aexit__ = AsyncMock(return_value=None) mock_response = MagicMock() mock_response.token = "new-token" mock_control_plane.auth.tokens.get_async = AsyncMock(return_value=mock_response) with patch("snc_cribl_mcp.client.token_manager.CriblControlPlane", return_value=mock_control_plane): token = await manager._request_token(username="user", password="pass") assert token == "new-token" @pytest.mark.asyncio async def test_fetch_and_cache_token_handles_unparseable_jwt(self) -> None: """If JWT parsing fails, token should still be cached with default expiration.""" manager = TokenManager(_config_with_credentials()) with ( patch.object(TokenManager, "_request_token", new_callable=AsyncMock, return_value="opaque-token"), patch.object(manager, "_get_jwt_exp", side_effect=ValueError("Invalid JWT")), patch("snc_cribl_mcp.client.token_manager.logger") as mock_logger, ): security = await manager.get_security() assert security.bearer_auth == "opaque-token" assert manager._cached_token == "opaque-token" assert manager._token_expires_at is not None # Should log warning about unparseable token mock_logger.warning.assert_called() @pytest.mark.asyncio async def test_fetch_and_cache_token_handles_json_decode_error(self) -> None: """If JWT payload is not valid JSON, token should still be cached.""" manager = TokenManager(_config_with_credentials()) with ( patch.object(TokenManager, "_request_token", new_callable=AsyncMock, return_value="bad-json-token"), patch.object(manager, "_get_jwt_exp", side_effect=json.JSONDecodeError("msg", "doc", 0)), patch("snc_cribl_mcp.client.token_manager.logger") as mock_logger, ): security = await manager.get_security() assert security.bearer_auth == "bad-json-token" mock_logger.warning.assert_called() @pytest.mark.asyncio async def test_fetch_token_exception_logged_and_raised(self) -> None: """Authentication failures should be logged and re-raised.""" manager = TokenManager(_config_with_credentials()) with ( patch.object(TokenManager, "_request_token", new_callable=AsyncMock, side_effect=Exception("Auth failed")), patch("snc_cribl_mcp.client.token_manager.logger") as mock_logger, pytest.raises(Exception, match="Auth failed"), ): await manager.get_security() mock_logger.exception.assert_called() class TestTokenManagerRefresh: """Tests for token refresh logic.""" @pytest.mark.asyncio async def test_token_refresh_when_near_expiration(self) -> None: """Token should be refreshed when within 3 seconds of expiration.""" manager = TokenManager(_config_with_credentials()) manager._cached_token = "old-token" # Set expiration to 2 seconds from now (< 3 second buffer) manager._token_expires_at = datetime.now(UTC) + timedelta(seconds=2) with ( patch.object(TokenManager, "_request_token", new_callable=AsyncMock, return_value="new-token"), patch.object(manager, "_get_jwt_exp", return_value=datetime.now(UTC) + timedelta(hours=1)), ): security = await manager.get_security() assert security.bearer_auth == "new-token"

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/atree1023/snc-cribl-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server