from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import structlog
from ..athena.query_builder import OCSFQueryBuilder
from .base_tool import BaseTool
logger = structlog.get_logger(__name__)
class IPSearchTool(BaseTool):
def __init__(self, settings):
super().__init__(settings)
self.query_builder = OCSFQueryBuilder(
settings.security_lake_database,
settings.aws_region,
settings.aws_profile
)
async def execute(
self,
ip_address: str,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
sources: Optional[List[str]] = None,
limit: int = 100
) -> Dict[str, Any]:
try:
# Validate IP address
if not self.query_builder.validate_ip_address(ip_address):
return self._format_error_response(
"Invalid IP address format",
f"'{ip_address}' is not a valid IP address"
)
# Validate time range
if not self._validate_time_range(start_time, end_time):
return self._format_error_response(
"Invalid time range",
"start_time must be before or equal to end_time"
)
# Set defaults for time range if not provided
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=7)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
# Validate and sanitize sources
if sources:
allowed_sources = [
"guardduty", "cloudtrail", "vpcflow",
"securityhub", "route53", "waf"
]
invalid_sources = [s for s in sources if s.lower() not in allowed_sources]
if invalid_sources:
return self._format_error_response(
"Invalid data sources",
f"Unknown sources: {', '.join(invalid_sources)}. "
f"Allowed sources: {', '.join(allowed_sources)}"
)
sources = [s.lower() for s in sources]
# Enforce reasonable limits
limit = min(limit, self.settings.max_query_results)
logger.info(
"Executing IP search",
ip_address=ip_address,
start_time=start_time,
end_time=end_time,
sources=sources,
limit=limit
)
# Build and execute query
try:
query, parameters = self.query_builder.build_ip_search_query(
ip_address=ip_address,
start_time=start_time,
end_time=end_time,
sources=sources,
limit=limit
)
results = await self.athena_client.execute_query(query, parameters)
except ValueError as e:
# Handle cases where no suitable tables are found
return self._format_error_response(
"No suitable data sources available",
str(e) + ". Please check if Security Lake is properly configured and has data."
)
# Process and enrich results
processed_results = self._process_ip_search_results(results, ip_address)
metadata = {
"query_info": {
"ip_address": ip_address,
"time_range": {
"start": start_time,
"end": end_time
},
"sources_requested": sources,
"limit": limit
},
"summary": self._generate_summary(processed_results, ip_address)
}
return self._format_success_response(processed_results, metadata)
except Exception as e:
return self._format_error_response(
"Query execution failed",
str(e)
)
def _process_ip_search_results(
self,
results: List[Dict[str, Any]],
ip_address: str
) -> List[Dict[str, Any]]:
processed = []
for result in results:
processed_result = {
"timestamp": result.get("time"),
"event_type": result.get("type_name"),
"severity": result.get("severity"),
"activity": result.get("activity_name"),
"ip_context": self._determine_ip_context(result, ip_address),
"network_info": {
"source_ip": result.get("src_ip"),
"destination_ip": result.get("dst_ip"),
"source_port": result.get("src_port"),
"destination_port": result.get("dst_port")
},
"aws_context": {
"account_id": result.get("account_id"),
"region": result.get("region")
},
"product_info": {
"name": result.get("product_name"),
"vendor": result.get("vendor_name"),
"version": result.get("metadata_version")
},
"raw_data": result # Include original data for debugging
}
processed.append(processed_result)
return processed
def _determine_ip_context(self, result: Dict[str, Any], search_ip: str) -> Dict[str, str]:
src_ip = result.get("src_ip")
dst_ip = result.get("dst_ip")
context = {"role": "unknown", "direction": "unknown"}
if src_ip == search_ip:
context["role"] = "source"
context["direction"] = "outbound"
elif dst_ip == search_ip:
context["role"] = "destination"
context["direction"] = "inbound"
return context
def _generate_summary(
self,
results: List[Dict[str, Any]],
ip_address: str
) -> Dict[str, Any]:
if not results:
return {
"total_events": 0,
"message": f"No events found for IP address {ip_address}"
}
# Analyze results for summary
event_types = {}
severities = {}
products = {}
time_range = {"earliest": None, "latest": None}
for result in results:
# Count event types
event_type = result.get("event_type", "unknown")
event_types[event_type] = event_types.get(event_type, 0) + 1
# Count severities
severity = result.get("severity", "unknown")
severities[severity] = severities.get(severity, 0) + 1
# Count products
product = result.get("product_info", {}).get("name", "unknown")
products[product] = products.get(product, 0) + 1
# Track time range
timestamp = result.get("timestamp")
if timestamp:
if not time_range["earliest"] or timestamp < time_range["earliest"]:
time_range["earliest"] = timestamp
if not time_range["latest"] or timestamp > time_range["latest"]:
time_range["latest"] = timestamp
return {
"total_events": len(results),
"event_breakdown": {
"by_type": event_types,
"by_severity": severities,
"by_product": products
},
"time_range": time_range,
"most_common_event_type": max(event_types.items(), key=lambda x: x[1])[0] if event_types else None,
"highest_severity": self._get_highest_severity(severities.keys()) if severities else None
}
def _get_highest_severity(self, severities: List[str]) -> str:
severity_order = ["critical", "high", "medium", "low", "informational", "unknown"]
for severity in severity_order:
if severity.lower() in [s.lower() for s in severities]:
return severity
return "unknown"