"""
Tests for the MCP server module.
"""
import os
import tempfile
from pathlib import Path
import pytest
from llama_index.core import (
Document
)
from server.mcp_app import Retriever
@pytest.fixture
def temp_dir():
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
@pytest.fixture
def sample_documents():
"""Create sample documents for testing."""
return [
Document(text="This is a test document. It contains multiple sentences. Each sentence is important."),
Document(text="Another test document. With different content. For testing purposes."),
]
@pytest.fixture
def mock_index(temp_dir, sample_documents):
"""Create a mock index for testing."""
from indexer.index import Indexer
# Create test files
doc_dir = os.path.join(temp_dir, "docs")
os.makedirs(doc_dir)
for i, doc in enumerate(sample_documents):
with open(os.path.join(doc_dir, f"doc_{i}.txt"), "w") as f:
f.write(doc.text)
indexer = Indexer(
input_path=doc_dir,
output_path=os.path.join(temp_dir, "index")
)
# Build index
indexer.build_index()
return os.path.join(temp_dir, "index")
def test_retriever_model():
"""Test Retriever model validation."""
# Valid input
retriever = Retriever(question="What is it about ?")
assert retriever.question == "What is it about ?"
# Invalid input
with pytest.raises(ValueError):
Retriever() # Missing required field
@pytest.mark.asyncio
async def test_handle_list_tools():
"""Test handle_list_tools function."""
from server.mcp_app import handle_list_tools
tools = await handle_list_tools()
assert len(tools) == 1
assert tools[0].name == "_retrieve"
assert tools[0].description == os.getenv('MCP_DESCRIPTION', 'RAG retriever')
@pytest.mark.asyncio
async def test_handle_call_tool(mock_index):
"""Test handle_call_tool function."""
from server.mcp_app import handle_call_tool
# Valid tool call
result = await handle_call_tool(
"_retrieve",
{"question": "What is it about ?"}
)
assert len(result) == 1
assert result[0].type == "text"
# Invalid tool call
with pytest.raises(ValueError):
await handle_call_tool("unknown_tool", {})
# Missing arguments
with pytest.raises(ValueError):
await handle_call_tool("_retrieve", None)