from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import structlog
from dateutil import parser as date_parser
from ..aws.table_discovery import SecurityLakeTableDiscovery
logger = structlog.get_logger(__name__)
class OCSFQueryBuilder:
def __init__(self, database: str, aws_region: str, aws_profile: Optional[str] = None):
self.database = database
self.aws_region = aws_region
self.aws_profile = aws_profile
self.table_discovery = SecurityLakeTableDiscovery(aws_region, aws_profile)
self._available_tables: Optional[Dict[str, List[str]]] = None
def _get_available_tables(self) -> Dict[str, List[str]]:
"""Get available tables, cached for performance"""
if self._available_tables is None:
self._available_tables = self.table_discovery.discover_security_lake_tables(self.database)
return self._available_tables
def _get_table_for_query(self, preferred_sources: List[str]) -> Optional[Tuple[str, str, str]]:
"""
Get the best table for a query based on preferred data sources
Returns:
Tuple of (table_name, data_source, ocsf_version) or None
"""
available_tables = self._get_available_tables()
for source in preferred_sources:
if source in available_tables:
table_name = self.table_discovery.get_best_table_for_data_source(source, available_tables)
if table_name:
ocsf_version = self.table_discovery.get_ocsf_version_from_table(table_name)
return table_name, source, ocsf_version
return None
def build_ip_search_query(
self,
ip_address: str,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
sources: Optional[List[str]] = None,
limit: int = 100
) -> Tuple[str, List[str]]:
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=7)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
start_ts = self._parse_timestamp(start_time)
end_ts = self._parse_timestamp(end_time)
# Determine which sources to search
if sources:
search_sources = sources
else:
# Auto-select sources good for IP searches
available_tables = self._get_available_tables()
search_sources = self.table_discovery.auto_select_data_sources('ip_search', available_tables)
if not search_sources:
raise ValueError("No suitable data sources available for IP search")
# Build union query across available sources
union_queries = []
parameters = []
for source in search_sources:
table_info = self._get_table_for_query([source])
if not table_info:
continue
table_name, data_source, ocsf_version = table_info
# Build source-specific query
source_query, source_params = self._build_ip_query_for_source(
table_name, data_source, ocsf_version, ip_address, start_ts, end_ts
)
if source_query:
union_queries.append(source_query)
parameters.extend(source_params)
if not union_queries:
raise ValueError("No available tables found for IP search")
# Combine queries with UNION ALL
if len(union_queries) == 1:
final_query = union_queries[0]
else:
final_query = f"""
SELECT * FROM (
{' UNION ALL '.join(union_queries)}
)
ORDER BY time DESC
LIMIT {limit}
"""
logger.info(
"Built dynamic IP search query",
ip_address=ip_address,
sources_used=search_sources,
tables_queried=len(union_queries)
)
return final_query.strip(), parameters
def _build_ip_query_for_source(
self,
table_name: str,
data_source: str,
ocsf_version: str,
ip_address: str,
start_time: str,
end_time: str
) -> Tuple[str, List[str]]:
"""Build IP search query for a specific data source and OCSF version"""
# Base fields that work across OCSF versions
base_fields = [
"time",
"type_name",
"severity",
"activity_name"
]
# Version-specific field mappings
if ocsf_version == "2.0":
# OCSF 2.0 uses nested structures
source_specific_fields = self._get_ocsf_v2_fields(data_source)
else:
# OCSF 1.0 uses flatter structure
source_specific_fields = self._get_ocsf_v1_fields(data_source)
all_fields = base_fields + source_specific_fields
# Build WHERE conditions for IP search
ip_conditions = self._get_ip_conditions_for_source(data_source, ocsf_version)
# Use the fixed timestamp filter - CRITICAL FIX
time_filter = self._build_time_filter(start_time, end_time)
where_conditions = [time_filter, f"({ip_conditions})"]
parameters = [ip_address, ip_address] # For src and dst IP parameters
query = f"""
SELECT {", ".join(all_fields)}, '{data_source}' as data_source
FROM "{self.database}"."{table_name}"
WHERE {" AND ".join(where_conditions)}
"""
return query, parameters
def _get_ocsf_v2_fields(self, data_source: str) -> List[str]:
"""Get OCSF 2.0 field mappings for a data source"""
common_fields = [
"src_endpoint.ip as src_ip",
"dst_endpoint.ip as dst_ip",
"src_endpoint.port as src_port",
"dst_endpoint.port as dst_port",
"metadata.product.name as product_name",
"metadata.product.vendor_name as vendor_name",
"cloud.account.uid as account_id",
"cloud.region as region"
]
# Add source-specific fields
if data_source == 'security_hub':
common_fields.extend([
"finding_info.title as finding_title",
"finding_info.desc as finding_description",
"compliance.control as compliance_control"
])
elif data_source == 'vpc_flow':
common_fields.extend([
"connection_info.protocol_num as protocol",
"traffic.bytes as bytes_transferred",
"traffic.packets as packet_count"
])
elif data_source == 'cloudtrail':
common_fields.extend([
"api.operation as api_operation",
"api.service.name as service_name",
"actor.user.name as user_name"
])
return common_fields
def _get_ocsf_v1_fields(self, data_source: str) -> List[str]:
"""Get OCSF 1.0 field mappings for a data source"""
# OCSF 1.0 typically has flatter structure
common_fields = [
"src_endpoint_ip as src_ip",
"dst_endpoint_ip as dst_ip",
"src_endpoint_port as src_port",
"dst_endpoint_port as dst_port",
"metadata_product_name as product_name",
"metadata_product_vendor_name as vendor_name",
"cloud_account_uid as account_id",
"cloud_region as region"
]
return common_fields
def _get_ip_conditions_for_source(self, data_source: str, ocsf_version: str) -> str:
"""Get IP search conditions based on data source and OCSF version"""
if ocsf_version == "2.0":
# OCSF 2.0 nested structure
return "src_endpoint.ip = ? OR dst_endpoint.ip = ?"
else:
# OCSF 1.0 flat structure
return "src_endpoint_ip = ? OR dst_endpoint_ip = ?"
def build_guardduty_search_query(
self,
finding_id: Optional[str] = None,
severity: Optional[str] = None,
finding_type: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
limit: int = 100
) -> Tuple[str, List[str]]:
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=7)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
start_ts = self._parse_timestamp(start_time)
end_ts = self._parse_timestamp(end_time)
# Try to find GuardDuty tables, fall back to Security Hub
available_tables = self._get_available_tables()
search_sources = ['guardduty', 'security_hub'] # Prefer GuardDuty, fallback to Security Hub
table_info = self._get_table_for_query(search_sources)
if not table_info:
raise ValueError(
"No GuardDuty or Security Hub tables found. "
f"Available sources: {list(available_tables.keys())}"
)
table_name, data_source, ocsf_version = table_info
# Build fields based on OCSF version and data source
if ocsf_version == "2.0":
base_fields = self._get_guardduty_v2_fields(data_source)
else:
base_fields = self._get_guardduty_v1_fields(data_source)
# Build WHERE conditions using fixed timestamp filter
time_filter = self._build_time_filter(start_time, end_time)
where_conditions = [time_filter]
parameters = []
# Add data source specific filters
if data_source == 'guardduty':
where_conditions.append("metadata.product.name = 'GuardDuty'")
elif data_source == 'security_hub':
# Filter for GuardDuty findings in Security Hub
if ocsf_version == "2.0":
where_conditions.append("metadata.product.name = 'Security Hub'")
where_conditions.append("finding_info.title LIKE '%GuardDuty%'")
else:
where_conditions.append("product_name = 'Security Hub'")
where_conditions.append("finding_title LIKE '%GuardDuty%'")
# Add specific search filters
if finding_id:
if ocsf_version == "2.0":
where_conditions.append("finding_info.uid = ?")
else:
where_conditions.append("finding_uid = ?")
parameters.append(finding_id)
if severity:
where_conditions.append("severity = ?")
parameters.append(severity)
if finding_type:
where_conditions.append("type_name = ?")
parameters.append(finding_type)
where_clause = " AND ".join(where_conditions)
query = f"""
SELECT {", ".join(base_fields)}
FROM "{self.database}"."{table_name}"
WHERE {where_clause}
ORDER BY time DESC, severity_id DESC
LIMIT {limit}
"""
logger.info(
"Built dynamic GuardDuty search query",
table_used=table_name,
data_source=data_source,
ocsf_version=ocsf_version,
finding_id=finding_id,
severity=severity
)
return query.strip(), parameters
def _get_guardduty_v2_fields(self, data_source: str) -> List[str]:
"""Get GuardDuty fields for OCSF 2.0"""
base_fields = [
"time",
"severity",
"severity_id",
"type_name",
"activity_name",
"src_endpoint.ip as src_ip",
"dst_endpoint.ip as dst_ip",
"cloud.account.uid as account_id",
"cloud.region as region",
"metadata.product.name as product_name",
"metadata.product.version as product_version"
]
if data_source == 'guardduty':
base_fields.extend([
"finding_info.uid as finding_id",
"finding_info.title as finding_title",
"finding_info.desc as finding_description",
"finding_info.types as finding_types",
"resources",
"remediation"
])
elif data_source == 'security_hub':
base_fields.extend([
"finding_info.uid as finding_id",
"finding_info.title as finding_title",
"finding_info.desc as finding_description",
"compliance.control as compliance_control",
"compliance.status as compliance_status",
"resource.type as resource_type"
])
return base_fields
def _get_guardduty_v1_fields(self, data_source: str) -> List[str]:
"""Get GuardDuty fields for OCSF 1.0"""
base_fields = [
"time",
"severity",
"severity_id",
"type_name",
"activity_name",
"src_endpoint_ip as src_ip",
"dst_endpoint_ip as dst_ip",
"cloud_account_uid as account_id",
"cloud_region as region",
"metadata_product_name as product_name",
"metadata_product_version as product_version"
]
if data_source == 'guardduty':
base_fields.extend([
"finding_uid as finding_id",
"finding_title as finding_title",
"finding_desc as finding_description",
"finding_types as finding_types"
])
elif data_source == 'security_hub':
base_fields.extend([
"finding_uid as finding_id",
"finding_title as finding_title",
"finding_desc as finding_description"
])
return base_fields
def build_data_sources_query(self) -> Tuple[str, List[str]]:
query = f"""
SELECT
table_name,
table_type,
input_format,
output_format,
location,
num_buckets,
bucket_count,
compressed
FROM information_schema.tables
WHERE table_schema = '{self.database}'
AND table_name LIKE '%amazon_security_lake%'
ORDER BY table_name
"""
return query.strip(), []
def build_table_schema_query(self, table_name: str) -> Tuple[str, List[str]]:
query = f"""
SELECT
column_name,
data_type,
is_nullable,
column_default,
ordinal_position
FROM information_schema.columns
WHERE table_schema = '{self.database}'
AND table_name = ?
ORDER BY ordinal_position
"""
return query.strip(), [table_name]
def _build_from_clause(self, sources: Optional[List[str]] = None) -> str:
if sources:
# Map source names to table patterns
source_to_table_mapping = {
"guardduty": "amazon_security_lake_table_*_guardduty_*",
"cloudtrail": "amazon_security_lake_table_*_cloud_trail_*",
"vpcflow": "amazon_security_lake_table_*_vpc_flow_*",
"securityhub": "amazon_security_lake_table_*_security_hub_*",
"route53": "amazon_security_lake_table_*_route53_*"
}
# For simplicity, use the main table for now
# In production, you might want to query specific tables based on sources
return f"{self.database}.amazon_security_lake_table_us_east_1_cloud_trail_mgmt_1_0"
else:
# Default to main Security Lake table
return f"{self.database}.amazon_security_lake_table_us_east_1_cloud_trail_mgmt_1_0"
def _parse_timestamp(self, timestamp_str: str) -> str:
"""
Parse timestamp and format for Athena TIMESTAMP type.
Critical fix: Ensures proper timestamp format for Iceberg tables.
"""
try:
if isinstance(timestamp_str, str):
dt = date_parser.parse(timestamp_str)
elif isinstance(timestamp_str, (int, float)):
# Handle Unix timestamps (seconds or milliseconds)
if timestamp_str > 1e10: # Milliseconds
dt = datetime.utcfromtimestamp(timestamp_str / 1000)
else: # Seconds
dt = datetime.utcfromtimestamp(timestamp_str)
else:
dt = timestamp_str
# Convert to UTC and format for Athena TIMESTAMP
# CRITICAL: Use proper timestamp format for Iceberg tables
return dt.strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error("Failed to parse timestamp", timestamp=timestamp_str, error=str(e))
# Fallback to current time minus 1 hour
return (datetime.utcnow() - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
def _build_time_filter(self, start_time: Optional[str], end_time: Optional[str]) -> str:
"""
Build proper timestamp filter for Iceberg tables.
Critical fix: Prevents TYPE_MISMATCH errors.
"""
filters = []
if start_time:
start_ts = self._parse_timestamp(start_time)
# CRITICAL: Use TIMESTAMP cast to avoid bigint comparison errors
filters.append(f"time >= TIMESTAMP '{start_ts}'")
if end_time:
end_ts = self._parse_timestamp(end_time)
# CRITICAL: Use TIMESTAMP cast to avoid bigint comparison errors
filters.append(f"time <= TIMESTAMP '{end_ts}'")
return " AND ".join(filters) if filters else "1=1"
def validate_ip_address(self, ip: str) -> bool:
import ipaddress
try:
ipaddress.ip_address(ip)
return True
except ValueError:
return False
def sanitize_query_parameter(self, param: str) -> str:
# Basic SQL injection prevention
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
sanitized = param
for char in dangerous_chars:
sanitized = sanitized.replace(char, "")
return sanitized.strip()[:255] # Limit length