from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import structlog
from ..athena.query_builder import OCSFQueryBuilder
from .base_tool import BaseTool
logger = structlog.get_logger(__name__)
class UniversalSearchTool(BaseTool):
"""Universal Security Lake search tool that adapts to available data sources"""
def __init__(self, settings):
super().__init__(settings)
self.query_builder = OCSFQueryBuilder(
settings.security_lake_database,
settings.aws_region,
settings.aws_profile
)
async def execute(
self,
query_type: str,
filters: Dict[str, Any],
data_sources: Optional[List[str]] = None,
limit: int = 100
) -> Dict[str, Any]:
"""
Universal search interface that adapts to available data sources
Args:
query_type: Type of search ('findings', 'network', 'api_calls', etc.)
filters: Search filters (ip_address, severity, time_range, etc.)
data_sources: Specific data sources to search (auto-detect if None)
limit: Maximum results to return
"""
try:
# Discover available tables
available_tables = self.query_builder._get_available_tables()
if not available_tables:
return self._format_error_response(
"No Security Lake data sources found",
"Please verify Security Lake is enabled and configured in your AWS account"
)
# Auto-select data sources if not specified
if data_sources is None:
data_sources = self.query_builder.table_discovery.auto_select_data_sources(
query_type, available_tables
)
if not data_sources:
return self._format_error_response(
f"No suitable data sources for query type '{query_type}'",
f"Available sources: {list(available_tables.keys())}"
)
logger.info(
"Starting universal search",
query_type=query_type,
data_sources=data_sources,
filters=filters
)
# Execute search based on query type
if query_type == 'findings':
return await self._search_findings(filters, data_sources, limit)
elif query_type == 'network':
return await self._search_network_activity(filters, data_sources, limit)
elif query_type == 'api_calls':
return await self._search_api_activity(filters, data_sources, limit)
elif query_type == 'ip_search':
return await self._search_by_ip(filters, data_sources, limit)
else:
return await self._generic_search(query_type, filters, data_sources, limit)
except Exception as e:
return self._format_error_response(
"Universal search failed",
str(e)
)
async def _search_findings(
self,
filters: Dict[str, Any],
data_sources: List[str],
limit: int
) -> Dict[str, Any]:
"""Search for security findings across available sources"""
# Prefer Security Hub and GuardDuty for findings
finding_sources = [s for s in data_sources if s in ['security_hub', 'guardduty']]
if not finding_sources:
return self._format_error_response(
"No finding data sources available",
f"Available sources: {data_sources}"
)
results = []
for source in finding_sources:
try:
source_results = await self._query_findings_source(source, filters, limit)
results.extend(source_results)
except Exception as e:
logger.warning(f"Failed to query {source} for findings", error=str(e))
continue
return self._format_success_response(
results[:limit],
{
"query_type": "findings",
"sources_used": finding_sources,
"total_sources_attempted": len(finding_sources)
}
)
async def _search_network_activity(
self,
filters: Dict[str, Any],
data_sources: List[str],
limit: int
) -> Dict[str, Any]:
"""Search for network activity across available sources"""
# Prefer VPC Flow, DNS, Route53 for network activity
network_sources = [s for s in data_sources if s in ['vpc_flow', 'dns', 'route53']]
if not network_sources:
return self._format_error_response(
"No network data sources available",
f"Available sources: {data_sources}"
)
results = []
for source in network_sources:
try:
source_results = await self._query_network_source(source, filters, limit)
results.extend(source_results)
except Exception as e:
logger.warning(f"Failed to query {source} for network activity", error=str(e))
continue
return self._format_success_response(
results[:limit],
{
"query_type": "network",
"sources_used": network_sources,
"total_sources_attempted": len(network_sources)
}
)
async def _search_by_ip(
self,
filters: Dict[str, Any],
data_sources: List[str],
limit: int
) -> Dict[str, Any]:
"""Search by IP address using the improved IP search tool"""
ip_address = filters.get('ip_address')
if not ip_address:
return self._format_error_response(
"IP address required for IP search",
"Please provide 'ip_address' in filters"
)
try:
query, parameters = self.query_builder.build_ip_search_query(
ip_address=ip_address,
start_time=filters.get('start_time'),
end_time=filters.get('end_time'),
sources=data_sources,
limit=limit
)
results = await self.athena_client.execute_query(query, parameters)
return self._format_success_response(
results,
{
"query_type": "ip_search",
"ip_address": ip_address,
"sources_used": data_sources
}
)
except ValueError as e:
return self._format_error_response(
"IP search failed",
str(e)
)
async def _generic_search(
self,
query_type: str,
filters: Dict[str, Any],
data_sources: List[str],
limit: int
) -> Dict[str, Any]:
"""Generic search across all available data sources"""
results = []
sources_used = []
for source in data_sources:
try:
source_results = await self._query_generic_source(source, filters, limit)
if source_results:
results.extend(source_results)
sources_used.append(source)
except Exception as e:
logger.warning(f"Failed to query {source} for {query_type}", error=str(e))
continue
return self._format_success_response(
results[:limit],
{
"query_type": query_type,
"sources_used": sources_used,
"total_sources_attempted": len(data_sources)
}
)
async def _query_findings_source(
self,
source: str,
filters: Dict[str, Any],
limit: int
) -> List[Dict[str, Any]]:
"""Query a specific source for findings"""
available_tables = self.query_builder._get_available_tables()
table_name = self.query_builder.table_discovery.get_best_table_for_data_source(
source, available_tables
)
if not table_name:
return []
ocsf_version = self.query_builder.table_discovery.get_ocsf_version_from_table(table_name)
# Build basic findings query
if ocsf_version == "2.0":
base_fields = [
"time",
"severity",
"type_name",
"finding_info.title as title",
"finding_info.desc as description",
"metadata.product.name as product"
]
else:
base_fields = [
"time",
"severity",
"type_name",
"finding_title as title",
"finding_desc as description",
"metadata_product_name as product"
]
# Build WHERE conditions
where_conditions = []
parameters = []
if filters.get('severity'):
where_conditions.append("severity = ?")
parameters.append(filters['severity'])
if filters.get('start_time'):
start_ts = self.query_builder._parse_timestamp(filters['start_time'])
where_conditions.append(f"time >= TIMESTAMP '{start_ts}'")
if filters.get('end_time'):
end_ts = self.query_builder._parse_timestamp(filters['end_time'])
where_conditions.append(f"time <= TIMESTAMP '{end_ts}'")
# Add default time range if none specified
if not filters.get('start_time') and not filters.get('end_time'):
default_start = (datetime.utcnow() - timedelta(days=7)).strftime("%Y-%m-%d %H:%M:%S")
where_conditions.append(f"time >= TIMESTAMP '{default_start}'")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
query = f"""
SELECT {", ".join(base_fields)}
FROM "{self.query_builder.database}"."{table_name}"
WHERE {where_clause}
ORDER BY time DESC
LIMIT {limit}
"""
return await self.athena_client.execute_query(query, parameters)
async def _query_network_source(
self,
source: str,
filters: Dict[str, Any],
limit: int
) -> List[Dict[str, Any]]:
"""Query a specific source for network activity"""
# Implementation similar to _query_findings_source but for network data
# This would be customized based on the specific network data source
return []
async def _query_generic_source(
self,
source: str,
filters: Dict[str, Any],
limit: int
) -> List[Dict[str, Any]]:
"""Generic query for any data source"""
available_tables = self.query_builder._get_available_tables()
table_name = self.query_builder.table_discovery.get_best_table_for_data_source(
source, available_tables
)
if not table_name:
return []
# Simple query to get recent data
query = f"""
SELECT time, type_name, severity, metadata.product.name as product
FROM "{self.query_builder.database}"."{table_name}"
WHERE time >= TIMESTAMP '{(datetime.utcnow() - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")}'
ORDER BY time DESC
LIMIT {min(limit, 100)}
"""
try:
return await self.athena_client.execute_query(query, [])
except Exception:
# If the above query fails, try a simpler one
simple_query = f"""
SELECT *
FROM "{self.query_builder.database}"."{table_name}"
LIMIT {min(limit, 10)}
"""
return await self.athena_client.execute_query(simple_query, [])