mcp-dbutils
by donghao1393
- tests
- unit
"""Unit tests for base.py helper methods"""
import importlib
import json
import os
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import mcp.types as types
import pytest
import yaml
from mcp_dbutils.base import (
ConfigurationError,
ConnectionHandler,
ConnectionServer,
)
class TestBaseHelpers:
"""Test helper methods in base.py"""
@pytest.fixture
def mock_config_yaml(self):
"""Mock configuration YAML content"""
return """
connections:
test_sqlite:
type: sqlite
path: /path/to/test.db
test_postgres:
type: postgres
host: localhost
port: 5432
database: test_db
user: test_user
password: test_password
test_mysql:
type: mysql
host: localhost
port: 3306
database: test_db
user: test_user
password: test_password
test_invalid:
type: invalid_type
test_missing_type:
host: localhost
"""
@pytest.fixture
def server(self):
"""Create a connection server"""
with patch('builtins.open', mock_open()):
server = ConnectionServer("/path/to/config.yaml", debug=True)
server.send_log = MagicMock()
return server
def test_get_config_or_raise_valid(self, server, mock_config_yaml):
"""Test _get_config_or_raise with valid input"""
with patch('builtins.open', mock_open(read_data=mock_config_yaml)):
result = server._get_config_or_raise("test_sqlite")
assert result == {"type": "sqlite", "path": "/path/to/test.db"}
def test_get_config_or_raise_invalid_yaml(self, server):
"""Test _get_config_or_raise with invalid YAML"""
with patch('builtins.open', mock_open(read_data="invalid: yaml: content:")), pytest.raises(yaml.YAMLError):
server._get_config_or_raise("test_connection")
def test_get_config_or_raise_missing_connections(self, server):
"""Test _get_config_or_raise with missing connections section"""
with patch('builtins.open', mock_open(read_data="other_section: value")), pytest.raises(ConfigurationError, match="must contain 'connections' section"):
server._get_config_or_raise("test_connection")
def test_get_config_or_raise_connection_not_found(self, server, mock_config_yaml):
"""Test _get_config_or_raise with non-existent connection"""
with patch('builtins.open', mock_open(read_data=mock_config_yaml)), pytest.raises(ConfigurationError, match="Connection not found"):
server._get_config_or_raise("nonexistent")
def test_get_config_or_raise_missing_type(self, server, mock_config_yaml):
"""Test _get_config_or_raise with connection missing type field"""
with patch('builtins.open', mock_open(read_data=mock_config_yaml)), pytest.raises(ConfigurationError, match="must include 'type' field"):
server._get_config_or_raise("test_missing_type")
@patch('mcp_dbutils.base.ConnectionServer._create_handler_for_type')
@pytest.mark.asyncio
async def test_get_handler_setup_session(self, mock_create_handler, server, mock_config_yaml):
"""Test get_handler sets session if available"""
mock_handler = MagicMock()
mock_handler.stats = MagicMock()
mock_handler.cleanup = AsyncMock()
mock_create_handler.return_value = mock_handler
server.server = MagicMock()
server.server.session = "test_session"
with patch('builtins.open', mock_open(read_data=mock_config_yaml)):
async with server.get_handler("test_sqlite") as handler:
assert handler._session == "test_session"
assert handler.stats.record_connection_start.called
# Check cleanup is called
assert mock_handler.stats.record_connection_end.called
assert mock_handler.cleanup.called
@patch('mcp_dbutils.sqlite.handler.SQLiteHandler')
def test_create_handler_for_type_sqlite(self, mock_sqlite_handler, server):
"""Test _create_handler_for_type with SQLite"""
mock_instance = MagicMock()
mock_sqlite_handler.return_value = mock_instance
result = server._create_handler_for_type('sqlite', 'test_connection')
mock_sqlite_handler.assert_called_once_with(
server.config_path, 'test_connection', server.debug
)
assert result == mock_instance
assert server.send_log.called
@patch('mcp_dbutils.postgres.handler.PostgreSQLHandler')
def test_create_handler_for_type_postgres(self, mock_postgres_handler, server):
"""Test _create_handler_for_type with PostgreSQL"""
mock_instance = MagicMock()
mock_postgres_handler.return_value = mock_instance
result = server._create_handler_for_type('postgres', 'test_connection')
mock_postgres_handler.assert_called_once_with(
server.config_path, 'test_connection', server.debug
)
assert result == mock_instance
assert server.send_log.called
@patch('mcp_dbutils.mysql.handler.MySQLHandler')
def test_create_handler_for_type_mysql(self, mock_mysql_handler, server):
"""Test _create_handler_for_type with MySQL"""
mock_instance = MagicMock()
mock_mysql_handler.return_value = mock_instance
result = server._create_handler_for_type('mysql', 'test_connection')
mock_mysql_handler.assert_called_once_with(
server.config_path, 'test_connection', server.debug
)
assert result == mock_instance
assert server.send_log.called
def test_create_handler_for_type_unsupported(self, server):
"""Test _create_handler_for_type with unsupported database type"""
with pytest.raises(ConfigurationError, match="Unsupported database type"):
server._create_handler_for_type('unsupported', 'test_connection')
def test_import_error_handled(self, server):
"""Test ImportError is converted to ConfigurationError"""
# 模拟导入错误
original_import = __import__
def mock_import_error(name, *args, **kwargs):
if 'sqlite' in name:
raise ImportError("Module not found")
return original_import(name, *args, **kwargs)
with patch('builtins.__import__', side_effect=mock_import_error), \
patch('builtins.open', mock_open(read_data="""
connections:
test_connection:
type: sqlite
path: /path/to/db.sqlite
""")), \
pytest.raises(ConfigurationError, match="Failed to import"):
server._create_handler_for_type('sqlite', 'test_connection')
@pytest.mark.asyncio
async def test_handle_list_tables(self, server):
"""Test _handle_list_tables helper method"""
mock_handler = AsyncMock()
mock_handler.db_type = "test_db"
mock_handler.get_tables.return_value = [
types.Resource(uri="test://table1", name="table1", description="Test Table 1"),
types.Resource(uri="test://table2", name="table2")
]
with patch.object(server, 'get_handler', return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_handler),
__aexit__=AsyncMock()
)):
result = await server._handle_list_tables("test_connection")
assert len(result) == 1
assert result[0].type == "text"
assert "[test_db]" in result[0].text
assert "Table: table1" in result[0].text
assert "Table: table2" in result[0].text
assert "Description: Test Table 1" in result[0].text
@pytest.mark.asyncio
async def test_handle_list_tables_empty(self, server):
"""Test _handle_list_tables with empty table list"""
mock_handler = AsyncMock()
mock_handler.db_type = "test_db"
mock_handler.get_tables.return_value = []
with patch.object(server, 'get_handler', return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_handler),
__aexit__=AsyncMock()
)):
result = await server._handle_list_tables("test_connection")
assert len(result) == 1
assert result[0].type == "text"
assert "[test_db] No tables found" in result[0].text
def test_get_optimization_suggestions(self, server):
"""Test _get_optimization_suggestions helper method"""
# Test seq scan suggestion
result = server._get_optimization_suggestions("seq scan on table", 0.2)
assert len(result) == 1
assert "Consider adding an index" in result[0]
# Test hash join suggestion
result = server._get_optimization_suggestions("hash join on tables", 0.6)
assert len(result) > 0
# 检查至少一条建议包含索引或优化连接相关内容
assert any("index" in r.lower() or "join" in r.lower() for r in result)
# Test slow query suggestion
result = server._get_optimization_suggestions("normal plan", 0.6)
assert len(result) > 0
assert any("slow" in r.lower() or "optimiz" in r.lower() for r in result)
# Test temporary tables suggestion
result = server._get_optimization_suggestions("creates temporary table for", 0.1)
assert len(result) > 0
assert any("temporary" in r.lower() for r in result)
# Test no suggestions
result = server._get_optimization_suggestions("normal plan", 0.05)
assert len(result) == 0
@pytest.mark.asyncio
async def test_handle_analyze_query(self, server):
"""Test _handle_analyze_query helper method"""
# Mock a handler for testing
mock_handler = AsyncMock()
mock_handler.db_type = "test_db"
mock_handler.explain_query.return_value = "Execution Plan Details"
mock_handler.execute_query = AsyncMock()
with patch.object(server, 'get_handler', return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_handler),
__aexit__=AsyncMock()
)), patch.object(server, '_get_optimization_suggestions', return_value=["- Test suggestion"]):
# Test with a valid SQL query
result = await server._handle_analyze_query("test_connection", "SELECT * FROM test")
# Verify the handler methods were called
mock_handler.explain_query.assert_called_once_with("SELECT * FROM test")
mock_handler.execute_query.assert_called_once_with("SELECT * FROM test")
# Check the result format
assert len(result) == 1
assert result[0].type == "text"
assert "[test_db] Query Analysis" in result[0].text
assert "SQL: SELECT * FROM test" in result[0].text
assert "Execution Plan:" in result[0].text
assert "Execution Plan Details" in result[0].text
assert "Optimization Suggestions:" in result[0].text
assert "- Test suggestion" in result[0].text
@pytest.mark.asyncio
async def test_handle_analyze_query_empty_sql(self, server):
"""Test _handle_analyze_query with empty SQL"""
with pytest.raises(ConfigurationError):
await server._handle_analyze_query("test_connection", "")
@pytest.mark.asyncio
async def test_handle_analyze_query_non_select(self, server):
"""Test _handle_analyze_query with non-SELECT query"""
# Mock a handler for testing
mock_handler = AsyncMock()
mock_handler.db_type = "test_db"
mock_handler.explain_query.return_value = "Execution Plan Details"
with patch.object(server, 'get_handler', return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_handler),
__aexit__=AsyncMock()
)), patch.object(server, '_get_optimization_suggestions', return_value=[]):
# Test with INSERT query (should not call execute_query)
result = await server._handle_analyze_query("test_connection", "INSERT INTO test VALUES (1, 2)")
# Verify explain was called but execute was not
mock_handler.explain_query.assert_called_once_with("INSERT INTO test VALUES (1, 2)")
mock_handler.execute_query.assert_not_called()
assert len(result) == 1
assert result[0].type == "text"
assert "INSERT INTO test VALUES (1, 2)" in result[0].text
@pytest.mark.asyncio
async def test_handle_analyze_query_execution_error(self, server):
"""Test _handle_analyze_query with query execution error"""
# Mock a handler for testing
mock_handler = AsyncMock()
mock_handler.db_type = "test_db"
mock_handler.explain_query.return_value = "Execution Plan Details"
mock_handler.execute_query.side_effect = Exception("Query execution failed")
with patch.object(server, 'get_handler', return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_handler),
__aexit__=AsyncMock()
)), patch.object(server, '_get_optimization_suggestions', return_value=[]), \
patch.object(server, 'send_log'):
# Test with a SELECT query that fails during execution
result = await server._handle_analyze_query("test_connection", "SELECT * FROM test")
# We should still get a result with the execution plan
assert len(result) == 1
assert server.send_log.called
assert "Execution Plan:" in result[0].text