"""MCP stdio client wrapper used by the LangGraph agent."""
from __future__ import annotations
import json
import sys
from typing import Any
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
MCP_AVAILABLE = True
except ImportError: # pragma: no cover - local fallback for offline environments
ClientSession = None # type: ignore[assignment]
StdioServerParameters = None # type: ignore[assignment]
stdio_client = None # type: ignore[assignment]
MCP_AVAILABLE = False
def _try_json_parse(raw_text: str) -> Any:
try:
return json.loads(raw_text)
except json.JSONDecodeError:
return raw_text
def _decode_tool_response(response: Any) -> Any:
if isinstance(response, (dict, list, str, int, float, bool)) or response is None:
return response
is_error = getattr(response, "isError", False)
if is_error:
raise RuntimeError(f"MCP tool call failed: {response}")
content = getattr(response, "content", None)
if not content:
return response
parsed: list[Any] = []
for chunk in content:
text = getattr(chunk, "text", None)
if text is not None:
parsed.append(_try_json_parse(text))
continue
data = getattr(chunk, "data", None)
if data is not None:
parsed.append(data)
continue
parsed.append(str(chunk))
if len(parsed) == 1:
return parsed[0]
return parsed
class MCPProductClient:
"""Async wrapper that calls MCP product tools via stdio transport."""
def __init__(
self,
command: str | None = None,
args: list[str] | None = None,
) -> None:
self.command = command or sys.executable
self.args = args or ["-m", "mcp_server.server"]
async def _call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
if not MCP_AVAILABLE:
raise RuntimeError("mcp package is required to call MCP tools")
server_params = StdioServerParameters(command=self.command, args=self.args)
async with stdio_client(server_params) as (reader, writer):
async with ClientSession(reader, writer) as session:
await session.initialize()
result = await session.call_tool(tool_name, arguments=arguments)
return _decode_tool_response(result)
async def list_products(self, category: str | None = None) -> list[dict[str, Any]]:
arguments: dict[str, Any] = {}
if category:
arguments["category"] = category
result = await self._call_tool("list_products", arguments)
return list(result)
async def get_product(self, product_id: int) -> dict[str, Any]:
result = await self._call_tool("get_product", {"product_id": product_id})
return dict(result)
async def add_product(
self,
name: str,
price: float,
category: str,
in_stock: bool = True,
) -> dict[str, Any]:
result = await self._call_tool(
"add_product",
{
"name": name,
"price": price,
"category": category,
"in_stock": in_stock,
},
)
return dict(result)
async def get_statistics(self) -> dict[str, Any]:
result = await self._call_tool("get_statistics", {})
return dict(result)