from typing import Any, Dict, List
import structlog
from ..athena.query_builder import OCSFQueryBuilder
from ..aws.table_discovery import SecurityLakeTableDiscovery
from .base_tool import BaseTool
logger = structlog.get_logger(__name__)
class DataSourcesTool(BaseTool):
def __init__(self, settings):
super().__init__(settings)
self.query_builder = OCSFQueryBuilder(
settings.security_lake_database,
settings.aws_region,
settings.aws_profile
)
self.table_discovery = SecurityLakeTableDiscovery(
settings.aws_region,
settings.aws_profile
)
async def execute(self, include_schema: bool = False) -> Dict[str, Any]:
try:
logger.info("Discovering available Security Lake data sources")
# Use the new table discovery system
discovery_summary = self.table_discovery.get_discovery_summary(
self.settings.security_lake_database
)
if 'error' in discovery_summary:
return self._format_error_response(
"Failed to discover Security Lake resources",
discovery_summary['error']
)
# Process each discovered data source
processed_sources = []
for source_name, source_info in discovery_summary.get('data_sources', {}).items():
for table_name in source_info['tables']:
# Get detailed table information
table_details = self.table_discovery.get_table_schema_info(
table_name, self.settings.security_lake_database
)
# Check data availability
availability = self.table_discovery.check_data_source_availability(
table_name, self.settings.security_lake_database
)
source_entry = {
"table_name": table_name,
"source_type": source_name,
"ocsf_version": table_details.get('ocsf_version'),
"status": availability.get('status'),
"has_data": availability.get('has_data', False),
"table_format": availability.get('table_format', 'unknown'),
"location": availability.get('location', ''),
"column_count": table_details.get('column_count', 0),
"partition_keys": table_details.get('partition_keys', []),
"has_nested_structures": table_details.get('has_nested_structures', False)
}
# Add schema details if requested
if include_schema:
source_entry["schema"] = {
"ocsf_fields": table_details.get('ocsf_fields', {}),
"storage_format": table_details.get('storage_format'),
"partition_info": table_details.get('partition_keys', [])
}
processed_sources.append(source_entry)
# Sort by data source type and OCSF version
processed_sources.sort(
key=lambda x: (x.get("source_type", ""), x.get("ocsf_version", ""))
)
metadata = {
"database": self.settings.security_lake_database,
"total_data_sources": discovery_summary.get('total_data_sources', 0),
"ocsf_versions": discovery_summary.get('ocsf_versions', []),
"recommendations": discovery_summary.get('recommendations', []),
"data_source_summary": self._generate_enhanced_summary(processed_sources),
"discovery_timestamp": discovery_summary.get('timestamp')
}
return self._format_success_response(processed_sources, metadata)
except Exception as e:
return self._format_error_response(
"Failed to list data sources",
str(e)
)
async def _process_table_info(
self,
table: Dict[str, Any],
include_schema: bool
) -> Dict[str, Any]:
table_name = table.get("name", "")
# Parse table name to extract source information
source_info = self._parse_table_name(table_name)
table_info = {
"table_name": table_name,
"source_type": source_info.get("source_type"),
"region": source_info.get("region"),
"data_version": source_info.get("data_version"),
"ocsf_version": source_info.get("ocsf_version"),
"table_type": table.get("type"),
"column_count": table.get("columns", 0),
"last_accessed": table.get("last_accessed"),
"location": table.get("location"),
"status": "active" if table.get("location") else "inactive"
}
# Add schema information if requested
if include_schema:
try:
schema_info = await self._get_table_schema(table_name)
table_info["schema"] = schema_info
except Exception as e:
logger.warning(
"Failed to get schema for table",
table_name=table_name,
error=str(e)
)
table_info["schema"] = {"error": "Schema unavailable"}
# Add data freshness information
table_info["data_freshness"] = await self._check_data_freshness(table_name)
return table_info
def _parse_table_name(self, table_name: str) -> Dict[str, str]:
parts = table_name.split("_")
# Try to extract meaningful information from table name
# Example: amazon_security_lake_table_us_east_1_cloudtrail_mgmt_1_0
source_info = {
"source_type": "unknown",
"region": "unknown",
"data_version": "unknown",
"ocsf_version": "unknown"
}
try:
if "cloudtrail" in table_name.lower():
source_info["source_type"] = "CloudTrail"
elif "guardduty" in table_name.lower():
source_info["source_type"] = "GuardDuty"
elif "vpcflow" in table_name.lower() or "vpc_flow" in table_name.lower():
source_info["source_type"] = "VPC Flow Logs"
elif "securityhub" in table_name.lower() or "security_hub" in table_name.lower():
source_info["source_type"] = "Security Hub"
elif "route53" in table_name.lower():
source_info["source_type"] = "Route 53"
elif "waf" in table_name.lower():
source_info["source_type"] = "AWS WAF"
elif "config" in table_name.lower():
source_info["source_type"] = "AWS Config"
# Extract region (look for patterns like us_east_1, eu_west_1, etc.)
for i, part in enumerate(parts):
if part in ["us", "eu", "ap", "ca", "sa"] and i + 2 < len(parts):
region_parts = parts[i:i+3]
source_info["region"] = "_".join(region_parts).replace("_", "-")
break
# Try to extract version information
if len(parts) >= 2:
if parts[-2].isdigit() and parts[-1].isdigit():
source_info["ocsf_version"] = f"{parts[-2]}.{parts[-1]}"
except Exception as e:
logger.warning("Failed to parse table name", table_name=table_name, error=str(e))
return source_info
async def _get_table_schema(self, table_name: str) -> Dict[str, Any]:
query, parameters = self.query_builder.build_table_schema_query(table_name)
results = await self.athena_client.execute_query(query, parameters)
schema_info = {
"columns": [],
"total_columns": len(results)
}
for row in results:
column_info = {
"name": row.get("column_name"),
"type": row.get("data_type"),
"nullable": row.get("is_nullable") == "YES",
"default": row.get("column_default"),
"position": row.get("ordinal_position")
}
schema_info["columns"].append(column_info)
# Identify OCSF standard fields
ocsf_fields = self._identify_ocsf_fields(schema_info["columns"])
schema_info["ocsf_compliance"] = {
"standard_fields_found": len(ocsf_fields),
"standard_fields": ocsf_fields,
"compliance_percentage": (len(ocsf_fields) / len(schema_info["columns"]) * 100) if schema_info["columns"] else 0
}
return schema_info
def _identify_ocsf_fields(self, columns: List[Dict[str, Any]]) -> List[str]:
# Common OCSF fields to look for
standard_ocsf_fields = [
"time", "type_name", "type_uid", "severity", "severity_id",
"activity_name", "activity_id", "category_name", "category_uid",
"class_name", "class_uid", "metadata", "cloud", "src_endpoint",
"dst_endpoint", "actor", "device", "finding", "resources"
]
found_fields = []
column_names = [col.get("name", "").lower() for col in columns]
for ocsf_field in standard_ocsf_fields:
if ocsf_field.lower() in column_names:
found_fields.append(ocsf_field)
return found_fields
async def _check_data_freshness(self, table_name: str) -> Dict[str, Any]:
try:
# Query for the most recent data timestamp
query = f"""
SELECT
MAX(time) as latest_timestamp,
MIN(time) as earliest_timestamp,
COUNT(*) as total_records
FROM {self.settings.security_lake_database}.{table_name}
WHERE time IS NOT NULL
LIMIT 1
"""
results = await self.athena_client.execute_query(query, [])
if results:
result = results[0]
latest = result.get("latest_timestamp")
earliest = result.get("earliest_timestamp")
total_records = result.get("total_records", 0)
freshness_info = {
"latest_data": latest,
"earliest_data": earliest,
"total_records": total_records,
"status": "active" if latest else "no_data"
}
# Calculate data age
if latest:
from datetime import datetime
from dateutil import parser as date_parser
try:
latest_dt = date_parser.parse(latest)
age_hours = (datetime.utcnow() - latest_dt.replace(tzinfo=None)).total_seconds() / 3600
if age_hours < 24:
freshness_info["freshness"] = "fresh"
elif age_hours < 168: # 1 week
freshness_info["freshness"] = "recent"
else:
freshness_info["freshness"] = "stale"
freshness_info["age_hours"] = round(age_hours, 2)
except Exception:
freshness_info["freshness"] = "unknown"
return freshness_info
else:
return {"status": "no_data", "freshness": "unknown"}
except Exception as e:
logger.warning(
"Failed to check data freshness",
table_name=table_name,
error=str(e)
)
return {"status": "unknown", "error": str(e)}
def _generate_source_summary(self, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
summary = {
"source_types": {},
"regions": {},
"total_active_sources": 0,
"freshness_status": {}
}
for source in sources:
# Count by source type
source_type = source.get("source_type", "unknown")
summary["source_types"][source_type] = summary["source_types"].get(source_type, 0) + 1
# Count by region
region = source.get("region", "unknown")
summary["regions"][region] = summary["regions"].get(region, 0) + 1
# Count active sources
if source.get("status") == "active":
summary["total_active_sources"] += 1
# Track freshness
freshness = source.get("data_freshness", {}).get("freshness", "unknown")
summary["freshness_status"][freshness] = summary["freshness_status"].get(freshness, 0) + 1
return summary
def _generate_enhanced_summary(self, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Generate enhanced summary using the new discovery system"""
summary = {
"source_types": {},
"ocsf_versions": {},
"total_active_sources": 0,
"data_status": {},
"table_formats": {}
}
for source in sources:
# Count by source type
source_type = source.get("source_type", "unknown")
summary["source_types"][source_type] = summary["source_types"].get(source_type, 0) + 1
# Count by OCSF version
ocsf_version = source.get("ocsf_version", "unknown")
summary["ocsf_versions"][ocsf_version] = summary["ocsf_versions"].get(ocsf_version, 0) + 1
# Count active sources with data
if source.get("has_data"):
summary["total_active_sources"] += 1
# Track data status
status = source.get("status", "unknown")
summary["data_status"][status] = summary["data_status"].get(status, 0) + 1
# Track table formats
table_format = source.get("table_format", "unknown")
if table_format != "unknown":
summary["table_formats"][table_format] = summary["table_formats"].get(table_format, 0) + 1
return summary