test_vulnerability_summary.py•29.1 kB
"""Tests for the get_wazuh_vulnerability_summary tool."""
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from wazuh_mcp_server.main import WazuhMCPServer
from wazuh_mcp_server.utils.validation import validate_vulnerability_summary_query, ValidationError
@pytest.fixture
def mock_agents_data():
"""Mock agent data for testing."""
return [
{
"id": "001",
"name": "web-server-01",
"status": "active",
"os": {
"platform": "ubuntu",
"version": "20.04"
}
},
{
"id": "002",
"name": "db-server-01",
"status": "active",
"os": {
"platform": "centos",
"version": "8"
}
},
{
"id": "003",
"name": "win-server-01",
"status": "active",
"os": {
"platform": "windows",
"version": "2019"
}
}
]
@pytest.fixture
def mock_vulnerability_data():
"""Mock vulnerability data for testing."""
return {
"agent_001": [
{
"cve": "CVE-2024-1234",
"name": "apache2",
"cvss2_score": 9.3,
"severity": "critical",
"exploit_available": True,
"references": ["https://nvd.nist.gov/vuln/detail/CVE-2024-1234"]
},
{
"cve": "CVE-2024-5678",
"name": "openssl",
"cvss3_score": 7.5,
"severity": "high",
"exploit_available": False,
"references": ["https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2024-5678"]
}
],
"agent_002": [
{
"cve": "CVE-2024-9999",
"name": "kernel",
"cvss3_score": 8.8,
"severity": "high",
"exploit_available": True,
"metasploit_module": True,
"references": ["https://nvd.nist.gov/vuln/detail/CVE-2024-9999"]
}
],
"agent_003": [
{
"cve": "CVE-2024-0001",
"name": "iis",
"cvss2_score": 4.3,
"severity": "medium",
"exploit_available": False,
"references": []
}
]
}
class TestVulnerabilitySummaryValidation:
"""Test validation of vulnerability summary query parameters."""
def test_valid_basic_query(self):
"""Test validation with basic valid parameters."""
params = {"cvss_threshold": 7.0, "group_by": "severity"}
result = validate_vulnerability_summary_query(params)
assert result.cvss_threshold == 7.0
assert result.group_by == "severity"
assert result.include_remediation is True
assert result.max_agents == 100
def test_valid_complete_query(self):
"""Test validation with all parameters."""
params = {
"cvss_threshold": 5.0,
"severity_filter": ["critical", "high"],
"cve_filter": ["CVE-2024-1234"],
"os_filter": ["ubuntu", "centos"],
"package_filter": ["apache2", "nginx"],
"exploitability": True,
"group_by": "package",
"include_remediation": False,
"include_analytics": True,
"max_agents": 50
}
result = validate_vulnerability_summary_query(params)
assert result.cvss_threshold == 5.0
assert result.severity_filter == ["critical", "high"]
assert result.exploitability is True
assert result.group_by == "package"
assert result.max_agents == 50
def test_invalid_cvss_threshold(self):
"""Test validation with invalid CVSS threshold."""
params = {"cvss_threshold": 11.0} # Invalid - over 10.0
with pytest.raises(ValidationError):
validate_vulnerability_summary_query(params)
def test_invalid_severity_filter(self):
"""Test validation with invalid severity."""
params = {
"cvss_threshold": 7.0,
"severity_filter": ["invalid_severity"]
}
with pytest.raises(ValidationError):
validate_vulnerability_summary_query(params)
def test_invalid_group_by(self):
"""Test validation with invalid group_by field."""
params = {
"cvss_threshold": 7.0,
"group_by": "invalid_field"
}
with pytest.raises(ValidationError):
validate_vulnerability_summary_query(params)
def test_max_agents_boundary(self):
"""Test max_agents boundary validation."""
# Test minimum
params = {"max_agents": 1}
result = validate_vulnerability_summary_query(params)
assert result.max_agents == 1
# Test maximum
params = {"max_agents": 1000}
result = validate_vulnerability_summary_query(params)
assert result.max_agents == 1000
# Test over maximum
params = {"max_agents": 1001}
with pytest.raises(ValidationError):
validate_vulnerability_summary_query(params)
@pytest.mark.asyncio
class TestVulnerabilitySummaryTool:
"""Test the vulnerability summary tool functionality."""
@pytest.fixture
async def wazuh_server(self):
"""Create a mock Wazuh MCP server."""
with patch('wazuh_mcp_server.main.WazuhConfig') as mock_config:
mock_config.from_env.return_value = MagicMock()
mock_config.from_env.return_value.log_level = "INFO"
mock_config.from_env.return_value.debug = False
mock_config.from_env.return_value.request_timeout_seconds = 30
server = WazuhMCPServer()
server.api_client = AsyncMock()
return server
async def test_basic_vulnerability_summary(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test basic vulnerability summary functionality."""
# Mock API responses
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"cvss_threshold": 0.0, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
assert len(result) == 1
response_data = json.loads(result[0].text)
# Check basic structure
assert "query_parameters" in response_data
assert "summary" in response_data
assert "grouped_analysis" in response_data
assert "risk_analytics" in response_data
assert "remediation" in response_data
assert "key_insights" in response_data
# Check summary data
assert response_data["summary"]["total_agents_analyzed"] == 3
assert response_data["summary"]["total_vulnerabilities"] == 4
assert response_data["grouped_analysis"]["grouping_field"] == "severity"
async def test_severity_grouping(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test grouping by severity."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"cvss_threshold": 0.0, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
groups = response_data["grouped_analysis"]["groups"]
# Should have different severity levels
severity_levels = list(groups.keys())
assert "critical" in severity_levels
assert "high" in severity_levels
assert "medium" in severity_levels
# Check critical vulnerabilities
assert groups["critical"]["count"] == 1
assert groups["high"]["count"] == 2
assert groups["medium"]["count"] == 1
async def test_package_grouping(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test grouping by package."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"cvss_threshold": 0.0, "group_by": "package"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
groups = response_data["grouped_analysis"]["groups"]
# Should have different packages
package_names = list(groups.keys())
assert "apache2" in package_names
assert "openssl" in package_names
assert "kernel" in package_names
assert "iis" in package_names
async def test_cvss_threshold_filtering(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test filtering by CVSS threshold."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"cvss_threshold": 8.0, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should only include vulnerabilities with CVSS >= 8.0
# From mock data: apache2 (9.3), kernel (8.8) - should be 2 total
assert response_data["summary"]["total_vulnerabilities"] == 2
async def test_exploitability_filtering(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test filtering by exploitability."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"exploitability": True, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should only include exploitable vulnerabilities
# From mock data: apache2 (exploit_available: True), kernel (exploit_available: True, metasploit_module: True)
assert response_data["summary"]["total_vulnerabilities"] == 2
async def test_os_filtering(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test filtering by operating system."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"os_filter": ["ubuntu"], "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should only analyze ubuntu agents (agent_001)
assert response_data["summary"]["total_agents_analyzed"] == 1
assert response_data["summary"]["total_vulnerabilities"] == 2 # apache2, openssl
async def test_risk_analytics(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test risk analytics functionality."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"include_analytics": True, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
analytics = response_data["risk_analytics"]
assert "cvss_analysis" in analytics
assert "severity_distribution" in analytics
assert "exploitability_analysis" in analytics
assert "risk_concentration" in analytics
# Check CVSS analysis
cvss_analysis = analytics["cvss_analysis"]
assert "mean_score" in cvss_analysis
assert "median_score" in cvss_analysis
assert cvss_analysis["mean_score"] > 0
# Check exploitability analysis
exploit_analysis = analytics["exploitability_analysis"]
assert "total_exploitable" in exploit_analysis
assert exploit_analysis["total_exploitable"] == 2 # apache2, kernel
async def test_remediation_recommendations(self, wazuh_server, mock_agents_data, mock_vulnerability_data):
"""Test remediation recommendations."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities(agent_id):
return {
"data": {"affected_items": mock_vulnerability_data.get(f"agent_{agent_id}", [])}
}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities
arguments = {"include_remediation": True, "group_by": "severity"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
remediation = response_data["remediation"]
assert "immediate_actions" in remediation
assert "patch_priorities" in remediation
assert "system_hardening" in remediation
assert "monitoring_recommendations" in remediation
# Should have immediate actions for critical vulnerabilities
immediate_actions = remediation["immediate_actions"]
assert len(immediate_actions) > 0
# Check for critical vulnerability action
critical_action = next((action for action in immediate_actions
if "critical" in action.get("action", "").lower()), None)
assert critical_action is not None
async def test_empty_vulnerability_response(self, wazuh_server, mock_agents_data):
"""Test handling of empty vulnerability response."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
wazuh_server.api_client.get_agent_vulnerabilities.return_value = {
"data": {"affected_items": []}
}
arguments = {"cvss_threshold": 0.0}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
assert response_data["summary"]["total_vulnerabilities"] == 0
assert "No vulnerabilities found" in response_data["summary"]["message"]
async def test_agent_error_handling(self, wazuh_server, mock_agents_data):
"""Test handling of agent-specific errors."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": mock_agents_data}
}
def mock_get_vulnerabilities_with_error(agent_id):
if agent_id == "001":
raise Exception("API Error for agent 001")
return {"data": {"affected_items": []}}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = mock_get_vulnerabilities_with_error
arguments = {"cvss_threshold": 0.0}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should have error recorded but not fail completely
assert response_data["summary"]["agents_with_errors"] == 1
assert "processing_errors" in response_data
assert len(response_data["processing_errors"]) == 1
async def test_large_agent_batch_processing(self, wazuh_server):
"""Test batch processing with many agents."""
# Create mock data for 25 agents (more than batch size of 10)
large_agent_list = []
for i in range(25):
large_agent_list.append({
"id": f"{i:03d}",
"name": f"agent-{i:03d}",
"status": "active",
"os": {"platform": "ubuntu", "version": "20.04"}
})
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": large_agent_list}
}
wazuh_server.api_client.get_agent_vulnerabilities.return_value = {
"data": {"affected_items": []}
}
arguments = {"cvss_threshold": 0.0, "max_agents": 25}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should process all 25 agents
assert response_data["summary"]["total_agents_analyzed"] == 25
assert response_data["summary"]["coverage_percentage"] == 100.0
# Verify batch processing occurred (should have made multiple API calls)
assert wazuh_server.api_client.get_agent_vulnerabilities.call_count == 25
class TestVulnerabilitySummaryHelperMethods:
"""Test helper methods used in vulnerability summary."""
@pytest.fixture
def wazuh_server(self):
"""Create a minimal server instance for testing helper methods."""
with patch('wazuh_mcp_server.main.WazuhConfig') as mock_config:
mock_config.from_env.return_value = MagicMock()
mock_config.from_env.return_value.log_level = "INFO"
mock_config.from_env.return_value.debug = False
server = WazuhMCPServer()
return server
def test_extract_cvss_score(self, wazuh_server):
"""Test CVSS score extraction from various formats."""
# Test CVSS3 score
vuln1 = {"cvss3_score": 8.5}
assert wazuh_server._extract_cvss_score(vuln1) == 8.5
# Test CVSS2 score
vuln2 = {"cvss2_score": "7.2"}
assert wazuh_server._extract_cvss_score(vuln2) == 7.2
# Test fallback to severity
vuln3 = {"severity": "high"}
assert wazuh_server._extract_cvss_score(vuln3) == 7.5
# Test no score available
vuln4 = {}
assert wazuh_server._extract_cvss_score(vuln4) == 0.0
def test_map_cvss_to_severity(self, wazuh_server):
"""Test CVSS score to severity mapping."""
assert wazuh_server._map_cvss_to_severity(9.5) == "critical"
assert wazuh_server._map_cvss_to_severity(8.0) == "high"
assert wazuh_server._map_cvss_to_severity(5.5) == "medium"
assert wazuh_server._map_cvss_to_severity(2.0) == "low"
def test_has_known_exploit(self, wazuh_server):
"""Test exploit detection logic."""
# Test explicit exploit available
vuln1 = {"exploit_available": True}
assert wazuh_server._has_known_exploit(vuln1) is True
# Test metasploit module
vuln2 = {"metasploit_module": True}
assert wazuh_server._has_known_exploit(vuln2) is True
# Test exploit in references
vuln3 = {"references": ["https://www.exploit-db.com/exploits/12345"]}
assert wazuh_server._has_known_exploit(vuln3) is True
# Test no exploit indicators
vuln4 = {"exploit_available": False, "references": []}
assert wazuh_server._has_known_exploit(vuln4) is False
def test_filter_vulnerabilities(self, wazuh_server):
"""Test vulnerability filtering logic."""
from wazuh_mcp_server.utils.validation import VulnerabilitySummaryQuery
vulnerabilities = [
{"cve": "CVE-2024-1234", "name": "apache2", "cvss3_score": 9.0},
{"cve": "CVE-2024-5678", "name": "nginx", "cvss3_score": 6.0},
{"cve": "CVE-2024-9999", "name": "kernel", "cvss3_score": 4.0}
]
# Test CVSS threshold filtering
query = VulnerabilitySummaryQuery(cvss_threshold=7.0)
filtered = wazuh_server._filter_vulnerabilities(vulnerabilities, query)
assert len(filtered) == 1
assert filtered[0]["cve"] == "CVE-2024-1234"
# Test package name filtering
query = VulnerabilitySummaryQuery(package_filter=["apache"])
filtered = wazuh_server._filter_vulnerabilities(vulnerabilities, query)
assert len(filtered) == 1
assert filtered[0]["name"] == "apache2"
def test_calculate_exploitation_risk_score(self, wazuh_server):
"""Test exploitation risk score calculation."""
from collections import Counter
severity_counts = Counter({"critical": 2, "high": 5, "medium": 10})
exploit_counts = Counter({"critical": 1, "high": 2})
risk_score = wazuh_server._calculate_exploitation_risk_score(
severity_counts, exploit_counts
)
assert 0 <= risk_score <= 100
assert isinstance(risk_score, float)
# Should be > 0 since there are exploitable vulnerabilities
assert risk_score > 0
@pytest.mark.asyncio
class TestVulnerabilitySummaryEdgeCases:
"""Test edge cases and error conditions."""
@pytest.fixture
async def wazuh_server(self):
"""Create a mock server for edge case testing."""
with patch('wazuh_mcp_server.main.WazuhConfig') as mock_config:
mock_config.from_env.return_value = MagicMock()
mock_config.from_env.return_value.log_level = "INFO"
mock_config.from_env.return_value.debug = False
mock_config.from_env.return_value.request_timeout_seconds = 30
server = WazuhMCPServer()
server.api_client = AsyncMock()
return server
async def test_malformed_vulnerability_data(self, wazuh_server):
"""Test handling of malformed vulnerability data."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": [{"id": "001", "name": "test-agent", "status": "active"}]}
}
# Mock response with malformed data
malformed_vulns = [
{"cve": "CVE-2024-001"}, # Missing required fields
{"invalid": "structure"}, # Completely wrong structure
{
"cve": "CVE-2024-002",
"cvss3_score": "not_a_number",
"severity": "invalid_severity"
}
]
wazuh_server.api_client.get_agent_vulnerabilities.return_value = {
"data": {"affected_items": malformed_vulns}
}
arguments = {"cvss_threshold": 0.0}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
# Should not crash and should handle gracefully
response_data = json.loads(result[0].text)
assert "summary" in response_data
assert response_data["summary"]["total_vulnerabilities"] >= 0
async def test_no_agents_found(self, wazuh_server):
"""Test handling when no agents are found."""
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": []}
}
arguments = {"cvss_threshold": 0.0}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
assert response_data["summary"]["total_agents_analyzed"] == 0
assert response_data["summary"]["total_vulnerabilities"] == 0
async def test_agent_api_failure(self, wazuh_server):
"""Test handling when agent API fails."""
wazuh_server.api_client.get_agents.side_effect = Exception("API Failure")
arguments = {"cvss_threshold": 0.0}
with pytest.raises(Exception):
await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
async def test_partial_agent_failures(self, wazuh_server):
"""Test handling partial agent failures gracefully."""
agents = [
{"id": "001", "name": "agent-001", "status": "active"},
{"id": "002", "name": "agent-002", "status": "active"},
{"id": "003", "name": "agent-003", "status": "active"}
]
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": agents}
}
def failing_vulnerability_fetch(agent_id):
if agent_id in ["001", "002"]:
raise Exception("Agent API failure")
return {"data": {"affected_items": []}}
wazuh_server.api_client.get_agent_vulnerabilities.side_effect = failing_vulnerability_fetch
arguments = {"cvss_threshold": 0.0}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should process successfully with partial results
assert response_data["summary"]["total_agents_analyzed"] == 3
assert response_data["summary"]["agents_with_errors"] == 2
assert response_data["summary"]["coverage_percentage"] == 100.0 # All agents attempted
assert len(response_data["processing_errors"]) == 2
async def test_performance_with_many_vulnerabilities(self, wazuh_server):
"""Test performance handling with large vulnerability datasets."""
# Single agent with many vulnerabilities
agent = {"id": "001", "name": "vuln-heavy-agent", "status": "active"}
# Generate 1000 vulnerabilities
many_vulns = []
for i in range(1000):
many_vulns.append({
"cve": f"CVE-2024-{i:04d}",
"name": f"package-{i % 50}", # 50 different packages
"cvss3_score": 5.0 + (i % 50) / 10, # Varying scores
"severity": "medium"
})
wazuh_server.api_client.get_agents.return_value = {
"data": {"affected_items": [agent]}
}
wazuh_server.api_client.get_agent_vulnerabilities.return_value = {
"data": {"affected_items": many_vulns}
}
arguments = {"cvss_threshold": 0.0, "group_by": "package"}
result = await wazuh_server._handle_get_wazuh_vulnerability_summary(arguments)
response_data = json.loads(result[0].text)
# Should handle large dataset successfully
assert response_data["summary"]["total_vulnerabilities"] == 1000
assert "analysis_metadata" in response_data
assert "processing_time_seconds" in response_data["analysis_metadata"]
# Check grouping worked correctly
groups = response_data["grouped_analysis"]["groups"]
assert len(groups) <= 50 # Should have <= 50 package groups