athena.py•31.2 kB
"""Athena service module for AWS Security MCP.
This module provides functions for interacting with AWS Athena for running SQL queries
on CloudTrail logs, VPC Flow Logs, and other security-related datasets.
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime
import time
import json
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
from aws_security_mcp.services.base import get_client, handle_aws_error, format_pagination_response
from aws_security_mcp.config import config
# Configure logging
logger = logging.getLogger(__name__)
# Athena query execution states
QUERY_STATES = {
'QUEUED': 'queued',
'RUNNING': 'running',
'SUCCEEDED': 'succeeded',
'FAILED': 'failed',
'CANCELLED': 'cancelled'
}
def serialize_datetime_objects(obj: Any) -> Any:
"""Recursively convert datetime objects to ISO format strings for JSON serialization.
Args:
obj: Object that may contain datetime objects
Returns:
Object with datetime objects converted to strings
"""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, dict):
return {key: serialize_datetime_objects(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [serialize_datetime_objects(item) for item in obj]
else:
return obj
def list_data_catalogs(
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None
) -> Dict[str, Any]:
"""List all available data catalogs.
Args:
session_context: Optional session key for cross-account access
max_items: Maximum number of catalogs to return
next_token: Pagination token for next page of results
Returns:
Dict containing list of data catalogs with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
params = {}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
response = client.list_data_catalogs(**params)
catalogs = response.get('DataCatalogsSummary', [])
next_token = response.get('NextToken')
# Convert datetime objects to strings for JSON serialization
catalogs = serialize_datetime_objects(catalogs)
return format_pagination_response(
items=catalogs,
next_token=next_token
)
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error listing Athena data catalogs: {str(e)}")
return format_pagination_response(items=[], next_token=None)
def list_databases(
catalog_name: Optional[str] = None,
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None
) -> Dict[str, Any]:
"""List all databases in the specified data catalog.
Args:
catalog_name: Name of the data catalog (if None, defaults to AwsDataCatalog)
session_context: Optional session key for cross-account access
max_items: Maximum number of databases to return
next_token: Pagination token for next page of results
Returns:
Dict containing list of databases with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
# Default to AwsDataCatalog if not specified, but allow override
if catalog_name is None:
catalog_name = 'AwsDataCatalog'
params = {
'CatalogName': catalog_name
}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
response = client.list_databases(**params)
databases = response.get('DatabaseList', [])
next_token = response.get('NextToken')
# Convert datetime objects to strings for JSON serialization
databases = serialize_datetime_objects(databases)
return format_pagination_response(
items=databases,
next_token=next_token
)
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error listing Athena databases in catalog {catalog_name}: {str(e)}")
return format_pagination_response(items=[], next_token=None)
def list_table_metadata(
database_name: str,
catalog_name: Optional[str] = None,
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None,
expression: Optional[str] = None
) -> Dict[str, Any]:
"""List table metadata for tables in the specified database.
Args:
database_name: Name of the database
catalog_name: Name of the data catalog (if None, defaults to AwsDataCatalog)
session_context: Optional session key for cross-account access
max_items: Maximum number of tables to return
next_token: Pagination token for next page of results
expression: Optional regex expression to filter table names
Returns:
Dict containing list of table metadata with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
# Default to AwsDataCatalog if not specified, but allow override
if catalog_name is None:
catalog_name = 'AwsDataCatalog'
params = {
'CatalogName': catalog_name,
'DatabaseName': database_name
}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
if expression:
params['Expression'] = expression
response = client.list_table_metadata(**params)
tables = response.get('TableMetadataList', [])
next_token = response.get('NextToken')
# Convert datetime objects to strings for JSON serialization
tables = serialize_datetime_objects(tables)
return format_pagination_response(
items=tables,
next_token=next_token
)
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error listing table metadata for database {database_name} in catalog {catalog_name}: {str(e)}")
return format_pagination_response(items=[], next_token=None)
def get_table_metadata(
database_name: str,
table_name: str,
catalog_name: Optional[str] = None,
session_context: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get detailed metadata for a specific table.
Args:
database_name: Name of the database
table_name: Name of the table
catalog_name: Name of the data catalog (if None, defaults to AwsDataCatalog)
session_context: Optional session key for cross-account access
Returns:
Dict containing detailed table metadata or None if error
"""
try:
client = get_client('athena', session_context=session_context)
# Default to AwsDataCatalog if not specified, but allow override
if catalog_name is None:
catalog_name = 'AwsDataCatalog'
response = client.get_table_metadata(
CatalogName=catalog_name,
DatabaseName=database_name,
TableName=table_name
)
table_metadata = response.get('TableMetadata')
# Convert datetime objects to strings for JSON serialization
if table_metadata:
table_metadata = serialize_datetime_objects(table_metadata)
return table_metadata
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error getting table metadata for {catalog_name}.{database_name}.{table_name}: {str(e)}")
return None
def start_query_execution(
query_string: str,
database: str,
output_location: Optional[str] = None,
catalog_name: Optional[str] = None,
workgroup: Optional[str] = None,
description: Optional[str] = None,
session_context: Optional[str] = None
) -> Optional[str]:
"""Start execution of an Athena SQL query.
Args:
query_string: The SQL query string to execute
database: Database to run the query against
output_location: S3 location for query results (if None, uses default from config)
catalog_name: Name of the data catalog (if None, uses default from config)
workgroup: Athena workgroup to use (if None, uses default from config)
description: Optional description for the query
session_context: Optional session key for cross-account access
Returns:
Query execution ID if successful, None if error
"""
try:
client = get_client('athena', session_context=session_context)
# Use config defaults if not specified
if output_location is None:
output_location = config.athena.default_output_location
if catalog_name is None:
catalog_name = config.athena.default_catalog
if workgroup is None:
workgroup = config.athena.default_workgroup
query_context = {
'Database': database,
'Catalog': catalog_name
}
result_configuration = {
'OutputLocation': output_location
}
params = {
'QueryString': query_string,
'QueryExecutionContext': query_context,
'ResultConfiguration': result_configuration,
'WorkGroup': workgroup
}
if description:
params['Description'] = description
response = client.start_query_execution(**params)
query_execution_id = response.get('QueryExecutionId')
logger.info(f"Started Athena query execution: {query_execution_id} in catalog {catalog_name}")
return query_execution_id
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error starting Athena query execution: {str(e)}")
return None
def get_query_execution(
query_execution_id: str,
session_context: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get the status and details of a query execution.
Args:
query_execution_id: The query execution ID
session_context: Optional session key for cross-account access
Returns:
Dict containing query execution details or None if error
"""
try:
client = get_client('athena', session_context=session_context)
response = client.get_query_execution(
QueryExecutionId=query_execution_id
)
query_execution = response.get('QueryExecution')
# Convert datetime objects to strings for JSON serialization
if query_execution:
query_execution = serialize_datetime_objects(query_execution)
return query_execution
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error getting query execution {query_execution_id}: {str(e)}")
return None
def wait_for_query_completion(
query_execution_id: str,
max_wait_time: int = 300,
poll_interval: int = 2,
session_context: Optional[str] = None
) -> Tuple[str, Optional[Dict[str, Any]]]:
"""Wait for a query to complete execution.
Args:
query_execution_id: The query execution ID
max_wait_time: Maximum time to wait in seconds (default: 300)
poll_interval: Polling interval in seconds (default: 2)
session_context: Optional session key for cross-account access
Returns:
Tuple of (final_status, query_execution_details)
"""
start_time = time.time()
while time.time() - start_time < max_wait_time:
query_execution = get_query_execution(query_execution_id, session_context)
if not query_execution:
return 'ERROR', None
status = query_execution.get('Status', {})
state = status.get('State')
if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
logger.info(f"Query {query_execution_id} completed with status: {state}")
return state, query_execution
logger.debug(f"Query {query_execution_id} still {state}, waiting...")
time.sleep(poll_interval)
logger.warning(f"Query {query_execution_id} timed out after {max_wait_time} seconds")
return 'TIMEOUT', get_query_execution(query_execution_id, session_context)
def get_query_results(
query_execution_id: str,
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None,
query_result_type: str = 'DATA_ROWS'
) -> Dict[str, Any]:
"""Get the results of a completed query execution.
Args:
query_execution_id: The query execution ID
session_context: Optional session key for cross-account access
max_items: Maximum number of result rows to return
next_token: Pagination token for next page of results
query_result_type: Type of result to return ('DATA_ROWS' or 'DATA_MANIFEST') - ignored if not supported
Returns:
Dict containing query results with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
params = {
'QueryExecutionId': query_execution_id
}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
# Try with QueryResultType first (newer boto3 versions)
try:
params['QueryResultType'] = query_result_type
response = client.get_query_results(**params)
except Exception as e:
# Fall back to basic call if QueryResultType not supported
if 'QueryResultType' in str(e):
logger.warning(f"QueryResultType parameter not supported in this boto3 version, falling back to basic call")
params.pop('QueryResultType', None)
response = client.get_query_results(**params)
else:
raise e
result_set = response.get('ResultSet', {})
rows = result_set.get('Rows', [])
next_token = response.get('NextToken')
update_count = response.get('UpdateCount', 0)
# Extract metadata about columns
metadata = result_set.get('ResultSetMetadata', {})
column_info = metadata.get('ColumnInfo', [])
return {
'items': rows,
'next_token': next_token,
'is_truncated': next_token is not None,
'count': len(rows),
'column_info': column_info,
'metadata': metadata,
'update_count': update_count,
'query_result_type': query_result_type # Note: May fallback to 'DATA_ROWS' if not supported
}
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error getting query results for {query_execution_id} (type: {query_result_type}): {str(e)}")
return format_pagination_response(items=[], next_token=None)
def get_query_results_paginated(
query_execution_id: str,
session_context: Optional[str] = None,
query_result_type: str = 'DATA_ROWS'
) -> Dict[str, Any]:
"""Get all results from a query execution using AWS paginator.
Args:
query_execution_id: The query execution ID
session_context: Optional session key for cross-account access
query_result_type: Type of result to return ('DATA_ROWS' or 'DATA_MANIFEST') - ignored if not supported
Returns:
Dict containing all query results (unpaginated)
"""
try:
client = get_client('athena', session_context=session_context)
# Use AWS official paginator
paginator = client.get_paginator('get_query_results')
# Try with QueryResultType first (newer boto3 versions)
try:
page_iterator = paginator.paginate(
QueryExecutionId=query_execution_id,
QueryResultType=query_result_type
)
except Exception as e:
# Fall back to basic call if QueryResultType not supported
if 'QueryResultType' in str(e):
logger.warning(f"QueryResultType parameter not supported in paginator, falling back to basic call")
page_iterator = paginator.paginate(
QueryExecutionId=query_execution_id
)
else:
raise e
all_rows = []
column_info = []
metadata = {}
update_count = 0
for page in page_iterator:
result_set = page.get('ResultSet', {})
rows = result_set.get('Rows', [])
all_rows.extend(rows)
# Get metadata from first page
if not column_info:
metadata = result_set.get('ResultSetMetadata', {})
column_info = metadata.get('ColumnInfo', [])
# Get update count if available
if page.get('UpdateCount'):
update_count = page.get('UpdateCount', 0)
return {
'items': all_rows,
'next_token': None,
'is_truncated': False,
'count': len(all_rows),
'column_info': column_info,
'metadata': metadata,
'update_count': update_count,
'query_result_type': query_result_type # Note: May fallback to 'DATA_ROWS' if not supported
}
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error getting paginated query results for {query_execution_id} (type: {query_result_type}): {str(e)}")
return format_pagination_response(items=[], next_token=None)
def list_query_executions(
workgroup: str = 'primary',
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None
) -> Dict[str, Any]:
"""List query executions in the specified workgroup.
Args:
workgroup: Athena workgroup name (default: primary)
session_context: Optional session key for cross-account access
max_items: Maximum number of executions to return
next_token: Pagination token for next page of results
Returns:
Dict containing list of query executions with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
params = {
'WorkGroup': workgroup
}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
response = client.list_query_executions(**params)
query_execution_ids = response.get('QueryExecutionIds', [])
next_token = response.get('NextToken')
return format_pagination_response(
items=query_execution_ids,
next_token=next_token
)
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error listing query executions: {str(e)}")
return format_pagination_response(items=[], next_token=None)
def execute_query_async(
query_string: str,
database: str,
output_location: Optional[str] = None,
catalog_name: Optional[str] = None,
workgroup: Optional[str] = None,
description: Optional[str] = None,
session_context: Optional[str] = None
) -> Dict[str, Any]:
"""Execute a query asynchronously and return execution ID immediately.
This is the proper MCP approach - return the query execution ID immediately
and let the client poll for status and results separately.
Args:
query_string: The SQL query string to execute
database: Database to run the query against
output_location: S3 location for query results (if None, uses default from config)
catalog_name: Name of the data catalog (if None, uses default from config)
workgroup: Athena workgroup to use (if None, uses default from config)
description: Optional description for the query
session_context: Optional session key for cross-account access
Returns:
Dict containing query execution ID and initial status
"""
try:
# Use config defaults if not specified
if output_location is None:
output_location = config.athena.default_output_location
if catalog_name is None:
catalog_name = config.athena.default_catalog
if workgroup is None:
workgroup = config.athena.default_workgroup
# Start query execution
query_execution_id = start_query_execution(
query_string=query_string,
database=database,
output_location=output_location,
catalog_name=catalog_name,
workgroup=workgroup,
description=description,
session_context=session_context
)
if not query_execution_id:
return {
'success': False,
'error': 'Failed to start query execution',
'query_execution_id': None,
'status': 'FAILED'
}
# Get initial status
query_execution = get_query_execution(query_execution_id, session_context)
initial_status = 'QUEUED'
if query_execution:
status = query_execution.get('Status', {})
initial_status = status.get('State', 'QUEUED')
return {
'success': True,
'query_execution_id': query_execution_id,
'status': initial_status,
'message': f'Query submitted successfully. Use query_execution_id to check status and get results.'
}
except Exception as e:
logger.error(f"Error executing query asynchronously: {str(e)}")
return {
'success': False,
'error': str(e),
'query_execution_id': None,
'status': 'FAILED'
}
def is_query_complete(
query_execution_id: str,
session_context: Optional[str] = None
) -> Tuple[bool, str, Optional[str]]:
"""Check if a query execution is complete and ready for results.
Args:
query_execution_id: The query execution ID
session_context: Optional session key for cross-account access
Returns:
Tuple of (is_complete, status, error_message)
"""
try:
query_execution = get_query_execution(query_execution_id, session_context)
if not query_execution:
return False, 'ERROR', 'Query execution not found'
status = query_execution.get('Status', {})
state = status.get('State', 'UNKNOWN')
if state in ['SUCCEEDED']:
return True, state, None
elif state in ['FAILED', 'CANCELLED']:
error_msg = status.get('StateChangeReason', f'Query {state.lower()}')
return True, state, error_msg
else:
# QUEUED, RUNNING
return False, state, None
except Exception as e:
logger.error(f"Error checking query completion for {query_execution_id}: {str(e)}")
return False, 'ERROR', str(e)
def stop_query_execution(
query_execution_id: str,
session_context: Optional[str] = None
) -> bool:
"""Stop a running query execution.
Args:
query_execution_id: The query execution ID to stop
session_context: Optional session key for cross-account access
Returns:
True if successful, False otherwise
"""
try:
client = get_client('athena', session_context=session_context)
client.stop_query_execution(
QueryExecutionId=query_execution_id
)
logger.info(f"Stopped query execution: {query_execution_id}")
return True
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error stopping query execution {query_execution_id}: {str(e)}")
return False
def list_workgroups(
session_context: Optional[str] = None,
max_items: Optional[int] = None,
next_token: Optional[str] = None
) -> Dict[str, Any]:
"""List all Athena workgroups.
Args:
session_context: Optional session key for cross-account access
max_items: Maximum number of workgroups to return
next_token: Pagination token for next page of results
Returns:
Dict containing list of workgroups with pagination info
"""
try:
client = get_client('athena', session_context=session_context)
params = {}
if max_items:
params['MaxResults'] = max_items
if next_token:
params['NextToken'] = next_token
response = client.list_work_groups(**params)
workgroups = response.get('WorkGroups', [])
next_token = response.get('NextToken')
# Convert datetime objects to strings for JSON serialization
workgroups = serialize_datetime_objects(workgroups)
return format_pagination_response(
items=workgroups,
next_token=next_token
)
except (ClientError, NoCredentialsError) as e:
logger.error(f"Error listing Athena workgroups: {str(e)}")
return format_pagination_response(items=[], next_token=None)
def validate_s3_output_location(output_location: str) -> Tuple[bool, Optional[str]]:
"""Validate S3 output location for Athena queries.
Args:
output_location: S3 output location to validate
Returns:
Tuple of (is_valid, error_message)
"""
if not output_location:
return False, "Output location cannot be empty"
if not output_location.startswith('s3://'):
return False, "Output location must be a valid S3 URI starting with 's3://' (e.g., s3://my-bucket/athena-results/)"
# Basic S3 URI structure validation
if output_location == 's3://':
return False, "Output location must include bucket name (e.g., s3://my-bucket/athena-results/)"
# Extract bucket and path
s3_parts = output_location[5:].split('/', 1) # Remove 's3://' prefix
bucket_name = s3_parts[0]
if not bucket_name:
return False, "S3 bucket name cannot be empty in output location"
# Validate bucket naming rules (basic check)
if len(bucket_name) < 3 or len(bucket_name) > 63:
return False, "S3 bucket name must be between 3 and 63 characters long"
# Should end with / for directory-like structure
if not output_location.endswith('/'):
return False, "Output location should end with '/' to specify a directory (e.g., s3://my-bucket/athena-results/)"
return True, None
def validate_query_parameters(
query_string: str,
database: str,
output_location: str
) -> Tuple[bool, Optional[str]]:
"""Validate query parameters before execution.
Allows safe read-only operations: SELECT, SHOW, DESCRIBE, EXPLAIN.
Blocks potentially dangerous operations: DROP, DELETE, TRUNCATE, ALTER, CREATE, INSERT, UPDATE.
Args:
query_string: The SQL query string to validate
database: Database name to validate
output_location: S3 output location to validate
Returns:
Tuple of (is_valid, error_message)
"""
# Basic validation
if not query_string or not query_string.strip():
return False, "Query string cannot be empty"
if not database or not database.strip():
return False, "Database name cannot be empty"
# Validate S3 output location
s3_valid, s3_error = validate_s3_output_location(output_location)
if not s3_valid:
return False, s3_error
# Check for dangerous operations (basic safety)
dangerous_keywords = ['DROP', 'DELETE', 'TRUNCATE', 'ALTER', 'CREATE', 'INSERT', 'UPDATE']
query_upper = query_string.upper().strip()
# Allow safe read-only operations for database exploration
# SHOW: SHOW TABLES, SHOW DATABASES, SHOW COLUMNS FROM table, etc.
# DESCRIBE/DESC: DESCRIBE table_name, DESC table_name
# EXPLAIN: EXPLAIN SELECT ... (query planning)
safe_read_operations = ['SHOW', 'DESCRIBE', 'DESC', 'EXPLAIN']
# Check if query starts with a safe read operation
is_safe_read_operation = any(query_upper.startswith(op) for op in safe_read_operations)
# If it's not a safe read operation, check for dangerous keywords
if not is_safe_read_operation:
for keyword in dangerous_keywords:
if keyword in query_upper:
return False, f"Query contains potentially dangerous keyword: {keyword}. Only SELECT, SHOW, DESCRIBE, and EXPLAIN queries are allowed for security."
# Special handling for SHOW commands - these are always safe and don't need additional restrictions
if query_upper.startswith('SHOW'):
return True, None
# Special handling for DESCRIBE/DESC commands - these are always safe
if query_upper.startswith(('DESCRIBE', 'DESC')):
return True, None
# Special handling for EXPLAIN commands - these are always safe
if query_upper.startswith('EXPLAIN'):
return True, None
# For SELECT queries, apply additional validation for performance and cost control
if query_upper.startswith('SELECT'):
# Check if querying large tables that need date filtering
large_table_patterns = ['CLOUDTRAIL', 'VPC_FLOW_LOGS', 'VPCFLOWLOGS', 'ACCESS_LOGS', 'ALB_LOGS', 'ELB_LOGS']
is_querying_large_table = any(
pattern in query_upper for pattern in large_table_patterns
)
# For large tables, recommend date/time filtering
if is_querying_large_table:
has_time_filter = any(
filter_keyword in query_upper
for filter_keyword in ['WHERE', 'LIMIT', 'DATE', 'TIMESTAMP', 'YEAR', 'MONTH', 'DAY']
)
if not has_time_filter:
return False, "Queries on large tables (CloudTrail/VPC Flow Logs/Access Logs) should include date/time filters (WHERE year='2024' AND month='01') or LIMIT clause to control result size and costs"
# General recommendation for LIMIT clause (warning only, not blocking)
if 'LIMIT' not in query_upper and not any(agg in query_upper for agg in ['COUNT(', 'SUM(', 'AVG(', 'GROUP BY']):
logger.warning("Consider adding a LIMIT clause to prevent unexpectedly large result sets")
return True, None