Deskaid
by ezyang
- codemcp
#!/usr/bin/env python3
import asyncio
import os
import subprocess
import sys
import tempfile
import unittest
from contextlib import asynccontextmanager
from typing import Any, List, Union
from expecttest import TestCase
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class MCPEndToEndTestCase(TestCase, unittest.IsolatedAsyncioTestCase):
"""Base class for end-to-end tests of codemcp using MCP client."""
async def asyncSetUp(self):
"""Async setup method to prepare the test environment."""
# Create a temporary directory for testing
self.temp_dir = tempfile.TemporaryDirectory()
self.testing_time = "1112911993" # Fixed timestamp for git
# Initialize environment variables for git
self.env = os.environ.copy()
# Set environment variables for reproducible git behavior
self.env.setdefault("GIT_TERMINAL_PROMPT", "0")
self.env.setdefault("EDITOR", ":")
self.env.setdefault("GIT_MERGE_AUTOEDIT", "no")
self.env.setdefault("LANG", "C")
self.env.setdefault("LC_ALL", "C")
self.env.setdefault("PAGER", "cat")
self.env.setdefault("TZ", "UTC")
self.env.setdefault("TERM", "dumb")
# For deterministic commit times
self.env.setdefault("GIT_AUTHOR_EMAIL", "author@example.com")
self.env.setdefault("GIT_AUTHOR_NAME", "A U Thor")
self.env.setdefault("GIT_COMMITTER_EMAIL", "committer@example.com")
self.env.setdefault("GIT_COMMITTER_NAME", "C O Mitter")
self.env.setdefault("GIT_COMMITTER_DATE", f"{self.testing_time} -0700")
self.env.setdefault("GIT_AUTHOR_DATE", f"{self.testing_time} -0700")
# Initialize a git repository in the temp directory
self.init_git_repo()
async def asyncTearDown(self):
"""Async teardown to clean up after the test."""
self.temp_dir.cleanup()
def init_git_repo(self):
"""Initialize a git repository for testing."""
subprocess.run(
["git", "init", "-b", "main"],
cwd=self.temp_dir.name,
env=self.env,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Create initial commit
readme_path = os.path.join(self.temp_dir.name, "README.md")
with open(readme_path, "w") as f:
f.write("# Test Repository\n")
# Create a codemcp.toml file in the repo root (required for permission checks)
codemcp_toml_path = os.path.join(self.temp_dir.name, "codemcp.toml")
with open(codemcp_toml_path, "w") as f:
f.write("")
subprocess.run(
["git", "add", "README.md", "codemcp.toml"],
cwd=self.temp_dir.name,
env=self.env,
check=True,
)
subprocess.run(
["git", "commit", "-m", "Initial commit"],
cwd=self.temp_dir.name,
env=self.env,
check=True,
)
def normalize_path(self, text):
"""Normalize temporary directory paths in output text."""
if self.temp_dir and self.temp_dir.name:
# Handle CallToolResult objects by converting to string first
if hasattr(text, "content"):
# This is a CallToolResult object, extract the content
text = text.content
# Handle lists of TextContent objects
if isinstance(text, list) and len(text) > 0 and hasattr(text[0], "text"):
# For list of TextContent objects, we'll preserve the list structure
# but normalize the path in each TextContent's text attribute
return text
# Replace the actual temp dir path with a fixed placeholder
if isinstance(text, str):
return text.replace(self.temp_dir.name, "/tmp/test_dir")
return text
def extract_text_from_result(self, result):
"""Extract text content from various result formats for assertions.
Args:
result: The result object (could be string, list of TextContent, etc.)
Returns:
str: The extracted text content
"""
if isinstance(result, list) and len(result) > 0 and hasattr(result[0], "text"):
return result[0].text
if isinstance(result, str):
return result
return str(result)
def extract_chat_id_from_text(self, text):
"""Extract chat_id from init_result_text.
Args:
text: The text output from InitProject tool
Returns:
str: The extracted chat_id
Raises:
AssertionError: If chat_id cannot be found in text
"""
import re
chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", text)
assert chat_id_match is not None, "Could not find chat ID in text"
return chat_id_match.group(1)
async def call_tool_assert_error(self, session, tool_name, tool_params):
"""Call a tool and assert that it fails (isError=True).
This is a helper method for the error path of tool calls, which:
1. Calls the specified tool with the given parameters
2. Asserts that the result is an error
3. Returns the extracted text result
Args:
session: The client session to use
tool_name: The name of the tool to call
tool_params: Dictionary of parameters to pass to the tool
Returns:
str: The extracted text content from the result
Raises:
AssertionError: If the tool call does not result in an error
"""
result = await session.call_tool(tool_name, tool_params)
# Check that the result is an error
self.assertTrue(
getattr(result, "isError", False),
f"Tool call to {tool_name} succeeded, expected to fail",
)
# Return the normalized, extracted text result
normalized_result = self.normalize_path(result)
return self.extract_text_from_result(normalized_result)
async def call_tool_assert_success(self, session, tool_name, tool_params):
"""Call a tool and assert that it succeeds (isError=False).
This is a helper method for the happy path of tool calls, which:
1. Calls the specified tool with the given parameters
2. Asserts that the result is not an error
3. Returns the extracted text result
Args:
session: The client session to use
tool_name: The name of the tool to call
tool_params: Dictionary of parameters to pass to the tool
Returns:
str: The extracted text content from the result
Raises:
AssertionError: If the tool call results in an error
"""
result = await session.call_tool(tool_name, tool_params)
# Check that the result is not an error
self.assertFalse(
getattr(result, "isError", False),
f"Tool call to {tool_name} failed with error: {self.extract_text_from_result(result)}",
)
# Return the normalized, extracted text result
normalized_result = self.normalize_path(result)
return self.extract_text_from_result(normalized_result)
async def get_chat_id(self, session):
"""Initialize project and get chat_id.
Args:
session: The client session to use
Returns:
str: The chat_id
"""
# First initialize project to get chat_id
init_result = await session.call_tool(
"codemcp",
{
"subtool": "InitProject",
"path": self.temp_dir.name,
"user_prompt": "Test initialization for get_chat_id",
"subject_line": "test: initialize for e2e testing",
"reuse_head_chat_id": False,
},
)
init_result_text = self.extract_text_from_result(init_result)
# Extract chat_id from the init result
import re
chat_id_match = re.search(r"chat ID: ([a-zA-Z0-9-]+)", init_result_text)
chat_id = chat_id_match.group(1)
assert chat_id is not None
return chat_id
@asynccontextmanager
async def _unwrap_exception_groups(self):
"""Context manager that unwraps ExceptionGroups with single exceptions.
Only unwraps if there's exactly one exception at each level.
"""
try:
yield
except ExceptionGroup as eg:
if len(eg.exceptions) == 1:
exc = eg.exceptions[0]
# Recursively unwrap if it's another ExceptionGroup with a single exception
while isinstance(exc, ExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]
raise exc from None
else:
# Multiple exceptions - don't unwrap
raise
@asynccontextmanager
async def create_client_session(self):
"""Create an MCP client session connected to codemcp server."""
# Set up server parameters for the codemcp MCP server
server_params = StdioServerParameters(
command=sys.executable, # Current Python executable
args=["-m", "codemcp"], # Module path to codemcp
env=self.env,
cwd=self.temp_dir.name, # Set the working directory to our test directory
)
async with self._unwrap_exception_groups():
async with stdio_client(server_params) as (read, write):
async with self._unwrap_exception_groups():
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
yield session
async def git_run(
self,
args: List[str],
check: bool = True,
capture_output: bool = False,
text: bool = False,
**kwargs: Any,
) -> Union[subprocess.CompletedProcess, str]:
"""Run git command asynchronously with appropriate temp_dir and env settings.
This helper method simplifies git subprocess calls in e2e tests by:
1. Automatically using the test's temp_dir as the working directory
2. Using the test's pre-configured env variables
3. Supporting async execution and various output capture options
Args:
args: List of git command arguments (without 'git' prefix)
check: If True, raises if the command returns a non-zero exit code
capture_output: If True, captures stdout and stderr
text: If True, decodes stdout and stderr using the preferred encoding
**kwargs: Additional keyword arguments to pass to subprocess.run
Returns:
If capture_output is False: subprocess.CompletedProcess instance
If capture_output is True and decode is True: The stdout content as string
Example:
# Run git add command
await self.git_run(["add", "file.txt"])
# Get commit log as string
log_output = await self.git_run(["log", "--oneline"], capture_output=True, text=True)
"""
# Always include 'git' as the command
cmd = ["git"] + args
# Set defaults for working directory and environment
kwargs.setdefault("cwd", self.temp_dir.name)
kwargs.setdefault("env", self.env)
# Capture output if requested
if capture_output:
kwargs.setdefault("stdout", subprocess.PIPE)
kwargs.setdefault("stderr", subprocess.PIPE)
# Run the command asynchronously
proc = await asyncio.create_subprocess_exec(
*cmd,
**kwargs,
)
stdout, stderr = await proc.communicate()
# Build a CompletedProcess-like result
result = subprocess.CompletedProcess(
args=cmd,
returncode=proc.returncode,
stdout=stdout,
stderr=stderr,
)
# Check for error if requested
if check and proc.returncode != 0:
stderr.decode() if stderr else "Unknown error"
cmd_str = " ".join(cmd)
raise subprocess.CalledProcessError(
proc.returncode, cmd_str, output=stdout, stderr=stderr
)
# Return the appropriate result type
if capture_output and text and stdout is not None:
return stdout.decode().strip()
return result