"""
MCP Client with SSE Support for TrustySign Agent
Uses the official MCP protocol with Server-Sent Events for async communication.
This is the CORRECT way to use MCP servers - not REST-style HTTP calls.
"""
import logging
import asyncio
import aiohttp
import json
import time
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class MCPError(Exception):
"""MCP protocol error"""
def __init__(self, code: int, message: str, data: Any = None):
self.code = code
self.message = message
self.data = data
super().__init__(f"MCP Error {code}: {message}")
@dataclass
class MCPServer:
"""MCP Server configuration"""
name: str
url: str
api_key: Optional[str] = None
description: str = ""
class SSEMCPClient:
"""
MCP Client with SSE support.
Maintains persistent SSE connections to MCP servers and handles
async request/response communication via JSON-RPC 2.0.
This is the PROPER way to use MCP - not custom HTTP endpoints!
"""
def __init__(self, timeout: int = 30):
"""
Initialize SSE MCP Client
Args:
timeout: Request timeout in seconds
"""
self.servers: Dict[str, MCPServer] = {}
self.sse_connections: Dict[str, aiohttp.ClientSession] = {}
self.response_queues: Dict[str, asyncio.Queue] = {}
self.request_id = 0
self.timeout = timeout
self.sse_tasks: Dict[str, asyncio.Task] = {}
logger.info("SSEMCPClient initialized")
def add_server(self, server: MCPServer):
"""Add an MCP server to the client"""
self.servers[server.name] = server
logger.info(f"Added MCP server: {server.name} ({server.url})")
async def connect(self, server_name: str):
"""
Establish SSE connection to an MCP server
Args:
server_name: Name of the server to connect to
"""
if server_name not in self.servers:
raise MCPError(-32001, f"Server not found: {server_name}")
server = self.servers[server_name]
# Close existing connection if any
if server_name in self.sse_connections:
await self.disconnect(server_name)
try:
# Create session with auth headers
headers = {}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
session = aiohttp.ClientSession(headers=headers)
# Open SSE connection
sse_url = f"{server.url.rstrip('/')}/sse"
logger.info(f"Opening SSE connection to {sse_url}")
response = await session.get(sse_url)
if response.status != 200:
await session.close()
raise MCPError(-32002, f"SSE connection failed: HTTP {response.status}")
# Store session and start SSE listener
self.sse_connections[server_name] = session
self.response_queues[server_name] = asyncio.Queue()
# Start SSE event listener task
task = asyncio.create_task(self._sse_listener(server_name, response))
self.sse_tasks[server_name] = task
logger.info(f"β
SSE connection established to {server_name}")
except Exception as e:
logger.error(f"Failed to connect to {server_name}: {e}")
if server_name in self.sse_connections:
await self.sse_connections[server_name].close()
del self.sse_connections[server_name]
raise MCPError(-32002, f"Connection failed: {str(e)}")
async def _sse_listener(self, server_name: str, response: aiohttp.ClientResponse):
"""
Listen to SSE events from MCP server
Args:
server_name: Server name
response: SSE response stream
"""
try:
event_type = None
async for line in response.content:
line_str = line.decode('utf-8').strip()
if not line_str or line_str.startswith(':'):
continue # Skip empty lines and comments
# Handle SSE event type
if line_str.startswith('event: '):
event_type = line_str[7:]
continue
# Handle SSE data
if line_str.startswith('data: '):
data_str = line_str[6:] # Remove 'data: ' prefix
# Skip non-JSON data (like endpoint URLs)
if not data_str or not data_str.startswith('{'):
logger.debug(f"Skipping non-JSON SSE data from {server_name}: {data_str[:50]}")
event_type = None
continue
try:
data = json.loads(data_str)
# Only process 'message' events (skip 'endpoint' events)
if event_type == 'message' or event_type is None:
# Put response in queue for matching request
await self.response_queues[server_name].put(data)
logger.debug(f"Received SSE message from {server_name}: {data.get('id')}")
else:
logger.debug(f"Skipping SSE event type '{event_type}' from {server_name}")
event_type = None
except json.JSONDecodeError as e:
logger.error(f"Failed to parse SSE data: {e}")
event_type = None
except asyncio.CancelledError:
logger.info(f"SSE listener cancelled for {server_name}")
except Exception as e:
logger.error(f"SSE listener error for {server_name}: {e}")
import traceback
logger.error(f"Full traceback:\n{traceback.format_exc()}")
finally:
logger.info(f"SSE listener stopped for {server_name}")
async def disconnect(self, server_name: str):
"""
Close SSE connection to a server
Args:
server_name: Server name
"""
# Cancel SSE listener task
if server_name in self.sse_tasks:
self.sse_tasks[server_name].cancel()
try:
await self.sse_tasks[server_name]
except asyncio.CancelledError:
pass
del self.sse_tasks[server_name]
# Close session
if server_name in self.sse_connections:
await self.sse_connections[server_name].close()
del self.sse_connections[server_name]
logger.info(f"Disconnected from {server_name}")
# Clear queue
if server_name in self.response_queues:
del self.response_queues[server_name]
async def list_tools(self, server_name: str) -> List[Dict[str, Any]]:
"""
List tools available from an MCP server
Args:
server_name: Name of the MCP server
Returns:
List of tool definitions
Raises:
MCPError: If call fails
"""
if server_name not in self.servers:
raise MCPError(-32001, f"Server not found: {server_name}")
# Ensure SSE connection is established (reconnect if needed)
if server_name not in self.sse_connections:
logger.info(f"π Reconnecting to {server_name} (connection lost)")
await self.connect(server_name)
# Check if SSE task is still running, reconnect if not
if server_name in self.sse_tasks:
task = self.sse_tasks[server_name]
if task.done() or task.cancelled():
logger.warning(f"π SSE task for {server_name} is not running, reconnecting...")
await self.connect(server_name)
server = self.servers[server_name]
# Generate request ID
self.request_id += 1
request_id = self.request_id
# Build JSON-RPC 2.0 request
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "tools/list",
"params": {}
}
try:
# Send request via POST /mcp/message
message_url = f"{server.url.rstrip('/')}/message"
headers = {}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
session = self.sse_connections[server_name]
logger.info(f"Listing tools from {server_name} (request_id={request_id})")
async with session.post(message_url, json=request, headers=headers) as response:
if response.status >= 400:
error_text = await response.text()
raise MCPError(-32003, f"Request failed: HTTP {response.status} - {error_text}")
# Wait for response via SSE
queue = self.response_queues[server_name]
# Wait for matching response (with timeout)
start_time = time.time()
while True:
elapsed = time.time() - start_time
timeout_remaining = self.timeout - elapsed
if timeout_remaining <= 0:
raise MCPError(-32000, f"Timeout waiting for response from {server_name}")
# Try to get response from queue with a small timeout
try:
response_data = await asyncio.wait_for(
queue.get(),
timeout=min(0.5, timeout_remaining) # Check every 0.5s or less
)
# Check if this is the matching response
if response_data.get("id") == request_id:
# Check for error
if "error" in response_data:
error = response_data["error"]
raise MCPError(
error.get("code", -32603),
error.get("message", "Unknown error"),
error.get("data")
)
# Return tools list
result = response_data.get("result", {})
return result.get("tools", [])
# Not our response, put it back
await queue.put(response_data)
except asyncio.TimeoutError:
# Continue waiting
continue
except MCPError:
raise
except Exception as e:
logger.error(f"Error listing tools from {server_name}: {e}")
raise MCPError(-32603, f"Internal error: {str(e)}")
async def call_tool(
self,
server_name: str,
tool_name: str,
arguments: Dict[str, Any]
) -> Any:
"""
Call an MCP tool via SSE connection
Args:
server_name: Name of the MCP server
tool_name: Name of the tool to call
arguments: Tool arguments
Returns:
Tool execution result
Raises:
MCPError: If call fails
"""
if server_name not in self.servers:
raise MCPError(-32001, f"Server not found: {server_name}")
# Ensure SSE connection is established (reconnect if needed)
if server_name not in self.sse_connections:
logger.info(f"π Reconnecting to {server_name} (connection lost)")
await self.connect(server_name)
# Check if SSE task is still running, reconnect if not
if server_name in self.sse_tasks:
task = self.sse_tasks[server_name]
if task.done() or task.cancelled():
logger.warning(f"π SSE task for {server_name} is not running, reconnecting...")
await self.connect(server_name)
server = self.servers[server_name]
# Generate request ID
self.request_id += 1
request_id = self.request_id
# Build JSON-RPC 2.0 request
request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
# Send request via POST /mcp/message
message_url = f"{server.url.rstrip('/')}/message"
headers = {}
if server.api_key:
headers["Authorization"] = f"Bearer {server.api_key}"
session = self.sse_connections[server_name]
logger.info(f"Calling {server_name}.{tool_name} (request_id={request_id})")
async with session.post(message_url, json=request, headers=headers) as response:
if response.status >= 400:
error_text = await response.text()
raise MCPError(-32003, f"Request failed: HTTP {response.status} - {error_text}")
# Wait for response via SSE
queue = self.response_queues[server_name]
# Wait for matching response (with timeout)
start_time = time.time()
while True:
elapsed = time.time() - start_time
timeout_remaining = self.timeout - elapsed
if timeout_remaining <= 0:
raise MCPError(-32000, f"Timeout waiting for response from {server_name}")
# Try to get response from queue with a small timeout
try:
response_data = await asyncio.wait_for(
queue.get(),
timeout=min(0.5, timeout_remaining) # Check every 0.5s or less
)
# Check if this is the matching response
if response_data.get("id") == request_id:
# Check for error
if "error" in response_data:
error = response_data["error"]
raise MCPError(
error.get("code", -32603),
error.get("message", "Unknown error"),
error.get("data")
)
# Return result
result = response_data.get("result")
logger.info(f"β
{server_name}.{tool_name} completed")
return result
else:
# Not our response, put it back for another caller
await queue.put(response_data)
await asyncio.sleep(0.01) # Small delay to avoid busy loop
except asyncio.TimeoutError:
# Timeout on queue.get(), continue loop to check overall timeout
continue
except MCPError:
raise
except Exception as e:
logger.error(f"Error calling {server_name}.{tool_name}: {e}")
raise MCPError(-32603, f"Internal error: {str(e)}")
async def close(self):
"""Close all SSE connections"""
for server_name in list(self.sse_connections.keys()):
await self.disconnect(server_name)
logger.info("All MCP connections closed")