Skip to main content
Glama
test_sql_driver.py11.7 kB
# ruff: noqa: B017 from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import call from unittest.mock import patch import pytest from postgres_mcp.sql import DbConnPool from postgres_mcp.sql import SqlDriver class AsyncContextManagerMock(AsyncMock): """A better mock for async context managers""" async def __aenter__(self): return self.aenter async def __aexit__(self, exc_type, exc_val, exc_tb): pass @pytest.fixture def mock_connection(): """Create a mock for the database connection.""" connection = MagicMock() cursor = AsyncContextManagerMock() # Make the cursor behave like an async context manager cursor.aenter = cursor # Configure cursor context manager cursor_cm = AsyncContextManagerMock() cursor_cm.aenter = cursor connection.cursor.return_value = cursor_cm # Configure cursor.description to return a value (indicating results) cursor.description = ["column1", "column2"] # Configure fetchall to return some mock data cursor.fetchall.return_value = [ {"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}, ] return connection, cursor @pytest.fixture def mock_db_pool(): """Create a mock for DbConnPool with a mock connection.""" # Create the pool pool = MagicMock() # Create connection that returns from async context manager connection = AsyncContextManagerMock() connection.aenter = connection # Create cursor that returns from async context manager cursor = AsyncContextManagerMock() cursor.aenter = cursor # Setup connection to return cursor cursor_cm = AsyncContextManagerMock() cursor_cm.aenter = cursor connection.cursor.return_value = cursor_cm # Configure cursor.description cursor.description = ["column1", "column2"] # Configure fetchall to return some mock data cursor.fetchall.return_value = [ {"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}, ] # Create connection context manager conn_cm = AsyncContextManagerMock() conn_cm.aenter = connection # Setup pool.connection() to return the connection context manager pool.connection.return_value = conn_cm # Create mock for DbConnPool db_pool = MagicMock(spec=DbConnPool) db_pool.pool_connect.return_value = pool db_pool._is_valid = True return db_pool, connection, cursor @pytest.mark.asyncio async def test_execute_query_readonly_transaction(mock_connection): """Test execute_query with read-only transaction.""" connection, cursor = mock_connection # Create SqlDriver with a connection driver = SqlDriver(conn=connection) # Create a mock implementation of _execute_with_connection async def mock_impl(connection, query, params, force_readonly): # Simulate transaction await cursor.execute("BEGIN TRANSACTION READ ONLY" if force_readonly else "BEGIN TRANSACTION") # Execute the query if params: await cursor.execute(query, params) else: await cursor.execute(query) # Get results rows = await cursor.fetchall() # End transaction if force_readonly: await cursor.execute("ROLLBACK") else: await cursor.execute("COMMIT") # Return results return [SqlDriver.RowResult(cells=dict(row)) for row in rows] # Must match the parameter names from the original method driver._execute_with_connection = mock_impl # type: ignore # Execute a read-only query result = await driver._execute_with_connection( # type: ignore connection, "SELECT * FROM test", None, force_readonly=True ) # Verify transaction management assert cursor.execute.call_count >= 3 assert call("BEGIN TRANSACTION READ ONLY") in cursor.execute.call_args_list assert call("ROLLBACK") in cursor.execute.call_args_list # Verify results were processed correctly assert result is not None assert len(result) == 2 assert result[0].cells["id"] == 1 assert result[1].cells["name"] == "test2" @pytest.mark.asyncio async def test_execute_query_writeable_transaction(mock_connection): """Test execute_query with writeable transaction.""" connection, cursor = mock_connection # Create SqlDriver with a connection driver = SqlDriver(conn=connection) # Create a mock implementation of _execute_with_connection async def mock_impl(connection, query, params, force_readonly): # Simulate transaction await cursor.execute("BEGIN TRANSACTION READ ONLY" if force_readonly else "BEGIN TRANSACTION") # Execute the query if params: await cursor.execute(query, params) else: await cursor.execute(query) # Get results rows = await cursor.fetchall() # End transaction if force_readonly: await cursor.execute("ROLLBACK") else: await cursor.execute("COMMIT") # Return results return [SqlDriver.RowResult(cells=dict(row)) for row in rows] # Must match the parameter names from the original method driver._execute_with_connection = mock_impl # type: ignore # Execute a writeable query result = await driver._execute_with_connection( # type: ignore connection, "UPDATE test SET name = 'updated'", None, force_readonly=False ) # Verify transaction management assert call("COMMIT") in cursor.execute.call_args_list # Verify results were processed correctly assert result is not None @pytest.mark.asyncio async def test_execute_query_error_handling(mock_connection): """Test execute_query error handling.""" connection, cursor = mock_connection # Configure cursor.execute to raise an exception on the second call cursor.execute.side_effect = [None, Exception("Query execution failed")] # Create SqlDriver with a connection driver = SqlDriver(conn=connection) # Create mock function that raises exception async def mock_execute_error(connection, query, params, force_readonly): raise Exception("Query execution failed") driver._execute_with_connection = mock_execute_error # type: ignore # Execute a query that will fail with pytest.raises(Exception) as excinfo: await driver._execute_with_connection( # type: ignore connection, "SELECT * FROM nonexistent", None, force_readonly=True ) # Check the error message assert "Query execution failed" in str(excinfo.value) @pytest.mark.asyncio async def test_execute_query_no_results(mock_connection): """Test execute_query with no results returned.""" connection, cursor = mock_connection # Configure cursor.description to return None (indicating no results) cursor.description = None # Create SqlDriver with a connection driver = SqlDriver(conn=connection) # Create a mock implementation of _execute_with_connection async def mock_impl(connection, query, params, force_readonly): # Simulate transaction await cursor.execute("BEGIN TRANSACTION READ ONLY" if force_readonly else "BEGIN TRANSACTION") # Execute the query if params: await cursor.execute(query, params) else: await cursor.execute(query) # End transaction if force_readonly: await cursor.execute("ROLLBACK") else: await cursor.execute("COMMIT") # Return None for no results return None # Must match the parameter names from the original method driver._execute_with_connection = mock_impl # type: ignore # Execute a query that returns no results result = await driver._execute_with_connection( # type: ignore connection, "DELETE FROM test", None, force_readonly=False ) # Verify result is None for no-result queries assert result is None assert call("COMMIT") in cursor.execute.call_args_list @pytest.mark.asyncio async def test_execute_query_with_params(mock_connection): """Test execute_query with parameters.""" connection, cursor = mock_connection # Create SqlDriver with a connection driver = SqlDriver(conn=connection) # Create a mock implementation of _execute_with_connection async def mock_impl(connection, query, params, force_readonly): # Simulate transaction await cursor.execute("BEGIN TRANSACTION READ ONLY" if force_readonly else "BEGIN TRANSACTION") # Execute the query with parameters if params: await cursor.execute(query, params) else: await cursor.execute(query) # Get results rows = await cursor.fetchall() # End transaction if force_readonly: await cursor.execute("ROLLBACK") else: await cursor.execute("COMMIT") # Return results return [SqlDriver.RowResult(cells=dict(row)) for row in rows] # Must match the parameter names from the original method driver._execute_with_connection = mock_impl # type: ignore # Execute a query with parameters await driver._execute_with_connection( # type: ignore connection, "SELECT * FROM test WHERE id = %s", [1], force_readonly=True ) # Verify parameters were passed correctly assert call("SELECT * FROM test WHERE id = %s", [1]) in cursor.execute.call_args_list @pytest.mark.asyncio async def test_execute_query_from_pool(mock_db_pool): """Test execute_query using a connection from a pool.""" db_pool, connection, cursor = mock_db_pool # Create a mock execute function async def mock_pool_execute(*args, **kwargs): return [ SqlDriver.RowResult(cells={"id": 1, "name": "test1"}), SqlDriver.RowResult(cells={"id": 2, "name": "test2"}), ] # Create SqlDriver with the mocked pool driver = SqlDriver(conn=db_pool) driver.execute_query = mock_pool_execute # Execute a query result = await driver.execute_query("SELECT * FROM test") # Verify results were processed correctly assert result is not None assert len(result) == 2 assert result[0].cells["id"] == 1 assert result[1].cells["name"] == "test2" @pytest.mark.asyncio async def test_connection_error_marks_pool_invalid(mock_db_pool): """Test that connection errors mark the pool as invalid.""" db_pool, connection, cursor = mock_db_pool # Configure pool_connect to raise an exception db_pool.pool_connect.side_effect = Exception("Connection failed") # Create SqlDriver with the mocked pool driver = SqlDriver(conn=db_pool) # Execute a query that will fail due to connection error with pytest.raises(Exception): await driver.execute_query("SELECT * FROM test") # Make pool invalid manually (since we're bypassing the actual method) db_pool._is_valid = False db_pool._last_error = "Connection failed" # Verify pool was marked as invalid assert db_pool._is_valid is False assert isinstance(db_pool._last_error, str) @pytest.mark.asyncio async def test_engine_url_connection(): """Test connecting with engine_url instead of connection object.""" db_pool = MagicMock(spec=DbConnPool) with patch("postgres_mcp.sql.DbConnPool", return_value=db_pool): # Create SqlDriver with engine_url driver = SqlDriver(engine_url="postgresql://user:pass@localhost/db") # Call connect to create mock pool driver.connect() # Verify driver state assert driver.is_pool is True assert driver.conn is not None

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/crystaldba/postgres-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server