Skip to main content
Glama
kebabmane

Amazon Security Lake MCP Server

by kebabmane
query_builder.py18.5 kB
from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple import structlog from dateutil import parser as date_parser from ..aws.table_discovery import SecurityLakeTableDiscovery logger = structlog.get_logger(__name__) class OCSFQueryBuilder: def __init__(self, database: str, aws_region: str, aws_profile: Optional[str] = None): self.database = database self.aws_region = aws_region self.aws_profile = aws_profile self.table_discovery = SecurityLakeTableDiscovery(aws_region, aws_profile) self._available_tables: Optional[Dict[str, List[str]]] = None def _get_available_tables(self) -> Dict[str, List[str]]: """Get available tables, cached for performance""" if self._available_tables is None: self._available_tables = self.table_discovery.discover_security_lake_tables(self.database) return self._available_tables def _get_table_for_query(self, preferred_sources: List[str]) -> Optional[Tuple[str, str, str]]: """ Get the best table for a query based on preferred data sources Returns: Tuple of (table_name, data_source, ocsf_version) or None """ available_tables = self._get_available_tables() for source in preferred_sources: if source in available_tables: table_name = self.table_discovery.get_best_table_for_data_source(source, available_tables) if table_name: ocsf_version = self.table_discovery.get_ocsf_version_from_table(table_name) return table_name, source, ocsf_version return None def build_ip_search_query( self, ip_address: str, start_time: Optional[str] = None, end_time: Optional[str] = None, sources: Optional[List[str]] = None, limit: int = 100 ) -> Tuple[str, List[str]]: if not start_time: start_time = (datetime.utcnow() - timedelta(days=7)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() start_ts = self._parse_timestamp(start_time) end_ts = self._parse_timestamp(end_time) # Determine which sources to search if sources: search_sources = sources else: # Auto-select sources good for IP searches available_tables = self._get_available_tables() search_sources = self.table_discovery.auto_select_data_sources('ip_search', available_tables) if not search_sources: raise ValueError("No suitable data sources available for IP search") # Build union query across available sources union_queries = [] parameters = [] for source in search_sources: table_info = self._get_table_for_query([source]) if not table_info: continue table_name, data_source, ocsf_version = table_info # Build source-specific query source_query, source_params = self._build_ip_query_for_source( table_name, data_source, ocsf_version, ip_address, start_ts, end_ts ) if source_query: union_queries.append(source_query) parameters.extend(source_params) if not union_queries: raise ValueError("No available tables found for IP search") # Combine queries with UNION ALL if len(union_queries) == 1: final_query = union_queries[0] else: final_query = f""" SELECT * FROM ( {' UNION ALL '.join(union_queries)} ) ORDER BY time DESC LIMIT {limit} """ logger.info( "Built dynamic IP search query", ip_address=ip_address, sources_used=search_sources, tables_queried=len(union_queries) ) return final_query.strip(), parameters def _build_ip_query_for_source( self, table_name: str, data_source: str, ocsf_version: str, ip_address: str, start_time: str, end_time: str ) -> Tuple[str, List[str]]: """Build IP search query for a specific data source and OCSF version""" # Base fields that work across OCSF versions base_fields = [ "time", "type_name", "severity", "activity_name" ] # Version-specific field mappings if ocsf_version == "2.0": # OCSF 2.0 uses nested structures source_specific_fields = self._get_ocsf_v2_fields(data_source) else: # OCSF 1.0 uses flatter structure source_specific_fields = self._get_ocsf_v1_fields(data_source) all_fields = base_fields + source_specific_fields # Build WHERE conditions for IP search ip_conditions = self._get_ip_conditions_for_source(data_source, ocsf_version) # Use the fixed timestamp filter - CRITICAL FIX time_filter = self._build_time_filter(start_time, end_time) where_conditions = [time_filter, f"({ip_conditions})"] parameters = [ip_address, ip_address] # For src and dst IP parameters query = f""" SELECT {", ".join(all_fields)}, '{data_source}' as data_source FROM "{self.database}"."{table_name}" WHERE {" AND ".join(where_conditions)} """ return query, parameters def _get_ocsf_v2_fields(self, data_source: str) -> List[str]: """Get OCSF 2.0 field mappings for a data source""" common_fields = [ "src_endpoint.ip as src_ip", "dst_endpoint.ip as dst_ip", "src_endpoint.port as src_port", "dst_endpoint.port as dst_port", "metadata.product.name as product_name", "metadata.product.vendor_name as vendor_name", "cloud.account.uid as account_id", "cloud.region as region" ] # Add source-specific fields if data_source == 'security_hub': common_fields.extend([ "finding_info.title as finding_title", "finding_info.desc as finding_description", "compliance.control as compliance_control" ]) elif data_source == 'vpc_flow': common_fields.extend([ "connection_info.protocol_num as protocol", "traffic.bytes as bytes_transferred", "traffic.packets as packet_count" ]) elif data_source == 'cloudtrail': common_fields.extend([ "api.operation as api_operation", "api.service.name as service_name", "actor.user.name as user_name" ]) return common_fields def _get_ocsf_v1_fields(self, data_source: str) -> List[str]: """Get OCSF 1.0 field mappings for a data source""" # OCSF 1.0 typically has flatter structure common_fields = [ "src_endpoint_ip as src_ip", "dst_endpoint_ip as dst_ip", "src_endpoint_port as src_port", "dst_endpoint_port as dst_port", "metadata_product_name as product_name", "metadata_product_vendor_name as vendor_name", "cloud_account_uid as account_id", "cloud_region as region" ] return common_fields def _get_ip_conditions_for_source(self, data_source: str, ocsf_version: str) -> str: """Get IP search conditions based on data source and OCSF version""" if ocsf_version == "2.0": # OCSF 2.0 nested structure return "src_endpoint.ip = ? OR dst_endpoint.ip = ?" else: # OCSF 1.0 flat structure return "src_endpoint_ip = ? OR dst_endpoint_ip = ?" def build_guardduty_search_query( self, finding_id: Optional[str] = None, severity: Optional[str] = None, finding_type: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, limit: int = 100 ) -> Tuple[str, List[str]]: if not start_time: start_time = (datetime.utcnow() - timedelta(days=7)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() start_ts = self._parse_timestamp(start_time) end_ts = self._parse_timestamp(end_time) # Try to find GuardDuty tables, fall back to Security Hub available_tables = self._get_available_tables() search_sources = ['guardduty', 'security_hub'] # Prefer GuardDuty, fallback to Security Hub table_info = self._get_table_for_query(search_sources) if not table_info: raise ValueError( "No GuardDuty or Security Hub tables found. " f"Available sources: {list(available_tables.keys())}" ) table_name, data_source, ocsf_version = table_info # Build fields based on OCSF version and data source if ocsf_version == "2.0": base_fields = self._get_guardduty_v2_fields(data_source) else: base_fields = self._get_guardduty_v1_fields(data_source) # Build WHERE conditions using fixed timestamp filter time_filter = self._build_time_filter(start_time, end_time) where_conditions = [time_filter] parameters = [] # Add data source specific filters if data_source == 'guardduty': where_conditions.append("metadata.product.name = 'GuardDuty'") elif data_source == 'security_hub': # Filter for GuardDuty findings in Security Hub if ocsf_version == "2.0": where_conditions.append("metadata.product.name = 'Security Hub'") where_conditions.append("finding_info.title LIKE '%GuardDuty%'") else: where_conditions.append("product_name = 'Security Hub'") where_conditions.append("finding_title LIKE '%GuardDuty%'") # Add specific search filters if finding_id: if ocsf_version == "2.0": where_conditions.append("finding_info.uid = ?") else: where_conditions.append("finding_uid = ?") parameters.append(finding_id) if severity: where_conditions.append("severity = ?") parameters.append(severity) if finding_type: where_conditions.append("type_name = ?") parameters.append(finding_type) where_clause = " AND ".join(where_conditions) query = f""" SELECT {", ".join(base_fields)} FROM "{self.database}"."{table_name}" WHERE {where_clause} ORDER BY time DESC, severity_id DESC LIMIT {limit} """ logger.info( "Built dynamic GuardDuty search query", table_used=table_name, data_source=data_source, ocsf_version=ocsf_version, finding_id=finding_id, severity=severity ) return query.strip(), parameters def _get_guardduty_v2_fields(self, data_source: str) -> List[str]: """Get GuardDuty fields for OCSF 2.0""" base_fields = [ "time", "severity", "severity_id", "type_name", "activity_name", "src_endpoint.ip as src_ip", "dst_endpoint.ip as dst_ip", "cloud.account.uid as account_id", "cloud.region as region", "metadata.product.name as product_name", "metadata.product.version as product_version" ] if data_source == 'guardduty': base_fields.extend([ "finding_info.uid as finding_id", "finding_info.title as finding_title", "finding_info.desc as finding_description", "finding_info.types as finding_types", "resources", "remediation" ]) elif data_source == 'security_hub': base_fields.extend([ "finding_info.uid as finding_id", "finding_info.title as finding_title", "finding_info.desc as finding_description", "compliance.control as compliance_control", "compliance.status as compliance_status", "resource.type as resource_type" ]) return base_fields def _get_guardduty_v1_fields(self, data_source: str) -> List[str]: """Get GuardDuty fields for OCSF 1.0""" base_fields = [ "time", "severity", "severity_id", "type_name", "activity_name", "src_endpoint_ip as src_ip", "dst_endpoint_ip as dst_ip", "cloud_account_uid as account_id", "cloud_region as region", "metadata_product_name as product_name", "metadata_product_version as product_version" ] if data_source == 'guardduty': base_fields.extend([ "finding_uid as finding_id", "finding_title as finding_title", "finding_desc as finding_description", "finding_types as finding_types" ]) elif data_source == 'security_hub': base_fields.extend([ "finding_uid as finding_id", "finding_title as finding_title", "finding_desc as finding_description" ]) return base_fields def build_data_sources_query(self) -> Tuple[str, List[str]]: query = f""" SELECT table_name, table_type, input_format, output_format, location, num_buckets, bucket_count, compressed FROM information_schema.tables WHERE table_schema = '{self.database}' AND table_name LIKE '%amazon_security_lake%' ORDER BY table_name """ return query.strip(), [] def build_table_schema_query(self, table_name: str) -> Tuple[str, List[str]]: query = f""" SELECT column_name, data_type, is_nullable, column_default, ordinal_position FROM information_schema.columns WHERE table_schema = '{self.database}' AND table_name = ? ORDER BY ordinal_position """ return query.strip(), [table_name] def _build_from_clause(self, sources: Optional[List[str]] = None) -> str: if sources: # Map source names to table patterns source_to_table_mapping = { "guardduty": "amazon_security_lake_table_*_guardduty_*", "cloudtrail": "amazon_security_lake_table_*_cloud_trail_*", "vpcflow": "amazon_security_lake_table_*_vpc_flow_*", "securityhub": "amazon_security_lake_table_*_security_hub_*", "route53": "amazon_security_lake_table_*_route53_*" } # For simplicity, use the main table for now # In production, you might want to query specific tables based on sources return f"{self.database}.amazon_security_lake_table_us_east_1_cloud_trail_mgmt_1_0" else: # Default to main Security Lake table return f"{self.database}.amazon_security_lake_table_us_east_1_cloud_trail_mgmt_1_0" def _parse_timestamp(self, timestamp_str: str) -> str: """ Parse timestamp and format for Athena TIMESTAMP type. Critical fix: Ensures proper timestamp format for Iceberg tables. """ try: if isinstance(timestamp_str, str): dt = date_parser.parse(timestamp_str) elif isinstance(timestamp_str, (int, float)): # Handle Unix timestamps (seconds or milliseconds) if timestamp_str > 1e10: # Milliseconds dt = datetime.utcfromtimestamp(timestamp_str / 1000) else: # Seconds dt = datetime.utcfromtimestamp(timestamp_str) else: dt = timestamp_str # Convert to UTC and format for Athena TIMESTAMP # CRITICAL: Use proper timestamp format for Iceberg tables return dt.strftime("%Y-%m-%d %H:%M:%S") except Exception as e: logger.error("Failed to parse timestamp", timestamp=timestamp_str, error=str(e)) # Fallback to current time minus 1 hour return (datetime.utcnow() - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") def _build_time_filter(self, start_time: Optional[str], end_time: Optional[str]) -> str: """ Build proper timestamp filter for Iceberg tables. Critical fix: Prevents TYPE_MISMATCH errors. """ filters = [] if start_time: start_ts = self._parse_timestamp(start_time) # CRITICAL: Use TIMESTAMP cast to avoid bigint comparison errors filters.append(f"time >= TIMESTAMP '{start_ts}'") if end_time: end_ts = self._parse_timestamp(end_time) # CRITICAL: Use TIMESTAMP cast to avoid bigint comparison errors filters.append(f"time <= TIMESTAMP '{end_ts}'") return " AND ".join(filters) if filters else "1=1" def validate_ip_address(self, ip: str) -> bool: import ipaddress try: ipaddress.ip_address(ip) return True except ValueError: return False def sanitize_query_parameter(self, param: str) -> str: # Basic SQL injection prevention dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"] sanitized = param for char in dangerous_chars: sanitized = sanitized.replace(char, "") return sanitized.strip()[:255] # Limit length

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kebabmane/asl-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server