Skip to main content
Glama

Adversary MCP Server

by brettbergin
test_threat_aggregator_coverage.py31 kB
"""Comprehensive tests for ThreatAggregator domain service to achieve 100% coverage.""" from unittest.mock import patch import pytest from adversary_mcp_server.domain.aggregation.threat_aggregator import ThreatAggregator from adversary_mcp_server.domain.entities.threat_match import ThreatMatch from adversary_mcp_server.domain.value_objects.file_path import FilePath from adversary_mcp_server.domain.value_objects.severity_level import SeverityLevel def create_test_threat( rule_id: str = "test_rule", rule_name: str = "Test Rule", category: str = "test_category", severity: SeverityLevel = SeverityLevel.MEDIUM, line_number: int = 1, file_path: str = "/test/file.py", description: str = "Test description", ) -> ThreatMatch: """Helper function to create test threats.""" return ThreatMatch( rule_id=rule_id, rule_name=rule_name, category=category, severity=severity, description=description, file_path=FilePath.from_string(file_path), line_number=line_number, ) class TestThreatAggregatorInitialization: """Test ThreatAggregator initialization and configuration.""" def test_default_initialization(self): """Test ThreatAggregator initialization with default parameters.""" aggregator = ThreatAggregator() assert aggregator.proximity_threshold == 2 assert aggregator._aggregation_stats == { "total_input_threats": 0, "deduplicated_threats": 0, "final_threat_count": 0, } def test_custom_proximity_threshold(self): """Test ThreatAggregator initialization with custom proximity threshold.""" aggregator = ThreatAggregator(proximity_threshold=5) assert aggregator.proximity_threshold == 5 def test_configure_proximity_threshold(self): """Test configuring proximity threshold after initialization.""" aggregator = ThreatAggregator(proximity_threshold=2) aggregator.configure_proximity_threshold(10) assert aggregator.proximity_threshold == 10 def test_configure_proximity_threshold_negative(self): """Test configuring proximity threshold with negative value raises error.""" aggregator = ThreatAggregator() with pytest.raises( ValueError, match="Proximity threshold must be non-negative" ): aggregator.configure_proximity_threshold(-1) def test_configure_proximity_threshold_zero(self): """Test configuring proximity threshold with zero.""" aggregator = ThreatAggregator() aggregator.configure_proximity_threshold(0) assert aggregator.proximity_threshold == 0 class TestThreatAggregation: """Test core threat aggregation functionality.""" def test_aggregate_empty_lists(self): """Test aggregating empty threat lists.""" aggregator = ThreatAggregator() result = aggregator.aggregate_threats([], []) assert result == [] stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 0 assert stats["deduplicated_threats"] == 0 assert stats["final_threat_count"] == 0 def test_aggregate_only_semgrep_threats(self): """Test aggregating with only Semgrep threats.""" aggregator = ThreatAggregator() semgrep_threats = [ create_test_threat(rule_id="semgrep1", line_number=10), create_test_threat(rule_id="semgrep2", line_number=20), ] result = aggregator.aggregate_threats(semgrep_threats, []) assert len(result) == 2 assert all(threat.rule_id.startswith("semgrep") for threat in result) stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 2 assert stats["deduplicated_threats"] == 0 assert stats["final_threat_count"] == 2 def test_aggregate_only_llm_threats(self): """Test aggregating with only LLM threats.""" aggregator = ThreatAggregator() llm_threats = [ create_test_threat(rule_id="llm1", line_number=10), create_test_threat(rule_id="llm2", line_number=20), ] result = aggregator.aggregate_threats([], llm_threats) assert len(result) == 2 assert all(threat.rule_id.startswith("llm") for threat in result) stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 2 assert stats["deduplicated_threats"] == 0 assert stats["final_threat_count"] == 2 def test_aggregate_no_duplicates(self): """Test aggregating threats with no duplicates.""" aggregator = ThreatAggregator() semgrep_threats = [ create_test_threat( rule_id="semgrep1", line_number=10, category="sql-injection" ), ] llm_threats = [ create_test_threat(rule_id="llm1", line_number=50, category="xss"), ] result = aggregator.aggregate_threats(semgrep_threats, llm_threats) assert len(result) == 2 assert any(threat.rule_id == "semgrep1" for threat in result) assert any(threat.rule_id == "llm1" for threat in result) stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 2 assert stats["deduplicated_threats"] == 0 assert stats["final_threat_count"] == 2 def test_aggregate_with_duplicates_proximity(self): """Test aggregating threats with duplicates based on proximity.""" aggregator = ThreatAggregator(proximity_threshold=2) semgrep_threats = [ create_test_threat( rule_id="semgrep1", line_number=10, category="sql-injection" ), ] llm_threats = [ create_test_threat( rule_id="llm1", line_number=11, category="sql-injection" ), # Within proximity ] result = aggregator.aggregate_threats(semgrep_threats, llm_threats) # Should keep only Semgrep threat (higher priority) assert len(result) == 1 assert result[0].rule_id == "semgrep1" stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 2 assert stats["deduplicated_threats"] == 1 assert stats["final_threat_count"] == 1 def test_aggregate_with_additional_threats(self): """Test aggregating with additional threats parameter.""" aggregator = ThreatAggregator() semgrep_threats = [ create_test_threat(rule_id="semgrep1", line_number=10), ] llm_threats = [ create_test_threat(rule_id="llm1", line_number=20), ] additional_threats = [ create_test_threat(rule_id="additional1", line_number=30), create_test_threat(rule_id="additional2", line_number=40), ] result = aggregator.aggregate_threats( semgrep_threats, llm_threats, additional_threats ) assert len(result) == 4 rule_ids = [threat.rule_id for threat in result] assert "semgrep1" in rule_ids assert "llm1" in rule_ids assert "additional1" in rule_ids assert "additional2" in rule_ids stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 4 assert stats["deduplicated_threats"] == 0 assert stats["final_threat_count"] == 4 def test_aggregate_default_additional_threats(self): """Test aggregating with default None additional threats.""" aggregator = ThreatAggregator() semgrep_threats = [create_test_threat(rule_id="semgrep1", line_number=10)] llm_threats = [ create_test_threat(rule_id="llm1", line_number=50) ] # Far apart to avoid deduplication # Call without additional_threats parameter (should default to None/[]) result = aggregator.aggregate_threats(semgrep_threats, llm_threats) assert len(result) == 2 class TestThreatDeduplication: """Test threat deduplication logic.""" def test_is_duplicate_same_line_same_category(self): """Test duplicate detection for same line and category.""" aggregator = ThreatAggregator() threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=10, category="sql-injection") existing_threats = [threat1] assert aggregator._is_duplicate(threat2, existing_threats) def test_is_duplicate_within_proximity_same_category(self): """Test duplicate detection within proximity threshold.""" aggregator = ThreatAggregator(proximity_threshold=2) threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat( line_number=12, category="sql-injection" ) # Within 2 lines existing_threats = [threat1] assert aggregator._is_duplicate(threat2, existing_threats) def test_is_duplicate_outside_proximity(self): """Test no duplicate detection outside proximity threshold.""" aggregator = ThreatAggregator(proximity_threshold=2) threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat( line_number=15, category="sql-injection" ) # Outside 2 lines existing_threats = [threat1] assert not aggregator._is_duplicate(threat2, existing_threats) def test_is_duplicate_different_category(self): """Test no duplicate detection for different categories.""" aggregator = ThreatAggregator() threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=10, category="xss") existing_threats = [threat1] assert not aggregator._is_duplicate(threat2, existing_threats) def test_are_threats_similar_exact_match(self): """Test threat similarity for exact matches.""" aggregator = ThreatAggregator() threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=10, category="sql-injection") assert aggregator._are_threats_similar(threat1, threat2) def test_are_threats_similar_within_proximity(self): """Test threat similarity within proximity.""" aggregator = ThreatAggregator(proximity_threshold=3) threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=13, category="sql-injection") assert aggregator._are_threats_similar(threat1, threat2) def test_are_threats_similar_outside_proximity(self): """Test threat similarity outside proximity.""" aggregator = ThreatAggregator(proximity_threshold=2) threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=15, category="sql-injection") assert not aggregator._are_threats_similar(threat1, threat2) def test_are_threats_similar_category_variations(self): """Test threat similarity for category variations.""" aggregator = ThreatAggregator() # Test SQL injection variations threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=10, category="injection") threat3 = create_test_threat(line_number=10, category="sqli") assert aggregator._are_threats_similar(threat1, threat2) assert aggregator._are_threats_similar(threat1, threat3) # Test XSS variations threat4 = create_test_threat(line_number=10, category="xss") threat5 = create_test_threat(line_number=10, category="cross-site-scripting") threat6 = create_test_threat(line_number=10, category="script-injection") assert aggregator._are_threats_similar(threat4, threat5) assert aggregator._are_threats_similar(threat4, threat6) # Test path traversal variations threat7 = create_test_threat(line_number=10, category="path-traversal") threat8 = create_test_threat(line_number=10, category="directory-traversal") threat9 = create_test_threat(line_number=10, category="path-injection") assert aggregator._are_threats_similar(threat7, threat8) assert aggregator._are_threats_similar(threat7, threat9) def test_are_threats_similar_case_insensitive(self): """Test threat similarity is case insensitive.""" aggregator = ThreatAggregator() threat1 = create_test_threat(line_number=10, category="SQL-Injection") threat2 = create_test_threat(line_number=10, category="sql-injection") assert aggregator._are_threats_similar(threat1, threat2) def test_are_threats_similar_unrelated_categories(self): """Test threat similarity for unrelated categories.""" aggregator = ThreatAggregator() threat1 = create_test_threat(line_number=10, category="sql-injection") threat2 = create_test_threat(line_number=10, category="memory-leak") assert not aggregator._are_threats_similar(threat1, threat2) class TestThreatSorting: """Test threat sorting functionality.""" def test_sort_threats_by_line_number(self): """Test sorting threats by line number.""" aggregator = ThreatAggregator() threats = [ create_test_threat( rule_id="threat3", line_number=30, severity=SeverityLevel.LOW ), create_test_threat( rule_id="threat1", line_number=10, severity=SeverityLevel.HIGH ), create_test_threat( rule_id="threat2", line_number=20, severity=SeverityLevel.MEDIUM ), ] sorted_threats = aggregator._sort_threats(threats) assert [t.line_number for t in sorted_threats] == [10, 20, 30] assert [t.rule_id for t in sorted_threats] == ["threat1", "threat2", "threat3"] def test_sort_threats_by_severity_when_same_line(self): """Test sorting threats by severity when line numbers are the same.""" aggregator = ThreatAggregator() threats = [ create_test_threat( rule_id="low", line_number=10, severity=SeverityLevel.LOW, category="b" ), create_test_threat( rule_id="high", line_number=10, severity=SeverityLevel.HIGH, category="a", ), create_test_threat( rule_id="medium", line_number=10, severity=SeverityLevel.MEDIUM, category="c", ), ] sorted_threats = aggregator._sort_threats(threats) # Should be sorted by severity value (HIGH=4, MEDIUM=3, LOW=2) assert ( sorted_threats[0].severity == SeverityLevel.LOW ) # Lowest numeric value first assert sorted_threats[1].severity == SeverityLevel.MEDIUM assert sorted_threats[2].severity == SeverityLevel.HIGH def test_sort_threats_by_category_when_same_line_and_severity(self): """Test sorting threats by category when line and severity are the same.""" aggregator = ThreatAggregator() threats = [ create_test_threat( rule_id="z", line_number=10, severity=SeverityLevel.MEDIUM, category="z-category", ), create_test_threat( rule_id="a", line_number=10, severity=SeverityLevel.MEDIUM, category="a-category", ), create_test_threat( rule_id="m", line_number=10, severity=SeverityLevel.MEDIUM, category="m-category", ), ] sorted_threats = aggregator._sort_threats(threats) # Should be sorted by category alphabetically assert [t.category for t in sorted_threats] == [ "a-category", "m-category", "z-category", ] class TestThreatDistributionAnalysis: """Test threat distribution analysis functionality.""" def test_analyze_threat_distribution_empty(self): """Test analyzing distribution of empty threat list.""" aggregator = ThreatAggregator() distribution = aggregator.analyze_threat_distribution([]) assert distribution == { "by_severity": {}, "by_category": {}, "by_line_range": {"1-50": 0, "51-100": 0, "101-500": 0, "500+": 0}, } def test_analyze_threat_distribution_by_severity(self): """Test analyzing distribution by severity.""" aggregator = ThreatAggregator() threats = [ create_test_threat(severity=SeverityLevel.HIGH), create_test_threat(severity=SeverityLevel.HIGH), create_test_threat(severity=SeverityLevel.MEDIUM), create_test_threat(severity=SeverityLevel.LOW), ] distribution = aggregator.analyze_threat_distribution(threats) assert distribution["by_severity"]["high"] == 2 assert distribution["by_severity"]["medium"] == 1 assert distribution["by_severity"]["low"] == 1 def test_analyze_threat_distribution_by_category(self): """Test analyzing distribution by category.""" aggregator = ThreatAggregator() threats = [ create_test_threat(category="sql-injection"), create_test_threat(category="sql-injection"), create_test_threat(category="xss"), create_test_threat(category="path-traversal"), ] distribution = aggregator.analyze_threat_distribution(threats) assert distribution["by_category"]["sql-injection"] == 2 assert distribution["by_category"]["xss"] == 1 assert distribution["by_category"]["path-traversal"] == 1 def test_analyze_threat_distribution_by_line_range(self): """Test analyzing distribution by line range.""" aggregator = ThreatAggregator() threats = [ create_test_threat(line_number=25), # 1-50 create_test_threat(line_number=50), # 1-50 create_test_threat(line_number=75), # 51-100 create_test_threat(line_number=100), # 51-100 create_test_threat(line_number=200), # 101-500 create_test_threat(line_number=500), # 101-500 create_test_threat(line_number=1000), # 500+ ] distribution = aggregator.analyze_threat_distribution(threats) assert distribution["by_line_range"]["1-50"] == 2 assert distribution["by_line_range"]["51-100"] == 2 assert distribution["by_line_range"]["101-500"] == 2 assert distribution["by_line_range"]["500+"] == 1 def test_analyze_threat_distribution_comprehensive(self): """Test comprehensive threat distribution analysis.""" aggregator = ThreatAggregator() threats = [ create_test_threat( category="sql-injection", severity=SeverityLevel.HIGH, line_number=25 ), create_test_threat( category="xss", severity=SeverityLevel.MEDIUM, line_number=75 ), create_test_threat( category="sql-injection", severity=SeverityLevel.LOW, line_number=200 ), ] distribution = aggregator.analyze_threat_distribution(threats) # Check all dimensions assert distribution["by_severity"]["high"] == 1 assert distribution["by_severity"]["medium"] == 1 assert distribution["by_severity"]["low"] == 1 assert distribution["by_category"]["sql-injection"] == 2 assert distribution["by_category"]["xss"] == 1 assert distribution["by_line_range"]["1-50"] == 1 assert distribution["by_line_range"]["51-100"] == 1 assert distribution["by_line_range"]["101-500"] == 1 assert distribution["by_line_range"]["500+"] == 0 class TestAggregationStatistics: """Test aggregation statistics functionality.""" def test_get_aggregation_stats_initial(self): """Test getting initial aggregation statistics.""" aggregator = ThreatAggregator() stats = aggregator.get_aggregation_stats() assert stats == { "total_input_threats": 0, "deduplicated_threats": 0, "final_threat_count": 0, } # Should return a copy, not the original stats["total_input_threats"] = 999 assert aggregator._aggregation_stats["total_input_threats"] == 0 def test_get_aggregation_stats_after_aggregation(self): """Test getting aggregation statistics after aggregation.""" aggregator = ThreatAggregator() semgrep_threats = [ create_test_threat( rule_id="semgrep1", line_number=10, category="sql-injection" ), ] llm_threats = [ create_test_threat( rule_id="llm1", line_number=11, category="sql-injection" ), # Duplicate create_test_threat( rule_id="llm2", line_number=50, category="xss" ), # Not duplicate ] aggregator.aggregate_threats(semgrep_threats, llm_threats) stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 3 assert stats["deduplicated_threats"] == 1 # One LLM threat was duplicate assert stats["final_threat_count"] == 2 class TestMergeThreats: """Test _merge_threats private method.""" def test_merge_threats_no_duplicates(self): """Test merging threats with no duplicates.""" aggregator = ThreatAggregator() existing_threats = [ create_test_threat(rule_id="existing1", line_number=10), ] new_threats = [ create_test_threat(rule_id="new1", line_number=50), create_test_threat(rule_id="new2", line_number=100), ] added_count = aggregator._merge_threats(existing_threats, new_threats) assert added_count == 2 assert len(existing_threats) == 3 rule_ids = [threat.rule_id for threat in existing_threats] assert "existing1" in rule_ids assert "new1" in rule_ids assert "new2" in rule_ids def test_merge_threats_with_duplicates(self): """Test merging threats with some duplicates.""" aggregator = ThreatAggregator(proximity_threshold=2) existing_threats = [ create_test_threat( rule_id="existing1", line_number=10, category="sql-injection" ), ] new_threats = [ create_test_threat( rule_id="new1", line_number=11, category="sql-injection" ), # Duplicate create_test_threat( rule_id="new2", line_number=50, category="xss" ), # Not duplicate ] added_count = aggregator._merge_threats(existing_threats, new_threats) assert added_count == 1 # Only one threat added assert len(existing_threats) == 2 rule_ids = [threat.rule_id for threat in existing_threats] assert "existing1" in rule_ids assert "new2" in rule_ids assert "new1" not in rule_ids # Was duplicate class TestThreatAggregatorLogging: """Test logging behavior of ThreatAggregator.""" @patch("adversary_mcp_server.domain.aggregation.threat_aggregator.logger") def test_aggregation_logging(self, mock_logger): """Test that aggregation operations are logged.""" aggregator = ThreatAggregator() semgrep_threats = [create_test_threat(rule_id="semgrep1")] llm_threats = [create_test_threat(rule_id="llm1")] aggregator.aggregate_threats(semgrep_threats, llm_threats) # Should log debug and info messages mock_logger.debug.assert_called() mock_logger.info.assert_called() @patch("adversary_mcp_server.domain.aggregation.threat_aggregator.logger") def test_duplicate_detection_logging(self, mock_logger): """Test that duplicate detection is logged.""" aggregator = ThreatAggregator(proximity_threshold=2) semgrep_threats = [ create_test_threat( rule_id="semgrep1", line_number=10, category="sql-injection" ), ] llm_threats = [ create_test_threat( rule_id="llm1", line_number=11, category="sql-injection" ), # Duplicate ] aggregator.aggregate_threats(semgrep_threats, llm_threats) # Should log duplicate detection debug_calls = [call.args[0] for call in mock_logger.debug.call_args_list] assert any("Duplicate threat detected" in msg for msg in debug_calls) @patch("adversary_mcp_server.domain.aggregation.threat_aggregator.logger") def test_proximity_threshold_configuration_logging(self, mock_logger): """Test that proximity threshold changes are logged.""" aggregator = ThreatAggregator(proximity_threshold=2) aggregator.configure_proximity_threshold(5) # Should log threshold change mock_logger.info.assert_called_with("Proximity threshold changed: 2 → 5") class TestThreatAggregatorIntegration: """Integration tests for ThreatAggregator.""" def test_complex_aggregation_scenario(self): """Test complex aggregation scenario with multiple threat types.""" aggregator = ThreatAggregator(proximity_threshold=3) # Semgrep threats (high priority) semgrep_threats = [ create_test_threat( rule_id="semgrep.sql-injection", line_number=15, category="sql-injection", severity=SeverityLevel.HIGH, ), create_test_threat( rule_id="semgrep.xss", line_number=50, category="xss", severity=SeverityLevel.MEDIUM, ), ] # LLM threats (some duplicates, some unique) llm_threats = [ create_test_threat( rule_id="llm.sql-injection", line_number=16, # Close to semgrep finding category="injection", # Similar category severity=SeverityLevel.MEDIUM, ), create_test_threat( rule_id="llm.path-traversal", line_number=100, # Unique finding category="path-traversal", severity=SeverityLevel.HIGH, ), create_test_threat( rule_id="llm.xss-variant", line_number=52, # Close to semgrep XSS category="script-injection", # Similar to XSS severity=SeverityLevel.LOW, ), ] # Additional threats additional_threats = [ create_test_threat( rule_id="custom.memory-leak", line_number=200, category="memory-leak", severity=SeverityLevel.LOW, ), ] result = aggregator.aggregate_threats( semgrep_threats, llm_threats, additional_threats ) # Should prioritize Semgrep threats and deduplicate similar ones rule_ids = [threat.rule_id for threat in result] # Semgrep threats should be kept assert "semgrep.sql-injection" in rule_ids assert "semgrep.xss" in rule_ids # LLM path-traversal should be kept (unique) assert "llm.path-traversal" in rule_ids # Additional threat should be kept (unique) assert "custom.memory-leak" in rule_ids # LLM duplicates should be filtered out assert "llm.sql-injection" not in rule_ids # Duplicate of semgrep SQL injection assert "llm.xss-variant" not in rule_ids # Duplicate of semgrep XSS # Check statistics stats = aggregator.get_aggregation_stats() assert stats["total_input_threats"] == 6 assert stats["deduplicated_threats"] == 2 assert stats["final_threat_count"] == 4 # Check sorting (should be by line number) line_numbers = [threat.line_number for threat in result] assert line_numbers == sorted(line_numbers) def test_distribution_analysis_after_aggregation(self): """Test distribution analysis on aggregated results.""" aggregator = ThreatAggregator() semgrep_threats = [ create_test_threat( category="sql-injection", severity=SeverityLevel.HIGH, line_number=25 ), create_test_threat( category="xss", severity=SeverityLevel.MEDIUM, line_number=75 ), ] llm_threats = [ create_test_threat( category="path-traversal", severity=SeverityLevel.LOW, line_number=150 ), ] aggregated = aggregator.aggregate_threats(semgrep_threats, llm_threats) distribution = aggregator.analyze_threat_distribution(aggregated) # Verify distribution matches aggregated results assert len(aggregated) == 3 assert distribution["by_severity"]["high"] == 1 assert distribution["by_severity"]["medium"] == 1 assert distribution["by_severity"]["low"] == 1 assert distribution["by_category"]["sql-injection"] == 1 assert distribution["by_category"]["xss"] == 1 assert distribution["by_category"]["path-traversal"] == 1 def test_zero_proximity_threshold(self): """Test aggregation with zero proximity threshold.""" aggregator = ThreatAggregator(proximity_threshold=0) # Same line threats should still be duplicates semgrep_threats = [ create_test_threat( rule_id="semgrep1", line_number=10, category="sql-injection" ), ] llm_threats = [ create_test_threat( rule_id="llm1", line_number=10, category="sql-injection" ), # Same line create_test_threat( rule_id="llm2", line_number=11, category="sql-injection" ), # Different line, no proximity ] result = aggregator.aggregate_threats(semgrep_threats, llm_threats) # Should keep semgrep1 and llm2, filter out llm1 rule_ids = [threat.rule_id for threat in result] assert "semgrep1" in rule_ids assert "llm2" in rule_ids assert "llm1" not in rule_ids assert len(result) == 2

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/brettbergin/adversary-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server