Skip to main content
Glama
session_middleware.py25.6 kB
"""Session management middleware for IPython shells with optimized persistence.""" import asyncio import logging import sqlite3 import time from pathlib import Path from typing import Optional, Dict, Any, Set import secrets import json import threading from contextlib import contextmanager import dill from IPython.core.completer import IPCompleter from IPython.core.interactiveshell import InteractiveShell from fastmcp.server.middleware import Middleware, MiddlewareContext logger = logging.getLogger("SherlogMCP.SessionMiddleware") SESSIONS_DIR = Path("/app/data/sessions") DB_FILE = SESSIONS_DIR / "sessions.db" SESSION_SHELLS: Dict[str, InteractiveShell] = {} SESSION_METADATA: Dict[str, Dict[str, Any]] = {} # Optimized persistence tracking DIRTY_SESSIONS: Set[str] = set() # Sessions that need saving LAST_SAVE_TIMES: Dict[str, float] = {} # Last save time per session SAVE_INTERVAL = 15 # Save dirty sessions every 15 seconds REQUEST_SAVE_THRESHOLD = 10 # Save after N requests per session SESSION_REQUEST_COUNTS: Dict[str, int] = {} # Background persistence worker _persistence_worker: Optional[asyncio.Task[None]] = None _shutdown_event = asyncio.Event() # SQLite connection pool _db_lock = threading.Lock() def generate_secure_session_id() -> str: """Generate a cryptographically secure session ID.""" return secrets.token_urlsafe(32) @contextmanager def get_db_connection(): """Thread-safe SQLite connection context manager.""" with _db_lock: conn = sqlite3.connect(DB_FILE, timeout=10.0) conn.row_factory = sqlite3.Row # Enable column access by name try: yield conn finally: conn.close() def _init_database() -> None: """Initialize SQLite database with required tables.""" try: with get_db_connection() as conn: cursor = conn.cursor() # Session metadata table cursor.execute(""" CREATE TABLE IF NOT EXISTS session_metadata ( session_id TEXT PRIMARY KEY, created_at REAL NOT NULL, last_initialized REAL, last_accessed REAL, metadata_json TEXT ) """) # Active sessions registry cursor.execute(""" CREATE TABLE IF NOT EXISTS active_sessions ( session_id TEXT PRIMARY KEY, created_at REAL NOT NULL, last_active REAL NOT NULL ) """) # Create indexes for performance cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_session_last_active ON active_sessions(last_active) """) conn.commit() logger.info("Database initialized successfully") except Exception as e: logger.error(f"Failed to initialize database: {e}") raise def _save_session_metadata_db(session_id: str, metadata: Dict[str, Any]) -> None: """Save session metadata to SQLite database.""" try: with get_db_connection() as conn: cursor = conn.cursor() # Serialize only JSON-compatible metadata serializable_metadata = { k: v for k, v in metadata.items() if isinstance(v, (str, int, float, bool, list, dict)) } cursor.execute(""" INSERT OR REPLACE INTO session_metadata (session_id, created_at, last_initialized, last_accessed, metadata_json) VALUES (?, ?, ?, ?, ?) """, ( session_id, metadata.get('created_at'), metadata.get('last_initialized'), time.time(), # Update last_accessed json.dumps(serializable_metadata) )) conn.commit() logger.debug(f"Saved metadata for session: {session_id}") except Exception as e: logger.error(f"Failed to save session metadata for {session_id}: {e}") def _save_active_session_db(session_id: str) -> None: """Register session as active in database.""" try: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute(""" INSERT OR REPLACE INTO active_sessions (session_id, created_at, last_active) VALUES (?, ?, ?) """, (session_id, time.time(), time.time())) conn.commit() logger.debug(f"Registered active session: {session_id}") except Exception as e: logger.error(f"Failed to register active session {session_id}: {e}") def _remove_active_session_db(session_id: str) -> None: """Remove session from active registry in database.""" try: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM active_sessions WHERE session_id = ?", (session_id,)) conn.commit() logger.debug(f"Removed active session: {session_id}") except Exception as e: logger.error(f"Failed to remove active session {session_id}: {e}") def _restore_metadata_db() -> None: """Restore session metadata from SQLite database.""" global SESSION_METADATA try: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT session_id, metadata_json FROM session_metadata") SESSION_METADATA = {} for row in cursor.fetchall(): session_id = row['session_id'] try: metadata = json.loads(row['metadata_json']) SESSION_METADATA[session_id] = metadata except json.JSONDecodeError as e: logger.warning(f"Failed to parse metadata for session {session_id}: {e}") logger.info(f"Restored metadata for {len(SESSION_METADATA)} sessions from database") except Exception as e: logger.error(f"Failed to restore session metadata from database: {e}") SESSION_METADATA = {} def _restore_active_sessions_db() -> list: """Restore active session registry from SQLite database.""" try: with get_db_connection() as conn: cursor = conn.cursor() cursor.execute("SELECT session_id FROM active_sessions ORDER BY last_active DESC") active_sessions = [row['session_id'] for row in cursor.fetchall()] logger.info(f"Found {len(active_sessions)} active sessions to restore") return active_sessions except Exception as e: logger.error(f"Failed to restore active sessions from database: {e}") return [] def _mark_session_dirty(session_id: str) -> None: """Mark a session as needing persistence.""" if session_id != "default": DIRTY_SESSIONS.add(session_id) SESSION_REQUEST_COUNTS[session_id] = SESSION_REQUEST_COUNTS.get(session_id, 0) + 1 logger.debug(f"Marked session {session_id} as dirty (requests: {SESSION_REQUEST_COUNTS[session_id]})") async def _persistence_worker_impl() -> None: """Background worker for periodic session persistence.""" logger.info("Starting background persistence worker") while not _shutdown_event.is_set(): try: await asyncio.wait_for(_shutdown_event.wait(), timeout=SAVE_INTERVAL) break # Shutdown requested except asyncio.TimeoutError: pass if DIRTY_SESSIONS: sessions_to_save = DIRTY_SESSIONS.copy() DIRTY_SESSIONS.clear() logger.debug(f"Persistence worker saving {len(sessions_to_save)} dirty sessions") for session_id in sessions_to_save: if session_id in SESSION_SHELLS: try: await _save_session_state_async(session_id, SESSION_SHELLS[session_id]) LAST_SAVE_TIMES[session_id] = time.time() except Exception as e: logger.error(f"Failed to save session {session_id} in worker: {e}") logger.info("Persistence worker stopped") async def _save_session_state_async(session_id: str, shell: InteractiveShell) -> None: """Async wrapper for session state saving.""" if session_id == "default": return logger.debug(f"Saving session state for: {session_id}") try: await asyncio.to_thread(_save_session_state, session_id, shell) logger.debug(f"Successfully saved session state for: {session_id}") except Exception as e: logger.error(f"Error saving session state for {session_id}: {e}") def _save_session_state(session_id: str, shell: InteractiveShell) -> None: """Save session state to disk using dill serialization.""" if session_id == "default": return try: # Extract serializable user namespace user_ns = { k: v for k, v in shell.user_ns.items() if not k.startswith("_") and k not in {"In", "Out", "exit", "quit", "get_ipython"} } state = {"user_ns": user_ns} session_file = SESSIONS_DIR / f"{session_id}.pkl" with open(session_file, "wb") as f: dill.dump(state, f) logger.debug(f"Saved {len(user_ns)} variables for session: {session_id}") except Exception as e: logger.error(f"Failed to save session state {session_id}: {e}") async def _should_save_session_immediately(session_id: str) -> bool: """Check if session should be saved immediately based on request count.""" if session_id == "default": return False request_count = SESSION_REQUEST_COUNTS.get(session_id, 0) return request_count >= REQUEST_SAVE_THRESHOLD async def start_persistence_worker() -> None: """Start the persistence worker during application startup.""" global _persistence_worker if _persistence_worker is None or _persistence_worker.done(): try: _persistence_worker = asyncio.create_task(_persistence_worker_impl()) logger.info("Started background persistence worker during startup") except Exception as e: logger.error(f"Failed to start persistence worker: {e}") raise async def shutdown_persistence() -> None: """Gracefully shutdown persistence and save all dirty sessions.""" global _persistence_worker logger.info("Shutting down session persistence...") # Signal shutdown _shutdown_event.set() # Wait for worker to finish if _persistence_worker and not _persistence_worker.done(): try: await asyncio.wait_for(_persistence_worker, timeout=10.0) except asyncio.TimeoutError: logger.warning("Persistence worker did not stop gracefully") _persistence_worker.cancel() # Save all remaining dirty sessions if DIRTY_SESSIONS: logger.info(f"Saving {len(DIRTY_SESSIONS)} dirty sessions on shutdown") for session_id in list(DIRTY_SESSIONS): if session_id in SESSION_SHELLS: try: _save_session_state(session_id, SESSION_SHELLS[session_id]) logger.debug(f"Saved session {session_id} on shutdown") except Exception as e: logger.error(f"Failed to save session {session_id} on shutdown: {e}") logger.info("Session persistence shutdown complete") # Initialize database and restore data on module load _init_database() _restore_metadata_db() _pending_session_restores = _restore_active_sessions_db() class SessionMiddleware(Middleware): """Middleware that manages IPython shell sessions with optimized persistence. Features: - SQLite-based metadata and registry storage - In-memory session state with periodic snapshots - Background persistence worker - Dirty flag tracking to minimize I/O - Graceful shutdown with state preservation """ def __init__(self, max_sessions: int = 4): """Initialize session middleware with optimized persistence. Parameters ---------- max_sessions : int Maximum number of concurrent sessions to maintain (default: 4) """ self.max_sessions = max_sessions self._lock = asyncio.Lock() self._worker_started = False logger.info(f"SessionMiddleware initialized with max_sessions={max_sessions}") logger.debug(f"Sessions directory: {SESSIONS_DIR}") # Note: Worker will be started when event loop is available # Restore sessions that were active for session_id in _pending_session_restores: if session_id in SESSION_SHELLS: continue # Already restored logger.debug(f"Restoring shell for session: {session_id}") try: restored_shell = self._restore_session(session_id) if restored_shell: SESSION_SHELLS[session_id] = restored_shell logger.info(f"Successfully restored shell for session: {session_id}") else: # If restoration failed, create a fresh shell logger.info(f"Creating fresh shell for session: {session_id}") SESSION_SHELLS[session_id] = self._create_shell(session_id) except Exception as e: logger.error(f"Failed to restore session {session_id}: {e}") def _start_persistence_worker(self) -> None: """Start the background persistence worker.""" global _persistence_worker if _persistence_worker is None or _persistence_worker.done(): try: _persistence_worker = asyncio.create_task(_persistence_worker_impl()) self._worker_started = True logger.debug("Started background persistence worker") except RuntimeError as e: if "no running event loop" in str(e): logger.debug("Event loop not yet available, will start worker on first request") self._worker_started = False else: raise async def startup(self) -> None: """Initialize middleware during app startup when event loop is available.""" if not self._worker_started: self._start_persistence_worker() logger.info("Session middleware startup complete") async def on_request(self, context: MiddlewareContext, call_next): """Process MCP requests with optimized session management.""" # Start persistence worker if not already started (lazy initialization) if not self._worker_started: self._start_persistence_worker() is_initialize = context.method == "initialize" logger.info(f"Session middleware on_request: {context.method}") logger.info(f"Session middleware context: {context}") # Get session ID from context session_id = None if hasattr(context, 'fastmcp_context') and context.fastmcp_context: session_id = getattr(context.fastmcp_context, 'session_id', None) logger.debug(f"Processing {context.method}, session_id: {session_id}") # Handle initialize requests if is_initialize: if session_id and session_id in SESSION_SHELLS: # Existing session logger.info(f"Reusing existing session: {session_id}") SESSION_METADATA[session_id]['last_initialized'] = time.time() _save_session_metadata_db(session_id, SESSION_METADATA[session_id]) else: # New session session_id = generate_secure_session_id() logger.info(f"Generated new session ID: {session_id}") SESSION_METADATA[session_id] = { 'created_at': time.time(), 'last_initialized': time.time(), 'needs_header': True } _save_session_metadata_db(session_id, SESSION_METADATA[session_id]) # Ensure shell exists if session_id: async with self._lock: if session_id not in SESSION_SHELLS: logger.info(f"Creating new shell for session: {session_id}") SESSION_SHELLS[session_id] = self._create_shell(session_id) _save_active_session_db(session_id) # Check session limit if len(SESSION_SHELLS) > self.max_sessions: await self._evict_oldest_session() else: # Default session session_id = "default" if session_id not in SESSION_SHELLS: logger.info("Creating default shell session") SESSION_SHELLS[session_id] = self._create_shell(session_id) # Mark session as accessed if session_id != "default": _mark_session_dirty(session_id) # Call next middleware/handler result = await call_next(context) # Handle initialize response if is_initialize and session_id != "default" and SESSION_METADATA.get(session_id, {}).get('needs_header'): SESSION_METADATA[session_id]['needs_header'] = False _save_session_metadata_db(session_id, SESSION_METADATA[session_id]) if hasattr(result, '__dict__'): result._mcp_session_id = session_id logger.info(f"Marked session {session_id} for header injection") # Check if immediate save is needed based on request frequency if session_id and session_id != "default" and session_id in SESSION_SHELLS: if await _should_save_session_immediately(session_id): logger.debug(f"Immediate save triggered for session {session_id}") asyncio.create_task(_save_session_state_async(session_id, SESSION_SHELLS[session_id])) SESSION_REQUEST_COUNTS[session_id] = 0 # Reset counter DIRTY_SESSIONS.discard(session_id) # Remove from dirty set return result async def _evict_oldest_session(self): """Evict the oldest session when limit is exceeded.""" logger.warning(f"Session limit exceeded ({len(SESSION_SHELLS)} > {self.max_sessions})") # Find oldest non-default session oldest_session = None oldest_time = float('inf') for sid, metadata in SESSION_METADATA.items(): if sid != "default" and metadata.get('created_at', 0) < oldest_time: oldest_time = metadata['created_at'] oldest_session = sid if oldest_session and oldest_session in SESSION_SHELLS: evicted_shell = SESSION_SHELLS.pop(oldest_session) SESSION_METADATA.pop(oldest_session, None) logger.info(f"Evicted session due to limit: {oldest_session}") # Remove from database _remove_active_session_db(oldest_session) # Save evicted session state await _save_session_state_async(oldest_session, evicted_shell) def _create_shell(self, session_id: str) -> InteractiveShell: """Create a new IPython shell for a session.""" logger.debug(f"Creating shell for session: {session_id}") # Try to restore from disk first restored_shell = self._restore_session(session_id) if restored_shell: logger.info(f"Successfully restored shell from disk for session: {session_id}") return restored_shell logger.debug(f"Creating fresh IPython shell for session: {session_id}") shell = InteractiveShell.instance() shell.reset() # Setup shell configuration shell.run_line_magic("load_ext", "autoreload") shell.run_line_magic("autoreload", "2") shell.Completer = IPCompleter(shell=shell, use_jedi=False) # Configure DataFrame column completion def _df_column_matcher(text): import re import pandas as pd try: m = re.match(r"(.+?)\['([^\]]*$)", text) if not m: return None var, stub = m.groups() if var in shell.user_ns and isinstance(shell.user_ns[var], pd.DataFrame): cols = shell.user_ns[var].columns.astype(str) return [f"{var}['{c}']" for c in cols if c.startswith(stub)][:200] except Exception: return None shell.Completer.custom_matchers.append(_df_column_matcher) # Import default libraries shell.run_cell("import pandas as pd\nimport numpy as np\nimport polars as pl") # Push helper functions self._push_helper_functions(shell) logger.info(f"Successfully created fresh IPython shell for session: {session_id}") return shell def _restore_session(self, session_id: str) -> Optional[InteractiveShell]: """Restore a session from disk if it exists.""" session_file = SESSIONS_DIR / f"{session_id}.pkl" if not session_file.exists(): logger.debug(f"No saved session file found for session: {session_id}") return None logger.info(f"Attempting to restore session from disk: {session_id}") try: with open(session_file, "rb") as f: state = dill.load(f) shell = InteractiveShell.instance() shell.reset() # Setup shell configuration shell.run_line_magic("load_ext", "autoreload") shell.run_line_magic("autoreload", "2") shell.Completer = IPCompleter(shell=shell, use_jedi=False) # Restore user namespace user_ns = state.get("user_ns", {}) logger.debug(f"Restoring {len(user_ns)} variables for session: {session_id}") shell.user_ns.update(user_ns) self._push_helper_functions(shell) logger.info(f"Successfully restored session from disk: {session_id}") return shell except Exception as e: logger.error(f"Failed to restore session {session_id}: {e}") try: session_file.unlink() logger.debug(f"Removed corrupted session file: {session_file}") except Exception: pass return None def _push_helper_functions(self, shell: InteractiveShell) -> None: """Push helper functions needed by tools into the shell.""" logger.debug("Loading helper functions for shell") try: from sherlog_mcp.tools.cli_tools import _search_pypi_impl, _query_apt_package_status_impl cli_funcs = { "_search_pypi_impl": _search_pypi_impl, "_query_apt_package_status_impl": _query_apt_package_status_impl, } except ImportError as e: logger.warning(f"Failed to import CLI tools: {e}") cli_funcs = {} try: from sherlog_mcp.tools.code_retrieval import ( _find_method_implementation_impl, _find_class_implementation_impl, _list_all_methods_impl, _list_all_classes_impl, _get_codebase_stats_impl, ) code_retrieval_funcs = { "_find_method_implementation_impl": _find_method_implementation_impl, "_find_class_implementation_impl": _find_class_implementation_impl, "_list_all_methods_impl": _list_all_methods_impl, "_list_all_classes_impl": _list_all_classes_impl, "_get_codebase_stats_impl": _get_codebase_stats_impl, } except ImportError as e: logger.warning(f"Failed to import code retrieval tools: {e}") code_retrieval_funcs = {} try: from sherlog_mcp.tools.external_mcp_tools import convert_to_dataframe external_funcs = {"convert_to_dataframe": convert_to_dataframe} except ImportError as e: logger.warning(f"Failed to import external MCP tools: {e}") external_funcs = {} all_funcs = {**cli_funcs, **code_retrieval_funcs, **external_funcs} if all_funcs: shell.push(all_funcs) logger.info(f"Successfully pushed {len(all_funcs)} helper functions to shell") def get_session_shell(session_id: str) -> Optional[InteractiveShell]: """Get shell for a session ID.""" if not session_id: session_id = "default" logger.debug(f"Retrieving shell for session: {session_id}") return SESSION_SHELLS.get(session_id)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/GetSherlog/Sherlog-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server