mcp-dbutils
by donghao1393
- tests
- integration
"""Integration tests for prompts functionality"""
import asyncio
import tempfile
import anyio
import mcp.types as types
import pytest
import yaml
from mcp import ClientSession
from mcp_dbutils.base import ConnectionServer
from mcp_dbutils.log import create_logger
# 创建测试用的 logger
logger = create_logger("test-prompts", True) # debug=True 以显示所有日志
@pytest.mark.asyncio
async def test_prompts_capability(sqlite_db, mcp_config):
"""Test that prompts capability is properly set"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml') as tmp:
yaml.dump(mcp_config, tmp)
tmp.flush()
server = ConnectionServer(config_path=tmp.name)
# Get initialization options and verify prompts capability
init_options = server.server.create_initialization_options()
assert init_options.capabilities is not None
assert init_options.capabilities.prompts is not None
@pytest.mark.asyncio
async def test_list_prompts(sqlite_db, mcp_config):
"""Test that list_prompts returns an empty list"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml') as tmp:
yaml.dump(mcp_config, tmp)
tmp.flush()
server = ConnectionServer(config_path=tmp.name)
# Create bidirectional streams with proper types
# Client -> Server stream (server receives messages and exceptions)
client_to_server_send, client_to_server_recv = anyio.create_memory_object_stream[types.JSONRPCMessage | Exception](10)
# Server -> Client stream (client receives messages)
server_to_client_send, server_to_client_recv = anyio.create_memory_object_stream[types.JSONRPCMessage](10)
# Start server in background task with proper stream connections
server_task = asyncio.create_task(
server.server.run(
client_to_server_recv, # Server reads from client
server_to_client_send, # Server writes to client
server.server.create_initialization_options(),
raise_exceptions=True # 让错误更容易调试
)
)
try:
try:
# Initialize client with proper stream connections
logger("debug", "Creating client session")
client = ClientSession(
server_to_client_recv, # Client reads from server
client_to_server_send # Client writes to server
)
async with client:
# Initialize the session
logger("debug", "Initializing client session")
init_response = await client.initialize()
logger("debug", f"Client session initialized with response: {init_response}")
# Test prompts list with timeout
logger("debug", "Sending prompts/list request")
try:
# 使用 SDK 的 list_prompts 方法
response = await asyncio.wait_for(client.list_prompts(), timeout=3.0)
logger("debug", f"Got response: {response}")
# 验证响应
assert isinstance(response.prompts, list)
assert len(response.prompts) == 0
logger("debug", "Test completed successfully")
except asyncio.TimeoutError:
logger("error", "Request timed out after 3 seconds")
raise
except Exception as e:
logger("error", f"Request failed: {str(e)}")
logger("error", f"Request error type: {type(e)}")
raise
except Exception as e:
logger("error", f"Test failed with error: {str(e)}")
logger("error", f"Error type: {type(e)}")
raise
except asyncio.TimeoutError:
logger("error", "Test timed out")
raise RuntimeError("Test timed out after 3 seconds")
finally:
# Clean up
logger("debug", "Starting cleanup")
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
logger("debug", "Server task cancelled")
except Exception as e:
logger("error", f"Error during server cleanup: {str(e)}")
# Close streams
try:
await client_to_server_send.aclose()
await client_to_server_recv.aclose()
await server_to_client_send.aclose()
await server_to_client_recv.aclose()
logger("debug", "All streams closed")
except Exception as e:
logger("error", f"Error during stream cleanup: {str(e)}")