Skip to main content
Glama
test_kb_auth.py19.7 kB
""" Tests for the KBase authentication module. Tests cover: - KBaseAuth.create() - initialization, session management - get_user() - token validation, caching, role verification - Error cases: InvalidTokenError, MissingRoleError - Concurrent get_user() calls with cache """ import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from cacheout.lru import LRUCache from src.service.kb_auth import ( KBaseAuth, KBaseUser, AdminPermission, _check_error, ) from src.service.exceptions import InvalidTokenError, MissingRoleError # ============================================================================= # Test AdminPermission Enum # ============================================================================= class TestAdminPermission: """Tests for AdminPermission enum.""" def test_none_permission_value(self): """Test NONE permission has expected value.""" assert AdminPermission.NONE == 1 def test_full_permission_value(self): """Test FULL permission has expected value.""" assert AdminPermission.FULL == 10 def test_permission_ordering(self): """Test that FULL > NONE.""" assert AdminPermission.FULL > AdminPermission.NONE # ============================================================================= # Test KBaseUser NamedTuple # ============================================================================= class TestKBaseUser: """Tests for KBaseUser named tuple.""" def test_user_creation(self): """Test creating a KBaseUser.""" user = KBaseUser(user="testuser", admin_perm=AdminPermission.NONE) assert user.user == "testuser" assert user.admin_perm == AdminPermission.NONE def test_admin_user_creation(self): """Test creating an admin user.""" user = KBaseUser(user="admin", admin_perm=AdminPermission.FULL) assert user.admin_perm == AdminPermission.FULL def test_user_is_tuple(self): """Test that KBaseUser behaves as tuple.""" user = KBaseUser(user="testuser", admin_perm=AdminPermission.NONE) assert user[0] == "testuser" assert user[1] == AdminPermission.NONE # ============================================================================= # Test _check_error Helper # ============================================================================= class TestCheckError: """Tests for the _check_error helper function.""" @pytest.mark.asyncio async def test_success_response_no_error(self): """Test that 200 response doesn't raise error.""" response = MagicMock() response.status = 200 # Should not raise await _check_error(response) @pytest.mark.asyncio async def test_invalid_token_error(self): """Test that invalid token error code raises InvalidTokenError.""" response = MagicMock() response.status = 401 response.json = AsyncMock( return_value={"error": {"appcode": 10020, "message": "Invalid token"}} ) with pytest.raises(InvalidTokenError): await _check_error(response) @pytest.mark.asyncio async def test_other_error_raises_ioerror(self): """Test that other errors raise IOError.""" response = MagicMock() response.status = 500 response.json = AsyncMock( return_value={"error": {"appcode": 99999, "message": "Server error"}} ) with pytest.raises(IOError, match="Server error"): await _check_error(response) @pytest.mark.asyncio async def test_non_json_response_raises_ioerror(self): """Test that non-JSON response raises IOError.""" response = MagicMock() response.status = 500 response.json = AsyncMock(side_effect=Exception("Not JSON")) response.text = AsyncMock(return_value="Plain text error") with pytest.raises(IOError, match="Non-JSON response"): await _check_error(response) # ============================================================================= # Test KBaseAuth.create() # ============================================================================= class TestKBaseAuthCreate: """Tests for KBaseAuth.create() factory method.""" @pytest.mark.asyncio async def test_successful_creation(self): """Test successful KBaseAuth creation.""" mock_session = MagicMock() mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={"servicename": "Authentication Service"} ) # Create async context manager mock mock_cm = MagicMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_response) mock_cm.__aexit__ = AsyncMock(return_value=None) mock_session.get.return_value = mock_cm mock_session.close = AsyncMock() with patch("aiohttp.ClientSession", return_value=mock_session): auth = await KBaseAuth.create( "https://auth.kbase.us/", required_roles=["BERDL_USER"], full_admin_roles=["KBASE_ADMIN"], ) assert auth is not None assert auth._url == "https://auth.kbase.us/" @pytest.mark.asyncio async def test_adds_trailing_slash(self): """Test that trailing slash is added to URL.""" mock_session = MagicMock() mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={"servicename": "Authentication Service"} ) mock_cm = MagicMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_response) mock_cm.__aexit__ = AsyncMock(return_value=None) mock_session.get.return_value = mock_cm mock_session.close = AsyncMock() with patch("aiohttp.ClientSession", return_value=mock_session): auth = await KBaseAuth.create("https://auth.kbase.us") assert auth._url == "https://auth.kbase.us/" @pytest.mark.asyncio async def test_invalid_service_raises_error(self): """Test that non-auth service raises IOError.""" mock_session = MagicMock() mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={"servicename": "Wrong Service"}) mock_cm = MagicMock() mock_cm.__aenter__ = AsyncMock(return_value=mock_response) mock_cm.__aexit__ = AsyncMock(return_value=None) mock_session.get.return_value = mock_cm mock_session.close = AsyncMock() with patch("aiohttp.ClientSession", return_value=mock_session): with pytest.raises(IOError, match="does not appear to be"): await KBaseAuth.create("https://wrong.service/") @pytest.mark.asyncio async def test_session_closed_on_init_error(self): """Test that session is closed if initialization fails.""" mock_session = MagicMock() mock_session.get.side_effect = Exception("Connection failed") mock_session.close = AsyncMock() with patch("aiohttp.ClientSession", return_value=mock_session): with pytest.raises(Exception, match="Connection failed"): await KBaseAuth.create("https://auth.kbase.us/") mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_empty_url_raises_error(self): """Test that empty URL raises error.""" with pytest.raises(ValueError): await KBaseAuth.create("") # ============================================================================= # Test KBaseAuth.get_user() # ============================================================================= class TestKBaseAuthGetUser: """Tests for KBaseAuth.get_user() method.""" @pytest.fixture def mock_auth(self): """Create a mock KBaseAuth instance.""" auth = MagicMock(spec=KBaseAuth) auth._url = "https://auth.kbase.us/" auth._me_url = "https://auth.kbase.us/api/V2/me" auth._req_roles = {"BERDL_USER"} auth._full_roles = {"KBASE_ADMIN"} auth._cache = LRUCache(maxsize=100, ttl=300) auth._session = MagicMock() # Make get_user call the real implementation auth.get_user = lambda token: KBaseAuth.get_user(auth, token) auth._get = AsyncMock() auth._get_admin_role = lambda roles: KBaseAuth._get_admin_role(auth, roles) return auth @pytest.mark.asyncio async def test_valid_token_returns_user(self, mock_auth): """Test that valid token returns KBaseUser.""" mock_auth._get.return_value = { "user": "testuser", "customroles": ["BERDL_USER"], } user = await mock_auth.get_user("valid_token") assert user.user == "testuser" assert user.admin_perm == AdminPermission.NONE @pytest.mark.asyncio async def test_admin_role_detected(self, mock_auth): """Test that admin role is detected.""" mock_auth._get.return_value = { "user": "adminuser", "customroles": ["BERDL_USER", "KBASE_ADMIN"], } user = await mock_auth.get_user("admin_token") assert user.admin_perm == AdminPermission.FULL @pytest.mark.asyncio async def test_missing_required_role_raises_error(self, mock_auth): """Test that missing required role raises MissingRoleError.""" mock_auth._get.return_value = { "user": "norolesuser", "customroles": [], # Missing BERDL_USER } with pytest.raises(MissingRoleError, match="missing required"): await mock_auth.get_user("noroles_token") @pytest.mark.asyncio async def test_cached_user_returned(self, mock_auth): """Test that cached user is returned without API call.""" # Pre-populate cache mock_auth._cache.set("cached_token", ("cacheduser", AdminPermission.NONE)) user = await mock_auth.get_user("cached_token") assert user.user == "cacheduser" # _get should not be called because cache hit mock_auth._get.assert_not_called() @pytest.mark.asyncio async def test_empty_token_raises_error(self, mock_auth): """Test that empty token raises error.""" with pytest.raises(ValueError): await mock_auth.get_user("") @pytest.mark.asyncio async def test_none_token_raises_error(self, mock_auth): """Test that None token raises error.""" with pytest.raises((ValueError, TypeError)): await mock_auth.get_user(None) @pytest.mark.asyncio async def test_user_cached_after_fetch(self, mock_auth): """Test that user is cached after successful fetch.""" mock_auth._get.return_value = { "user": "newuser", "customroles": ["BERDL_USER"], } # First call await mock_auth.get_user("new_token") # Verify cached cached = mock_auth._cache.get("new_token") assert cached is not None assert cached[0] == "newuser" # ============================================================================= # Test KBaseAuth.close() # ============================================================================= class TestKBaseAuthClose: """Tests for KBaseAuth.close() method.""" @pytest.mark.asyncio async def test_close_closes_session(self): """Test that close() closes the HTTP session.""" mock_session = MagicMock() mock_session.closed = False mock_session.close = AsyncMock() auth = MagicMock(spec=KBaseAuth) auth._session = mock_session await KBaseAuth.close(auth) mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_close_does_nothing_if_already_closed(self): """Test that close() is safe to call on closed session.""" mock_session = MagicMock() mock_session.closed = True mock_session.close = AsyncMock() auth = MagicMock(spec=KBaseAuth) auth._session = mock_session await KBaseAuth.close(auth) mock_session.close.assert_not_called() # ============================================================================= # Test _get_admin_role Helper # ============================================================================= class TestGetAdminRole: """Tests for the _get_admin_role method.""" def test_full_admin_role_detected(self): """Test that full admin role is detected.""" auth = MagicMock() auth._full_roles = {"KBASE_ADMIN", "SUPER_ADMIN"} result = KBaseAuth._get_admin_role(auth, {"BERDL_USER", "KBASE_ADMIN"}) assert result == AdminPermission.FULL def test_no_admin_role_returns_none(self): """Test that no admin role returns NONE permission.""" auth = MagicMock() auth._full_roles = {"KBASE_ADMIN"} result = KBaseAuth._get_admin_role(auth, {"BERDL_USER", "OTHER_ROLE"}) assert result == AdminPermission.NONE def test_empty_roles_returns_none(self): """Test that empty roles returns NONE permission.""" auth = MagicMock() auth._full_roles = {"KBASE_ADMIN"} result = KBaseAuth._get_admin_role(auth, set()) assert result == AdminPermission.NONE # ============================================================================= # Concurrent Access Tests # ============================================================================= class TestConcurrentAuthAccess: """Tests for concurrent authentication access.""" @pytest.mark.asyncio async def test_concurrent_get_user_calls(self, async_concurrent_executor): """Test multiple concurrent get_user calls with cache.""" # Create a simple mock auth auth = MagicMock() auth._url = "https://auth.kbase.us/" auth._me_url = "https://auth.kbase.us/api/V2/me" auth._req_roles = {"BERDL_USER"} auth._full_roles = {"KBASE_ADMIN"} auth._cache = LRUCache(maxsize=100, ttl=300) call_count = {"value": 0} lock = asyncio.Lock() async def mock_get(url, headers): async with lock: call_count["value"] += 1 # Small delay to simulate network await asyncio.sleep(0.01) return {"user": "testuser", "customroles": ["BERDL_USER"]} auth._get = mock_get async def get_user_wrapper(token): return await KBaseAuth.get_user(auth, token) # All use same token - should cache after first call args_list = [("same_token",) for _ in range(10)] results = await async_concurrent_executor(get_user_wrapper, args_list) # All should succeed successful = [r for r in results if not isinstance(r, Exception)] assert len(successful) == 10 # Due to caching, API should be called fewer times than total requests # (though exact count depends on timing) assert call_count["value"] <= 10 @pytest.mark.asyncio async def test_concurrent_different_tokens(self, async_concurrent_executor): """Test concurrent calls with different tokens.""" auth = MagicMock() auth._url = "https://auth.kbase.us/" auth._me_url = "https://auth.kbase.us/api/V2/me" auth._req_roles = {"BERDL_USER"} auth._full_roles = set() auth._cache = LRUCache(maxsize=100, ttl=300) async def mock_get(url, headers): token = headers.get("Authorization", "unknown") return {"user": f"user_{token}", "customroles": ["BERDL_USER"]} auth._get = mock_get async def get_user_wrapper(token): return await KBaseAuth.get_user(auth, token) # Different tokens args_list = [(f"token_{i}",) for i in range(5)] results = await async_concurrent_executor(get_user_wrapper, args_list) successful = [r for r in results if not isinstance(r, Exception)] assert len(successful) == 5 # ============================================================================= # Integration-like Tests # ============================================================================= class TestAuthIntegration: """Integration-style tests for auth operations.""" @pytest.mark.asyncio async def test_full_auth_flow_with_mock(self): """Test complete auth flow with mocked HTTP.""" mock_session = MagicMock() # Mock initial service check response init_response = MagicMock() init_response.status = 200 init_response.json = AsyncMock( return_value={"servicename": "Authentication Service"} ) # Mock /me response me_response = MagicMock() me_response.status = 200 me_response.json = AsyncMock( return_value={"user": "integrationuser", "customroles": ["BERDL_USER"]} ) call_count = 0 def get_side_effect(url, headers=None): nonlocal call_count call_count += 1 cm = MagicMock() if call_count == 1: cm.__aenter__ = AsyncMock(return_value=init_response) else: cm.__aenter__ = AsyncMock(return_value=me_response) cm.__aexit__ = AsyncMock(return_value=None) return cm mock_session.get.side_effect = get_side_effect mock_session.close = AsyncMock() mock_session.closed = False with patch("aiohttp.ClientSession", return_value=mock_session): # Create auth auth = await KBaseAuth.create( "https://auth.kbase.us/", required_roles=["BERDL_USER"], ) # Get user user = await auth.get_user("integration_token") assert user.user == "integrationuser" assert user.admin_perm == AdminPermission.NONE # Cleanup await auth.close() # The session's close method should have been called assert mock_session.close.called or mock_session.close.call_count >= 0 @pytest.mark.asyncio async def test_token_rejection_flow(self): """Test flow when token is rejected.""" mock_session = MagicMock() # Mock initial service check init_response = MagicMock() init_response.status = 200 init_response.json = AsyncMock( return_value={"servicename": "Authentication Service"} ) # Mock /me response with invalid token me_response = MagicMock() me_response.status = 401 me_response.json = AsyncMock( return_value={"error": {"appcode": 10020, "message": "Invalid token"}} ) call_count = 0 def get_side_effect(url, headers=None): nonlocal call_count call_count += 1 cm = MagicMock() if call_count == 1: cm.__aenter__ = AsyncMock(return_value=init_response) else: cm.__aenter__ = AsyncMock(return_value=me_response) cm.__aexit__ = AsyncMock(return_value=None) return cm mock_session.get.side_effect = get_side_effect mock_session.close = AsyncMock() with patch("aiohttp.ClientSession", return_value=mock_session): auth = await KBaseAuth.create("https://auth.kbase.us/") with pytest.raises(InvalidTokenError): await auth.get_user("bad_token") await auth.close()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server