mcp-dbutils

"""Unit tests for MySQL server implementation""" import json from unittest.mock import AsyncMock, MagicMock, patch import mcp.types as types import pytest from mysql.connector.pooling import MySQLConnectionPool from mcp_dbutils.mysql.config import MySQLConfig from mcp_dbutils.mysql.server import MySQLServer @pytest.fixture def mock_mysql_config(): """Mock MySQL configuration""" config = MagicMock(spec=MySQLConfig) config.host = "localhost" config.port = 3306 config.database = "test_db" config.user = "test_user" config.password = "test_password" config.debug = False # Mock the get_connection_params method config.get_connection_params.return_value = { "host": "localhost", "port": 3306, "database": "test_db", "user": "test_user", "password": "test_password" } # Mock the get_masked_connection_info method config.get_masked_connection_info.return_value = { "host": "localhost", "port": 3306, "database": "test_db", "user": "test_user", "password": "********" } return config @pytest.fixture def mock_cursor(): """Mock MySQL cursor""" cursor = MagicMock() cursor.__enter__ = MagicMock(return_value=cursor) cursor.__exit__ = MagicMock(return_value=None) return cursor @pytest.fixture def mock_connection(mock_cursor): """Mock MySQL connection""" connection = MagicMock() connection.cursor.return_value = mock_cursor connection.close = MagicMock() return connection @pytest.fixture def mock_pool(mock_connection): """Mock MySQL connection pool""" pool = MagicMock(spec=MySQLConnectionPool) pool.get_connection.return_value = mock_connection return pool class TestMySQLServer: """Test MySQL server implementation""" def test_init(self, mock_mysql_config): """Test server initialization""" # Skip actual initialization and test the class structure with patch.object(MySQLServer, "__init__", return_value=None) as mock_init: server = MySQLServer(mock_mysql_config) mock_init.assert_called_once_with(mock_mysql_config) # Manually set attributes that would be set in __init__ server.config = mock_mysql_config server.pool = MagicMock(spec=MySQLConnectionPool) # Verify assert server.config == mock_mysql_config assert hasattr(server, "pool") @pytest.mark.asyncio async def test_list_resources(self, mock_mysql_config, mock_pool, mock_cursor): """Test listing resources""" # Setup mock_tables = [ {"table_name": "users", "description": "User table"}, {"table_name": "products", "description": None} ] mock_cursor.fetchall.return_value = mock_tables with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.pool = mock_pool server.log = MagicMock() # Execute resources = await server.list_resources() # Verify assert len(resources) == 2 assert resources[0].name == "users schema" # Convert AnyUrl to string for comparison assert str(resources[0].uri) == "mysql://localhost/users/schema" assert resources[0].description == "User table" assert resources[1].name == "products schema" assert resources[1].description is None mock_pool.get_connection.assert_called_once() mock_cursor.execute.assert_called_once() mock_cursor.fetchall.assert_called_once() @pytest.mark.asyncio async def test_read_resource(self, mock_mysql_config, mock_pool, mock_cursor): """Test reading resource""" # Setup mock_columns = [ {"column_name": "id", "data_type": "int", "is_nullable": "NO", "description": "Primary key"}, {"column_name": "name", "data_type": "varchar", "is_nullable": "YES", "description": "User name"} ] mock_constraints = [ {"constraint_name": "pk_users", "constraint_type": "PRIMARY KEY"} ] # Configure mock cursor to return different results for different queries def mock_execute(query, params=None): if "columns" in query: mock_cursor.fetchall.return_value = mock_columns elif "constraints" in query: mock_cursor.fetchall.return_value = mock_constraints mock_cursor.execute.side_effect = mock_execute with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.pool = mock_pool server.log = MagicMock() # Execute result = await server.read_resource("mysql://localhost/users/schema") # Verify result_dict = eval(result) # Convert string representation to dict assert len(result_dict["columns"]) == 2 assert result_dict["columns"][0]["name"] == "id" assert result_dict["columns"][0]["nullable"] is False assert len(result_dict["constraints"]) == 1 assert result_dict["constraints"][0]["name"] == "pk_users" mock_pool.get_connection.assert_called_once() assert mock_cursor.execute.call_count == 2 def test_get_tools(self, mock_mysql_config): """Test getting tools""" # Setup with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) # Execute tools = server.get_tools() # Verify assert len(tools) == 1 assert tools[0].name == "query" assert "SQL" in tools[0].description assert "sql" in tools[0].inputSchema["properties"] assert "sql" in tools[0].inputSchema["required"] @pytest.mark.asyncio async def test_call_tool_query(self, mock_mysql_config, mock_pool, mock_cursor): """Test calling query tool""" # Setup mock_cursor.description = [("id",), ("name",)] mock_cursor.fetchall.return_value = [{"id": 1, "name": "Test User"}] with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.pool = mock_pool server.log = MagicMock() # Execute result = await server.call_tool("query", {"sql": "SELECT * FROM users"}) # Verify assert len(result) == 1 assert result[0].type == "text" result_dict = eval(result[0].text) assert result_dict["type"] == "mysql" assert result_dict["query_result"]["row_count"] == 1 assert "id" in result_dict["query_result"]["columns"] assert "name" in result_dict["query_result"]["columns"] mock_pool.get_connection.assert_called_once() assert mock_cursor.execute.call_count >= 2 # SET TRANSACTION + query + ROLLBACK @pytest.mark.asyncio async def test_call_tool_invalid_name(self, mock_mysql_config): """Test calling invalid tool""" # Setup with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="未知工具"): await server.call_tool("invalid_tool", {}) @pytest.mark.asyncio async def test_call_tool_empty_sql(self, mock_mysql_config): """Test calling query tool with empty SQL""" # Setup with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="SQL查询不能为空"): await server.call_tool("query", {"sql": ""}) @pytest.mark.asyncio async def test_call_tool_non_select(self, mock_mysql_config): """Test calling query tool with non-SELECT SQL""" # Setup with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="仅支持SELECT查询"): await server.call_tool("query", {"sql": "DELETE FROM users"}) @pytest.mark.asyncio async def test_call_tool_query_error(self, mock_mysql_config, mock_pool, mock_cursor): """Test calling query tool with error""" # Setup mock_cursor.execute.side_effect = Exception("Test error") with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.config = mock_mysql_config server.pool = mock_pool server.log = MagicMock() # Execute result = await server.call_tool("query", {"sql": "SELECT * FROM users"}) # Verify assert len(result) == 1 assert result[0].type == "text" result_dict = eval(result[0].text) assert result_dict["type"] == "mysql" assert "error" in result_dict assert "Test error" in result_dict["error"] mock_pool.get_connection.assert_called_once() @pytest.mark.asyncio async def test_cleanup(self, mock_mysql_config, mock_pool): """Test cleanup""" # Setup with patch.object(MySQLServer, "__init__", return_value=None): server = MySQLServer(None) server.pool = mock_pool server.log = MagicMock() # Execute await server.cleanup() # Verify server.log.assert_called_once()