Skip to main content
Glama

MCP Memory Service

sse.py12 kB
# Copyright 2024 Heinrich Krupp # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Server-Sent Events (SSE) implementation for real-time memory service updates. Provides real-time notifications for memory operations, search results, and system status changes. """ import asyncio import json import time import uuid from typing import Dict, List, Any, Optional, Set from datetime import datetime, timezone from dataclasses import dataclass, asdict from contextlib import asynccontextmanager from fastapi import Request from sse_starlette import EventSourceResponse import logging from ..config import SSE_HEARTBEAT_INTERVAL logger = logging.getLogger(__name__) @dataclass class SSEEvent: """Represents a Server-Sent Event.""" event_type: str data: Dict[str, Any] event_id: Optional[str] = None retry: Optional[int] = None timestamp: Optional[str] = None def __post_init__(self): """Set default values after initialization.""" if self.event_id is None: self.event_id = str(uuid.uuid4()) if self.timestamp is None: self.timestamp = datetime.now(timezone.utc).isoformat() class SSEManager: """Manages Server-Sent Event connections and broadcasting.""" def __init__(self, heartbeat_interval: int = SSE_HEARTBEAT_INTERVAL): self.connections: Dict[str, Dict[str, Any]] = {} self.heartbeat_interval = heartbeat_interval self._heartbeat_task: Optional[asyncio.Task] = None self._running = False async def start(self): """Start the SSE manager and heartbeat task.""" if self._running: return self._running = True self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) logger.info(f"SSE Manager started with {self.heartbeat_interval}s heartbeat interval") async def stop(self): """Stop the SSE manager and cleanup connections.""" self._running = False if self._heartbeat_task: self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass # Close all connections for connection_id in list(self.connections.keys()): await self._remove_connection(connection_id) logger.info("SSE Manager stopped") async def add_connection(self, connection_id: str, request: Request) -> asyncio.Queue: """Add a new SSE connection.""" queue = asyncio.Queue() self.connections[connection_id] = { 'queue': queue, 'request': request, 'connected_at': time.time(), 'last_heartbeat': time.time(), 'user_agent': request.headers.get('User-Agent', 'Unknown'), 'client_ip': request.client.host if request.client else 'Unknown' } logger.info(f"SSE connection added: {connection_id} from {self.connections[connection_id]['client_ip']}") # Send welcome event welcome_event = SSEEvent( event_type="connection_established", data={ "connection_id": connection_id, "message": "Connected to MCP Memory Service SSE stream", "heartbeat_interval": self.heartbeat_interval } ) await queue.put(welcome_event) return queue async def _remove_connection(self, connection_id: str): """Remove an SSE connection.""" if connection_id in self.connections: connection_info = self.connections[connection_id] duration = time.time() - connection_info['connected_at'] # Put a close event in the queue before removing try: close_event = SSEEvent( event_type="connection_closed", data={"connection_id": connection_id, "duration_seconds": duration} ) await connection_info['queue'].put(close_event) except: pass # Queue might be closed del self.connections[connection_id] logger.info(f"SSE connection removed: {connection_id} (duration: {duration:.1f}s)") async def broadcast_event(self, event: SSEEvent, connection_filter: Optional[Set[str]] = None): """Broadcast an event to all or filtered connections.""" if not self.connections: return target_connections = ( connection_filter.intersection(self.connections.keys()) if connection_filter else self.connections.keys() ) if not target_connections: return logger.debug(f"Broadcasting {event.event_type} to {len(target_connections)} connections") # Send to all target connections for connection_id in list(target_connections): # Copy to avoid modification during iteration if connection_id in self.connections: try: await self.connections[connection_id]['queue'].put(event) except Exception as e: logger.error(f"Failed to send event to {connection_id}: {e}") await self._remove_connection(connection_id) async def _heartbeat_loop(self): """Send periodic heartbeat events to maintain connections.""" while self._running: try: await asyncio.sleep(self.heartbeat_interval) if not self._running: break if self.connections: heartbeat_event = SSEEvent( event_type="heartbeat", data={ "timestamp": datetime.now(timezone.utc).isoformat(), "active_connections": len(self.connections), "server_status": "healthy" } ) # Update last heartbeat time for all connections current_time = time.time() for connection_info in self.connections.values(): connection_info['last_heartbeat'] = current_time await self.broadcast_event(heartbeat_event) logger.debug(f"Heartbeat sent to {len(self.connections)} connections") except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in heartbeat loop: {e}") def get_connection_stats(self) -> Dict[str, Any]: """Get statistics about current connections.""" if not self.connections: return { "total_connections": 0, "connections": [] } current_time = time.time() connection_details = [] for connection_id, info in self.connections.items(): connection_details.append({ "connection_id": connection_id, "client_ip": info['client_ip'], "user_agent": info['user_agent'], "connected_duration_seconds": current_time - info['connected_at'], "last_heartbeat_seconds_ago": current_time - info['last_heartbeat'] }) return { "total_connections": len(self.connections), "heartbeat_interval": self.heartbeat_interval, "connections": connection_details } # Global SSE manager instance sse_manager = SSEManager() async def create_event_stream(request: Request): """Create an SSE event stream for a client.""" connection_id = str(uuid.uuid4()) async def event_generator(): queue = await sse_manager.add_connection(connection_id, request) try: while True: try: # Wait for events with timeout to handle disconnections event = await asyncio.wait_for(queue.get(), timeout=60.0) # Format the SSE event event_data = { "id": event.event_id, "event": event.event_type, "data": json.dumps({ "timestamp": event.timestamp, **event.data }), } if event.retry: event_data["retry"] = event.retry yield event_data except asyncio.TimeoutError: # Send a ping to keep connection alive yield { "event": "ping", "data": json.dumps({ "timestamp": datetime.now(timezone.utc).isoformat(), "message": "Connection alive" }) } except asyncio.CancelledError: break except Exception as e: logger.error(f"Error in event stream for {connection_id}: {e}") finally: await sse_manager._remove_connection(connection_id) return EventSourceResponse(event_generator()) # Event creation helpers def create_memory_stored_event(memory_data: Dict[str, Any]) -> SSEEvent: """Create a memory_stored event.""" return SSEEvent( event_type="memory_stored", data={ "content_hash": memory_data.get("content_hash"), "content_preview": memory_data.get("content", "")[:100] + "..." if len(memory_data.get("content", "")) > 100 else memory_data.get("content", ""), "tags": memory_data.get("tags", []), "memory_type": memory_data.get("memory_type"), "message": "New memory stored successfully" } ) def create_memory_deleted_event(content_hash: str, success: bool = True) -> SSEEvent: """Create a memory_deleted event.""" return SSEEvent( event_type="memory_deleted", data={ "content_hash": content_hash, "success": success, "message": "Memory deleted successfully" if success else "Memory deletion failed" } ) def create_search_completed_event(query: str, search_type: str, results_count: int, processing_time_ms: float) -> SSEEvent: """Create a search_completed event.""" return SSEEvent( event_type="search_completed", data={ "query": query, "search_type": search_type, "results_count": results_count, "processing_time_ms": processing_time_ms, "message": f"Search completed: {results_count} results found" } ) def create_health_update_event(status: str, details: Dict[str, Any] = None) -> SSEEvent: """Create a health_update event.""" return SSEEvent( event_type="health_update", data={ "status": status, "details": details or {}, "message": f"System status: {status}" } )

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/doobidoo/mcp-memory-service'

If you have feedback or need assistance with the MCP directory API, please join our Discord server