Skip to main content
Glama

Adversary MCP Server

by brettbergin
test_session_persistence_coverage.py18.5 kB
"""Additional tests for session persistence to improve coverage.""" import tempfile import time from pathlib import Path from unittest.mock import Mock, patch import pytest from adversary_mcp_server.database.models import AdversaryDatabase from adversary_mcp_server.scanner.types import Severity from adversary_mcp_server.session.session_persistence import SessionPersistenceStore from adversary_mcp_server.session.session_types import ( AnalysisSession, ConversationMessage, SecurityFinding, SessionState, ) class TestSessionPersistenceCoverage: """Additional tests to improve coverage of session persistence.""" @pytest.fixture def temp_db_path(self): """Create a temporary database for testing.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = Path(f.name) yield db_path # Cleanup if db_path.exists(): db_path.unlink() @pytest.fixture def session_store(self, temp_db_path): """Create a session persistence store for testing.""" database = AdversaryDatabase(temp_db_path) return SessionPersistenceStore( database=database, max_active_sessions=3, session_timeout_hours=1, cleanup_interval_minutes=1, ) def test_store_session_exception_handling(self, session_store): """Test exception handling in store_session.""" session = AnalysisSession() session.update_state(SessionState.READY) # Mock database persistence to raise exception with patch.object( session_store, "_persist_session_to_database", side_effect=Exception("DB Error"), ): result = session_store.store_session(session) assert result is False def test_get_session_exception_handling(self, session_store): """Test exception handling in get_session.""" # Mock _active_sessions to raise exception with patch.object( session_store, "_active_sessions", side_effect=Exception("Memory Error") ): result = session_store.get_session("test_id") assert result is None def test_remove_session_exception_handling(self, session_store): """Test exception handling in remove_session.""" # Test case where exception occurs during the removal process with patch.object( session_store, "_active_sessions", side_effect=Exception("Memory Error") ): result = session_store.remove_session("test_id") assert result is False def test_list_active_sessions_exception_handling(self, session_store): """Test exception handling in list_active_sessions.""" with patch.object( session_store, "_active_sessions", side_effect=Exception("Memory Error") ): result = session_store.list_active_sessions() assert result == [] def test_cleanup_expired_sessions_exception_handling(self, session_store): """Test exception handling in cleanup_expired_sessions.""" with patch.object( session_store, "_active_sessions", side_effect=Exception("Memory Error") ): result = session_store.cleanup_expired_sessions() assert result == 0 def test_get_statistics_exception_handling(self, session_store): """Test exception handling in get_statistics.""" with patch.object( session_store, "_list_database_sessions", side_effect=Exception("DB Error") ): result = session_store.get_statistics() assert result == {} def test_list_active_sessions_with_database_cleanup(self, session_store): """Test list_active_sessions with database session cleanup.""" # Create a session that will be expired expired_session = AnalysisSession() expired_session.update_state(SessionState.READY) expired_session.last_activity = time.time() - 7200 # 2 hours ago # Store it so it's in database but not active memory session_store.store_session(expired_session) # Clear from memory to simulate it being only in database session_store._active_sessions.clear() session_store._session_access_times.clear() # Mock database session list to return our expired session with patch.object( session_store, "_list_database_sessions", return_value=[expired_session.session_id], ): # Mock load to return the expired session with patch.object( session_store, "_load_session_from_database", return_value=expired_session, ): active_sessions = session_store.list_active_sessions() # Should be empty because session is expired and should be removed assert len(active_sessions) == 0 def test_cleanup_expired_sessions_database_only(self, session_store): """Test cleanup of expired sessions that are only in database.""" # Create expired session expired_session = AnalysisSession() expired_session.update_state(SessionState.READY) expired_session.last_activity = time.time() - 7200 # 2 hours ago # Store it then clear from memory session_store.store_session(expired_session) session_store._active_sessions.clear() session_store._session_access_times.clear() # Mock database methods with patch.object( session_store, "_list_database_sessions", return_value=[expired_session.session_id], ): with patch.object( session_store, "_load_session_from_database", return_value=expired_session, ): with patch.object( session_store, "_remove_session_from_database", return_value=True ) as mock_remove: cleaned_count = session_store.cleanup_expired_sessions() assert cleaned_count >= 1 mock_remove.assert_called_with(expired_session.session_id) def test_persist_session_to_database_exception(self, session_store): """Test exception handling in _persist_session_to_database.""" session = AnalysisSession() session.update_state(SessionState.READY) # Mock database session to raise exception with patch.object( session_store.database, "get_session", side_effect=Exception("DB Connection Error"), ): result = session_store._persist_session_to_database(session) assert result is False def test_load_session_from_database_exception(self, session_store): """Test exception handling in _load_session_from_database.""" with patch.object( session_store.database, "get_session", side_effect=Exception("DB Connection Error"), ): result = session_store._load_session_from_database("test_id") assert result is None def test_load_session_from_database_no_session(self, session_store): """Test loading non-existent session from database.""" # Mock database session that returns None mock_db_session = Mock() mock_query = Mock() mock_query.filter_by.return_value.first.return_value = None mock_db_session.query.return_value = mock_query with patch.object( session_store.database, "get_session", return_value=mock_db_session ): result = session_store._load_session_from_database("nonexistent") assert result is None def test_remove_session_from_database_exception(self, session_store): """Test exception handling in _remove_session_from_database.""" with patch.object( session_store.database, "get_session", side_effect=Exception("DB Connection Error"), ): result = session_store._remove_session_from_database("test_id") assert result is False def test_remove_session_from_database_no_session(self, session_store): """Test removing non-existent session from database.""" # Mock database session that returns None for the query mock_db_session = Mock() mock_query = Mock() mock_query.filter_by.return_value.first.return_value = None mock_db_session.query.return_value = mock_query with patch.object( session_store.database, "get_session", return_value=mock_db_session ): result = session_store._remove_session_from_database("nonexistent") assert result is False def test_list_database_sessions_exception(self, session_store): """Test exception handling in _list_database_sessions.""" with patch.object( session_store.database, "get_session", side_effect=Exception("DB Connection Error"), ): result = session_store._list_database_sessions() assert result == [] def test_serialize_session_context_with_project_context(self, session_store): """Test serialization of session context with ProjectContext.""" # Mock ProjectContext class mock_project_context = Mock() mock_project_context.__class__.__name__ = "ProjectContext" session_context = { "project": mock_project_context, "timestamp": time.time(), "serializable_data": {"key": "value"}, } with patch.object( session_store, "_serialize_project_context", return_value={"serialized": "context"}, ): result = session_store._serialize_session_context(session_context) assert "timestamp" in result assert "serializable_data" in result assert result["serializable_data"] == {"key": "value"} def test_serialize_session_context_non_serializable(self, session_store): """Test serialization of session context with non-serializable data.""" # Create non-serializable object class NonSerializable: def __init__(self): self.data = "test" session_context = {"non_serializable": NonSerializable(), "normal_data": "test"} result = session_store._serialize_session_context(session_context) assert "normal_data" in result assert result["normal_data"] == "test" # Non-serializable should be converted to string assert "non_serializable" in result assert isinstance(result["non_serializable"], str) def test_serialize_session_context_empty(self, session_store): """Test serialization of empty session context.""" result = session_store._serialize_session_context(None) assert result == {} result = session_store._serialize_session_context({}) assert result == {} def test_serialize_project_context_with_dict_nested(self, session_store): """Test serialization of project context with nested dict structure.""" # Create a dict with nested ProjectContext (simulated) context_dict = {"loaded_at": time.time(), "serializable_data": {"key": "value"}} # Test serialization of dict that's not a ProjectContext result = session_store._serialize_project_context(context_dict) assert "loaded_at" in result assert "serializable_data" in result assert result["serializable_data"] == {"key": "value"} def test_serialize_project_context_serialization_error(self, session_store): """Test serialization error handling in _serialize_project_context.""" # Create a dict that will cause serialization issues problematic_data = {"data": object()} # Non-serializable # Mock json.dumps to raise TypeError for non-serializable object with patch("json.dumps", side_effect=TypeError("Not serializable")): result = session_store._serialize_project_context(problematic_data) # Should convert to string representation assert isinstance(result["data"], str) def test_serialize_project_context_exception(self, session_store): """Test exception handling in _serialize_project_context.""" # Mock to raise general exception with patch("adversary_mcp_server.session.session_persistence.logger"): # Use a context that will cause an ImportError when checking ProjectContext result = session_store._serialize_project_context("invalid_context") assert result == "invalid_context" def test_deserialize_project_context_none(self, session_store): """Test deserialization of None project context.""" result = session_store._deserialize_project_context(None) assert result is None def test_deserialize_project_context_exception(self, session_store): """Test exception handling in _deserialize_project_context.""" context_data = {"project_root": "/test"} # Mock to raise ImportError when importing ProjectContext with patch("adversary_mcp_server.session.session_persistence.logger"): # This should trigger the exception handler with patch( "builtins.__import__", side_effect=ImportError("Module not found") ): result = session_store._deserialize_project_context(context_data) # Should return the original data when deserialization fails assert result == context_data def test_deserialize_project_context_not_expected_format(self, session_store): """Test deserialization of project context that's not in expected format.""" # Data without project_root key context_data = {"some_other_field": "value"} result = session_store._deserialize_project_context(context_data) # Should return as-is if not in expected format assert result == context_data def test_periodic_cleanup_triggers(self, session_store): """Test that periodic cleanup triggers when interval passes.""" # Set last cleanup time to be old enough to trigger cleanup session_store._last_cleanup_time = time.time() - 3700 # Over an hour ago # Create an expired session to clean up expired_session = AnalysisSession() expired_session.update_state(SessionState.READY) expired_session.last_activity = time.time() - 7200 # 2 hours ago session_store._active_sessions[expired_session.session_id] = expired_session session_store._session_access_times[expired_session.session_id] = ( time.time() - 7200 ) # Call periodic cleanup initial_cleanup_time = session_store._last_cleanup_time session_store._periodic_cleanup() # Cleanup time should be updated assert session_store._last_cleanup_time > initial_cleanup_time def test_session_with_findings_and_messages_serialization(self, session_store): """Test serialization of session with findings and messages.""" session = AnalysisSession() session.update_state(SessionState.READY) # Add message message = ConversationMessage( role="user", content="Test message", metadata={"test": "data"} ) session.messages.append(message) # Add finding finding = SecurityFinding( uuid="test-finding", rule_id="test-rule", rule_name="Test Rule", description="Test finding", severity=Severity.HIGH, file_path="/test/file.py", line_number=10, confidence=0.9, session_context={"test": "context"}, ) session.findings.append(finding) # Store and retrieve assert session_store.store_session(session) is True retrieved = session_store.get_session(session.session_id) assert retrieved is not None assert len(retrieved.messages) == 1 assert retrieved.messages[0].content == "Test message" assert len(retrieved.findings) == 1 assert retrieved.findings[0].rule_id == "test-rule" def test_get_session_expired_cleanup(self, session_store): """Test that expired sessions are cleaned up when retrieved.""" # Create expired session expired_session = AnalysisSession() expired_session.update_state(SessionState.READY) expired_session.last_activity = time.time() - 7200 # 2 hours ago # Add to memory cache directly session_store._active_sessions[expired_session.session_id] = expired_session session_store._session_access_times[expired_session.session_id] = time.time() # Try to get it - should be cleaned up and return None result = session_store.get_session(expired_session.session_id) assert result is None # Should be removed from memory assert expired_session.session_id not in session_store._active_sessions def test_memory_enforcement_with_lru(self, session_store): """Test that LRU eviction works correctly.""" sessions = [] # Create sessions up to the limit for i in range(session_store.max_active_sessions): session = AnalysisSession(metadata={"index": i}) session.update_state(SessionState.READY) sessions.append(session) session_store.store_session(session) time.sleep(0.01) # Small delay to ensure different access times # Access the first session to make it more recently used session_store.get_session(sessions[0].session_id) # Add one more session to trigger eviction new_session = AnalysisSession(metadata={"index": "new"}) new_session.update_state(SessionState.READY) session_store.store_session(new_session) # First session should still be in memory (was accessed recently) assert sessions[0].session_id in session_store._active_sessions # New session should be in memory assert new_session.session_id in session_store._active_sessions # Memory should not exceed limit assert len(session_store._active_sessions) <= session_store.max_active_sessions

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/brettbergin/adversary-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server