"""
Tests for error handling across all functions.
"""
import pytest
import json
from unittest.mock import AsyncMock, patch
from aiohttp import ClientError, ClientResponseError
from threat_intel_mcp import server
from threat_intel_mcp.config import threat_cache
class TestHTTPErrorHandling:
"""Test error handling in HTTP functions."""
@pytest.mark.asyncio
async def test_fetch_url_timeout(self):
"""Should handle timeout errors."""
from aioresponses import aioresponses as aioresponses_ctx
import asyncio
with aioresponses_ctx() as mocked:
mocked.get(
"http://example.com/slow",
exception=asyncio.TimeoutError("Timeout"),
)
with pytest.raises(asyncio.TimeoutError):
await server.fetch_url("http://example.com/slow")
@pytest.mark.asyncio
async def test_fetch_json_malformed_response(self):
"""Should handle malformed JSON responses."""
from aioresponses import aioresponses as aioresponses_ctx
from aiohttp import ContentTypeError
with aioresponses_ctx() as mocked:
# Return non-JSON body - aiohttp's .json() will raise
mocked.get(
"http://example.com/api/bad",
body="not valid json {{{{",
content_type="text/plain",
)
with pytest.raises(Exception):
await server.fetch_json("http://example.com/api/bad")
class TestMCPToolErrorHandling:
"""Test error handling in MCP tools."""
@pytest.mark.asyncio
async def test_fetch_feed_handles_network_error(self, clean_cache):
"""Should gracefully handle network errors."""
with patch('threat_intel_mcp.server.fetch_url', side_effect=ClientError("Connection failed")):
result = await server.fetch_threat_feed.fn("feodo_tracker")
data = json.loads(result)
assert data["success"] is False
assert "Network error" in data["error"]
assert data["feed"] == "feodo_tracker"
@pytest.mark.asyncio
async def test_fetch_feed_handles_unexpected_error(self, clean_cache):
"""Should handle unexpected exceptions."""
with patch('threat_intel_mcp.server.fetch_url', side_effect=RuntimeError("Unexpected error")):
result = await server.fetch_threat_feed.fn("feodo_tracker")
data = json.loads(result)
assert data["success"] is False
assert "error" in data
@pytest.mark.asyncio
async def test_check_ip_reputation_handles_api_errors(self, clean_cache):
"""Should continue checking other sources if one API fails."""
threat_cache.set("feed_feodo_tracker", {
"ips": ["192.0.2.1"],
"count": 1,
"type": "ip_list"
})
# Mock VirusTotal API failure
with patch('aiohttp.ClientSession') as mock_session_class:
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock()
mock_session.get = AsyncMock(side_effect=ClientError("API error"))
mock_session_class.return_value = mock_session
result = await server.check_ip_reputation.fn("192.0.2.1")
data = json.loads(result)
# Should still succeed with feed data
assert data["success"] is True
assert data["is_malicious"] is True
@pytest.mark.asyncio
async def test_get_cisa_kev_handles_invalid_dates(self):
"""Should handle invalid date formats in CISA KEV data."""
invalid_kev_response = {
"vulnerabilities": [
{
"cveID": "CVE-2024-0001",
"vendorProject": "Test",
"product": "Product",
"vulnerabilityName": "Test",
"dateAdded": "invalid-date", # Invalid format
"shortDescription": "Test"
}
]
}
with patch('threat_intel_mcp.server.fetch_json', new=AsyncMock(return_value=invalid_kev_response)):
result = await server.get_cisa_kev.fn(days=30)
data = json.loads(result)
# Should not crash, should skip invalid entries
assert data["success"] is True
assert "vulnerabilities" in data
@pytest.mark.asyncio
async def test_bulk_check_handles_malformed_input(self):
"""Should handle various malformed input formats."""
# Empty string
result = await server.check_bulk_ips.fn("")
data = json.loads(result)
assert data["success"] is False
# Whitespace only
result = await server.check_bulk_ips.fn(" ")
data = json.loads(result)
assert data["success"] is False
# Invalid JSON
result = await server.check_bulk_ips.fn("[invalid json")
# Should fallback to comma-separated parsing
data = json.loads(result)
# May succeed or fail depending on parsing
@pytest.mark.asyncio
async def test_check_network_partial_failures(self, clean_cache):
"""Should process valid devices even if some are invalid."""
threat_cache.set("feed_feodo_tracker", {
"ips": ["192.0.2.1"],
"count": 1,
"type": "ip_list"
})
scan_results = json.dumps({
"devices": [
{"ip": "invalid-ip", "hostname": "bad"},
{"ip": "10.0.0.1", "hostname": "good"},
{"ip": "192.0.2.1", "hostname": "threat"}
]
})
result = await server.check_network_against_threats.fn(scan_results)
data = json.loads(result)
# Should process valid IPs
assert data["success"] is True
# Exact count depends on validation
class TestValidationErrorHandling:
"""Test validation error messages."""
@pytest.mark.asyncio
async def test_ip_validation_error_message(self):
"""Should provide clear error for invalid IP."""
result = await server.check_ip_reputation.fn("not.an.ip.address")
data = json.loads(result)
assert data["success"] is False
assert "error" in data
assert "Invalid IP" in data["error"]
@pytest.mark.asyncio
async def test_hash_validation_error_message(self):
"""Should provide clear error for invalid hash."""
result = await server.check_hash_reputation.fn("tooshort")
data = json.loads(result)
assert data["success"] is False
assert "error" in data
assert "Invalid hash" in data["error"]
@pytest.mark.asyncio
async def test_ioc_type_validation_error_message(self):
"""Should provide clear error for invalid IOC type."""
result = await server.get_recent_iocs.fn(ioc_type="invalid_type")
data = json.loads(result)
assert data["success"] is False
assert "error" in data
assert "Invalid IOC type" in data["error"]
class TestCacheErrorRecovery:
"""Test cache failure recovery."""
def test_cache_handles_expiry_errors(self):
"""Should handle errors during expiry cleanup."""
cache = threat_cache
# Set a value
cache.set("test", "value")
# Manually corrupt expiry (simulate edge case)
# Cache should still be usable
assert cache.get("test") == "value"
def test_cache_handles_concurrent_eviction(self):
"""Should handle concurrent eviction scenarios."""
from threat_intel_mcp.config import ThreatCache
cache = ThreatCache(max_size=2)
# Fill cache
cache.set("key1", "value1")
cache.set("key2", "value2")
# Trigger eviction
cache.set("key3", "value3")
# Should not crash
stats = cache.stats()
assert stats["size"] <= 2
class TestPartialDataHandling:
"""Test handling of partial/incomplete data."""
@pytest.mark.asyncio
async def test_feed_with_empty_response(self, clean_cache):
"""Should handle empty feed responses."""
with patch('threat_intel_mcp.server.fetch_url', new=AsyncMock(return_value="")):
result = await server.fetch_threat_feed.fn("feodo_tracker")
data = json.loads(result)
assert data["success"] is True
assert data["count"] == 0
@pytest.mark.asyncio
async def test_feed_with_all_invalid_entries(self, clean_cache):
"""Should handle feeds with all invalid entries."""
invalid_content = """# All invalid
999.999.999.999
invalid.ip.address
not-an-ip"""
with patch('threat_intel_mcp.server.fetch_url', new=AsyncMock(return_value=invalid_content)):
result = await server.fetch_threat_feed.fn("feodo_tracker")
data = json.loads(result)
assert data["success"] is True
assert data["count"] == 0