Skip to main content
Glama

MCP Platform

by jck411
websocket_server.py•23.2 kB
""" WebSocket Server for MCP Platform This module provides a thin communication layer between the frontend and chat service. It handles WebSocket connections and message routing only. """ import contextlib import json import logging import uuid from typing import TYPE_CHECKING, Any import uvicorn from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, ValidationError from src.chat import ChatOrchestrator from src.chat.models import ChatMessage from src.config import Configuration from src.history import ChatRepository from src.history.models import ChatEvent from src.history.repository import SavedSessionsRepository # TYPE_CHECKING imports to avoid circular imports if TYPE_CHECKING: from typing import Any as LLMClient from typing import Any as MCPClient else: # At runtime, these will be the actual types passed in LLMClient = object MCPClient = object logger = logging.getLogger(__name__) # Pydantic models for WebSocket message validation class ChatPayload(BaseModel): """Payload for chat messages with validation.""" text: str streaming: bool | None = None metadata: dict[str, Any] = Field(default_factory=dict) class WebSocketMessage(BaseModel): """WebSocket message structure with validation.""" request_id: str payload: ChatPayload message_type: str = "chat" class WebSocketResponse(BaseModel): """WebSocket response structure.""" request_id: str status: str # "processing", "chunk", "completed", "init", "error" chunk: dict[str, Any] = Field(default_factory=dict) class WebSocketServer: """ Pure WebSocket communication server. This class only handles: - WebSocket connections - Message parsing and routing - Response streaming All business logic is delegated to ChatService. """ def __init__( self, clients: list["MCPClient"], llm_client: "LLMClient", config: dict[str, Any], repo: ChatRepository, configuration: Configuration, ): service_config = ChatOrchestrator.ChatOrchestratorConfig( clients=clients, llm_client=llm_client, config=config, repo=repo, configuration=configuration, ) self.chat_service = ChatOrchestrator(service_config) self.repo = repo self.config = config self.configuration = configuration self.app = self._create_app() self.active_connections: list[WebSocket] = [] # Store conversation id per socket self.conversation_ids: dict[WebSocket, str] = {} def _create_app(self) -> FastAPI: """Create and configure FastAPI app.""" app = FastAPI(title="MCP WebSocket Chat Server") router = APIRouter() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.websocket("/ws/chat") async def websocket_endpoint(websocket: WebSocket): # type: ignore await self._handle_websocket_connection(websocket) @app.get("/") async def root(): # type: ignore return {"message": "MCP WebSocket Chat Server"} @app.get("/health") async def health(): # type: ignore return {"status": "healthy"} # Session management endpoints (AutoPersistRepo) @router.post("/sessions/{conversation_id}/save") async def save_session(conversation_id: str, name: str | None = None) -> dict[str, str]: # pyright: ignore[reportUnusedFunction] if isinstance(self.repo, SavedSessionsRepository): saved_id: str = await self.repo.save_session(conversation_id, name) return {"saved_id": saved_id} return {"error": "Saved sessions are disabled"} @router.get("/sessions") async def list_sessions() -> list[dict[str, Any]]: # type: ignore if isinstance(self.repo, SavedSessionsRepository): sessions: list[dict[str, Any]] = await self.repo.list_saved_sessions() return sessions return [] @router.get("/sessions/{saved_id}") async def load_session(saved_id: str) -> list[dict[str, Any]] | dict[str, str]: # type: ignore if isinstance(self.repo, SavedSessionsRepository): events: list[ChatEvent] = await self.repo.load_saved_session(saved_id) # Return as plain dicts for API consumers return [ev.model_dump() for ev in events] return {"error": "Saved sessions are disabled"} # Prevent static analyzers from marking route handlers as unused __keep_for_pyright = (save_session, list_sessions, load_session) del __keep_for_pyright app.include_router(router) return app async def _handle_websocket_connection(self, websocket: WebSocket): """Handle a WebSocket connection.""" await self._connect_websocket(websocket) try: while True: # Receive message from client data = await websocket.receive_text() message_data = json.loads(data) # Handle message based on action if message_data.get("action") == "chat": await self._handle_chat_message(websocket, message_data) elif message_data.get("action") == "clear_session": await self._handle_clear_session(websocket, message_data) else: # Unknown message format logger.warning(f"Unknown message format: {message_data}") await self._send_error_response( websocket, message_data.get("request_id", "unknown"), ("Unknown message format. Expected 'action': 'chat' or 'clear_session'"), ) except WebSocketDisconnect: logger.info("WebSocket disconnected") except Exception as e: logger.error(f"WebSocket error: {e}") with contextlib.suppress(Exception): await self._send_error_response( websocket, locals().get("message_data", {}).get("request_id", "unknown"), f"Server error: {e!s}", ) finally: self._disconnect_websocket(websocket) async def _handle_chat_message(self, websocket: WebSocket, message_data: dict[str, Any]): """Handle a chat message from the frontend with Pydantic validation.""" try: # Validate message structure using Pydantic ws_message = WebSocketMessage.model_validate(message_data) except ValidationError as e: await self._send_error_response( websocket, message_data.get("request_id", "unknown"), f"Invalid message format: {e}", ) return request_id = ws_message.request_id payload = ws_message.payload user_message = payload.text # Check streaming configuration - FAIL FAST approach service_config = self.config.get("chat", {}).get("service", {}) streaming_config = service_config.get("streaming", {}) if payload.streaming is not None: # Client explicitly set streaming preference - use it streaming = payload.streaming elif streaming_config.get("enabled") is not None: # Use configured default - must be explicitly set streaming = streaming_config["enabled"] else: # FAIL FAST: No streaming configuration found await self._send_error_response( websocket, request_id, "Streaming configuration missing. " "Set 'chat.service.streaming.enabled' in config.yaml " "or specify 'streaming: true/false' in payload.", ) return logger.info(f"Processing message with streaming={streaming}") logger.info(f"Received chat message: {user_message[:50]}...") try: # Send processing status response = WebSocketResponse( request_id=request_id, status="processing", chunk={"metadata": {"user_message": user_message}}, ) await websocket.send_text(response.model_dump_json()) # get or assign conversation_id conversation_id = self.conversation_ids.get(websocket) if not conversation_id: conversation_id = str(uuid.uuid4()) self.conversation_ids[websocket] = conversation_id if streaming: # Streaming mode: real-time chunks await self._handle_streaming_chat(websocket, request_id, conversation_id, user_message) else: # Non-streaming mode: single final assistant message await self._handle_non_streaming_chat(websocket, request_id, conversation_id, user_message) except Exception as e: logger.error(f"Error processing chat message: {e}") await self._send_error_response(websocket, request_id, str(e)) async def _handle_clear_session(self, websocket: WebSocket, message_data: dict[str, Any]): """Handle a clear session request from the frontend.""" request_id = message_data.get("request_id", str(uuid.uuid4())) try: # Get current conversation_id old_conversation_id = self.conversation_ids.get(websocket, "") # AutoPersist uses periodic full wipe logic # Returns True when a full wipe occurs full_wipe_occurred = await self.repo.handle_clear_session() if full_wipe_occurred: retention_conf = ( self.config.get("chat", {}).get("storage", {}).get("persistence", {}).get("retention", {}) ) clear_triggers = retention_conf.get("clear_triggers_before_full_wipe") logger.info( "Full history wipe occurred%s", (f" on {clear_triggers}th clear session" if isinstance(clear_triggers, int) else ""), ) # Only create new conversation_id if full wipe occurred # Otherwise keep the same conversation_id to maintain LLM memory if full_wipe_occurred: new_conversation_id = str(uuid.uuid4()) self.conversation_ids[websocket] = new_conversation_id logger.info( "Full wipe: Session cleared: %s -> %s", old_conversation_id, new_conversation_id, ) else: new_conversation_id = old_conversation_id # Keep same conversation counter = getattr(self.repo, "_clear_session_counter", "?") clear_triggers = getattr(self.repo, "clear_triggers_before_full_wipe", "N/A") logger.info(f"UI clear only: Session {old_conversation_id} (counter: {counter}/{clear_triggers})") # Send success response await websocket.send_text( json.dumps( { "request_id": request_id, "status": "completed", "chunk": { "type": "session_cleared", "metadata": { "new_conversation_id": new_conversation_id, "old_conversation_id": old_conversation_id, "full_wipe_occurred": full_wipe_occurred, }, }, } ) ) except Exception as e: logger.error(f"Error clearing session: {e}") await self._send_error_response(websocket, request_id, f"Clear session failed: {e}") async def _send_error_response(self, websocket: WebSocket, request_id: str, error_message: str): """Send error response using Pydantic model.""" response = WebSocketResponse(request_id=request_id, status="error", chunk={"error": error_message}) await websocket.send_text(response.model_dump_json()) async def _handle_streaming_chat( self, websocket: WebSocket, request_id: str, conversation_id: str, user_message: str, ): """Handle streaming chat response.""" # Stream response chunks using new signature (no external history needed) async for chat_message in self.chat_service.process_message(conversation_id, user_message, request_id): await self._send_chat_response(websocket, request_id, chat_message) # Send completion signal await websocket.send_text( json.dumps( { "request_id": request_id, "status": "completed", "chunk": {}, } ) ) async def _handle_non_streaming_chat( self, websocket: WebSocket, request_id: str, conversation_id: str, user_message: str, ): """Handle non-streaming chat response.""" assistant_ev = await self.chat_service.chat_once( conversation_id=conversation_id, user_msg=user_message, request_id=request_id, ) await websocket.send_text( json.dumps( { "request_id": request_id, "status": "chunk", "chunk": { "type": "text", "data": assistant_ev.content, "metadata": {}, }, } ) ) await websocket.send_text( json.dumps( { "request_id": request_id, "status": "completed", "chunk": {}, } ) ) async def _send_chat_response(self, websocket: WebSocket, request_id: str, chat_message: ChatMessage): """Send a chat response to the frontend.""" logger.info(f"Sending WebSocket message: type={chat_message.type}, content={chat_message.content[:50]}...") # Convert chat service message to frontend format if chat_message.type == "text": # Only send text messages that aren't tool results if not chat_message.metadata.get("tool_result"): await websocket.send_text( json.dumps( { "request_id": request_id, "status": "chunk", "chunk": { "type": "text", "data": chat_message.content, "metadata": chat_message.metadata, }, } ) ) elif chat_message.type == "tool_execution": await websocket.send_text( json.dumps( { "request_id": request_id, "status": "processing", "chunk": { "type": "tool_execution", "data": chat_message.content, "metadata": chat_message.metadata, }, } ) ) elif chat_message.type == "error": await websocket.send_text( json.dumps( { "request_id": request_id, "status": "error", "chunk": { "error": chat_message.content, "metadata": chat_message.metadata, }, } ) ) async def _load_previous_conversation(self, websocket: WebSocket): """Load and send previous conversation history to the frontend.""" try: # Get all conversations from the repository all_conversations = await self.repo.list_conversations() if not all_conversations: # No previous conversations, start fresh conversation_id = str(uuid.uuid4()) self.conversation_ids[websocket] = conversation_id logger.info(f"Starting new conversation: {conversation_id}") return # Get the most recent conversation (first one in the list) # list_conversations() returns newest -> oldest by last activity recent_conversation_id = all_conversations[0] self.conversation_ids[websocket] = recent_conversation_id # Load conversation history history = await self.repo.get_conversation_history(recent_conversation_id, limit=50) if not history: logger.info(f"Resuming empty conversation: {recent_conversation_id}") return logger.info(f"Loading {len(history)} messages from conversation: {recent_conversation_id}") # Send history to frontend as a special message history_messages: list[dict[str, Any]] = [] def _collapse_double(text: str) -> str: # If content is an exact duplication (XX), collapse to X if not text: return text n = len(text) if n % 2 == 0: half = n // 2 if text[:half] == text[half:]: return text[:half] return text for event in history: if event.type == "user_message": history_messages.append( { "role": "user", "content": str(event.content or ""), "timestamp": event.created_at.isoformat(), } ) elif event.type == "assistant_message": history_messages.append( { "role": "assistant", "content": _collapse_double(str(event.content or "")), "timestamp": event.created_at.isoformat(), } ) if history_messages: # Send history as a special initialization message await websocket.send_text( json.dumps( { "request_id": str(uuid.uuid4()), "status": "init", "chunk": { "type": "conversation_history", "data": history_messages, "metadata": { "conversation_id": recent_conversation_id, "message_count": len(history_messages), }, }, } ) ) logger.info(f"Sent conversation history: {len(history_messages)} messages") except Exception as e: logger.error(f"Error loading previous conversation: {e}") # Fall back to new conversation on error conversation_id = str(uuid.uuid4()) self.conversation_ids[websocket] = conversation_id logger.info(f"Fallback to new conversation: {conversation_id}") async def _connect_websocket(self, websocket: WebSocket): """Connect a WebSocket and load previous conversation history.""" try: logger.info(f"WebSocket connection attempt from {websocket.client}") await websocket.accept() self.active_connections.append(websocket) # Try to resume the most recent conversation await self._load_previous_conversation(websocket) logger.info(f"WebSocket connection established. Total connections: {len(self.active_connections)}") except Exception as e: logger.error(f"Failed to accept WebSocket connection: {e}") raise def _disconnect_websocket(self, websocket: WebSocket): """Disconnect a WebSocket.""" if websocket in self.active_connections: self.active_connections.remove(websocket) # Clean up conversation id if websocket in self.conversation_ids: del self.conversation_ids[websocket] logger.info(f"WebSocket connection closed. Total connections: {len(self.active_connections)}") async def start_server(self): """Start the WebSocket server with comprehensive cleanup.""" # Initialize chat service await self.chat_service.initialize() # Start server chat_config = self.config.get("chat", {}) websocket_config = chat_config.get("websocket", {}) host = websocket_config.get("host", "localhost") port = websocket_config.get("port", 8000) logger.info(f"Starting WebSocket server on {host}:{port}") server_config = uvicorn.Config(self.app, host=host, port=port, log_level="info") server = uvicorn.Server(server_config) try: await server.serve() except KeyboardInterrupt: logger.info("Received shutdown signal, cleaning up...") except Exception as e: logger.error(f"WebSocket server error: {e}") raise finally: logger.info("Shutting down WebSocket server and cleaning up resources...") try: await self.chat_service.cleanup() logger.info("Chat service cleanup completed") except Exception as e: logger.error(f"Error during chat service cleanup: {e}") # Don't re-raise cleanup errors to avoid masking the original exception async def run_websocket_server( clients: list["MCPClient"], llm_client: "LLMClient", config: dict[str, Any], repo: ChatRepository, configuration: Configuration, ) -> None: """ Run the WebSocket server. This function maintains the same interface as before but now uses the clean separation between communication and business logic. """ server = WebSocketServer(clients, llm_client, config, repo, configuration) await server.start_server()

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/jck411/MCP_BACKEND_OPENROUTER'

If you have feedback or need assistance with the MCP directory API, please join our Discord server