import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from src.asl_mcp_server.tools.ip_search import IPSearchTool
class TestIPSearchTool:
"""Test cases for IP search functionality"""
@pytest.fixture
def ip_search_tool(self, mock_settings):
"""Create IP search tool instance for testing"""
return IPSearchTool(mock_settings)
@pytest.mark.asyncio
async def test_valid_ip_search(self, ip_search_tool, sample_ip_search_results):
"""Test successful IP search with valid parameters"""
# Mock the Athena client and query builder
with patch.object(ip_search_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(ip_search_tool.query_builder, 'build_ip_search_query') as mock_build_query:
mock_execute.return_value = sample_ip_search_results
mock_build_query.return_value = ("SELECT * FROM test", ["192.168.1.100", "192.168.1.100"])
result = await ip_search_tool.execute(
ip_address="192.168.1.100",
start_time="2024-01-15T00:00:00Z",
end_time="2024-01-15T23:59:59Z",
limit=100
)
assert result["success"] is True
assert result["count"] == 2
assert len(result["results"]) == 2
# Check that results are properly processed
first_result = result["results"][0]
assert "timestamp" in first_result
assert "ip_context" in first_result
assert "network_info" in first_result
assert first_result["ip_context"]["role"] == "source"
@pytest.mark.asyncio
async def test_invalid_ip_address(self, ip_search_tool):
"""Test IP search with invalid IP address"""
result = await ip_search_tool.execute(ip_address="invalid-ip")
assert result["success"] is False
assert "Invalid IP address format" in result["error"]
assert result["count"] == 0
@pytest.mark.asyncio
async def test_invalid_time_range(self, ip_search_tool):
"""Test IP search with invalid time range"""
result = await ip_search_tool.execute(
ip_address="192.168.1.1",
start_time="2024-01-15T23:59:59Z",
end_time="2024-01-15T00:00:00Z" # End before start
)
assert result["success"] is False
assert "Invalid time range" in result["error"]
@pytest.mark.asyncio
async def test_invalid_data_sources(self, ip_search_tool):
"""Test IP search with invalid data sources"""
result = await ip_search_tool.execute(
ip_address="192.168.1.1",
sources=["invalid_source", "another_invalid"]
)
assert result["success"] is False
assert "Invalid data sources" in result["error"]
@pytest.mark.asyncio
async def test_query_execution_failure(self, ip_search_tool):
"""Test handling of query execution failure"""
with patch.object(ip_search_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute:
mock_execute.side_effect = Exception("Athena query failed")
result = await ip_search_tool.execute(ip_address="192.168.1.1")
assert result["success"] is False
assert "Query execution failed" in result["error"]
assert result["count"] == 0
@pytest.mark.asyncio
async def test_empty_results(self, ip_search_tool):
"""Test handling of empty query results"""
with patch.object(ip_search_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(ip_search_tool.query_builder, 'build_ip_search_query') as mock_build_query:
mock_execute.return_value = []
mock_build_query.return_value = ("SELECT * FROM test", ["192.168.1.1", "192.168.1.1"])
result = await ip_search_tool.execute(ip_address="192.168.1.1")
assert result["success"] is True
assert result["count"] == 0
assert len(result["results"]) == 0
assert "No events found" in result["metadata"]["summary"]["message"]
def test_ip_context_determination(self, ip_search_tool):
"""Test IP context determination logic"""
# Test source IP context
result_data = {"src_ip": "192.168.1.1", "dst_ip": "10.0.0.1"}
context = ip_search_tool._determine_ip_context(result_data, "192.168.1.1")
assert context["role"] == "source"
assert context["direction"] == "outbound"
# Test destination IP context
result_data = {"src_ip": "10.0.0.1", "dst_ip": "192.168.1.1"}
context = ip_search_tool._determine_ip_context(result_data, "192.168.1.1")
assert context["role"] == "destination"
assert context["direction"] == "inbound"
# Test unknown context
result_data = {"src_ip": "10.0.0.1", "dst_ip": "10.0.0.2"}
context = ip_search_tool._determine_ip_context(result_data, "192.168.1.1")
assert context["role"] == "unknown"
assert context["direction"] == "unknown"
def test_summary_generation(self, ip_search_tool, sample_ip_search_results):
"""Test summary generation for search results"""
# Process the sample results first
processed_results = ip_search_tool._process_ip_search_results(
sample_ip_search_results, "192.168.1.100"
)
summary = ip_search_tool._generate_summary(processed_results, "192.168.1.100")
assert summary["total_events"] == 2
assert "event_breakdown" in summary
assert "by_type" in summary["event_breakdown"]
assert "time_range" in summary
assert "most_common_event_type" in summary
def test_highest_severity_detection(self, ip_search_tool):
"""Test highest severity detection logic"""
severities = ["Low", "High", "Medium"]
highest = ip_search_tool._get_highest_severity(severities)
assert highest == "high"
severities = ["Informational", "Low"]
highest = ip_search_tool._get_highest_severity(severities)
assert highest == "low"
severities = ["Unknown"]
highest = ip_search_tool._get_highest_severity(severities)
assert highest == "unknown"
@pytest.mark.asyncio
async def test_valid_sources_filtering(self, ip_search_tool, sample_ip_search_results):
"""Test valid sources filtering"""
with patch.object(ip_search_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(ip_search_tool.query_builder, 'build_ip_search_query') as mock_build_query:
mock_execute.return_value = sample_ip_search_results
mock_build_query.return_value = ("SELECT * FROM test", ["192.168.1.100", "192.168.1.100"])
result = await ip_search_tool.execute(
ip_address="192.168.1.100",
sources=["guardduty", "cloudtrail"]
)
assert result["success"] is True
# Verify that sources were passed to query builder
mock_build_query.assert_called_once()
call_args = mock_build_query.call_args[1]
assert "sources" in call_args
assert call_args["sources"] == ["guardduty", "cloudtrail"]
@pytest.mark.asyncio
async def test_limit_enforcement(self, ip_search_tool):
"""Test that query limits are properly enforced"""
with patch.object(ip_search_tool.athena_client, 'execute_query', new_callable=AsyncMock) as mock_execute, \
patch.object(ip_search_tool.query_builder, 'build_ip_search_query') as mock_build_query:
mock_execute.return_value = []
mock_build_query.return_value = ("SELECT * FROM test", ["192.168.1.1", "192.168.1.1"])
# Test with limit larger than max allowed
await ip_search_tool.execute(
ip_address="192.168.1.1",
limit=10000 # Should be capped to max_query_results
)
call_args = mock_build_query.call_args[1]
assert call_args["limit"] <= ip_search_tool.settings.max_query_results
def test_result_processing(self, ip_search_tool, sample_ip_search_results):
"""Test detailed result processing logic"""
processed = ip_search_tool._process_ip_search_results(
sample_ip_search_results, "192.168.1.100"
)
assert len(processed) == 2
# Check first result structure
first_result = processed[0]
required_fields = [
"timestamp", "event_type", "severity", "activity",
"ip_context", "network_info", "aws_context", "product_info", "raw_data"
]
for field in required_fields:
assert field in first_result
# Check network info structure
network_info = first_result["network_info"]
assert "source_ip" in network_info
assert "destination_ip" in network_info
assert "source_port" in network_info
assert "destination_port" in network_info
# Check AWS context
aws_context = first_result["aws_context"]
assert "account_id" in aws_context
assert "region" in aws_context