mcp-dbutils

"""Unit tests for SQLite server implementation""" import json import sqlite3 from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import mcp.types as types import pytest from mcp_dbutils.sqlite.config import SQLiteConfig from mcp_dbutils.sqlite.server import SQLiteServer @pytest.fixture def mock_sqlite_config(): """Mock SQLite configuration""" config = MagicMock(spec=SQLiteConfig) config.path = "/path/to/test.db" config.absolute_path = "/path/to/test.db" config.debug = False # Mock the get_connection_params method config.get_connection_params.return_value = { "database": "/path/to/test.db" } # Mock the get_masked_connection_info method config.get_masked_connection_info.return_value = { "database": "/path/to/test.db" } return config @pytest.fixture def mock_cursor(): """Mock SQLite cursor""" cursor = MagicMock() cursor.fetchall.return_value = [] cursor.description = [] return cursor @pytest.fixture def mock_connection(mock_cursor): """Mock SQLite connection""" connection = MagicMock() connection.execute.return_value = mock_cursor connection.row_factory = None connection.close = MagicMock() return connection class TestSQLiteServer: """Test SQLite server implementation""" @patch("sqlite3.connect") @patch("pathlib.Path.mkdir") def test_init(self, mock_mkdir, mock_connect, mock_sqlite_config): """Test server initialization""" # Setup mock_conn = MagicMock() mock_connect.return_value = mock_conn mock_conn.__enter__.return_value = mock_conn mock_conn.__exit__.return_value = None # Execute server = SQLiteServer(mock_sqlite_config) # Verify assert server.config == mock_sqlite_config mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) mock_connect.assert_called_once() @patch("sqlite3.connect") def test_get_connection(self, mock_connect, mock_sqlite_config): """Test getting connection""" # Setup mock_conn = MagicMock() mock_connect.return_value = mock_conn with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute connection = server._get_connection() # Verify assert connection == mock_conn mock_connect.assert_called_once_with(**mock_sqlite_config.get_connection_params()) assert connection.row_factory == sqlite3.Row @pytest.mark.asyncio @patch("sqlite3.connect") async def test_list_resources(self, mock_connect, mock_sqlite_config, mock_connection, mock_cursor): """Test listing resources""" # Setup mock_connect.return_value = mock_connection mock_cursor.fetchall.return_value = [("users",), ("products",)] with patch.object(SQLiteServer, "__init__", return_value=None), \ patch.object(SQLiteServer, "_get_connection", return_value=mock_connection): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute resources = await server.list_resources() # Verify assert len(resources) == 2 assert resources[0].name == "users schema" # Convert AnyUrl to string for comparison assert str(resources[0].uri) == "sqlite://users/schema" assert resources[1].name == "products schema" mock_connection.execute.assert_called_once() mock_cursor.fetchall.assert_called_once() @pytest.mark.asyncio @patch("sqlite3.connect") async def test_read_resource(self, mock_connect, mock_sqlite_config, mock_connection, mock_cursor): """Test reading resource""" # Setup mock_connect.return_value = mock_connection # Mock table_info results table_info_results = [ {"name": "id", "type": "INTEGER", "notnull": 1, "pk": 1}, {"name": "name", "type": "TEXT", "notnull": 0, "pk": 0} ] # Mock index_list results index_list_results = [ {"name": "idx_name", "unique": 1} ] # Configure mock cursor to return different results for different queries def mock_execute(query): if "table_info" in query: mock_cursor.fetchall.return_value = table_info_results elif "index_list" in query: mock_cursor.fetchall.return_value = index_list_results return mock_cursor mock_connection.execute.side_effect = mock_execute with patch.object(SQLiteServer, "__init__", return_value=None), \ patch.object(SQLiteServer, "_get_connection", return_value=mock_connection): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute result = await server.read_resource("sqlite://users/schema") # Verify import json result_dict = json.loads(result) # Convert JSON string to dict assert len(result_dict["columns"]) == 2 assert result_dict["columns"][0]["name"] == "id" assert result_dict["columns"][0]["primary_key"] is True assert result_dict["columns"][1]["name"] == "name" assert result_dict["columns"][1]["nullable"] is True assert len(result_dict["indexes"]) == 1 assert result_dict["indexes"][0]["name"] == "idx_name" assert result_dict["indexes"][0]["unique"] is True assert mock_connection.execute.call_count == 2 def test_get_tools(self, mock_sqlite_config): """Test getting tools""" # Setup with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) # Execute tools = server.get_tools() # Verify assert len(tools) == 1 assert tools[0].name == "query" assert "SQL" in tools[0].description assert "sql" in tools[0].inputSchema["properties"] assert "sql" in tools[0].inputSchema["required"] @pytest.mark.asyncio @patch("sqlite3.connect") async def test_call_tool_query(self, mock_connect, mock_sqlite_config, mock_connection, mock_cursor): """Test calling query tool""" # Setup mock_connect.return_value = mock_connection mock_cursor.description = [("id",), ("name",)] mock_cursor.fetchall.return_value = [{"id": 1, "name": "Test User"}] with patch.object(SQLiteServer, "__init__", return_value=None), \ patch.object(SQLiteServer, "_get_connection", return_value=mock_connection): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute result = await server.call_tool("query", {"sql": "SELECT * FROM users"}) # Verify assert len(result) == 1 assert result[0].type == "text" import json result_dict = json.loads(result[0].text) assert result_dict["type"] == "sqlite" assert result_dict["query_result"]["row_count"] == 1 assert "id" in result_dict["query_result"]["columns"] assert "name" in result_dict["query_result"]["columns"] mock_connection.execute.assert_called_once_with("SELECT * FROM users") mock_cursor.fetchall.assert_called_once() @pytest.mark.asyncio @patch("sqlite3.connect") async def test_call_tool_with_connection(self, mock_connect, mock_sqlite_config, mock_cursor): """Test calling query tool with specific connection""" # Setup mock_connection = MagicMock() mock_connection.execute.return_value = mock_cursor mock_cursor.description = [("id",), ("name",)] mock_cursor.fetchall.return_value = [{"id": 1, "name": "Test User"}] mock_connect.return_value = mock_connection with patch.object(SQLiteServer, "__init__", return_value=None), \ patch.object(SQLiteConfig, "from_yaml") as mock_from_yaml: mock_config = MagicMock(spec=SQLiteConfig) mock_config.get_connection_params.return_value = {"database": "/path/to/other.db"} mock_config.get_masked_connection_info.return_value = {"database": "/path/to/other.db"} mock_from_yaml.return_value = mock_config server = SQLiteServer(None) server.config = mock_sqlite_config server.config_path = "/path/to/config.yaml" server.log = MagicMock() # Execute the method with a specific connection try: await server.call_tool("query", { "sql": "SELECT * FROM users", "connection": "test_connection" }) except Exception: # We don't care about the result, just that from_yaml was called pass # Verify that from_yaml was called with the correct arguments mock_from_yaml.assert_called_once_with("/path/to/config.yaml", "test_connection") mock_connect.assert_called_once_with(**mock_config.get_connection_params()) @pytest.mark.asyncio async def test_call_tool_invalid_name(self, mock_sqlite_config): """Test calling invalid tool""" # Setup with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="未知工具"): await server.call_tool("invalid_tool", {}) @pytest.mark.asyncio async def test_call_tool_empty_sql(self, mock_sqlite_config): """Test calling query tool with empty SQL""" # Setup with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="SQL查询不能为空"): await server.call_tool("query", {"sql": ""}) @pytest.mark.asyncio async def test_call_tool_non_select(self, mock_sqlite_config): """Test calling query tool with non-SELECT SQL""" # Setup with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute and verify with pytest.raises(ValueError, match="仅支持SELECT查询"): await server.call_tool("query", {"sql": "DELETE FROM users"}) @pytest.mark.asyncio @patch("sqlite3.connect") async def test_call_tool_query_error(self, mock_connect, mock_sqlite_config): """Test calling query tool with error""" # Setup mock_connect.return_value = MagicMock() mock_connect.return_value.execute.side_effect = sqlite3.Error("no such table: users") with patch.object(SQLiteServer, "__init__", return_value=None), \ patch.object(SQLiteServer, "_get_connection", side_effect=mock_connect): server = SQLiteServer(None) server.config = mock_sqlite_config server.log = MagicMock() # Execute result = await server.call_tool("query", {"sql": "SELECT * FROM users"}) # Verify assert len(result) == 1 assert result[0].type == "text" import json result_dict = json.loads(result[0].text) assert result_dict["type"] == "sqlite" assert "error" in result_dict assert "no such table" in result_dict["error"] @pytest.mark.asyncio async def test_cleanup(self, mock_sqlite_config): """Test cleanup""" # Setup with patch.object(SQLiteServer, "__init__", return_value=None): server = SQLiteServer(None) server.config = mock_sqlite_config # Execute await server.cleanup() # Verify - SQLite doesn't need cleanup, so just make sure it doesn't error pass