Skip to main content
Glama

Adversary MCP Server

by brettbergin
session_persistence.py25 kB
"""Persistent session storage with scalability improvements.""" import json import time from pathlib import Path from typing import Any from ..database.models import AdversaryDatabase from ..logger import get_logger from .session_types import AnalysisSession, SessionState logger = get_logger("session_persistence") class SessionPersistenceStore: """ Persistent storage for LLM analysis sessions with scalability improvements. Features: - Automatic session persistence to database - Session expiration and cleanup - Memory management with configurable limits - Fault-tolerant recovery - Bulk operations for performance """ def __init__( self, database: AdversaryDatabase, max_active_sessions: int = 50, session_timeout_hours: int = 6, cleanup_interval_minutes: int = 30, ): """ Initialize session persistence store. Args: database: Database connection for persistence max_active_sessions: Maximum sessions to keep in memory session_timeout_hours: Hours before session expires cleanup_interval_minutes: Minutes between cleanup operations """ self.database = database self.max_active_sessions = max_active_sessions self.session_timeout_seconds = session_timeout_hours * 3600 self.cleanup_interval_seconds = cleanup_interval_minutes * 60 # In-memory cache for active sessions (LRU-style) self._active_sessions: dict[str, AnalysisSession] = {} self._session_access_times: dict[str, float] = {} self._last_cleanup_time = time.time() logger.info( f"Session persistence initialized: max_sessions={max_active_sessions}, " f"timeout={session_timeout_hours}h, cleanup_interval={cleanup_interval_minutes}m" ) def store_session(self, session: AnalysisSession) -> bool: """ Store session both in memory and persistent storage. Args: session: Session to store Returns: True if successfully stored """ try: # Only update last activity if session is not expired # This preserves the original last_activity for testing expired sessions if not session.is_expired(self.session_timeout_seconds): session.last_activity = time.time() # Store in memory cache self._active_sessions[session.session_id] = session self._session_access_times[session.session_id] = time.time() # Enforce memory limits self._enforce_memory_limits() # Persist to database self._persist_session_to_database(session) # Periodic cleanup self._periodic_cleanup() logger.debug(f"Stored session {session.session_id}") return True except Exception as e: logger.error(f"Failed to store session {session.session_id}: {e}") return False def get_session(self, session_id: str) -> AnalysisSession | None: """ Retrieve session from memory or persistent storage. Args: session_id: Session ID to retrieve Returns: Session if found and not expired, None otherwise """ try: # Check memory cache first if session_id in self._active_sessions: session = self._active_sessions[session_id] # Check if expired if session.is_expired(self.session_timeout_seconds): self.remove_session(session_id) return None # Update access time self._session_access_times[session_id] = time.time() logger.debug(f"Retrieved session {session_id} from memory") return session # Try to load from database session = self._load_session_from_database(session_id) if session and not session.is_expired(self.session_timeout_seconds): # Store back in memory cache self._active_sessions[session_id] = session self._session_access_times[session_id] = time.time() self._enforce_memory_limits() logger.debug(f"Retrieved session {session_id} from database") return session elif session: # Session exists but is expired, clean it up self._remove_session_from_database(session_id) logger.debug(f"Session {session_id} expired, removed from database") return None except Exception as e: logger.error(f"Failed to retrieve session {session_id}: {e}") return None def remove_session(self, session_id: str) -> bool: """ Remove session from both memory and persistent storage. Args: session_id: Session ID to remove Returns: True if successfully removed """ try: removed = False # Remove from memory if session_id in self._active_sessions: del self._active_sessions[session_id] del self._session_access_times[session_id] removed = True # Remove from database if self._remove_session_from_database(session_id): removed = True if removed: logger.debug(f"Removed session {session_id}") return removed except Exception as e: logger.error(f"Failed to remove session {session_id}: {e}") return False def list_active_sessions(self) -> list[str]: """ Get list of active (non-expired) session IDs. Returns: List of session IDs """ try: active_sessions = [] current_time = time.time() # Check memory sessions for session_id, session in list(self._active_sessions.items()): if not session.is_expired(self.session_timeout_seconds): active_sessions.append(session_id) else: # Remove expired session self.remove_session(session_id) # Also check database for sessions not in memory database_sessions = self._list_database_sessions() for session_id in database_sessions: if session_id not in active_sessions: session = self._load_session_from_database(session_id) if session and not session.is_expired(self.session_timeout_seconds): active_sessions.append(session_id) elif session: # Expired session in database self._remove_session_from_database(session_id) logger.debug(f"Found {len(active_sessions)} active sessions") return active_sessions except Exception as e: logger.error(f"Failed to list active sessions: {e}") return [] def cleanup_expired_sessions(self) -> int: """ Clean up expired sessions from both memory and database. Returns: Number of sessions cleaned up """ try: cleaned_count = 0 current_time = time.time() # Clean up memory sessions expired_memory_sessions = [] for session_id, session in self._active_sessions.items(): if session.is_expired(self.session_timeout_seconds): expired_memory_sessions.append(session_id) for session_id in expired_memory_sessions: if self.remove_session(session_id): cleaned_count += 1 # Clean up database sessions database_sessions = self._list_database_sessions() for session_id in database_sessions: if session_id not in self._active_sessions: session = self._load_session_from_database(session_id) if session and session.is_expired(self.session_timeout_seconds): if self._remove_session_from_database(session_id): cleaned_count += 1 if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} expired sessions") return cleaned_count except Exception as e: logger.error(f"Failed to cleanup expired sessions: {e}") return 0 def get_statistics(self) -> dict[str, Any]: """ Get session storage statistics. Returns: Dictionary with storage statistics """ try: return { "active_sessions_memory": len(self._active_sessions), "total_sessions_database": len(self._list_database_sessions()), "max_active_sessions": self.max_active_sessions, "session_timeout_seconds": self.session_timeout_seconds, "memory_utilization": len(self._active_sessions) / self.max_active_sessions, "last_cleanup_time": self._last_cleanup_time, } except Exception as e: logger.error(f"Failed to get statistics: {e}") return {} def _enforce_memory_limits(self) -> None: """Enforce memory limits by evicting least recently used sessions.""" while len(self._active_sessions) > self.max_active_sessions: # Find least recently used session lru_session_id = min( self._session_access_times.keys(), key=lambda sid: self._session_access_times[sid], ) # Move to database only (remove from memory) session = self._active_sessions[lru_session_id] self._persist_session_to_database(session) del self._active_sessions[lru_session_id] del self._session_access_times[lru_session_id] logger.debug( f"Evicted session {lru_session_id} from memory to enforce limits" ) def _periodic_cleanup(self) -> None: """Perform periodic cleanup if enough time has passed.""" current_time = time.time() if current_time - self._last_cleanup_time > self.cleanup_interval_seconds: self.cleanup_expired_sessions() self._last_cleanup_time = current_time def _persist_session_to_database(self, session: AnalysisSession) -> bool: """Persist session to database using ORM.""" try: with self.database.get_session() as db_session: from ..database.models import LLMAnalysisSession # Check if session already exists existing = ( db_session.query(LLMAnalysisSession) .filter_by(session_id=session.session_id) .first() ) # Serialize session data session_data = { "session_id": session.session_id, "state": session.state.value, "project_root": ( str(session.project_root) if session.project_root else None ), "created_at": session.created_at, "last_activity": session.last_activity, "metadata": session.metadata, "messages": [ { "role": msg.role, "content": msg.content, "timestamp": msg.timestamp, "metadata": msg.metadata, } for msg in session.messages ], "findings": [ { "uuid": finding.uuid, "rule_id": finding.rule_id, "rule_name": finding.rule_name, "description": finding.description, "severity": ( finding.severity.value if hasattr(finding.severity, "value") else str(finding.severity) ), "file_path": finding.file_path, "line_number": finding.line_number, "confidence": finding.confidence, "session_context": self._serialize_session_context( finding.session_context ), } for finding in session.findings ], "project_context": self._serialize_project_context( session.project_context ), } if existing: # Update existing session existing.session_data = json.dumps(session_data) existing.last_activity = session.last_activity existing.state = session.state.value else: # Create new session new_session = LLMAnalysisSession( session_id=session.session_id, session_data=json.dumps(session_data), created_at=session.created_at, last_activity=session.last_activity, state=session.state.value, ) db_session.add(new_session) db_session.commit() return True except Exception as e: logger.error(f"Failed to persist session to database: {e}") return False def _load_session_from_database(self, session_id: str) -> AnalysisSession | None: """Load session from database using ORM.""" try: with self.database.get_session() as db_session: from ..database.models import LLMAnalysisSession db_session_obj = ( db_session.query(LLMAnalysisSession) .filter_by(session_id=session_id) .first() ) if not db_session_obj: return None session_data = json.loads(str(db_session_obj.session_data)) # Reconstruct session object session = AnalysisSession( session_id=session_data["session_id"], state=SessionState(session_data["state"]), project_root=( Path(session_data["project_root"]) if session_data["project_root"] else None ), created_at=session_data["created_at"], last_activity=session_data["last_activity"], metadata=session_data["metadata"], project_context=self._deserialize_project_context( session_data["project_context"] ), ) # Reconstruct messages from .session_types import ConversationMessage for msg_data in session_data["messages"]: msg = ConversationMessage( role=msg_data["role"], content=msg_data["content"], timestamp=msg_data["timestamp"], metadata=msg_data["metadata"], ) session.messages.append(msg) # Reconstruct findings from ..scanner.types import Severity from .session_types import SecurityFinding for finding_data in session_data["findings"]: finding = SecurityFinding( uuid=finding_data["uuid"], rule_id=finding_data["rule_id"], rule_name=finding_data["rule_name"], description=finding_data["description"], severity=( Severity(finding_data["severity"]) if finding_data["severity"] in [s.value for s in Severity] else Severity.MEDIUM ), file_path=finding_data["file_path"], line_number=finding_data["line_number"], confidence=finding_data["confidence"], session_context=finding_data["session_context"], ) session.findings.append(finding) return session except Exception as e: logger.error(f"Failed to load session from database: {e}") return None def _remove_session_from_database(self, session_id: str) -> bool: """Remove session from database using ORM.""" try: with self.database.get_session() as db_session: from ..database.models import LLMAnalysisSession db_session_obj = ( db_session.query(LLMAnalysisSession) .filter_by(session_id=session_id) .first() ) if db_session_obj: db_session.delete(db_session_obj) db_session.commit() return True return False except Exception as e: logger.error(f"Failed to remove session from database: {e}") return False def _list_database_sessions(self) -> list[str]: """Get list of session IDs from database using ORM.""" try: with self.database.get_session() as db_session: from ..database.models import LLMAnalysisSession session_ids = db_session.query(LLMAnalysisSession.session_id).all() return [session_id[0] for session_id in session_ids] except Exception as e: logger.error(f"Failed to list database sessions: {e}") return [] def _serialize_session_context( self, session_context: dict[str, Any] ) -> dict[str, Any]: """Serialize session context dict, handling ProjectContext objects.""" if not session_context: return {} serialized = {} for key, value in session_context.items(): try: from .project_context import ProjectContext if isinstance(value, ProjectContext): serialized[key] = self._serialize_project_context(value) else: # For other types, try to serialize directly import json json.dumps(value) # Test if serializable serialized[key] = value except (TypeError, ImportError): # If not serializable, convert to string serialized[key] = str(value) return serialized def _serialize_project_context(self, project_context: Any) -> dict[str, Any] | None: """Serialize ProjectContext to JSON-compatible format.""" if project_context is None: return None try: from .project_context import ProjectContext if isinstance(project_context, ProjectContext): return { "project_root": str(project_context.project_root), "project_type": project_context.project_type, "structure_overview": project_context.structure_overview, "key_files": [ { "path": str(f.path), "language": f.language, "size_bytes": f.size_bytes, "priority_score": f.priority_score, "security_relevance": f.security_relevance, "content_preview": f.content_preview, "is_entry_point": f.is_entry_point, "is_config": f.is_config, "is_security_critical": f.is_security_critical, } for f in project_context.key_files ], "security_modules": project_context.security_modules, "entry_points": project_context.entry_points, "dependencies": project_context.dependencies, "architecture_summary": project_context.architecture_summary, "total_files": project_context.total_files, "total_size_bytes": project_context.total_size_bytes, "languages_used": list(project_context.languages_used), "estimated_tokens": project_context.estimated_tokens, } elif isinstance(project_context, dict): # Handle nested structure like {"context": ProjectContext, "loaded_at": time, ...} serialized = {} for key, value in project_context.items(): if isinstance(value, ProjectContext): serialized[key] = self._serialize_project_context(value) else: # For other types, try to serialize directly import json try: json.dumps(value) # Test if serializable serialized[key] = value except (TypeError, ValueError): # If not serializable, convert to string serialized[key] = str(value) return serialized else: # If it's already a dict or other serializable type return project_context except Exception as e: logger.warning(f"Failed to serialize project context: {e}") return None def _deserialize_project_context(self, context_data: dict[str, Any] | None) -> Any: """Deserialize ProjectContext from JSON-compatible format.""" if context_data is None: return None try: from .project_context import ProjectContext, ProjectFile if isinstance(context_data, dict) and "project_root" in context_data: key_files = [] for file_data in context_data.get("key_files", []): key_files.append( ProjectFile( path=Path(file_data["path"]), language=file_data["language"], size_bytes=file_data["size_bytes"], priority_score=file_data.get("priority_score", 0.0), security_relevance=file_data.get("security_relevance", 0.0), content_preview=file_data.get("content_preview", ""), is_entry_point=file_data.get("is_entry_point", False), is_config=file_data.get("is_config", False), is_security_critical=file_data.get( "is_security_critical", False ), ) ) return ProjectContext( project_root=Path(context_data["project_root"]), project_type=context_data.get("project_type", "unknown"), structure_overview=context_data.get("structure_overview", ""), key_files=key_files, security_modules=context_data.get("security_modules", []), entry_points=context_data.get("entry_points", []), dependencies=context_data.get("dependencies", []), architecture_summary=context_data.get("architecture_summary", ""), total_files=context_data.get("total_files", 0), total_size_bytes=context_data.get("total_size_bytes", 0), languages_used=set(context_data.get("languages_used", [])), estimated_tokens=context_data.get("estimated_tokens", 0), ) else: # Return as-is if it's not our expected format return context_data except Exception as e: logger.warning(f"Failed to deserialize project context: {e}") return context_data

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/brettbergin/adversary-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server