Skip to main content
Glama
test_async_utils.py39.1 kB
""" Comprehensive tests for async_utils module - critical connection pooling and async utilities. """ import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch import pytest from tools.common.async_utils import ( ConnectionPool, RetryManager, get_connection_pool, run_concurrent_operations, shutdown_connection_pool, ) from tools.common.errors import ( ConnectionPoolError, ConnectionPoolExhaustedError, ServerConnectionError, ) class TestConnectionPool: """Comprehensive tests for ConnectionPool class.""" @pytest.mark.asyncio async def test_connection_pool_initialization(self): """Test ConnectionPool initialization and setup.""" pool = ConnectionPool( max_connections_per_server=3, connection_timeout=10.0, idle_timeout=60.0, ) assert pool.max_connections == 3 assert pool.connection_timeout == 10.0 assert pool.idle_timeout == 60.0 assert pool._pools == {} assert pool._active_connections == {} assert pool._connection_configs == {} assert pool._cleanup_task is None assert pool._shutdown is False await pool.initialize() assert pool._cleanup_task is not None assert not pool._cleanup_task.done() await pool.close_all() @pytest.mark.asyncio async def test_register_server(self): """Test server registration with connection pool.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "sse", "url": "http://localhost:8001/mcp", } await pool.register_server("test-server", server_config) assert "test-server" in pool._connection_configs assert pool._connection_configs["test-server"] == server_config assert "test-server" in pool._pools assert "test-server" in pool._active_connections assert pool._active_connections["test-server"] == 0 await pool.close_all() @pytest.mark.asyncio async def test_register_multiple_servers(self): """Test registering multiple servers.""" pool = ConnectionPool() await pool.initialize() servers = { "server1": {"type": "sse", "url": "http://localhost:8001/mcp"}, "server2": {"type": "stdio", "command": "python", "args": ["-m", "server"]}, "server3": {"type": "sse", "url": "http://localhost:8003/mcp"}, } for name, config in servers.items(): await pool.register_server(name, config) assert len(pool._connection_configs) == 3 assert len(pool._pools) == 3 assert len(pool._active_connections) == 3 for name in servers: assert name in pool._connection_configs assert pool._active_connections[name] == 0 await pool.close_all() @pytest.mark.asyncio async def test_get_connection_unregistered_server(self): """Test getting connection for unregistered server.""" pool = ConnectionPool() await pool.initialize() with pytest.raises( ConnectionPoolError, match="Server unregistered not registered" ): async with pool.get_connection("unregistered"): pass await pool.close_all() @pytest.mark.asyncio async def test_create_connection_sse(self): """Test creating SSE connection.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "sse", "url": "http://localhost:8001/mcp", } await pool.register_server("sse-server", server_config) with patch("tools.common.async_utils.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client connection = await pool._create_connection("sse-server") assert connection == mock_client mock_client_class.assert_called_once_with("http://localhost:8001/mcp") await pool.close_all() @pytest.mark.asyncio async def test_create_connection_stdio(self): """Test creating stdio connection.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "stdio", "command": "python", "args": ["-m", "server"], } await pool.register_server("stdio-server", server_config) with patch("tools.common.async_utils.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client connection = await pool._create_connection("stdio-server") assert connection == mock_client # Should encode the full command mock_client_class.assert_called_once() call_args = mock_client_class.call_args[0][0] assert call_args.startswith("stdio://") await pool.close_all() @pytest.mark.asyncio async def test_create_connection_stdio_no_args(self): """Test creating stdio connection without args.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "stdio", "command": "python", } await pool.register_server("stdio-server", server_config) with patch("tools.common.async_utils.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client connection = await pool._create_connection("stdio-server") assert connection == mock_client mock_client_class.assert_called_once() await pool.close_all() @pytest.mark.asyncio async def test_create_connection_stdio_with_special_characters(self): """Test creating stdio connection with special characters in command.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "stdio", "command": "python", "args": ["-m", "server", "--config", "path with spaces/config.json"], } await pool.register_server("stdio-server", server_config) with patch("tools.common.async_utils.Client") as mock_client_class: mock_client = MagicMock() mock_client_class.return_value = mock_client connection = await pool._create_connection("stdio-server") assert connection == mock_client mock_client_class.assert_called_once() call_args = mock_client_class.call_args[0][0] assert call_args.startswith("stdio://") # Should properly encode special characters assert "path%20with%20spaces" in call_args await pool.close_all() @pytest.mark.asyncio async def test_create_connection_failure(self): """Test connection creation failure.""" pool = ConnectionPool() await pool.initialize() server_config = { "type": "sse", "url": "http://localhost:8001/mcp", } await pool.register_server("failing-server", server_config) with patch( "tools.common.async_utils.Client", side_effect=Exception("Connection failed"), ): with pytest.raises( ServerConnectionError, match="Failed to create connection" ): await pool._create_connection("failing-server") await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_reuse(self): """Test connection reuse from pool.""" pool = ConnectionPool(max_connections_per_server=2) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # First connection - should create new async with pool.get_connection("test-server") as conn1: assert conn1 == mock_client assert pool._active_connections["test-server"] == 1 # Second connection - should reuse from pool async with pool.get_connection("test-server") as conn2: assert conn2 == mock_client await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_exhaustion(self): """Test connection pool exhaustion.""" pool = ConnectionPool(max_connections_per_server=1, connection_timeout=0.1) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # Hold one connection async with pool.get_connection("test-server"): # Try to get another - should timeout with pytest.raises(ConnectionPoolExhaustedError): async with pool.get_connection("test-server"): pass await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_concurrent_access(self): """Test concurrent access to connection pool.""" pool = ConnectionPool(max_connections_per_server=3) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) connection_count = 0 def create_mock_client(*args, **kwargs): nonlocal connection_count connection_count += 1 return mock_client with patch("tools.common.async_utils.Client", side_effect=create_mock_client): # Start multiple concurrent connections async def use_connection(): async with pool.get_connection("test-server") as conn: await asyncio.sleep(0.1) return conn tasks = [use_connection() for _ in range(5)] results = await asyncio.gather(*tasks) assert len(results) == 5 assert all(r == mock_client for r in results) # Should have created at most max_connections assert connection_count <= 3 await pool.close_all() @pytest.mark.asyncio async def test_connection_cleanup(self): """Test idle connection cleanup.""" pool = ConnectionPool(idle_timeout=0.1) # Very short timeout await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # Use and release a connection async with pool.get_connection("test-server"): pass # Wait for cleanup to run await asyncio.sleep(0.2) await pool.close_all() @pytest.mark.asyncio async def test_connection_cleanup_with_active_connections(self): """Test cleanup doesn't affect active connections.""" pool = ConnectionPool(idle_timeout=0.1) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # Hold an active connection while cleanup runs async with pool.get_connection("test-server"): await asyncio.sleep(0.2) # Longer than idle timeout # Connection should still be valid await pool.close_all() @pytest.mark.asyncio async def test_close_all_connections(self): """Test closing all connections.""" pool = ConnectionPool() await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): async with pool.get_connection("test-server"): pass await pool.close_all() assert pool._shutdown is True assert pool._cleanup_task.cancelled() @pytest.mark.asyncio async def test_cleanup_task_exception_handling(self): """Test cleanup task handles exceptions gracefully.""" pool = ConnectionPool(idle_timeout=0.01) await pool.initialize() # Mock the cleanup to raise an exception async def failing_cleanup(): await asyncio.sleep(0.01) raise Exception("Cleanup failed") pool._cleanup_task.cancel() pool._cleanup_task = asyncio.create_task(failing_cleanup()) # Should not crash the pool await asyncio.sleep(0.05) # The cleanup task should have failed, but close_all should handle it try: await pool.close_all() except Exception: # Expected - the failing cleanup task will raise when awaited pass @pytest.mark.asyncio async def test_connection_pool_stress_test(self): """Stress test connection pool with many concurrent operations.""" pool = ConnectionPool(max_connections_per_server=5, connection_timeout=1.0) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("stress-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) successful_connections = 0 async def stress_connection(): nonlocal successful_connections try: async with pool.get_connection("stress-server"): await asyncio.sleep(0.01) successful_connections += 1 except Exception: pass with patch("tools.common.async_utils.Client", return_value=mock_client): # Run many concurrent operations tasks = [stress_connection() for _ in range(50)] await asyncio.gather(*tasks, return_exceptions=True) # Should have handled all connections successfully assert successful_connections > 40 # Allow for some timing variations await pool.close_all() class TestRetryManager: """Comprehensive tests for RetryManager class.""" @pytest.mark.asyncio async def test_retry_manager_initialization(self): """Test RetryManager initialization.""" retry_manager = RetryManager( max_attempts=5, base_delay=2.0, max_delay=30.0, exponential_base=3.0, jitter=False, ) assert retry_manager.max_attempts == 5 assert retry_manager.base_delay == 2.0 assert retry_manager.max_delay == 30.0 assert retry_manager.exponential_base == 3.0 assert retry_manager.jitter is False @pytest.mark.asyncio async def test_retry_manager_default_initialization(self): """Test RetryManager with default values.""" retry_manager = RetryManager() assert retry_manager.max_attempts == 3 assert retry_manager.base_delay == 1.0 assert retry_manager.max_delay == 60.0 assert retry_manager.exponential_base == 2.0 assert retry_manager.jitter is True @pytest.mark.asyncio async def test_successful_operation_first_attempt(self): """Test successful operation on first attempt.""" retry_manager = RetryManager(max_attempts=3) async def successful_operation(): return "success" result = await retry_manager.execute_with_retry( successful_operation, operation_name="test_op" ) assert result.is_success assert result.data == "success" @pytest.mark.asyncio async def test_successful_operation_after_retries(self): """Test successful operation after some failures.""" retry_manager = RetryManager(max_attempts=3, base_delay=0.01) attempt_count = 0 async def eventually_successful_operation(): nonlocal attempt_count attempt_count += 1 if attempt_count < 3: raise ValueError("Temporary failure") return "success" result = await retry_manager.execute_with_retry( eventually_successful_operation, operation_name="test_op" ) assert result.is_success assert result.data == "success" assert attempt_count == 3 @pytest.mark.asyncio async def test_operation_fails_all_attempts(self): """Test operation that fails all retry attempts.""" retry_manager = RetryManager(max_attempts=2, base_delay=0.01) async def always_failing_operation(): raise ValueError("Always fails") result = await retry_manager.execute_with_retry( always_failing_operation, operation_name="test_op" ) assert result.is_failed assert "RETRY_EXHAUSTED" in result.error_code assert "Always fails" in result.error @pytest.mark.asyncio async def test_non_retryable_exception(self): """Test non-retryable exception handling.""" retry_manager = RetryManager(max_attempts=3) async def operation_with_non_retryable_error(): raise TypeError("Non-retryable error") result = await retry_manager.execute_with_retry( operation_with_non_retryable_error, retryable_exceptions=(ValueError,), operation_name="test_op", ) assert result.is_failed assert result.error_code == "TypeError" assert "Non-retryable error" in result.error @pytest.mark.asyncio async def test_multiple_retryable_exceptions(self): """Test with multiple retryable exception types.""" retry_manager = RetryManager(max_attempts=4, base_delay=0.01) attempt_count = 0 async def operation_with_multiple_exceptions(): nonlocal attempt_count attempt_count += 1 if attempt_count == 1: raise ValueError("First failure") elif attempt_count == 2: raise ConnectionError("Second failure") elif attempt_count == 3: raise TimeoutError("Third failure") return "success" result = await retry_manager.execute_with_retry( operation_with_multiple_exceptions, retryable_exceptions=(ValueError, ConnectionError, TimeoutError), operation_name="test_op", ) assert result.is_success assert result.data == "success" assert attempt_count == 4 @pytest.mark.asyncio async def test_exponential_backoff_calculation(self): """Test exponential backoff delay calculation.""" retry_manager = RetryManager( max_attempts=4, base_delay=1.0, max_delay=10.0, exponential_base=2.0, jitter=False, ) attempt_times = [] async def failing_operation(): attempt_times.append(time.time()) raise ValueError("Test failure") start_time = time.time() result = await retry_manager.execute_with_retry( failing_operation, operation_name="test_op" ) assert result.is_failed assert len(attempt_times) == 4 # Check that delays increase exponentially (approximately) # First attempt: immediate # Second attempt: ~1s delay # Third attempt: ~2s delay # Fourth attempt: ~4s delay total_time = time.time() - start_time assert total_time >= 6.0 # At least 1 + 2 + 4 = 7 seconds of delays @pytest.mark.asyncio async def test_max_delay_cap(self): """Test that delay is capped at max_delay.""" retry_manager = RetryManager( max_attempts=3, base_delay=10.0, max_delay=0.1, # Very low max delay exponential_base=2.0, jitter=False, ) attempt_times = [] async def failing_operation(): attempt_times.append(time.time()) raise ValueError("Test failure") start_time = time.time() await retry_manager.execute_with_retry( failing_operation, operation_name="test_op" ) total_time = time.time() - start_time # Should be capped at max_delay, not exponential assert total_time < 1.0 # Much less than base_delay would suggest @pytest.mark.asyncio async def test_jitter_enabled(self): """Test retry with jitter enabled.""" retry_manager = RetryManager(max_attempts=3, base_delay=0.1, jitter=True) attempt_times = [] async def failing_operation(): attempt_times.append(time.time()) raise ValueError("Test failure") # Run multiple times to test jitter variability delays = [] for _ in range(3): attempt_times.clear() start_time = time.time() await retry_manager.execute_with_retry( failing_operation, operation_name="test_op" ) total_time = time.time() - start_time delays.append(total_time) # With jitter, delays should vary assert len(set(delays)) > 1 or all( d < 0.5 for d in delays ) # Allow for fast execution @pytest.mark.asyncio async def test_zero_max_attempts(self): """Test retry manager with zero max attempts.""" retry_manager = RetryManager(max_attempts=0) async def operation(): return "success" result = await retry_manager.execute_with_retry(operation) # Should not execute at all assert result.is_failed assert "RETRY_EXHAUSTED" in result.error_code @pytest.mark.asyncio async def test_single_attempt(self): """Test retry manager with single attempt.""" retry_manager = RetryManager(max_attempts=1) async def failing_operation(): raise ValueError("Failure") result = await retry_manager.execute_with_retry(failing_operation) assert result.is_failed assert "RETRY_EXHAUSTED" in result.error_code @pytest.mark.asyncio async def test_very_large_exponential_base(self): """Test with very large exponential base.""" retry_manager = RetryManager( max_attempts=3, base_delay=0.01, max_delay=0.1, exponential_base=100.0, jitter=False, ) async def failing_operation(): raise ValueError("Failure") start_time = time.time() await retry_manager.execute_with_retry(failing_operation) total_time = time.time() - start_time # Should be capped by max_delay assert total_time < 0.5 class TestConcurrentOperations: """Tests for concurrent operation utilities.""" @pytest.mark.asyncio async def test_run_concurrent_operations_success(self): """Test successful concurrent operations.""" async def operation_1(): await asyncio.sleep(0.01) return "result_1" async def operation_2(): await asyncio.sleep(0.01) return "result_2" operations = [operation_1, operation_2] operation_names = ["op_1", "op_2"] results = await run_concurrent_operations( operations, max_concurrent=2, operation_names=operation_names ) assert len(results) == 2 assert all(r.is_success for r in results) assert results[0].data == "result_1" assert results[1].data == "result_2" assert all(r.duration_ms > 0 for r in results) @pytest.mark.asyncio async def test_run_concurrent_operations_with_failures(self): """Test concurrent operations with some failures.""" async def successful_operation(): return "success" async def failing_operation(): raise ValueError("Operation failed") operations = [successful_operation, failing_operation] results = await run_concurrent_operations(operations, max_concurrent=2) assert len(results) == 2 assert results[0].is_success assert results[0].data == "success" assert results[1].is_failed assert "Operation failed" in results[1].error @pytest.mark.asyncio async def test_run_concurrent_operations_concurrency_limit(self): """Test concurrent operations respect concurrency limit.""" execution_order = [] async def timed_operation(): execution_order.append("start") await asyncio.sleep(0.1) execution_order.append("end") return "done" operations = [timed_operation for _ in range(4)] start_time = time.time() results = await run_concurrent_operations(operations, max_concurrent=2) total_time = time.time() - start_time assert len(results) == 4 assert all(r.is_success for r in results) # With max_concurrent=2, should take at least 0.2s (two batches) assert total_time >= 0.15 @pytest.mark.asyncio async def test_run_concurrent_operations_empty_list(self): """Test executing empty list of operations.""" results = await run_concurrent_operations([]) assert results == [] @pytest.mark.asyncio async def test_run_concurrent_operations_default_names(self): """Test concurrent operations with default names.""" async def simple_operation(): return "result" operations = [simple_operation, simple_operation] results = await run_concurrent_operations(operations) assert len(results) == 2 assert all(r.is_success for r in results) @pytest.mark.asyncio async def test_run_concurrent_operations_exception_handling(self): """Test exception handling in concurrent operations.""" async def exception_operation(): raise RuntimeError("Runtime error") operations = [exception_operation] results = await run_concurrent_operations(operations) assert len(results) == 1 assert results[0].is_failed assert results[0].error_code == "RuntimeError" @pytest.mark.asyncio async def test_run_concurrent_operations_mixed_timing(self): """Test operations with different execution times.""" async def fast_operation(): await asyncio.sleep(0.01) return "fast" async def slow_operation(): await asyncio.sleep(0.1) return "slow" operations = [fast_operation, slow_operation, fast_operation] start_time = time.time() results = await run_concurrent_operations(operations, max_concurrent=3) total_time = time.time() - start_time assert len(results) == 3 assert all(r.is_success for r in results) # Should complete in time of slowest operation assert 0.08 <= total_time <= 0.2 @pytest.mark.asyncio async def test_run_concurrent_operations_large_batch(self): """Test concurrent operations with large batch.""" async def batch_operation(): await asyncio.sleep(0.01) return "batch_result" operations = [batch_operation for _ in range(20)] results = await run_concurrent_operations(operations, max_concurrent=5) assert len(results) == 20 assert all(r.is_success for r in results) assert all(r.data == "batch_result" for r in results) @pytest.mark.asyncio async def test_run_concurrent_operations_zero_max_concurrent(self): """Test concurrent operations with zero max_concurrent (should use unlimited).""" async def test_operation(): await asyncio.sleep(0.01) return "unlimited" operations = [test_operation for _ in range(5)] start_time = time.time() results = await run_concurrent_operations(operations, max_concurrent=0) total_time = time.time() - start_time assert len(results) == 5 assert all(r.is_success for r in results) assert all(r.data == "unlimited" for r in results) # Should complete quickly since all operations run concurrently assert total_time < 0.1 @pytest.mark.asyncio async def test_run_concurrent_operations_negative_max_concurrent(self): """Test concurrent operations with negative max_concurrent (should use unlimited).""" async def test_operation(): await asyncio.sleep(0.01) return "negative_unlimited" operations = [test_operation for _ in range(3)] results = await run_concurrent_operations(operations, max_concurrent=-1) assert len(results) == 3 assert all(r.is_success for r in results) assert all(r.data == "negative_unlimited" for r in results) class TestGlobalConnectionPool: """Tests for global connection pool management.""" @pytest.mark.asyncio async def test_get_connection_pool_singleton(self): """Test global connection pool singleton behavior.""" # Clean up any existing pool await shutdown_connection_pool() pool1 = await get_connection_pool() pool2 = await get_connection_pool() assert pool1 is pool2 assert pool1._cleanup_task is not None await shutdown_connection_pool() @pytest.mark.asyncio async def test_shutdown_connection_pool(self): """Test shutting down global connection pool.""" pool = await get_connection_pool() assert pool is not None await shutdown_connection_pool() # Should create new pool after shutdown new_pool = await get_connection_pool() assert new_pool is not pool await shutdown_connection_pool() @pytest.mark.asyncio async def test_shutdown_connection_pool_when_none(self): """Test shutting down when no pool exists.""" await shutdown_connection_pool() # Should not raise an error await shutdown_connection_pool() @pytest.mark.asyncio async def test_global_pool_lifecycle(self): """Test complete lifecycle of global connection pool.""" # Start with clean state await shutdown_connection_pool() # Create pool and register server pool = await get_connection_pool() await pool.register_server( "test-server", {"type": "sse", "url": "http://localhost:8001"} ) # Verify pool is working assert "test-server" in pool._connection_configs # Shutdown and verify cleanup await shutdown_connection_pool() # Create new pool - should be fresh new_pool = await get_connection_pool() assert new_pool is not pool assert "test-server" not in new_pool._connection_configs await shutdown_connection_pool() class TestConnectionPoolEdgeCases: """Test edge cases and error scenarios for ConnectionPool.""" @pytest.mark.asyncio async def test_connection_pool_queue_full_scenario(self): """Test connection pool behavior when queue is full.""" pool = ConnectionPool(max_connections_per_server=2) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) # Fill the queue to capacity with patch("tools.common.async_utils.Client", return_value=mock_client): connections = [] for _ in range(2): async with pool.get_connection("test-server") as conn: connections.append(conn) # Queue should be full, additional connections should be closed async with pool.get_connection("test-server") as conn: assert conn == mock_client await pool.close_all() @pytest.mark.asyncio async def test_connection_release_error_handling(self): """Test error handling during connection release.""" pool = ConnectionPool() await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) # Mock queue to raise exception on put with patch("tools.common.async_utils.Client", return_value=mock_client): with patch.object( pool._pools["test-server"], "put_nowait", side_effect=asyncio.QueueFull ): async with pool.get_connection("test-server") as conn: assert conn == mock_client # Should handle the queue full exception gracefully await pool.close_all() @pytest.mark.asyncio async def test_cleanup_with_connection_errors(self): """Test cleanup handles connection errors gracefully.""" pool = ConnectionPool(idle_timeout=0.01) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # Use connection and let it become idle async with pool.get_connection("test-server"): pass # Mock close_connection to raise an error async def failing_close(*args, **kwargs): raise Exception("Close failed") pool._close_connection = failing_close # Wait for cleanup - should handle errors gracefully await asyncio.sleep(0.1) await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_initialization_without_initialize_call(self): """Test using connection pool without calling initialize.""" pool = ConnectionPool() # Don't call initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) # Should still work, just without cleanup task assert pool._cleanup_task is None await pool.close_all() @pytest.mark.asyncio async def test_double_initialization(self): """Test calling initialize multiple times.""" pool = ConnectionPool() await pool.initialize() first_task = pool._cleanup_task await pool.initialize() second_task = pool._cleanup_task # Should not create multiple cleanup tasks assert first_task is second_task await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_with_zero_max_connections(self): """Test connection pool with zero max connections.""" pool = ConnectionPool(max_connections_per_server=0, connection_timeout=0.1) await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) # Should immediately timeout since no connections allowed with pytest.raises(ConnectionPoolExhaustedError): async with pool.get_connection("test-server"): pass await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_server_reregistration(self): """Test re-registering a server with different config.""" pool = ConnectionPool() await pool.initialize() # Register server first time config1 = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", config1) # Re-register with different config config2 = {"type": "sse", "url": "http://localhost:8002/mcp"} await pool.register_server("test-server", config2) # Should use new config assert pool._connection_configs["test-server"] == config2 await pool.close_all() @pytest.mark.asyncio async def test_connection_pool_with_very_short_idle_timeout(self): """Test connection pool with very short idle timeout.""" pool = ConnectionPool(idle_timeout=0.001) # 1ms await pool.initialize() server_config = {"type": "sse", "url": "http://localhost:8001/mcp"} await pool.register_server("test-server", server_config) mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) with patch("tools.common.async_utils.Client", return_value=mock_client): # Use connection briefly async with pool.get_connection("test-server"): pass # Wait longer than idle timeout await asyncio.sleep(0.1) # Connection should have been cleaned up # (This is hard to test directly, but cleanup should run) await pool.close_all()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/lightfastai/lightfast-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server