test_base_handlers.py•5.52 kB
from unittest.mock import AsyncMock, MagicMock, patch
import mcp.types as types
import pytest
from mcp_dbutils.base import ConfigurationError, ConnectionServer
# Constants for error messages - use the actual error messages from the code
EMPTY_QUERY_ERROR = "SQL query cannot be empty"
SELECT_ONLY_ERROR = "Only SELECT queries are supported for security reasons"
EMPTY_TABLE_NAME_ERROR = "Table name cannot be empty"
class TestBaseHandlers:
    """Test base.py handler functions"""
    
    @pytest.fixture
    def mock_handler(self):
        """Create a mock handler"""
        handler = AsyncMock()
        handler.db_type = "mock_db"
        handler.stats = MagicMock()
        handler.stats.get_performance_stats.return_value = "Performance stats"
        handler.execute_query = AsyncMock(return_value="Query result")
        handler.execute_tool_query = AsyncMock(return_value="Tool result")
        handler.explain_query = AsyncMock(return_value="Explain result")
        return handler
    
    @pytest.fixture
    def server(self):
        """Create a server instance with mocked components"""
        with patch('mcp_dbutils.base.mcp.server.Server'), \
             patch('mcp_dbutils.base.ConnectionServer._setup_handlers'), \
             patch('mcp_dbutils.base.ConnectionServer._setup_prompts'):
            server = ConnectionServer(config_path="test_config.yaml")
            # Mock the config loading
            server._config = {"connections": {"test_conn": {"type": "mock"}}}
            return server
    
    @pytest.mark.asyncio
    async def test_handle_run_query_empty(self, server):
        """Test _handle_run_query with empty query"""
        with pytest.raises(ConfigurationError, match=EMPTY_QUERY_ERROR):
            await server._handle_run_query("test_conn", "")
    
    @pytest.mark.asyncio
    async def test_handle_run_query_non_select(self, server):
        """Test _handle_run_query with non-SELECT query"""
        with pytest.raises(ConfigurationError, match=SELECT_ONLY_ERROR):
            await server._handle_run_query("test_conn", "INSERT INTO table VALUES (1)")
    
    @pytest.mark.asyncio
    async def test_handle_run_query(self, server, mock_handler):
        """Test _handle_run_query with valid query"""
        with patch.object(server, 'get_handler') as mock_get_handler:
            mock_get_handler.return_value.__aenter__.return_value = mock_handler
            
            result = await server._handle_run_query("test_conn", "SELECT * FROM table")
            
            assert isinstance(result, list)
            assert len(result) == 1
            assert isinstance(result[0], types.TextContent)
            assert result[0].text == "Query result"
            mock_handler.execute_query.assert_awaited_once_with("SELECT * FROM table")
    
    @pytest.mark.asyncio
    async def test_handle_table_tools_empty(self, server):
        """Test _handle_table_tools with empty table name"""
        with pytest.raises(ConfigurationError, match=EMPTY_TABLE_NAME_ERROR):
            await server._handle_table_tools("tool_name", "test_conn", "")
    
    @pytest.mark.asyncio
    async def test_handle_table_tools(self, server, mock_handler):
        """Test _handle_table_tools with valid table name"""
        with patch.object(server, 'get_handler') as mock_get_handler:
            mock_get_handler.return_value.__aenter__.return_value = mock_handler
            
            result = await server._handle_table_tools("tool_name", "test_conn", "table_name")
            
            assert isinstance(result, list)
            assert len(result) == 1
            assert isinstance(result[0], types.TextContent)
            assert result[0].text == "Tool result"
            mock_handler.execute_tool_query.assert_awaited_once_with("tool_name", table_name="table_name")
    
    @pytest.mark.asyncio
    async def test_handle_explain_query_empty(self, server):
        """Test _handle_explain_query with empty query"""
        with pytest.raises(ConfigurationError, match=EMPTY_QUERY_ERROR):
            await server._handle_explain_query("test_conn", "")
    
    @pytest.mark.asyncio
    async def test_handle_explain_query(self, server, mock_handler):
        """Test _handle_explain_query with valid query"""
        with patch.object(server, 'get_handler') as mock_get_handler:
            mock_get_handler.return_value.__aenter__.return_value = mock_handler
            
            result = await server._handle_explain_query("test_conn", "SELECT * FROM table")
            
            assert isinstance(result, list)
            assert len(result) == 1
            assert isinstance(result[0], types.TextContent)
            assert result[0].text == "Tool result"
            mock_handler.execute_tool_query.assert_awaited_once_with("dbutils-explain-query", sql="SELECT * FROM table")
    
    @pytest.mark.asyncio
    async def test_handle_performance(self, server, mock_handler):
        """Test _handle_performance"""
        with patch.object(server, 'get_handler') as mock_get_handler:
            mock_get_handler.return_value.__aenter__.return_value = mock_handler
            
            result = await server._handle_performance("test_conn")
            
            assert isinstance(result, list)
            assert len(result) == 1
            assert isinstance(result[0], types.TextContent)
            assert result[0].text == "[mock_db]\nPerformance stats"
            mock_handler.stats.get_performance_stats.assert_called_once()