Skip to main content
Glama

mcp-dbutils

MIT License
81
  • Linux
  • Apple
"""Test base module write operations""" import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp.types import TextContent from mcp_dbutils.base import ( UNSUPPORTED_WRITE_OPERATION_ERROR, WRITE_CONFIRMATION_REQUIRED_ERROR, ConfigurationError, ConnectionHandlerError, ConnectionServer, ) class TestBaseWriteOperations: """Test base module write operations""" @pytest.fixture def connection_server(self): """Create a ConnectionServer instance for testing""" server = ConnectionServer("tests/fixtures/config.yaml", debug=True) # Mock the server's send_log method server.send_log = MagicMock() return server @pytest.mark.asyncio async def test_get_sql_type(self, connection_server): """Test get_sql_type method""" # Create a mock handler to test the method from mcp_dbutils.base import ConnectionHandler # Create a subclass of ConnectionHandler for testing class TestHandler(ConnectionHandler): @property def db_type(self) -> str: return "test" async def get_tables(self): return [] async def get_schema(self, table_name: str): return "" async def _execute_query(self, sql: str): return "" async def _execute_write_query(self, sql: str): return "" async def get_table_description(self, table_name: str): return "" async def get_table_ddl(self, table_name: str): return "" async def get_table_indexes(self, table_name: str): return "" async def get_table_stats(self, table_name: str): return "" async def get_table_constraints(self, table_name: str): return "" async def explain_query(self, sql: str): return "" async def test_connection(self): return True async def cleanup(self): pass # Create an instance of TestHandler handler = TestHandler("tests/fixtures/config.yaml", "test_conn") # Test SELECT statement assert handler._get_sql_type("SELECT * FROM users") == "SELECT" # Test INSERT statement assert handler._get_sql_type("INSERT INTO users VALUES (1, 'test')") == "INSERT" # Test UPDATE statement assert handler._get_sql_type("UPDATE users SET name = 'test' WHERE id = 1") == "UPDATE" # Test DELETE statement assert handler._get_sql_type("DELETE FROM users WHERE id = 1") == "DELETE" # Test CREATE statement assert handler._get_sql_type("CREATE TABLE users (id INT)") == "CREATE" # Test ALTER statement assert handler._get_sql_type("ALTER TABLE users ADD COLUMN name TEXT") == "ALTER" # Test DROP statement assert handler._get_sql_type("DROP TABLE users") == "DROP" # Test TRUNCATE statement assert handler._get_sql_type("TRUNCATE TABLE users") == "TRUNCATE" # Test BEGIN statement assert handler._get_sql_type("BEGIN TRANSACTION") == "TRANSACTION_START" # Test START statement assert handler._get_sql_type("START TRANSACTION") == "TRANSACTION_START" # Test COMMIT statement assert handler._get_sql_type("COMMIT") == "TRANSACTION_COMMIT" # Test ROLLBACK statement assert handler._get_sql_type("ROLLBACK") == "TRANSACTION_ROLLBACK" # Test unknown statement assert handler._get_sql_type("UNKNOWN STATEMENT") == "UNKNOWN" # Test case insensitivity assert handler._get_sql_type("select * from users") == "SELECT" assert handler._get_sql_type("insert into users values (1, 'test')") == "INSERT" @pytest.mark.asyncio async def test_extract_table_name(self, connection_server): """Test _extract_table_name method""" # Create a mock handler to test the method from mcp_dbutils.base import ConnectionHandler # Create a subclass of ConnectionHandler for testing class TestHandler(ConnectionHandler): @property def db_type(self) -> str: return "test" async def get_tables(self): return [] async def get_schema(self, table_name: str): return "" async def _execute_query(self, sql: str): return "" async def _execute_write_query(self, sql: str): return "" async def get_table_description(self, table_name: str): return "" async def get_table_ddl(self, table_name: str): return "" async def get_table_indexes(self, table_name: str): return "" async def get_table_stats(self, table_name: str): return "" async def get_table_constraints(self, table_name: str): return "" async def explain_query(self, sql: str): return "" async def test_connection(self): return True async def cleanup(self): pass # Create an instance of TestHandler handler = TestHandler("tests/fixtures/config.yaml", "test_conn") # Test INSERT statement assert handler._extract_table_name("INSERT INTO users VALUES (1, 'test')").lower() == "users" # Test INSERT statement with schema assert handler._extract_table_name("INSERT INTO public.users VALUES (1, 'test')").lower() == "public.users" # Test UPDATE statement assert handler._extract_table_name("UPDATE users SET name = 'test' WHERE id = 1").lower() == "users" # Test DELETE statement assert handler._extract_table_name("DELETE FROM users WHERE id = 1").lower() == "users" # Test with quoted table name assert handler._extract_table_name('INSERT INTO "users" VALUES (1, \'test\')').lower() == "users" assert handler._extract_table_name("INSERT INTO `users` VALUES (1, 'test')").lower() == "users" assert handler._extract_table_name("INSERT INTO [users] VALUES (1, 'test')").lower() == "users" # Test unknown statement assert handler._extract_table_name("UNKNOWN STATEMENT") == "unknown_table" @pytest.mark.asyncio async def test_handle_execute_write_no_confirmation(self, connection_server): """Test _handle_execute_write method without confirmation""" # Test without confirmation with pytest.raises(ConfigurationError, match=WRITE_CONFIRMATION_REQUIRED_ERROR): await connection_server._handle_execute_write("test_conn", "INSERT INTO users VALUES (1, 'test')", "") @pytest.mark.asyncio async def test_handle_execute_write_unsupported_operation(self, connection_server): """Test _handle_execute_write method with unsupported operation""" # Create a patch for the internal methods with ( patch("mcp_dbutils.base.ConnectionHandler._get_sql_type", return_value="SELECT"), patch.object(connection_server, "_get_sql_type", side_effect=lambda sql: "SELECT"), pytest.raises(ConfigurationError, match=UNSUPPORTED_WRITE_OPERATION_ERROR.format(operation="SELECT")) ): # Test with unsupported operation await connection_server._handle_execute_write("test_conn", "SELECT * FROM users", "CONFIRM_WRITE") @pytest.mark.asyncio async def test_handle_execute_write_success(self, connection_server): """Test _handle_execute_write method with success""" # Create a patch for the internal methods with ( patch.object(connection_server, "_get_sql_type", side_effect=lambda sql: "INSERT"), patch.object(connection_server, "_extract_table_name", side_effect=lambda sql: "users"), patch.object(connection_server, "_get_config_or_raise", return_value={ "type": "mysql", "host": "localhost", "port": 3306, "database": "test", "user": "test", "password": "test", "writable": True }), patch.object(connection_server, "_check_write_permission", new_callable=AsyncMock) ): # Mock get_handler to return a handler that returns a success message mock_handler = AsyncMock() mock_handler.__aenter__.return_value.execute_write_query.return_value = "Write operation executed successfully. 1 row affected." connection_server.get_handler = MagicMock(return_value=mock_handler) # Test with success result = await connection_server._handle_execute_write( "test_conn", "INSERT INTO users VALUES (1, 'test')", "CONFIRM_WRITE" ) # Verify the result assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].type == "text" assert "Write operation executed successfully" in result[0].text # Verify the handler was called correctly connection_server.get_handler.assert_called_once_with("test_conn") mock_handler.__aenter__.return_value.execute_write_query.assert_called_once_with( "INSERT INTO users VALUES (1, 'test')" ) @pytest.mark.asyncio async def test_handle_execute_write_error(self, connection_server): """Test _handle_execute_write method with error""" # Create a patch for the internal methods with ( patch.object(connection_server, "_get_sql_type", side_effect=lambda sql: "UPDATE"), patch.object(connection_server, "_extract_table_name", side_effect=lambda sql: "users"), patch.object(connection_server, "_get_config_or_raise", return_value={ "type": "mysql", "host": "localhost", "port": 3306, "database": "test", "user": "test", "password": "test", "writable": True }), patch.object(connection_server, "_check_write_permission", new_callable=AsyncMock) ): # Mock get_handler to return a handler that raises an exception mock_handler = AsyncMock() mock_handler.__aenter__.return_value.execute_write_query.side_effect = ConnectionHandlerError("Database error") connection_server.get_handler = MagicMock(return_value=mock_handler) # Test with error with pytest.raises(ConnectionHandlerError, match="Database error"): await connection_server._handle_execute_write( "test_conn", "UPDATE users SET name = 'test' WHERE id = 1", "CONFIRM_WRITE" ) # Verify the handler was called correctly connection_server.get_handler.assert_called_once_with("test_conn") mock_handler.__aenter__.return_value.execute_write_query.assert_called_once_with( "UPDATE users SET name = 'test' WHERE id = 1" ) @pytest.mark.asyncio async def test_handle_get_audit_logs(self, connection_server): """Test _handle_get_audit_logs method""" # Mock get_logs and format_logs functions with ( patch("mcp_dbutils.base.get_logs") as mock_get_logs, patch("mcp_dbutils.base.format_logs") as mock_format_logs ): mock_get_logs.return_value = [ { "timestamp": "2023-01-01T12:00:00", "connection_name": "test_conn", "table_name": "users", "operation_type": "INSERT", "sql_statement": "INSERT INTO users VALUES (?)", "affected_rows": 1, "status": "SUCCESS", "execution_time": 10.5 } ] mock_format_logs.return_value = "Formatted audit logs" # Test with all parameters result = await connection_server._handle_get_audit_logs( "test_conn", "users", "INSERT", "SUCCESS", 10 ) # Verify the result assert isinstance(result, list) assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].type == "text" assert "Formatted audit logs" in result[0].text # Verify get_logs was called correctly mock_get_logs.assert_called_once_with( connection_name="test_conn", table_name="users", operation_type="INSERT", status="SUCCESS", limit=10 ) # Verify format_logs was called correctly mock_format_logs.assert_called_once_with(mock_get_logs.return_value) # Test with minimal parameters mock_get_logs.reset_mock() mock_format_logs.reset_mock() result = await connection_server._handle_get_audit_logs("test_conn", "", "", "", 100) # Verify get_logs was called correctly mock_get_logs.assert_called_once_with( connection_name="test_conn", table_name="", operation_type="", status="", limit=100 ) @pytest.mark.asyncio async def test_handle_get_audit_logs_error(self, connection_server): """Test _handle_get_audit_logs method with error""" # Mock get_logs to raise an exception and expect an exception with ( patch("mcp_dbutils.base.get_logs", side_effect=ValueError("Test error")), pytest.raises(ValueError, match="Test error") ): # Test with error await connection_server._handle_get_audit_logs("test_conn", "users", "INSERT", "SUCCESS", 10) @pytest.mark.asyncio async def test_handle_call_tool_execute_write(self, connection_server): """Test call_tool handler with dbutils-execute-write tool""" # Mock the _handle_execute_write method connection_server._handle_execute_write = AsyncMock(return_value=[ TextContent(type="text", text="Write operation executed successfully") ]) # Create a mock call_tool handler function async def mock_handle_call_tool(name, arguments): if name == "dbutils-execute-write": connection = arguments.get("connection", "") sql = arguments.get("sql", "").strip() confirmation = arguments.get("confirmation", "").strip() return await connection_server._handle_execute_write(connection, sql, confirmation) else: raise ValueError(f"Unknown tool: {name}") # Test with execute-write tool result = await mock_handle_call_tool("dbutils-execute-write", { "connection": "test_conn", "sql": "INSERT INTO users (name) VALUES ('Test User')", "confirmation": "CONFIRM_WRITE" }) # Verify the result assert result == [TextContent(type="text", text="Write operation executed successfully")] # Verify _handle_execute_write was called correctly connection_server._handle_execute_write.assert_called_once_with( "test_conn", "INSERT INTO users (name) VALUES ('Test User')", "CONFIRM_WRITE" ) @pytest.mark.asyncio async def test_handle_call_tool_get_audit_logs(self, connection_server): """Test call_tool handler with dbutils-get-audit-logs tool""" # Mock the _handle_get_audit_logs method connection_server._handle_get_audit_logs = AsyncMock(return_value=[ TextContent(type="text", text="Audit logs") ]) # Create a mock call_tool handler function async def mock_handle_call_tool(name, arguments): if name == "dbutils-get-audit-logs": connection = arguments.get("connection", "") table = arguments.get("table", "").strip() operation_type = arguments.get("operation_type", "").strip() status = arguments.get("status", "").strip() limit = arguments.get("limit", 100) return await connection_server._handle_get_audit_logs(connection, table, operation_type, status, limit) else: raise ValueError(f"Unknown tool: {name}") # Test with get-audit-logs tool result = await mock_handle_call_tool("dbutils-get-audit-logs", { "connection": "test_conn", "table": "users", "operation_type": "INSERT", "status": "SUCCESS", "limit": 10 }) # Verify the result assert result == [TextContent(type="text", text="Audit logs")] # Verify _handle_get_audit_logs was called correctly connection_server._handle_get_audit_logs.assert_called_once_with( "test_conn", "users", "INSERT", "SUCCESS", 10 )

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/donghao1393/mcp-dbutils'

If you have feedback or need assistance with the MCP directory API, please join our Discord server