Skip to main content
Glama

Adversary MCP Server

by brettbergin
test_services.py23.8 kB
"""Comprehensive tests for domain services.""" from datetime import UTC, datetime from unittest.mock import AsyncMock, Mock import pytest from adversary_mcp_server.domain.entities.scan_request import ScanRequest from adversary_mcp_server.domain.entities.scan_result import ScanResult from adversary_mcp_server.domain.entities.threat_match import ThreatMatch from adversary_mcp_server.domain.interfaces import ( ConfigurationError, IScanStrategy, IValidationStrategy, SecurityError, ValidationError, ) from adversary_mcp_server.domain.services.scan_orchestrator import ScanOrchestrator from adversary_mcp_server.domain.services.threat_aggregator import ( FingerprintBasedAggregationStrategy, HybridAggregationStrategy, ProximityBasedAggregationStrategy, ThreatAggregator, ) from adversary_mcp_server.domain.services.validation_service import ValidationService from adversary_mcp_server.domain.value_objects.confidence_score import ConfidenceScore from adversary_mcp_server.domain.value_objects.file_path import FilePath from adversary_mcp_server.domain.value_objects.scan_context import ScanContext from adversary_mcp_server.domain.value_objects.scan_metadata import ScanMetadata from adversary_mcp_server.domain.value_objects.severity_level import SeverityLevel class TestScanOrchestrator: """Test ScanOrchestrator service.""" @pytest.fixture def orchestrator(self): """Create ScanOrchestrator instance.""" return ScanOrchestrator() @pytest.fixture def mock_scan_strategy(self): """Create mock scan strategy.""" strategy = Mock(spec=IScanStrategy) strategy.get_strategy_name.return_value = "mock_semgrep" strategy.can_scan.return_value = True strategy.execute_scan = AsyncMock() return strategy @pytest.fixture def mock_validation_strategy(self): """Create mock validation strategy.""" strategy = Mock(spec=IValidationStrategy) strategy.get_strategy_name.return_value = "mock_llm_validator" strategy.can_validate.return_value = True strategy.validate_threats = AsyncMock() return strategy @pytest.fixture def sample_scan_request(self): """Create sample scan request.""" file_path = FilePath.from_virtual("test.py") metadata = ScanMetadata( scan_id="test-scan-123", scan_type="file", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext(target_path=file_path, metadata=metadata) return ScanRequest(context=context) @pytest.fixture def sample_threat(self): """Create sample threat.""" return ThreatMatch( rule_id="test-rule-001", rule_name="Test Rule", description="Test vulnerability", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/home/user/test.py"), line_number=42, column_number=10, code_snippet="test code", confidence=ConfidenceScore(0.8), ) def test_register_scan_strategy(self, orchestrator, mock_scan_strategy): """Test registering scan strategy.""" orchestrator.register_scan_strategy(mock_scan_strategy) strategies = orchestrator.get_registered_strategies() assert "mock_semgrep" in strategies["scan_strategies"] def test_register_validation_strategy(self, orchestrator, mock_validation_strategy): """Test registering validation strategy.""" orchestrator.register_validation_strategy(mock_validation_strategy) strategies = orchestrator.get_registered_strategies() assert "mock_llm_validator" in strategies["validation_strategies"] def test_set_threat_aggregator(self, orchestrator): """Test setting threat aggregator.""" aggregator = Mock() aggregator.get_aggregation_strategy_name.return_value = "hybrid" orchestrator.set_threat_aggregator(aggregator) strategies = orchestrator.get_registered_strategies() assert strategies["threat_aggregator"] == ["hybrid"] @pytest.mark.asyncio async def test_execute_scan_no_strategies(self, orchestrator, sample_scan_request): """Test executing scan with no strategies raises error.""" with pytest.raises( ConfigurationError, match="No registered strategies can scan" ): await orchestrator.execute_scan(sample_scan_request) @pytest.mark.asyncio async def test_execute_scan_successful( self, orchestrator, mock_scan_strategy, sample_scan_request, sample_threat ): """Test successful scan execution.""" # Setup mock strategy to return a result mock_result = ScanResult.create_from_threats( request=sample_scan_request, threats=[sample_threat], scan_metadata={"scanner": "mock_semgrep"}, ) mock_scan_strategy.execute_scan.return_value = mock_result orchestrator.register_scan_strategy(mock_scan_strategy) result = await orchestrator.execute_scan(sample_scan_request) assert result is not None assert len(result.threats) == 1 assert result.threats[0] == sample_threat mock_scan_strategy.execute_scan.assert_called_once_with(sample_scan_request) @pytest.mark.asyncio async def test_execute_scan_with_validation( self, orchestrator, mock_scan_strategy, mock_validation_strategy, sample_scan_request, sample_threat, ): """Test scan execution with validation.""" # Setup mock strategy to return a result mock_result = ScanResult.create_from_threats( request=sample_scan_request, threats=[sample_threat], scan_metadata={"scanner": "mock_semgrep"}, ) mock_scan_strategy.execute_scan.return_value = mock_result # Setup validation to return filtered threats validated_threat = sample_threat.update_confidence(ConfidenceScore(0.9)) mock_validation_strategy.validate_threats.return_value = [validated_threat] orchestrator.register_scan_strategy(mock_scan_strategy) orchestrator.register_validation_strategy(mock_validation_strategy) result = await orchestrator.execute_scan(sample_scan_request) assert result is not None assert len(result.threats) == 1 assert result.threats[0].confidence == ConfidenceScore(0.9) mock_validation_strategy.validate_threats.assert_called_once() @pytest.mark.asyncio async def test_execute_scan_strategy_failure( self, orchestrator, mock_scan_strategy, sample_scan_request ): """Test scan execution with strategy failure.""" # Setup mock strategy to fail mock_scan_strategy.execute_scan.side_effect = Exception("Strategy failed") orchestrator.register_scan_strategy(mock_scan_strategy) result = await orchestrator.execute_scan(sample_scan_request) # Should return empty result instead of failing assert result is not None assert len(result.threats) == 0 def test_can_execute_scan( self, orchestrator, mock_scan_strategy, sample_scan_request ): """Test can_execute_scan method.""" # No strategies - cannot execute assert not orchestrator.can_execute_scan(sample_scan_request) # With strategy - can execute orchestrator.register_scan_strategy(mock_scan_strategy) assert orchestrator.can_execute_scan(sample_scan_request) @pytest.mark.asyncio async def test_severity_filtering( self, orchestrator, mock_scan_strategy, sample_scan_request ): """Test severity filtering.""" # Create threats with different severities high_threat = ThreatMatch( rule_id="high-rule", rule_name="High Rule", description="High severity", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/home/user/test.py"), line_number=42, column_number=10, code_snippet="test", confidence=ConfidenceScore(0.8), ) low_threat = ThreatMatch( rule_id="low-rule", rule_name="Low Rule", description="Low severity", category="disclosure", severity=SeverityLevel.from_string("low"), file_path=FilePath.from_string("/home/user/test.py"), line_number=50, column_number=5, code_snippet="test", confidence=ConfidenceScore(0.6), ) mock_result = ScanResult.create_from_threats( request=sample_scan_request, threats=[high_threat, low_threat], scan_metadata={}, ) mock_scan_strategy.execute_scan.return_value = mock_result # Set high severity threshold high_threshold_request = ScanRequest( context=sample_scan_request.context, severity_threshold=SeverityLevel.from_string("high"), ) orchestrator.register_scan_strategy(mock_scan_strategy) result = await orchestrator.execute_scan(high_threshold_request) # Should only include high severity threat assert len(result.threats) == 1 assert result.threats[0].severity == SeverityLevel.from_string("high") class TestThreatAggregator: """Test ThreatAggregator service and strategies.""" def test_proximity_based_strategy(self): """Test ProximityBasedAggregationStrategy.""" strategy = ProximityBasedAggregationStrategy(proximity_threshold=3) assert strategy.get_strategy_name() == "proximity_based(threshold=3)" # Create threats at different distances threat1 = ThreatMatch( rule_id="rule-1", rule_name="Rule 1", description="Test", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.8), ) threat2 = ThreatMatch( rule_id="rule-2", rule_name="Rule 2", description="Test", category="injection", severity=SeverityLevel.from_string("medium"), file_path=FilePath.from_string("/test.py"), line_number=12, # Within threshold column_number=1, code_snippet="test", confidence=ConfidenceScore(0.7), ) threat3 = ThreatMatch( rule_id="rule-3", rule_name="Rule 3", description="Test", category="injection", severity=SeverityLevel.from_string("low"), file_path=FilePath.from_string("/test.py"), line_number=20, # Outside threshold column_number=1, code_snippet="test", confidence=ConfidenceScore(0.6), ) threats = [threat1, threat2, threat3] merged = strategy.merge_similar_threats(threats) # Should merge threat1 and threat2, keep threat3 separate assert len(merged) == 2 def test_fingerprint_based_strategy(self): """Test FingerprintBasedAggregationStrategy.""" strategy = FingerprintBasedAggregationStrategy() assert strategy.get_strategy_name() == "fingerprint_based" # Create threats with same fingerprint threat1 = ThreatMatch( rule_id="rule-1", rule_name="Rule 1", description="Test", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.8), ) # Identical threat (different rule_id) threat2 = ThreatMatch( rule_id="rule-2", rule_name="Rule 2", description="Test", category="injection", severity=SeverityLevel.from_string("medium"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.7), ) threats = [threat1, threat2] merged = strategy.merge_similar_threats(threats) # Should merge into single threat with higher confidence assert len(merged) == 1 assert merged[0].confidence == ConfidenceScore(0.8) def test_hybrid_strategy(self): """Test HybridAggregationStrategy.""" strategy = HybridAggregationStrategy(proximity_threshold=5) assert "hybrid" in strategy.get_strategy_name() assert "proximity_threshold=5" in strategy.get_strategy_name() # Create mix of threats for testing threats = [ ThreatMatch( rule_id="rule-1", rule_name="Rule 1", description="Test", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.8), ), # Duplicate fingerprint ThreatMatch( rule_id="rule-2", rule_name="Rule 2", description="Test", category="injection", severity=SeverityLevel.from_string("medium"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.7), ), # Close proximity ThreatMatch( rule_id="rule-3", rule_name="Rule 3", description="Test", category="injection", severity=SeverityLevel.from_string("low"), file_path=FilePath.from_string("/test.py"), line_number=12, column_number=1, code_snippet="different test", confidence=ConfidenceScore(0.6), ), ] merged = strategy.merge_similar_threats(threats) # Should apply both fingerprint and proximity merging assert len(merged) <= len(threats) def test_threat_aggregator_with_strategy(self): """Test ThreatAggregator with different strategies.""" aggregator = ThreatAggregator() # Test default strategy assert isinstance(aggregator.strategy, HybridAggregationStrategy) # Test setting different strategy proximity_strategy = ProximityBasedAggregationStrategy(proximity_threshold=10) aggregator.set_strategy(proximity_strategy) assert aggregator.strategy == proximity_strategy assert ( aggregator.get_aggregation_strategy_name() == "proximity_based(threshold=10)" ) def test_threat_aggregator_statistics(self): """Test ThreatAggregator statistics.""" aggregator = ThreatAggregator() # Create test threats threats_group1 = [ ThreatMatch( rule_id="rule-1", rule_name="Rule 1", description="Test", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/test.py"), line_number=10, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.8), ) ] threats_group2 = [ ThreatMatch( rule_id="rule-2", rule_name="Rule 2", description="Test", category="xss", severity=SeverityLevel.from_string("medium"), file_path=FilePath.from_string("/test.py"), line_number=20, column_number=1, code_snippet="test", confidence=ConfidenceScore(0.7), ) ] result = aggregator.aggregate_threats([threats_group1, threats_group2]) stats = aggregator.get_statistics(original_count=2, final_count=len(result)) assert stats["original_threat_count"] == 2 assert stats["final_threat_count"] == len(result) assert stats["threats_merged"] == 2 - len(result) assert "strategy" in stats class TestValidationService: """Test ValidationService.""" @pytest.fixture def validation_service(self): """Create ValidationService instance.""" return ValidationService() @pytest.fixture def sample_scan_request(self): """Create sample scan request.""" file_path = FilePath.from_virtual("test.py") metadata = ScanMetadata( scan_id="test-scan-123", scan_type="file", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext(target_path=file_path, metadata=metadata) return ScanRequest(context=context) @pytest.fixture def sample_threat(self): """Create sample threat.""" return ThreatMatch( rule_id="test-rule-001", rule_name="Test Rule", description="Test vulnerability", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/home/user/test.py"), line_number=42, column_number=10, code_snippet="test code", confidence=ConfidenceScore(0.8), ) def test_validate_scan_request_valid(self, validation_service, sample_scan_request): """Test validating valid scan request.""" # Should not raise any exceptions validation_service.validate_scan_request(sample_scan_request) def test_validate_scan_request_invalid_no_scanners(self, validation_service): """Test validating scan request with no scanners enabled.""" file_path = FilePath.from_virtual("test.py") metadata = ScanMetadata( scan_id="test-scan-123", scan_type="file", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext(target_path=file_path, metadata=metadata) with pytest.raises( ValidationError, match="At least one scanner must be enabled" ): invalid_request = ScanRequest( context=context, enable_semgrep=False, enable_llm=False ) validation_service.validate_scan_request(invalid_request) def test_validate_scan_request_validation_without_llm(self, validation_service): """Test validating scan request with validation but no LLM (should succeed).""" file_path = FilePath.from_virtual("test.py") metadata = ScanMetadata( scan_id="test-scan-123", scan_type="file", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext(target_path=file_path, metadata=metadata) # This should now succeed - validation can work independently of LLM scanning valid_request = ScanRequest( context=context, enable_semgrep=True, enable_llm=False, enable_validation=True, ) # Should not raise any exceptions validation_service.validate_scan_request(valid_request) def test_validate_threat_data_valid(self, validation_service, sample_threat): """Test validating valid threat data.""" # Should not raise any exceptions validation_service.validate_threat_data(sample_threat) def test_validate_threat_data_invalid_line_number(self, validation_service): """Test that invalid line number is caught at entity level.""" # Entity validation should prevent creation of threat with invalid line number with pytest.raises(ValidationError, match="Line number must be positive"): ThreatMatch( rule_id="test-rule-001", rule_name="Test Rule", description="Test vulnerability", category="injection", severity=SeverityLevel.from_string("high"), file_path=FilePath.from_string("/home/user/test.py"), line_number=0, # Invalid column_number=10, code_snippet="test code", confidence=ConfidenceScore(0.8), ) def test_validate_scan_result_valid( self, validation_service, sample_scan_request, sample_threat ): """Test validating valid scan result.""" result = ScanResult.create_from_threats( request=sample_scan_request, threats=[sample_threat], scan_metadata={} ) # Should not raise any exceptions validation_service.validate_scan_result(result) def test_enforce_security_constraints_valid( self, validation_service, sample_scan_request ): """Test enforcing security constraints on valid request.""" # Should not raise any exceptions validation_service.enforce_security_constraints(sample_scan_request.context) def test_enforce_security_constraints_blocked_path(self, validation_service): """Test enforcing security constraints on blocked path.""" blocked_path = FilePath.from_string("/home/user/.ssh/id_rsa") metadata = ScanMetadata( scan_id="test-scan-123", scan_type="file", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext(target_path=blocked_path, metadata=metadata) with pytest.raises(SecurityError, match="Scanning blocked path"): validation_service.enforce_security_constraints(context) def test_update_security_constraints(self, validation_service): """Test updating security constraints.""" original_constraints = validation_service.get_security_constraints() validation_service.update_security_constraints( max_file_size_bytes=5 * 1024 * 1024, # 5MB additional_blocked_patterns={r".*\.secret$"}, ) updated_constraints = validation_service.get_security_constraints() assert updated_constraints["max_file_size_bytes"] == 5 * 1024 * 1024 assert r".*\.secret$" in updated_constraints["blocked_path_patterns"] def test_code_scan_constraints(self, validation_service): """Test code scan size constraints.""" # Create large code content large_content = "print('test')\n" * 2000 # Exceeds default limit metadata = ScanMetadata( scan_id="test-scan-123", scan_type="code", timestamp=datetime.now(UTC), requester="test-user", ) context = ScanContext( target_path=None, metadata=metadata, content=large_content ) with pytest.raises(SecurityError, match="Code snippet too large"): validation_service.enforce_security_constraints(context)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/brettbergin/adversary-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server