from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import structlog
from ..athena.query_builder import OCSFQueryBuilder
from .base_tool import BaseTool
logger = structlog.get_logger(__name__)
class GuardDutySearchTool(BaseTool):
def __init__(self, settings):
super().__init__(settings)
self.query_builder = OCSFQueryBuilder(
settings.security_lake_database,
settings.aws_region,
settings.aws_profile
)
async def execute(
self,
finding_id: Optional[str] = None,
severity: Optional[str] = None,
finding_type: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: int = 100
) -> Dict[str, Any]:
try:
# Validate that at least one search parameter is provided
if not any([finding_id, severity, finding_type]):
# If no specific parameters, default to recent high/critical findings
severity = "High"
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
# Validate severity if provided
if severity:
valid_severities = ["Critical", "High", "Medium", "Low", "Informational"]
if severity not in valid_severities:
return self._format_error_response(
"Invalid severity level",
f"'{severity}' is not valid. Use one of: {', '.join(valid_severities)}"
)
# Validate time range
if not self._validate_time_range(start_time, end_time):
return self._format_error_response(
"Invalid time range",
"start_time must be before or equal to end_time"
)
# Set defaults for time range if not provided
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=7)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
# Sanitize finding_id if provided
if finding_id:
finding_id = self.query_builder.sanitize_query_parameter(finding_id)
# Sanitize finding_type if provided
if finding_type:
finding_type = self.query_builder.sanitize_query_parameter(finding_type)
# Enforce reasonable limits
limit = min(limit, self.settings.max_query_results)
logger.info(
"Executing GuardDuty search",
finding_id=finding_id,
severity=severity,
finding_type=finding_type,
start_time=start_time,
end_time=end_time,
limit=limit
)
# Build and execute query
try:
query, parameters = self.query_builder.build_guardduty_search_query(
finding_id=finding_id,
severity=severity,
finding_type=finding_type,
start_time=start_time,
end_time=end_time,
limit=limit
)
results = await self.athena_client.execute_query(query, parameters)
except ValueError as e:
# Handle cases where no suitable tables are found
return self._format_error_response(
"No GuardDuty data sources available",
str(e) + ". Consider using Security Hub findings search as an alternative."
)
# Process and enrich results
processed_results = self._process_guardduty_results(results)
metadata = {
"query_info": {
"finding_id": finding_id,
"severity": severity,
"finding_type": finding_type,
"time_range": {
"start": start_time,
"end": end_time
},
"limit": limit
},
"summary": self._generate_summary(processed_results)
}
return self._format_success_response(processed_results, metadata)
except Exception as e:
return self._format_error_response(
"Query execution failed",
str(e)
)
def _process_guardduty_results(
self,
results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
processed = []
for result in results:
# Parse finding types if it's a JSON string
finding_types = result.get("finding_types")
if isinstance(finding_types, str):
try:
import json
finding_types = json.loads(finding_types)
except (json.JSONDecodeError, TypeError):
finding_types = [finding_types] if finding_types else []
elif not isinstance(finding_types, list):
finding_types = []
# Parse resources if it's a JSON string
resources = result.get("resources")
if isinstance(resources, str):
try:
import json
resources = json.loads(resources)
except (json.JSONDecodeError, TypeError):
resources = []
elif not isinstance(resources, list):
resources = []
# Parse remediation if it's a JSON string
remediation = result.get("remediation")
if isinstance(remediation, str):
try:
import json
remediation = json.loads(remediation)
except (json.JSONDecodeError, TypeError):
remediation = {}
elif not isinstance(remediation, dict):
remediation = {}
processed_result = {
"finding_id": result.get("finding_id"),
"title": result.get("finding_title"),
"description": result.get("finding_description"),
"severity": result.get("severity"),
"severity_score": self._map_severity_to_score(result.get("severity_id")),
"event_type": result.get("type_name"),
"activity": result.get("activity_name"),
"timestamp": result.get("time"),
"finding_details": {
"types": finding_types,
"resources": self._extract_resource_info(resources),
"remediation": remediation
},
"network_context": {
"source_ip": result.get("src_ip"),
"destination_ip": result.get("dst_ip")
},
"aws_context": {
"account_id": result.get("account_id"),
"region": result.get("region")
},
"product_info": {
"name": result.get("product_name"),
"version": result.get("product_version")
},
"risk_assessment": self._assess_risk_level(result),
"raw_data": result # Include original data for debugging
}
processed.append(processed_result)
return processed
def _extract_resource_info(self, resources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
resource_info = []
for resource in resources:
if isinstance(resource, dict):
info = {
"type": resource.get("type"),
"uid": resource.get("uid"),
"name": resource.get("name"),
"region": resource.get("region")
}
resource_info.append(info)
return resource_info
def _map_severity_to_score(self, severity_id: Any) -> int:
# Map OCSF severity IDs to numeric scores
severity_mapping = {
1: 10, # Informational
2: 25, # Low
3: 50, # Medium
4: 75, # High
5: 90, # Critical
99: 0 # Unknown
}
try:
return severity_mapping.get(int(severity_id), 0)
except (ValueError, TypeError):
return 0
def _assess_risk_level(self, result: Dict[str, Any]) -> Dict[str, Any]:
risk_factors = []
risk_score = 0
# Severity contribution
severity = result.get("severity", "").lower()
if severity == "critical":
risk_score += 40
risk_factors.append("Critical severity finding")
elif severity == "high":
risk_score += 30
risk_factors.append("High severity finding")
elif severity == "medium":
risk_score += 20
risk_factors.append("Medium severity finding")
# Network context contribution
if result.get("src_ip") or result.get("dst_ip"):
risk_score += 10
risk_factors.append("Network activity detected")
# External IP involvement (basic heuristic)
external_ips = []
for ip_field in ["src_ip", "dst_ip"]:
ip = result.get(ip_field)
if ip and not self._is_private_ip(ip):
external_ips.append(ip)
risk_score += 15
if external_ips:
risk_factors.append(f"External IP addresses involved: {', '.join(external_ips)}")
# Determine risk level
if risk_score >= 70:
risk_level = "HIGH"
elif risk_score >= 40:
risk_level = "MEDIUM"
elif risk_score >= 20:
risk_level = "LOW"
else:
risk_level = "INFORMATIONAL"
return {
"level": risk_level,
"score": risk_score,
"factors": risk_factors
}
def _is_private_ip(self, ip: str) -> bool:
try:
import ipaddress
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_private
except ValueError:
return False
def _generate_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
if not results:
return {
"total_findings": 0,
"message": "No GuardDuty findings found matching the criteria"
}
# Analyze results for summary
severities = {}
finding_types = {}
risk_levels = {}
time_range = {"earliest": None, "latest": None}
affected_accounts = set()
affected_regions = set()
for result in results:
# Count severities
severity = result.get("severity", "unknown")
severities[severity] = severities.get(severity, 0) + 1
# Count finding types
event_type = result.get("event_type", "unknown")
finding_types[event_type] = finding_types.get(event_type, 0) + 1
# Count risk levels
risk_level = result.get("risk_assessment", {}).get("level", "unknown")
risk_levels[risk_level] = risk_levels.get(risk_level, 0) + 1
# Track affected resources
aws_context = result.get("aws_context", {})
if aws_context.get("account_id"):
affected_accounts.add(aws_context["account_id"])
if aws_context.get("region"):
affected_regions.add(aws_context["region"])
# Track time range
timestamp = result.get("timestamp")
if timestamp:
if not time_range["earliest"] or timestamp < time_range["earliest"]:
time_range["earliest"] = timestamp
if not time_range["latest"] or timestamp > time_range["latest"]:
time_range["latest"] = timestamp
return {
"total_findings": len(results),
"breakdown": {
"by_severity": severities,
"by_finding_type": finding_types,
"by_risk_level": risk_levels
},
"affected_resources": {
"account_count": len(affected_accounts),
"region_count": len(affected_regions),
"accounts": list(affected_accounts),
"regions": list(affected_regions)
},
"time_range": time_range,
"highest_risk_findings": len([r for r in results if r.get("risk_assessment", {}).get("level") == "HIGH"]),
"recommendations": self._generate_recommendations(results)
}
def _generate_recommendations(self, results: List[Dict[str, Any]]) -> List[str]:
recommendations = []
high_risk_count = len([r for r in results if r.get("risk_assessment", {}).get("level") == "HIGH"])
critical_count = len([r for r in results if r.get("severity", "").lower() == "critical"])
if critical_count > 0:
recommendations.append(f"Immediate attention required: {critical_count} critical findings detected")
if high_risk_count > 0:
recommendations.append(f"High priority: {high_risk_count} high-risk findings need investigation")
# Check for external IP involvement
external_ip_findings = [
r for r in results
if any(self._is_external_ip_involved(r))
]
if external_ip_findings:
recommendations.append(f"Network security review needed: {len(external_ip_findings)} findings involve external IPs")
if not recommendations:
recommendations.append("Continue monitoring - findings appear to be low risk")
return recommendations
def _is_external_ip_involved(self, result: Dict[str, Any]) -> List[bool]:
network_context = result.get("network_context", {})
return [
network_context.get("source_ip") and not self._is_private_ip(network_context["source_ip"]),
network_context.get("destination_ip") and not self._is_private_ip(network_context["destination_ip"])
]