mcp-dbutils
by donghao1393
- tests
- unit
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import mcp.types as types
import pytest
from mcp_dbutils.base import (
LOG_LEVEL_DEBUG,
LOG_LEVEL_ERROR,
ConfigurationError,
ConnectionServer,
)
# Constants for error messages
CONNECTION_NAME_REQUIRED_ERROR = "Connection name is required"
INVALID_URI_FORMAT_ERROR = "Invalid URI format"
DATABASE_CONNECTION_NAME = "Database connection name"
@pytest.fixture
def connection_server():
# Mock the config path and server initialization
with patch("os.path.exists", return_value=True), \
patch("builtins.open", MagicMock()), \
patch("yaml.safe_load", return_value={}):
server = ConnectionServer(config_path="mock_config.yaml")
server.send_log = MagicMock()
return server
class TestConnectionServerPrompts:
@pytest.mark.asyncio
async def test_list_prompts(self, connection_server):
"""Test the list_prompts handler returns an empty list"""
# Create a mock list_prompts handler function
async def mock_list_prompts():
connection_server.send_log()
return []
# Call the mock function
result = await mock_list_prompts()
assert isinstance(result, list)
assert len(result) == 0
connection_server.send_log.assert_called_once()
@pytest.mark.asyncio
async def test_list_prompts_exception(self, connection_server):
"""Test the list_prompts handler handles exceptions"""
# Create a mock list_prompts handler function that raises an exception
async def mock_list_prompts_with_exception():
connection_server.send_log()
raise ValueError("Test exception")
# Call the mock function and expect an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_list_prompts_with_exception()
# Verify that send_log was called
connection_server.send_log.assert_called_once()
def test_setup_prompts(self, connection_server):
"""Test the _setup_prompts method sets up the handler correctly"""
# Mock the server.list_prompts decorator
mock_decorator = MagicMock()
mock_decorator.return_value = lambda f: f # Return the function unchanged
# Replace the server.list_prompts with our mock
original_list_prompts = connection_server.server.list_prompts
connection_server.server.list_prompts = mock_decorator
try:
# Call the method
connection_server._setup_prompts()
# Verify the decorator was called
assert mock_decorator.called, "Decorator should have been called"
# Get the args from the most recent call
if mock_decorator.call_args: # Check if call_args exists first
# If it was called with positional arguments
if mock_decorator.call_args.args:
handler = mock_decorator.call_args.args[0]
assert callable(handler)
assert handler.__name__ == "handle_list_prompts"
assert handler.__doc__ == "Handle prompts/list request"
# If it was called with a no-args decorator
else:
# Just check that the decorator was called
pass
else:
# If call_args doesn't exist, it was called differently
pass
finally:
# Restore the original method
connection_server.server.list_prompts = original_list_prompts
@pytest.mark.asyncio
async def test_setup_prompts_exception_handler(self, connection_server):
"""Test that the _setup_prompts method handles exceptions correctly in the decorated function"""
# Mock the decorator to capture the decorated function
original_list_prompts = connection_server.server.list_prompts
captured_handler = None
def mock_decorator():
def wrapper(func):
nonlocal captured_handler
captured_handler = func
return func
return wrapper
# Replace the decorator with our mock
connection_server.server.list_prompts = mock_decorator
try:
# Call the method to set up the handler
connection_server._setup_prompts()
# Now we have captured the handler, restore the original decorator
connection_server.server.list_prompts = original_list_prompts
# Make sure we captured the handler
assert captured_handler is not None
# Setup mocks for testing exception flow
connection_server.send_log = MagicMock()
# Create a test exception that will be raised in the try block
test_exception = ValueError("Test error in list_prompts")
# Mock the self.send_log in the try block to raise an exception after being called
def mock_send_log_and_raise(*args, **kwargs):
# First call during normal operation (debug log)
if args[0] == LOG_LEVEL_DEBUG:
# After logging debug, simulate an error
raise test_exception
connection_server.send_log.side_effect = mock_send_log_and_raise
# Call the handler and expect the exception to be caught and logged, then re-raised
with pytest.raises(ValueError, match="Test error in list_prompts"):
await captured_handler()
# Check that both the debug log and the error log were called
assert connection_server.send_log.call_count == 2
connection_server.send_log.assert_any_call(LOG_LEVEL_DEBUG, "Handling list_prompts request")
connection_server.send_log.assert_any_call(LOG_LEVEL_ERROR, f"Error in list_prompts: {test_exception}")
finally:
# Make sure we always restore the original decorator
connection_server.server.list_prompts = original_list_prompts
class TestConnectionServerTools:
def test_get_available_tools(self, connection_server):
"""Test the _get_available_tools method returns the expected tools"""
tools = connection_server._get_available_tools()
assert isinstance(tools, list)
assert len(tools) > 0
# Verify the first tool is dbutils-run-query
assert tools[0].name == "dbutils-run-query"
assert "Execute read-only SQL query" in tools[0].description
# Check that all tools have the required properties
for tool in tools:
assert isinstance(tool, types.Tool)
assert tool.name.startswith("dbutils-")
assert isinstance(tool.description, str)
assert isinstance(tool.inputSchema, dict)
class TestConnectionServerHandlers:
@pytest.mark.asyncio
async def test_handle_list_resources_no_connection(self, connection_server):
"""Test list_resources handler with no connection argument"""
# Create a mock list_resources handler function
async def mock_handle_list_resources(arguments=None):
if not arguments or 'connection' not in arguments:
return []
connection = arguments['connection']
async with connection_server.get_handler(connection) as handler:
return await handler.get_tables()
# Test with no arguments
result = await mock_handle_list_resources()
assert isinstance(result, list)
assert len(result) == 0
# Test with empty arguments
result = await mock_handle_list_resources({})
assert isinstance(result, list)
assert len(result) == 0
@pytest.mark.asyncio
async def test_handle_list_resources_with_connection(self, connection_server):
"""Test list_resources handler with connection argument"""
# Mock the get_handler method
mock_handler = AsyncMock()
mock_handler.__aenter__.return_value.get_tables.return_value = [
types.Resource(
uri="mock://table1/schema",
name="table1",
description="Test table 1",
mimeType="application/json"
),
types.Resource(
uri="mock://table2/schema",
name="table2",
description=None,
mimeType="application/json"
)
]
connection_server.get_handler = MagicMock(return_value=mock_handler)
# Create a mock list_resources handler function
async def mock_handle_list_resources(arguments=None):
if not arguments or 'connection' not in arguments:
return []
connection = arguments['connection']
async with connection_server.get_handler(connection) as handler:
return await handler.get_tables()
# Test with connection argument
result = await mock_handle_list_resources({"connection": "test_conn"})
assert isinstance(result, list)
assert len(result) == 2
assert result[0].name == "table1"
assert result[1].name == "table2"
connection_server.get_handler.assert_called_once_with("test_conn")
mock_handler.__aenter__.return_value.get_tables.assert_called_once()
@pytest.mark.asyncio
async def test_handle_list_resources_exception(self, connection_server):
"""Test list_resources handler with an exception"""
# Mock the get_handler method to raise an exception
mock_handler = AsyncMock()
mock_handler.__aenter__.return_value.get_tables.side_effect = ValueError("Test exception")
connection_server.get_handler = MagicMock(return_value=mock_handler)
connection_server.send_log = MagicMock()
# Create a mock list_resources handler function with exception handling
async def mock_handle_list_resources(arguments=None):
if not arguments or 'connection' not in arguments:
# Return empty list when no connection specified
return []
connection = arguments['connection']
try:
async with connection_server.get_handler(connection) as handler:
return await handler.get_tables()
except Exception as e:
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in list_resources: {str(e)}")
# Re-raise to test exception handling
raise
# Test with connection argument that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_list_resources({"connection": "test_conn"})
connection_server.get_handler.assert_called_once_with("test_conn")
mock_handler.__aenter__.return_value.get_tables.assert_called_once()
connection_server.send_log.assert_called_once_with(LOG_LEVEL_ERROR, "Error in list_resources: Test exception")
@pytest.mark.asyncio
async def test_handle_read_resource_no_connection(self, connection_server):
"""Test read_resource handler with no connection argument"""
# Create a mock read_resource handler function
async def mock_handle_read_resource(uri, arguments=None):
if not arguments or 'connection' not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
parts = uri.split('/')
if len(parts) < 3:
raise ConfigurationError(INVALID_URI_FORMAT_ERROR)
connection = arguments['connection']
table_name = parts[-2] # URI format: xxx/table_name/schema
async with connection_server.get_handler(connection) as handler:
return await handler.get_schema(table_name)
# Test with no arguments
with pytest.raises(ConfigurationError, match=CONNECTION_NAME_REQUIRED_ERROR):
await mock_handle_read_resource("mock://table/schema", None)
# Test with empty arguments
with pytest.raises(ConfigurationError, match=CONNECTION_NAME_REQUIRED_ERROR):
await mock_handle_read_resource("mock://table/schema", {})
@pytest.mark.asyncio
async def test_handle_read_resource_invalid_uri(self, connection_server):
"""Test read_resource handler with invalid URI format"""
# Create a mock read_resource handler function
async def mock_handle_read_resource(uri, arguments=None):
if not arguments or 'connection' not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
parts = uri.split('/')
if len(parts) < 3:
raise ConfigurationError(INVALID_URI_FORMAT_ERROR)
connection = arguments['connection']
table_name = parts[-2] # URI format: xxx/table_name/schema
async with connection_server.get_handler(connection) as handler:
return await handler.get_schema(table_name)
# Test with invalid URI
with pytest.raises(ConfigurationError, match=INVALID_URI_FORMAT_ERROR):
await mock_handle_read_resource("invalid", {"connection": "test_conn"})
@pytest.mark.asyncio
async def test_handle_read_resource_valid(self, connection_server):
"""Test read_resource handler with valid arguments"""
# Mock the get_handler method
mock_handler = AsyncMock()
mock_handler.__aenter__.return_value.get_schema.return_value = "table schema"
connection_server.get_handler = MagicMock(return_value=mock_handler)
# Create a mock read_resource handler function
async def mock_handle_read_resource(uri, arguments=None):
if not arguments or 'connection' not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
parts = uri.split('/')
if len(parts) < 3:
raise ConfigurationError(INVALID_URI_FORMAT_ERROR)
connection = arguments['connection']
table_name = parts[-2] # URI format: xxx/table_name/schema
async with connection_server.get_handler(connection) as handler:
return await handler.get_schema(table_name)
# Test with valid arguments
result = await mock_handle_read_resource("mock://table1/schema", {"connection": "test_conn"})
assert result == "table schema"
connection_server.get_handler.assert_called_once_with("test_conn")
mock_handler.__aenter__.return_value.get_schema.assert_called_once_with("table1")
@pytest.mark.asyncio
async def test_handle_read_resource_exception(self, connection_server):
"""Test read_resource handler with an exception"""
# Mock the get_handler method to raise an exception
mock_handler = AsyncMock()
mock_handler.__aenter__.return_value.get_schema.side_effect = ValueError("Test exception")
connection_server.get_handler = MagicMock(return_value=mock_handler)
connection_server.send_log = MagicMock()
# Create a mock read_resource handler function with exception handling
async def mock_handle_read_resource(uri, arguments=None):
if not arguments or 'connection' not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
parts = uri.split('/')
if len(parts) < 3:
raise ConfigurationError(INVALID_URI_FORMAT_ERROR)
connection = arguments['connection']
table_name = parts[-2] # URI format: xxx/table_name/schema
try:
async with connection_server.get_handler(connection) as handler:
return await handler.get_schema(table_name)
except Exception as e:
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in read_resource: {str(e)}")
# Re-raise to test exception handling
raise
# Test with arguments that raise an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_read_resource("mock://table1/schema", {"connection": "test_conn"})
connection_server.get_handler.assert_called_once_with("test_conn")
mock_handler.__aenter__.return_value.get_schema.assert_called_once_with("table1")
connection_server.send_log.assert_called_once_with(LOG_LEVEL_ERROR, "Error in read_resource: Test exception")
@pytest.mark.asyncio
async def test_handle_list_tools(self, connection_server):
"""Test list_tools handler returns available tools"""
# Mock the _get_available_tools method
mock_tools = [types.Tool(name="test-tool", description="Test tool", inputSchema={})]
connection_server._get_available_tools = MagicMock(return_value=mock_tools)
# Create a mock list_tools handler function
async def mock_handle_list_tools():
return connection_server._get_available_tools()
# Test the function
result = await mock_handle_list_tools()
assert result == mock_tools
connection_server._get_available_tools.assert_called_once()
@pytest.mark.asyncio
async def test_handle_call_tool_no_connection(self, connection_server):
"""Test call_tool handler with no connection argument"""
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-list-tables":
return await connection_server._handle_list_tables(connection)
elif name == "dbutils-run-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_run_query(connection, sql)
elif name in ["dbutils-describe-table", "dbutils-get-ddl", "dbutils-list-indexes",
"dbutils-get-stats", "dbutils-list-constraints"]:
table = arguments.get("table", "").strip()
return await connection_server._handle_table_tools(name, connection, table)
elif name == "dbutils-explain-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_explain_query(connection, sql)
elif name == "dbutils-get-performance":
return await connection_server._handle_performance(connection)
elif name == "dbutils-analyze-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_analyze_query(connection, sql)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with no connection
with pytest.raises(ConfigurationError, match=CONNECTION_NAME_REQUIRED_ERROR):
await mock_handle_call_tool("dbutils-run-query", {})
@pytest.mark.asyncio
async def test_handle_call_tool_unknown_tool(self, connection_server):
"""Test call_tool handler with unknown tool name"""
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-list-tables":
return await connection_server._handle_list_tables(connection)
elif name == "dbutils-run-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_run_query(connection, sql)
elif name in ["dbutils-describe-table", "dbutils-get-ddl", "dbutils-list-indexes",
"dbutils-get-stats", "dbutils-list-constraints"]:
table = arguments.get("table", "").strip()
return await connection_server._handle_table_tools(name, connection, table)
elif name == "dbutils-explain-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_explain_query(connection, sql)
elif name == "dbutils-get-performance":
return await connection_server._handle_performance(connection)
elif name == "dbutils-analyze-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_analyze_query(connection, sql)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with unknown tool
with pytest.raises(ConfigurationError, match="Unknown tool: unknown-tool"):
await mock_handle_call_tool("unknown-tool", {"connection": "test_conn"})
@pytest.mark.asyncio
async def test_handle_call_tool_list_tables(self, connection_server):
"""Test call_tool handler with dbutils-list-tables tool"""
# Mock the _handle_list_tables method
expected_result = [types.TextContent(type="text", text="Table list")]
connection_server._handle_list_tables = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-list-tables":
return await connection_server._handle_list_tables(connection)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with list-tables tool
result = await mock_handle_call_tool("dbutils-list-tables", {"connection": "test_conn"})
assert result == expected_result
connection_server._handle_list_tables.assert_called_once_with("test_conn")
@pytest.mark.asyncio
async def test_handle_call_tool_run_query(self, connection_server):
"""Test call_tool handler with dbutils-run-query tool"""
# Mock the _handle_run_query method
expected_result = [types.TextContent(type="text", text="Query result")]
connection_server._handle_run_query = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-run-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_run_query(connection, sql)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with run-query tool
result = await mock_handle_call_tool("dbutils-run-query", {"connection": "test_conn", "sql": "SELECT 1"})
assert result == expected_result
connection_server._handle_run_query.assert_called_once_with("test_conn", "SELECT 1")
@pytest.mark.asyncio
async def test_handle_call_tool_run_query_exception(self, connection_server):
"""Test call_tool handler with dbutils-run-query tool when an exception occurs"""
# Mock the _handle_run_query method to raise an exception
connection_server._handle_run_query = AsyncMock(side_effect=ValueError("Test exception"))
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-run-query":
sql = arguments.get("sql", "").strip()
try:
return await connection_server._handle_run_query(connection, sql)
except Exception as e:
# Log the error and re-raise
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in run_query: {str(e)}")
raise
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with run-query tool that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_call_tool("dbutils-run-query", {"connection": "test_conn", "sql": "SELECT 1"})
# Verify that _handle_run_query was called and send_log was called for the error
connection_server._handle_run_query.assert_called_once_with("test_conn", "SELECT 1")
connection_server.send_log.assert_called_once()
@pytest.mark.asyncio
async def test_handle_call_tool_table_tools(self, connection_server):
"""Test call_tool handler with table-related tools"""
# Mock the _handle_table_tools method
expected_result = [types.TextContent(type="text", text="Table info")]
connection_server._handle_table_tools = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name in ["dbutils-describe-table", "dbutils-get-ddl", "dbutils-list-indexes",
"dbutils-get-stats", "dbutils-list-constraints"]:
table = arguments.get("table", "").strip()
return await connection_server._handle_table_tools(name, connection, table)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with table tools
table_tools = [
"dbutils-describe-table",
"dbutils-get-ddl",
"dbutils-list-indexes",
"dbutils-get-stats",
"dbutils-list-constraints"
]
for tool in table_tools:
result = await mock_handle_call_tool(tool, {"connection": "test_conn", "table": "users"})
assert result == expected_result
connection_server._handle_table_tools.assert_called_with(tool, "test_conn", "users")
@pytest.mark.asyncio
async def test_handle_call_tool_table_tools_exception(self, connection_server):
"""Test call_tool handler with table-related tools when an exception occurs"""
# Mock the _handle_table_tools method to raise an exception
connection_server._handle_table_tools = AsyncMock(side_effect=ValueError("Test exception"))
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name in ["dbutils-describe-table", "dbutils-get-ddl", "dbutils-list-indexes",
"dbutils-get-stats", "dbutils-list-constraints"]:
table = arguments.get("table", "").strip()
try:
return await connection_server._handle_table_tools(name, connection, table)
except Exception as e:
# Log the error and re-raise
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in table_tools: {str(e)}")
raise
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with table tool that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_call_tool("dbutils-describe-table", {"connection": "test_conn", "table": "users"})
# Verify that _handle_table_tools was called and send_log was called for the error
connection_server._handle_table_tools.assert_called_once_with("dbutils-describe-table", "test_conn", "users")
connection_server.send_log.assert_called_once()
@pytest.mark.asyncio
async def test_handle_call_tool_explain_query(self, connection_server):
"""Test call_tool handler with dbutils-explain-query tool"""
# Mock the _handle_explain_query method
expected_result = [types.TextContent(type="text", text="Query explanation")]
connection_server._handle_explain_query = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-explain-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_explain_query(connection, sql)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with explain-query tool
result = await mock_handle_call_tool("dbutils-explain-query", {"connection": "test_conn", "sql": "SELECT 1"})
assert result == expected_result
connection_server._handle_explain_query.assert_called_once_with("test_conn", "SELECT 1")
@pytest.mark.asyncio
async def test_handle_call_tool_explain_query_exception(self, connection_server):
"""Test call_tool handler with dbutils-explain-query tool when an exception occurs"""
# Mock the _handle_explain_query method to raise an exception
connection_server._handle_explain_query = AsyncMock(side_effect=ValueError("Test exception"))
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-explain-query":
sql = arguments.get("sql", "").strip()
try:
return await connection_server._handle_explain_query(connection, sql)
except Exception as e:
# Log the error and re-raise
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in explain_query: {str(e)}")
raise
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with explain-query tool that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_call_tool("dbutils-explain-query", {"connection": "test_conn", "sql": "SELECT 1"})
# Verify that _handle_explain_query was called and send_log was called for the error
connection_server._handle_explain_query.assert_called_once_with("test_conn", "SELECT 1")
connection_server.send_log.assert_called_once()
@pytest.mark.asyncio
async def test_handle_call_tool_get_performance(self, connection_server):
"""Test call_tool handler with dbutils-get-performance tool"""
# Mock the _handle_performance method
expected_result = [types.TextContent(type="text", text="Performance info")]
connection_server._handle_performance = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-get-performance":
return await connection_server._handle_performance(connection)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with get-performance tool
result = await mock_handle_call_tool("dbutils-get-performance", {"connection": "test_conn"})
assert result == expected_result
connection_server._handle_performance.assert_called_once_with("test_conn")
@pytest.mark.asyncio
async def test_handle_call_tool_get_performance_exception(self, connection_server):
"""Test call_tool handler with dbutils-get-performance tool when an exception occurs"""
# Mock the _handle_performance method to raise an exception
connection_server._handle_performance = AsyncMock(side_effect=ValueError("Test exception"))
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-get-performance":
try:
return await connection_server._handle_performance(connection)
except Exception as e:
# Log the error and re-raise
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in get_performance: {str(e)}")
raise
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with get-performance tool that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_call_tool("dbutils-get-performance", {"connection": "test_conn"})
# Verify that _handle_performance was called and send_log was called for the error
connection_server._handle_performance.assert_called_once_with("test_conn")
connection_server.send_log.assert_called_once()
@pytest.mark.asyncio
async def test_handle_call_tool_analyze_query(self, connection_server):
"""Test call_tool handler with dbutils-analyze-query tool"""
# Mock the _handle_analyze_query method
expected_result = [types.TextContent(type="text", text="Query analysis")]
connection_server._handle_analyze_query = AsyncMock(return_value=expected_result)
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-analyze-query":
sql = arguments.get("sql", "").strip()
return await connection_server._handle_analyze_query(connection, sql)
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with analyze-query tool
result = await mock_handle_call_tool("dbutils-analyze-query", {"connection": "test_conn", "sql": "SELECT 1"})
assert result == expected_result
connection_server._handle_analyze_query.assert_called_once_with("test_conn", "SELECT 1")
@pytest.mark.asyncio
async def test_handle_call_tool_analyze_query_exception(self, connection_server):
"""Test call_tool handler with dbutils-analyze-query tool when an exception occurs"""
# Mock the _handle_analyze_query method to raise an exception
connection_server._handle_analyze_query = AsyncMock(side_effect=ValueError("Test exception"))
# Create a mock call_tool handler function
async def mock_handle_call_tool(name, arguments):
if "connection" not in arguments:
raise ConfigurationError(CONNECTION_NAME_REQUIRED_ERROR)
connection = arguments["connection"]
if name == "dbutils-analyze-query":
sql = arguments.get("sql", "").strip()
try:
return await connection_server._handle_analyze_query(connection, sql)
except Exception as e:
# Log the error and re-raise
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in analyze_query: {str(e)}")
raise
else:
raise ConfigurationError(f"Unknown tool: {name}")
# Test with analyze-query tool that raises an exception
with pytest.raises(ValueError, match="Test exception"):
await mock_handle_call_tool("dbutils-analyze-query", {"connection": "test_conn", "sql": "SELECT 1"})
# Verify that _handle_analyze_query was called and send_log was called for the error
connection_server._handle_analyze_query.assert_called_once_with("test_conn", "SELECT 1")
connection_server.send_log.assert_called_once()
def test_setup_handlers(self, connection_server):
"""Test the _setup_handlers method sets up all handlers correctly"""
# Mock the server decorators
mock_list_resources = MagicMock()
mock_list_resources.return_value = lambda f: f # Return the function unchanged
mock_read_resource = MagicMock()
mock_read_resource.return_value = lambda f: f
mock_list_tools = MagicMock()
mock_list_tools.return_value = lambda f: f
mock_call_tool = MagicMock()
mock_call_tool.return_value = lambda f: f
# Store original decorators
original_list_resources = connection_server.server.list_resources
original_read_resource = connection_server.server.read_resource
original_list_tools = connection_server.server.list_tools
original_call_tool = connection_server.server.call_tool
# Replace with mocks
connection_server.server.list_resources = mock_list_resources
connection_server.server.read_resource = mock_read_resource
connection_server.server.list_tools = mock_list_tools
connection_server.server.call_tool = mock_call_tool
try:
# Call the method
connection_server._setup_handlers()
# Verify all decorators were called
assert mock_list_resources.called, "list_resources decorator should have been called"
assert mock_read_resource.called, "read_resource decorator should have been called"
assert mock_list_tools.called, "list_tools decorator should have been called"
assert mock_call_tool.called, "call_tool decorator should have been called"
# Verify handle_list_resources function
if mock_list_resources.call_args and mock_list_resources.call_args.args:
handler = mock_list_resources.call_args.args[0]
assert callable(handler)
assert handler.__name__ == "handle_list_resources"
# Verify handle_read_resource function
if mock_read_resource.call_args and mock_read_resource.call_args.args:
handler = mock_read_resource.call_args.args[0]
assert callable(handler)
assert handler.__name__ == "handle_read_resource"
# Verify handle_list_tools function
if mock_list_tools.call_args and mock_list_tools.call_args.args:
handler = mock_list_tools.call_args.args[0]
assert callable(handler)
assert handler.__name__ == "handle_list_tools"
# Verify handle_call_tool function
if mock_call_tool.call_args and mock_call_tool.call_args.args:
handler = mock_call_tool.call_args.args[0]
assert callable(handler)
assert handler.__name__ == "handle_call_tool"
finally:
# Restore original decorators
connection_server.server.list_resources = original_list_resources
connection_server.server.read_resource = original_read_resource
connection_server.server.list_tools = original_list_tools
connection_server.server.call_tool = original_call_tool
@pytest.mark.asyncio
async def test_setup_handlers_list_resources_exception(self, connection_server):
"""Test the exception handling in handle_list_resources function from _setup_handlers"""
# Mock the decorators to capture the decorated functions
original_list_resources = connection_server.server.list_resources
captured_list_resources = None
def mock_list_resources_decorator():
def wrapper(func):
nonlocal captured_list_resources
captured_list_resources = func
return func
return wrapper
# Replace decorators with mocks
connection_server.server.list_resources = mock_list_resources_decorator
try:
# Call setup_handlers to register handlers
connection_server._setup_handlers()
# Restore the decorators
connection_server.server.list_resources = original_list_resources
# Validate we captured the handlers
assert captured_list_resources is not None
# Use the original self reference from the captured function
# Store the original self object
original_self = connection_server
# Prepare mocks
mock_handler = AsyncMock()
mock_handler.get_tables = AsyncMock(side_effect=ValueError("DB error"))
# Define a custom async context manager for testing
@asynccontextmanager
async def mock_get_handler(connection_name):
try:
yield mock_handler
finally:
pass
# Replace the get_handler method
original_get_handler = original_self.get_handler
original_self.get_handler = mock_get_handler
try:
# The function should raise the exception
with pytest.raises(ValueError, match="DB error"):
await captured_list_resources({"connection": "test_conn"})
# Verify mock_handler's get_tables was called
mock_handler.get_tables.assert_called_once()
finally:
# Restore the original get_handler
original_self.get_handler = original_get_handler
finally:
# Always restore the original decorator
connection_server.server.list_resources = original_list_resources
@pytest.mark.asyncio
async def test_setup_handlers_read_resource_exception(self, connection_server):
"""Test the exception handling in handle_read_resource function from _setup_handlers"""
# Mock the decorators to capture the decorated functions
original_read_resource = connection_server.server.read_resource
captured_read_resource = None
def mock_read_resource_decorator():
def wrapper(func):
nonlocal captured_read_resource
captured_read_resource = func
return func
return wrapper
# Replace decorators with mocks
connection_server.server.read_resource = mock_read_resource_decorator
try:
# Call setup_handlers to register handlers
connection_server._setup_handlers()
# Restore the decorators
connection_server.server.read_resource = original_read_resource
# Validate we captured the handlers
assert captured_read_resource is not None
# Use the original self reference from the captured function
# Store the original self object
original_self = connection_server
# Prepare mocks
mock_handler = AsyncMock()
mock_handler.get_schema = AsyncMock(side_effect=ValueError("Schema error"))
# Define a custom async context manager for testing
@asynccontextmanager
async def mock_get_handler(connection_name):
try:
yield mock_handler
finally:
pass
# Replace the get_handler method
original_get_handler = original_self.get_handler
original_self.get_handler = mock_get_handler
try:
# The function should raise the exception
with pytest.raises(ValueError, match="Schema error"):
await captured_read_resource("mock://table1/schema", {"connection": "test_conn"})
# Verify mock_handler's get_schema was called
mock_handler.get_schema.assert_called_once_with("table1")
finally:
# Restore the original get_handler
original_self.get_handler = original_get_handler
finally:
# Always restore the original decorator
connection_server.server.read_resource = original_read_resource
class TestConnectionServerRun:
@pytest.mark.asyncio
@patch("mcp.server.stdio.stdio_server")
async def test_run(self, mock_stdio_server, connection_server):
"""Test the run method"""
# Setup mocks
mock_stdin = AsyncMock()
mock_stdout = AsyncMock()
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = [mock_stdin, mock_stdout]
mock_stdio_server.return_value = mock_context_manager
# Mock the server.run method to avoid validation errors
connection_server.server.run = AsyncMock()
# Call the run method
await connection_server.run()
# Verify the server.run method was called
mock_stdio_server.assert_called_once()
mock_context_manager.__aenter__.assert_called_once()
assert connection_server.server.run.called
@pytest.mark.asyncio
@patch("mcp.server.stdio.stdio_server")
async def test_run_with_exception(self, mock_stdio_server, connection_server):
"""Test the run method when an exception occurs"""
# Setup mocks
mock_stdin = AsyncMock()
mock_stdout = AsyncMock()
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__.return_value = [mock_stdin, mock_stdout]
mock_stdio_server.return_value = mock_context_manager
# Create a patched version of run that catches exceptions the same way the actual run method would
original_run = connection_server.run
async def patched_run():
try:
await original_run()
except Exception as e:
connection_server.send_log(LOG_LEVEL_ERROR, f"Error in run: {str(e)}")
raise
# Replace run with our patched version
connection_server.run = patched_run
# Mock the server.run method to raise an exception
connection_server.server.run = AsyncMock(side_effect=ValueError("Test exception"))
# Call the run method and expect an exception
with pytest.raises(ValueError, match="Test exception"):
await connection_server.run()
# Verify the server.run method was called and the exception was logged
mock_stdio_server.assert_called_once()
mock_context_manager.__aenter__.assert_called_once()
connection_server.server.run.assert_called_once()
connection_server.send_log.assert_called_with(LOG_LEVEL_ERROR, "Error in run: Test exception")