test_vulnerability_tools.py•14.9 kB
"""Tests for vulnerability tools."""
import pytest
from unittest.mock import Mock, AsyncMock
from datetime import datetime, timedelta
from wazuh_mcp_server.tools.vulnerabilities import VulnerabilityTools
class TestVulnerabilityTools:
"""Test cases for VulnerabilityTools class."""
@pytest.fixture
def mock_server(self):
"""Create a mock server instance."""
server = Mock()
server.api_client = AsyncMock()
server.logger = Mock()
server.config = Mock()
server.config.critical_agents = ["001", "002"]
return server
@pytest.fixture
def vulnerability_tools(self, mock_server):
"""Create VulnerabilityTools instance with mock server."""
return VulnerabilityTools(mock_server)
def test_tool_definitions(self, vulnerability_tools):
"""Test that tool definitions are properly created."""
tools = vulnerability_tools.tool_definitions
assert len(tools) == 2
tool_names = [tool.name for tool in tools]
assert "get_wazuh_vulnerability_summary" in tool_names
assert "get_wazuh_critical_vulnerabilities" in tool_names
def test_handler_mapping(self, vulnerability_tools):
"""Test that handler mapping is correct."""
mapping = vulnerability_tools.get_handler_mapping()
assert "get_wazuh_vulnerability_summary" in mapping
assert "get_wazuh_critical_vulnerabilities" in mapping
# Check all handlers are callable
for handler in mapping.values():
assert callable(handler)
@pytest.mark.asyncio
async def test_handle_vulnerability_summary(self, vulnerability_tools, mock_server):
"""Test vulnerability summary handler."""
# Mock API response
mock_server.api_client.get_vulnerabilities.return_value = {
"data": {
"affected_items": [
{
"name": "openssl",
"version": "1.1.1",
"severity": "critical",
"cve": "CVE-2023-0001",
"agent": {"id": "001", "name": "server1"},
"published": "2023-01-01T00:00:00Z",
"cvss": {
"cvss3": {
"score": 9.8,
"attackVector": "network"
}
}
},
{
"name": "nginx",
"version": "1.18.0",
"severity": "high",
"cve": "CVE-2023-0002",
"agent": {"id": "002", "name": "server2"},
"published": "2023-02-01T00:00:00Z",
"cvss": {
"cvss3": {
"score": 7.5,
"attackVector": "network"
}
}
}
]
}
}
# Call handler
result = await vulnerability_tools.handle_vulnerability_summary({
"severity_filter": ["critical", "high"],
"include_remediation": True
})
# Verify result structure
assert result["status"] == "success"
assert "data" in result
data = result["data"]
assert "overview" in data
assert data["overview"]["total_vulnerabilities"] == 2
assert "risk_assessment" in data
assert "distribution" in data
assert "affected_systems" in data
assert "package_analysis" in data
assert "timeline_analysis" in data
assert "remediation" in data
assert "compliance_impact" in data
# Verify API was called correctly
mock_server.api_client.get_vulnerabilities.assert_called_once()
@pytest.mark.asyncio
async def test_handle_critical_vulnerabilities(self, vulnerability_tools, mock_server):
"""Test critical vulnerabilities handler."""
# Mock API response
mock_server.api_client.get_vulnerabilities.return_value = {
"data": {
"affected_items": [
{
"name": "kernel",
"version": "5.4.0",
"severity": "critical",
"cve": "CVE-2024-0001",
"agent": {"id": "001", "name": "server1"},
"published": "2024-01-01T00:00:00Z",
"cvss": {
"cvss3": {
"score": 9.8,
"attackVector": "network"
}
}
}
]
}
}
# Call handler
result = await vulnerability_tools.handle_critical_vulnerabilities({
"include_exploit_data": True,
"priority_threshold": 80,
"time_range_days": 30
})
# Verify result structure
assert result["status"] == "success"
data = result["data"]
assert "summary" in data
assert "critical_vulnerabilities" in data
assert "exploit_intelligence" in data
assert "attack_vectors" in data
assert "immediate_actions" in data
assert "trending_analysis" in data
# Check that vulnerabilities are enriched with exploit data
if data["critical_vulnerabilities"]:
vuln = data["critical_vulnerabilities"][0]
assert "priority_score" in vuln
assert "exploit_available" in vuln
def test_calculate_risk_score(self, vulnerability_tools):
"""Test risk score calculation."""
# High risk vulnerability
high_risk_vuln = {
"severity": "critical",
"cvss": {
"cvss3": {
"score": 9.8
}
},
"agent": {"id": "001"}, # Critical agent
"name": "kernel" # Critical package
}
score = vulnerability_tools._calculate_risk_score(high_risk_vuln)
assert score >= 80 # Should be high risk
# Low risk vulnerability
low_risk_vuln = {
"severity": "low",
"cvss": {
"cvss3": {
"score": 2.0
}
},
"agent": {"id": "999"}, # Non-critical agent
"name": "someapp" # Non-critical package
}
score = vulnerability_tools._calculate_risk_score(low_risk_vuln)
assert score < 50 # Should be low risk
def test_assess_business_impact(self, vulnerability_tools):
"""Test business impact assessment."""
# Mock _calculate_risk_score for consistent testing
vulnerability_tools._calculate_risk_score = Mock()
# Critical impact
vulnerability_tools._calculate_risk_score.return_value = 90
assert vulnerability_tools._assess_business_impact({}) == "critical"
# High impact
vulnerability_tools._calculate_risk_score.return_value = 70
assert vulnerability_tools._assess_business_impact({}) == "high"
# Medium impact
vulnerability_tools._calculate_risk_score.return_value = 50
assert vulnerability_tools._assess_business_impact({}) == "medium"
# Low impact
vulnerability_tools._calculate_risk_score.return_value = 30
assert vulnerability_tools._assess_business_impact({}) == "low"
def test_assess_remediation_complexity(self, vulnerability_tools):
"""Test remediation complexity assessment."""
# Kernel update - high complexity
kernel_vuln = {"name": "linux-kernel"}
assert vulnerability_tools._assess_remediation_complexity(kernel_vuln) == "high"
# System library - medium complexity
ssl_vuln = {"name": "openssl"}
assert vulnerability_tools._assess_remediation_complexity(ssl_vuln) == "medium"
# Application package - low complexity
app_vuln = {"name": "myapp"}
assert vulnerability_tools._assess_remediation_complexity(app_vuln) == "low"
def test_calculate_priority_score(self, vulnerability_tools):
"""Test priority score calculation."""
# Recent critical vulnerability with exploit
recent_critical = {
"cvss": {"cvss3": {"score": 9.8}},
"exploit_available": True,
"days_since_discovery": 5,
"agent": {"id": "001"} # Critical agent
}
score = vulnerability_tools._calculate_priority_score(recent_critical)
assert score >= 90 # Should be very high priority
# Old low severity vulnerability
old_low = {
"cvss": {"cvss3": {"score": 3.0}},
"exploit_available": False,
"days_since_discovery": 200,
"agent": {"id": "999"} # Non-critical agent
}
score = vulnerability_tools._calculate_priority_score(old_low)
assert score <= 30 # Should be low priority
def test_calculate_days_since_discovery(self, vulnerability_tools):
"""Test days since discovery calculation."""
# Recent vulnerability
recent_date = datetime.utcnow() - timedelta(days=10)
recent_vuln = {"published": recent_date.isoformat() + "Z"}
days = vulnerability_tools._calculate_days_since_discovery(recent_vuln)
assert 9 <= days <= 11 # Allow for some time variance
# No published date
no_date_vuln = {}
days = vulnerability_tools._calculate_days_since_discovery(no_date_vuln)
assert days == 365 # Default value
def test_generate_vulnerability_overview(self, vulnerability_tools):
"""Test vulnerability overview generation."""
vulns = [
{"severity": "critical", "name": "test1"},
{"severity": "high", "name": "test2"},
{"severity": "critical", "name": "test3"}
]
# Mock risk score calculation
vulnerability_tools._calculate_risk_score = Mock(side_effect=[90, 70, 85])
overview = vulnerability_tools._generate_vulnerability_overview(vulns)
assert overview["total_vulnerabilities"] == 3
assert overview["severity_distribution"]["critical"] == 2
assert overview["severity_distribution"]["high"] == 1
assert "risk_metrics" in overview
assert overview["risk_metrics"]["critical_risk_count"] == 2
def test_assess_infrastructure_risk(self, vulnerability_tools):
"""Test infrastructure risk assessment."""
vulns = [{"name": "test"}] * 15 # 15 vulnerabilities
# Mock high risk scores
vulnerability_tools._calculate_risk_score = Mock(return_value=85)
risk_assessment = vulnerability_tools._assess_infrastructure_risk(vulns)
assert risk_assessment["level"] == "critical"
assert risk_assessment["critical_vulnerabilities"] == 15
assert "recommendations" in risk_assessment
def test_analyze_vulnerability_distribution(self, vulnerability_tools):
"""Test vulnerability distribution analysis."""
vulns = [
{"agent": {"id": "001"}, "name": "package1"},
{"agent": {"id": "001"}, "name": "package2"},
{"agent": {"id": "002"}, "name": "package1"},
]
distribution = vulnerability_tools._analyze_vulnerability_distribution(vulns)
assert "by_agent" in distribution
assert "by_package" in distribution
assert distribution["by_agent"]["001"] == 2
assert distribution["by_package"]["package1"] == 2
def test_analyze_affected_systems(self, vulnerability_tools):
"""Test affected systems analysis."""
vulns = [
{
"agent": {"id": "001", "name": "server1"},
"cve": "CVE-2023-0001",
"severity": "critical"
},
{
"agent": {"id": "001", "name": "server1"},
"cve": "CVE-2023-0002",
"severity": "high"
}
]
# Mock risk scores
vulnerability_tools._calculate_risk_score = Mock(side_effect=[90, 70])
systems_analysis = vulnerability_tools._analyze_affected_systems(vulns)
assert systems_analysis["total_affected_systems"] == 1
assert "001" in systems_analysis["highest_risk_systems"]
assert systems_analysis["highest_risk_systems"]["001"]["vulnerability_count"] == 2
def test_compliance_impact_assessment(self, vulnerability_tools):
"""Test compliance impact assessment."""
vulns = [
{"name": "apache", "severity": "critical"}, # PCI DSS
{"name": "nginx", "severity": "high"}, # PCI DSS
{"name": "mysql", "severity": "medium"} # PCI DSS
]
# Mock risk scores
vulnerability_tools._calculate_risk_score = Mock(side_effect=[90, 70, 50])
compliance = vulnerability_tools._assess_compliance_impact(vulns)
assert "compliance_risks" in compliance
assert compliance["compliance_risks"]["PCI_DSS"] == 2 # High risk apache and nginx
assert compliance["high_risk_vulnerabilities"] == 2
@pytest.mark.asyncio
async def test_error_handling(self, vulnerability_tools, mock_server):
"""Test error handling in handlers."""
# Make API call fail
mock_server.api_client.get_vulnerabilities.side_effect = Exception("API Error")
# Call handler
result = await vulnerability_tools.handle_vulnerability_summary({})
# Should return error response
assert result["status"] == "error"
assert "error" in result
assert "API Error" in result["error"]["message"]
def test_empty_vulnerability_list(self, vulnerability_tools):
"""Test handling of empty vulnerability lists."""
# Test overview with empty list
overview = vulnerability_tools._generate_vulnerability_overview([])
assert overview["total"] == 0
assert "message" in overview
# Test risk assessment with empty list
risk = vulnerability_tools._assess_infrastructure_risk([])
assert risk["level"] == "low"
assert risk["score"] == 0