import pytest
from unittest.mock import AsyncMock, patch
from src.asl_mcp_server.tools.guardduty_search import GuardDutySearchTool
class TestGuardDutySearchTool:
"""Test cases for GuardDuty search functionality"""
@pytest.fixture
def guardduty_tool(self, mock_settings):
"""Create GuardDuty search tool instance for testing"""
return GuardDutySearchTool(mock_settings)
@pytest.mark.asyncio
async def test_search_by_finding_id(self, guardduty_tool, sample_guardduty_results):
"""Test searching by specific finding ID"""
with patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(guardduty_tool.query_builder, 'build_guardduty_search_query') as mock_build_query:
mock_execute.return_value = sample_guardduty_results
mock_build_query.return_value = ("SELECT * FROM test", ["test-finding-id"])
result = await guardduty_tool.execute(
finding_id="12345abc-def0-1234-5678-90abcdef1234"
)
assert result["success"] is True
assert result["count"] == 1
assert len(result["results"]) == 1
# Check finding-specific fields
finding = result["results"][0]
assert finding["finding_id"] == "12345abc-def0-1234-5678-90abcdef1234"
assert "risk_assessment" in finding
assert "finding_details" in finding
@pytest.mark.asyncio
async def test_search_by_severity(self, guardduty_tool, sample_guardduty_results):
"""Test searching by severity level"""
with patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(guardduty_tool.query_builder, 'build_guardduty_search_query') as mock_build_query:
mock_execute.return_value = sample_guardduty_results
mock_build_query.return_value = ("SELECT * FROM test", ["High"])
result = await guardduty_tool.execute(severity="High")
assert result["success"] is True
assert result["count"] == 1
# Verify severity filtering was applied
mock_build_query.assert_called_once()
call_args = mock_build_query.call_args[1]
assert call_args["severity"] == "High"
@pytest.mark.asyncio
async def test_invalid_severity(self, guardduty_tool):
"""Test search with invalid severity level"""
result = await guardduty_tool.execute(severity="Invalid")
assert result["success"] is False
assert "Invalid severity level" in result["error"]
@pytest.mark.asyncio
async def test_default_search_behavior(self, guardduty_tool, sample_guardduty_results):
"""Test default search behavior when no parameters provided"""
with patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(guardduty_tool.query_builder, 'build_guardduty_search_query') as mock_build_query:
mock_execute.return_value = sample_guardduty_results
mock_build_query.return_value = ("SELECT * FROM test", ["High"])
result = await guardduty_tool.execute()
assert result["success"] is True
# Should default to High severity with recent time range
call_args = mock_build_query.call_args[1]
assert call_args["severity"] == "High"
assert call_args["start_time"] is not None
@pytest.mark.asyncio
async def test_query_execution_failure(self, guardduty_tool):
"""Test handling of query execution failure"""
with patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute:
mock_execute.side_effect = Exception("Database connection failed")
result = await guardduty_tool.execute(finding_id="test-id")
assert result["success"] is False
assert "Query execution failed" in result["error"]
def test_severity_score_mapping(self, guardduty_tool):
"""Test severity ID to score mapping"""
assert guardduty_tool._map_severity_to_score(1) == 10 # Informational
assert guardduty_tool._map_severity_to_score(2) == 25 # Low
assert guardduty_tool._map_severity_to_score(3) == 50 # Medium
assert guardduty_tool._map_severity_to_score(4) == 75 # High
assert guardduty_tool._map_severity_to_score(5) == 90 # Critical
assert guardduty_tool._map_severity_to_score(99) == 0 # Unknown
assert guardduty_tool._map_severity_to_score("invalid") == 0
def test_private_ip_detection(self, guardduty_tool):
"""Test private IP address detection"""
assert guardduty_tool._is_private_ip("192.168.1.1") is True
assert guardduty_tool._is_private_ip("10.0.0.1") is True
assert guardduty_tool._is_private_ip("172.16.0.1") is True
assert guardduty_tool._is_private_ip("8.8.8.8") is False
assert guardduty_tool._is_private_ip("203.0.113.1") is False
assert guardduty_tool._is_private_ip("invalid-ip") is False
def test_risk_assessment(self, guardduty_tool):
"""Test risk assessment logic"""
# High severity with external IP
result_data = {
"severity": "High",
"src_ip": "8.8.8.8", # External IP
"dst_ip": "192.168.1.1"
}
risk = guardduty_tool._assess_risk_level(result_data)
assert risk["level"] in ["MEDIUM", "HIGH"]
assert risk["score"] > 40
assert "High severity finding" in risk["factors"]
assert any("External IP" in factor for factor in risk["factors"])
# Low severity, internal IPs only
result_data = {
"severity": "Low",
"src_ip": "192.168.1.1",
"dst_ip": "10.0.0.1"
}
risk = guardduty_tool._assess_risk_level(result_data)
assert risk["level"] in ["LOW", "INFORMATIONAL"]
def test_resource_extraction(self, guardduty_tool):
"""Test resource information extraction"""
resources = [
{
"type": "EC2Instance",
"uid": "i-1234567890abcdef0",
"name": "web-server-01",
"region": "us-east-1"
},
{
"type": "S3Bucket",
"uid": "my-secure-bucket",
"name": "my-secure-bucket"
}
]
extracted = guardduty_tool._extract_resource_info(resources)
assert len(extracted) == 2
assert extracted[0]["type"] == "EC2Instance"
assert extracted[0]["uid"] == "i-1234567890abcdef0"
assert extracted[1]["type"] == "S3Bucket"
def test_result_processing(self, guardduty_tool, sample_guardduty_results):
"""Test processing of GuardDuty results"""
processed = guardduty_tool._process_guardduty_results(sample_guardduty_results)
assert len(processed) == 1
result = processed[0]
# Check required fields
required_fields = [
"finding_id", "title", "description", "severity", "severity_score",
"event_type", "activity", "timestamp", "finding_details",
"network_context", "aws_context", "product_info", "risk_assessment"
]
for field in required_fields:
assert field in result
# Check finding details structure
finding_details = result["finding_details"]
assert "types" in finding_details
assert "resources" in finding_details
assert "remediation" in finding_details
# Check that JSON strings were parsed
assert isinstance(finding_details["types"], list)
assert isinstance(finding_details["resources"], list)
assert isinstance(finding_details["remediation"], dict)
def test_summary_generation(self, guardduty_tool, sample_guardduty_results):
"""Test summary generation for GuardDuty results"""
processed = guardduty_tool._process_guardduty_results(sample_guardduty_results)
summary = guardduty_tool._generate_summary(processed)
assert summary["total_findings"] == 1
assert "breakdown" in summary
assert "affected_resources" in summary
assert "recommendations" in summary
# Check breakdown structure
breakdown = summary["breakdown"]
assert "by_severity" in breakdown
assert "by_finding_type" in breakdown
assert "by_risk_level" in breakdown
# Check affected resources
affected = summary["affected_resources"]
assert "account_count" in affected
assert "region_count" in affected
assert "accounts" in affected
assert "regions" in affected
def test_recommendations_generation(self, guardduty_tool):
"""Test recommendations generation logic"""
# High risk findings
high_risk_results = [
{"severity": "Critical", "risk_assessment": {"level": "HIGH"}},
{"severity": "High", "risk_assessment": {"level": "HIGH"}}
]
recommendations = guardduty_tool._generate_recommendations(high_risk_results)
assert len(recommendations) > 0
assert any("critical" in rec.lower() for rec in recommendations)
assert any("high-risk" in rec.lower() for rec in recommendations)
# Low risk findings
low_risk_results = [
{"severity": "Low", "risk_assessment": {"level": "LOW"}}
]
recommendations = guardduty_tool._generate_recommendations(low_risk_results)
assert any("low risk" in rec.lower() for rec in recommendations)
def test_external_ip_involvement_detection(self, guardduty_tool):
"""Test detection of external IP involvement"""
# Result with external IPs
result_with_external = {
"network_context": {
"source_ip": "8.8.8.8",
"destination_ip": "192.168.1.1"
}
}
involvement = guardduty_tool._is_external_ip_involved(result_with_external)
assert involvement[0] is True # Source IP is external
assert involvement[1] is False # Destination IP is private
# Result with only internal IPs
result_internal_only = {
"network_context": {
"source_ip": "192.168.1.1",
"destination_ip": "10.0.0.1"
}
}
involvement = guardduty_tool._is_external_ip_involved(result_internal_only)
assert involvement[0] is False
assert involvement[1] is False
@pytest.mark.asyncio
async def test_parameter_sanitization(self, guardduty_tool):
"""Test that input parameters are properly sanitized"""
with patch.object(guardduty_tool.query_builder, 'sanitize_query_parameter') as mock_sanitize, \
patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(guardduty_tool.query_builder, 'build_guardduty_search_query') as mock_build_query:
mock_execute.return_value = []
mock_build_query.return_value = ("SELECT * FROM test", [])
mock_sanitize.return_value = "sanitized_value"
await guardduty_tool.execute(
finding_id="potentially'malicious'input",
finding_type="another;dangerous;input"
)
# Verify sanitization was called
assert mock_sanitize.call_count >= 2
@pytest.mark.asyncio
async def test_empty_results_handling(self, guardduty_tool):
"""Test handling of empty query results"""
with patch.object(guardduty_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(guardduty_tool.query_builder, 'build_guardduty_search_query') as mock_build_query:
mock_execute.return_value = []
mock_build_query.return_value = ("SELECT * FROM test", [])
result = await guardduty_tool.execute(finding_id="nonexistent-id")
assert result["success"] is True
assert result["count"] == 0
assert "No GuardDuty findings found" in result["metadata"]["summary"]["message"]