Skip to main content
Glama
test_access_mode.py3.87 kB
import asyncio from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch import pytest from postgres_mcp.server import AccessMode from postgres_mcp.server import get_sql_driver from postgres_mcp.sql.safe_sql import SafeSqlDriver from postgres_mcp.sql.sql_driver import DbConnPool from postgres_mcp.sql.sql_driver import SqlDriver @pytest.fixture def mock_db_connection(): """Mock database connection pool.""" conn = MagicMock(spec=DbConnPool) conn.is_valid = True return conn @pytest.mark.parametrize( "access_mode,expected_driver_type", [ (AccessMode.UNRESTRICTED, SqlDriver), (AccessMode.RESTRICTED, SafeSqlDriver), ], ) @pytest.mark.asyncio async def test_get_sql_driver_returns_correct_driver(access_mode, expected_driver_type, mock_db_connection): """Test that get_sql_driver returns the correct driver type based on access mode.""" with ( patch("postgres_mcp.server.current_access_mode", access_mode), patch("postgres_mcp.server.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, expected_driver_type) # When in RESTRICTED mode, verify timeout is set if access_mode == AccessMode.RESTRICTED: assert isinstance(driver, SafeSqlDriver) assert driver.timeout == 30 @pytest.mark.asyncio async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection): """Test that get_sql_driver sets the timeout in restricted mode.""" with ( patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch("postgres_mcp.server.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SafeSqlDriver) assert driver.timeout == 30 assert hasattr(driver, "sql_driver") @pytest.mark.asyncio async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection): """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" with ( patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch("postgres_mcp.server.db_connection", mock_db_connection), ): driver = await get_sql_driver() assert isinstance(driver, SqlDriver) assert not hasattr(driver, "timeout") @pytest.mark.asyncio async def test_command_line_parsing(): """Test that command-line arguments correctly set the access mode.""" import sys from postgres_mcp.server import main # Mock sys.argv and asyncio.run original_argv = sys.argv original_run = asyncio.run try: # Test with --access-mode=restricted sys.argv = [ "postgres_mcp", "postgresql://user:password@localhost/db", "--access-mode=restricted", ] asyncio.run = AsyncMock() with ( patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), patch("postgres_mcp.server.shutdown", AsyncMock()), ): # Reset the current_access_mode to UNRESTRICTED import postgres_mcp.server postgres_mcp.server.current_access_mode = AccessMode.UNRESTRICTED # Run main (partially mocked to avoid actual connection) try: await main() except Exception: pass # Verify the mode was changed to RESTRICTED assert postgres_mcp.server.current_access_mode == AccessMode.RESTRICTED finally: # Restore original values sys.argv = original_argv asyncio.run = original_run

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/crystaldba/postgres-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server