"""
SQL query validator for MCP Athena Analytics server.
This module validates ad-hoc SQL queries before execution to prevent:
1. Destructive operations (DROP, DELETE, TRUNCATE, etc)
2. Schema modifications (CREATE, ALTER, INSERT, REPLACE)
3. Excessive resource usage (timeouts, row limits, scan limits)
## Validation strategy
### Ad-hoc queries (execute_query tool)
- Whitelist approach: only SELECT and WITH (CTE) allowed
- Negative regex: block forbidden keywords (case-insensitive)
- Resource limits: 100K rows, 600s timeout, 100GB scan warning
### Template queries (execute_template tool)
- Blacklist approach: specific dangerous templates blocked (see blacklist.py)
- Templates are pre-validated SQL code from files
- Same resource limits as ad-hoc queries
## Usage
from validator import validate_query, QueryValidationError
try:
validate_query('SELECT * FROM table')
except QueryValidationError as e:
print(f'Query rejected: {e}')
## Resource limits (from user requirements)
- MAX_ROWS: 100,000 (Relaxed tier: sufficient for analytical queries)
- MAX_TIMEOUT: 600 seconds (10 minutes - allows complex aggregations)
- SCAN_WARNING_THRESHOLD: 100 GB (log warning but don't block)
These limits match 'Relaxed' tier from planning phase.
"""
import re
# Resource limits (Relaxed tier - from user requirements)
MAX_ROWS = 100_000 # Truncate results if more
MAX_TIMEOUT_SEC = 600 # 10 minutes max execution time
SCAN_WARNING_GB = 100 # Log warning if scan exceeds this (but don't block)
# Constants for byte conversions
BYTES_IN_KB = 1024
# Forbidden SQL keywords (destructive/modifying operations)
FORBIDDEN_KEYWORDS = {
'DROP',
'DELETE',
'TRUNCATE',
'INSERT',
'UPDATE',
'CREATE',
'ALTER',
'REPLACE',
}
# Allowed query start patterns (case-insensitive)
# WITH for CTEs, SELECT for queries
ALLOWED_START_PATTERN = re.compile(r'^\s*(WITH|SELECT)\b', re.IGNORECASE)
# Forbidden keywords pattern (word boundaries to avoid false positives)
# Example: 'DESCRIPTION' should not match 'DROP' inside it
FORBIDDEN_PATTERN = re.compile(r'\b(' + '|'.join(FORBIDDEN_KEYWORDS) + r')\b', re.IGNORECASE)
class QueryValidationError(Exception):
"""Raised when query fails validation checks."""
pass
def validate_query(sql_query: str) -> None:
"""
Validate ad-hoc SQL query before execution.
Checks:
1. Query starts with SELECT or WITH (whitelist)
2. No forbidden keywords present (blacklist)
3. Query is not empty or whitespace-only
Args:
sql_query: SQL query string to validate
Raises:
QueryValidationError: If query fails any validation check
Example:
>>> validate_query('SELECT * FROM users') # OK
>>> validate_query('DROP TABLE users') # Raises QueryValidationError
"""
if not sql_query or not sql_query.strip():
raise QueryValidationError('Query cannot be empty')
if not ALLOWED_START_PATTERN.match(sql_query):
raise QueryValidationError(
'Query must start with SELECT or WITH. '
'Other operations (CREATE, DROP, INSERT, etc) are forbidden.'
)
forbidden_match = FORBIDDEN_PATTERN.search(sql_query)
if forbidden_match:
keyword = forbidden_match.group(1).upper()
forbidden_list = ', '.join(sorted(FORBIDDEN_KEYWORDS))
raise QueryValidationError(
f'Query contains forbidden keyword: {keyword}. '
f'Destructive/modifying operations are not allowed. '
f'Forbidden: {forbidden_list}'
)
def format_scan_bytes(bytes_scanned: int) -> str:
"""
Format scan bytes as human-readable string (GB/MB/KB).
Args:
bytes_scanned: Number of bytes scanned by Athena
Returns:
Formatted string like '15.3 GB' or '234.5 MB'
Example:
>>> format_scan_bytes(15_000_000_000)
'15.0 GB'
>>> format_scan_bytes(234_500_000)
'234.5 MB'
"""
gb = bytes_scanned / (BYTES_IN_KB**3)
if gb >= 1.0:
return f'{gb:.1f} GB'
mb = bytes_scanned / (BYTES_IN_KB**2)
if mb >= 1.0:
return f'{mb:.1f} MB'
kb = bytes_scanned / BYTES_IN_KB
if kb >= 1.0:
return f'{kb:.1f} KB'
return f'{bytes_scanned} B'
def check_scan_limit(bytes_scanned: int) -> str | None:
"""
Check if query scan exceeds warning threshold.
Note: This is a soft limit (warning only, doesn't block query).
Athena charges per byte scanned, so we warn about expensive queries.
Args:
bytes_scanned: Number of bytes scanned by query
Returns:
Warning message if threshold exceeded, None otherwise
Example:
>>> warning = check_scan_limit(150_000_000_000) # 150 GB
>>> print(warning)
'Query scanned 150.0 GB (threshold: 100 GB). This is an expensive query.'
"""
threshold_bytes = SCAN_WARNING_GB * (BYTES_IN_KB**3)
if bytes_scanned > threshold_bytes:
scanned_str = format_scan_bytes(bytes_scanned)
return (
f'Query scanned {scanned_str} (threshold: {SCAN_WARNING_GB} GB). '
f'This is an expensive query.'
)
return None