"""Helper utilities for calling the MCP server from the DPS Coach UI."""
from __future__ import annotations
import asyncio
import contextlib
import json
import os
import tempfile
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import CallToolResult
from .constants import MCP_SERVER_MODULE, SAFETY_LIMIT_RUNS
StatusCallback = Optional[Callable[[str], None]]
class MCPClientError(RuntimeError):
"""Raised when the MCP server interaction fails."""
class MCPAnalyzerClient:
"""Synchronous-friendly wrapper around the MCP stdio client."""
def __init__(
self,
python_executable: Path | None = None,
cwd: Path | None = None,
*,
inprocess: bool = False,
) -> None:
project_root = Path(__file__).resolve().parents[1]
self._python_executable = str(python_executable or sys.executable)
self._cwd = str(cwd or project_root)
self._inprocess = inprocess
def analyze(
self,
log_dir: str | Path | None,
limit_runs: Optional[int],
*,
status_callback: StatusCallback = None,
) -> Dict[str, Any]:
arguments: Dict[str, Any] = {}
if log_dir:
arguments["log_dir"] = str(log_dir)
effective_limit = int(limit_runs) if limit_runs is not None else SAFETY_LIMIT_RUNS
arguments["limit_runs"] = effective_limit
return self.call_tool("analyze_dps_logs", arguments or None, status_callback=status_callback)
def get_events_schema(self) -> Dict[str, Any]:
return self.call_tool("get_events_schema")
def query(self, sql: str) -> Dict[str, Any]:
return self.call_tool("query_dps", {"sql": sql})
def get_analysis_packet(
self,
*,
run_id: str | int | None = "last",
last_n_runs: int = 10,
top_k_skills: int = 10,
bucket_seconds: int = 5,
status_callback: StatusCallback = None,
) -> Dict[str, Any]:
args = {
"run_id": run_id,
"last_n_runs": last_n_runs,
"top_k_skills": top_k_skills,
"bucket_seconds": bucket_seconds,
}
return self.call_tool("get_analysis_packet", args, status_callback=status_callback)
def call_tool(
self,
tool_name: str,
arguments: Optional[Dict[str, Any]] = None,
*,
status_callback: StatusCallback = None,
) -> Dict[str, Any]:
if self._inprocess:
try:
return self._call_tool_inprocess(tool_name, arguments, status_callback=status_callback)
except MCPClientError:
raise
except Exception as exc: # pragma: no cover - surfaced to UI
raise MCPClientError(str(exc)) from exc
try:
return asyncio.run(self._call_tool(tool_name, arguments, status_callback=status_callback))
except MCPClientError:
raise
except Exception as exc: # pragma: no cover - surfaced to UI
raise MCPClientError(str(exc)) from exc
async def _call_tool(
self,
tool_name: str,
arguments: Optional[Dict[str, Any]] = None,
*,
status_callback: StatusCallback = None,
) -> Dict[str, Any]:
stderr_file = tempfile.NamedTemporaryFile(mode="w+b", delete=False)
stderr_path = stderr_file.name
progress_reader = open(stderr_path, "r", encoding="utf-8", errors="replace")
stderr_chunks: List[str] = []
stderr_state = {"pos": 0, "remainder": ""}
async def _tail_stderr() -> None:
while True:
self._drain_stderr(progress_reader, stderr_state, stderr_chunks, status_callback)
await asyncio.sleep(0.2)
tail_task = asyncio.create_task(_tail_stderr())
server = StdioServerParameters(
command=self._python_executable,
args=["-m", MCP_SERVER_MODULE],
cwd=self._cwd,
)
self._notify_status(status_callback, "Starting MCP server…")
try:
async with stdio_client(server, errlog=stderr_file) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
self._notify_status(status_callback, "Initializing session…")
try:
await asyncio.wait_for(session.initialize(), timeout=10.0)
except asyncio.TimeoutError as exc:
self._drain_stderr(progress_reader, stderr_state, stderr_chunks, status_callback)
raise MCPClientError(
self._format_server_error(
"Timed out while initializing MCP session.",
self._build_stderr_text(stderr_chunks, stderr_state["remainder"]),
)
) from exc
self._notify_status(status_callback, f"Calling {tool_name}…")
result = await session.call_tool(tool_name, arguments or None)
except MCPClientError:
raise
except Exception as exc:
self._drain_stderr(progress_reader, stderr_state, stderr_chunks, status_callback)
raise MCPClientError(
self._format_server_error(
f"MCP call failed: {exc}",
self._build_stderr_text(stderr_chunks, stderr_state["remainder"]),
)
) from exc
finally:
tail_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await tail_task
self._drain_stderr(progress_reader, stderr_state, stderr_chunks, status_callback)
progress_reader.close()
stderr_file.close()
with contextlib.suppress(OSError):
os.unlink(stderr_path)
return self._extract_payload(result)
def _call_tool_inprocess(
self,
tool_name: str,
arguments: Optional[Dict[str, Any]],
*,
status_callback: StatusCallback = None,
) -> Dict[str, Any]:
try:
import mcp_server # pylint: disable=import-error
except Exception as exc: # pragma: no cover - surfaced to UI
raise MCPClientError(f"Unable to import mcp_server: {exc}") from exc
args = arguments or {}
self._notify_status(status_callback, f"Calling {tool_name}…")
if tool_name == "analyze_dps_logs":
coro = mcp_server.analyze_dps_logs(**args)
elif tool_name == "get_events_schema":
coro = mcp_server.get_events_schema()
elif tool_name == "query_dps":
coro = mcp_server.query_dps(**args)
elif tool_name == "get_analysis_packet":
coro = mcp_server.get_analysis_packet(**args)
else: # pragma: no cover - defensive
raise MCPClientError(f"Unsupported tool {tool_name} in in-process mode")
try:
return asyncio.run(coro)
except Exception as exc: # pragma: no cover - surfaced to UI
raise MCPClientError(str(exc)) from exc
@staticmethod
def _drain_stderr(
reader,
state: Dict[str, Any],
chunk_store: List[str],
status_callback: StatusCallback,
) -> None:
reader.seek(state["pos"])
data = reader.read()
if not data:
return
state["pos"] = reader.tell()
chunk_store.append(data)
text = state["remainder"] + data
if text.endswith(("\n", "\r")):
state["remainder"] = ""
lines = text.splitlines()
else:
lines = text.splitlines()
if lines:
state["remainder"] = lines.pop()
else:
state["remainder"] = text
for line in lines:
stripped = line.strip()
if stripped:
MCPAnalyzerClient._notify_status(status_callback, stripped)
@staticmethod
def _build_stderr_text(chunks: List[str], remainder: str) -> str:
if not chunks and not remainder:
return ""
return "".join(chunks) + remainder
@staticmethod
def _format_server_error(message: str, stderr_text: str) -> str:
stderr_text = stderr_text.strip()
if stderr_text:
return f"{message}\n\nServer stderr:\n{stderr_text}"
return message
@staticmethod
def _notify_status(callback: StatusCallback, message: str) -> None:
if callback:
callback(message)
@staticmethod
def _extract_payload(result: CallToolResult) -> Dict[str, Any]:
# Prefer structured output (structuredContent or structured_content)
structured = getattr(result, "structuredContent", None)
if structured is None:
structured = getattr(result, "structured_content", None)
# Handle dict-like, pydantic, or JSON string
if structured is not None:
# Pydantic model: try model_dump()
if hasattr(structured, "model_dump"):
structured = structured.model_dump()
# JSON string
if isinstance(structured, str):
try:
structured = json.loads(structured)
except Exception:
pass
# Dict-like
if isinstance(structured, dict):
# Some tools wrap payload in 'result'
payload = structured.get("result")
if isinstance(payload, dict):
return payload
return structured
# Fallback: parse result.content as text
content = getattr(result, "content", None)
text_blocks = []
if content:
for block in content:
# TextContent object or dict
text = None
if isinstance(block, dict):
text = block.get("text")
else:
text = getattr(block, "text", None)
if text:
text_blocks.append(text)
joined = "\n".join(text_blocks).strip()
if joined:
try:
parsed = json.loads(joined)
if isinstance(parsed, dict):
return parsed
except Exception:
return {"text": joined}
# If both are empty, return empty payload instead of crashing
return {}