import itertools
from typing import Any
import sanitizer
import validator
from athena_tools import _utils
from jinja2 import Environment, StrictUndefined
from lib import loggerutils
from lib.aws.athena.athena_config import AwsAthenaQuery
from lib.enums import AthenaQueryNames
MS_IN_SEC = 1000
TOOL_NAME = 'execute_query'
TOOL_DESCRIPTION = """
Execute ad-hoc SQL query (SELECT only, validated) with Jinja2 template support.
For exploratory analysis when templates don't fit your needs.
Strict validation: only SELECT/WITH allowed, no DROP/DELETE/CREATE/ALTER/INSERT.
Same limits as templates: 100K rows, 600s timeout, data sanitization.
Query supports Jinja2 variables like {brand}, {year}, {month}, {day}.
Example: "SELECT COUNT(*) FROM provider__actions_{brand} WHERE year={year}"
with params={"brand": "alpha", "year": 2024}
""".strip()
TOOL_INPUT_SCHEMA = {
'type': 'object',
'properties': {
'sql_query': {
'type': 'string',
'description': (
'SQL query with Jinja2 template variables. '
'Must start with SELECT or WITH. '
'Example: \'SELECT * FROM provider__actions_{brand} '
'WHERE year={year} AND month={month}\''
),
},
'params': {
'type': 'object',
'description': (
'Optional Jinja2 template parameters. '
'Common params: brand (alpha/beta/gamma), year, month, '
'day, hour. Example: {\'brand\': \'alpha\', \'year\': 2024, '
'\'month\': 12}'
),
},
},
'required': ['sql_query'],
}
logger = loggerutils.getLogger('analytics')
def execute(sql_query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Execute ad-hoc SQL query (SELECT only) with optional Jinja2 template parameters.
Applies strict validation:
1. Query validation (only SELECT/WITH allowed)
2. Resource limits (100K rows, 600s timeout)
3. Data sanitization (redact sensitive columns)
Args:
sql_query: SQL query string with Jinja2 template variables (must start with SELECT or WITH)
Example: 'SELECT * FROM provider__actions_{brand} WHERE year={year}'
params: Optional dict of template parameters for Jinja2 rendering
Example: {'brand': 'alpha', 'year': 2024, 'month': 12}
Returns:
Same format as execute_template_tool
Raises:
QueryValidationError: If query contains forbidden keywords or patterns
Example:
>>> execute(
... 'SELECT COUNT(*) FROM provider__actions_{brand} WHERE year={year}',
... {'brand': 'alpha', 'year': 2024}
... )
{'results': [{'count': 12345}], 'row_count': 1, ...}
"""
logger.info('execute_query called', data={'sql_query': sql_query, 'params': params})
validator.validate_query(sql_query)
env = Environment(undefined=StrictUndefined) # nosec B701
template = env.from_string(sql_query)
query = AwsAthenaQuery(
name=AthenaQueryNames.AD_HOC_QUERY,
query_template=template,
)
rendered_sql = query.render(params or {})
logger.info(
'Query rendering',
data={
'rendered_sql': rendered_sql,
'query.brand': query.brand,
'params': params or {},
},
)
athena = _utils.get_athena_helper()
response = athena.execute(
query=query,
params=params or {},
timeout=validator.MAX_TIMEOUT_SEC,
enable_cache=True,
wait_execution_finish=True,
)
response.status.raise_for_status()
results_iter = athena.get_query_results(response.execution_id)
results = list(itertools.islice(results_iter, validator.MAX_ROWS + 1))
original_row_count = len(results)
if len(results) > validator.MAX_ROWS:
results = results[: validator.MAX_ROWS]
sensitive_columns = sanitizer.get_sensitive_columns(results)
sanitized_results = sanitizer.sanitize_results(results)
stats = response.status.raw_details.get('Statistics', {})
bytes_scanned = stats.get('DataScannedInBytes')
execution_time_ms = stats.get('TotalExecutionTimeInMillis')
warnings = []
if bytes_scanned:
scan_warning = validator.check_scan_limit(bytes_scanned)
if scan_warning:
warnings.append(scan_warning)
if original_row_count > validator.MAX_ROWS:
warnings.append(
f'Results truncated: {original_row_count} rows → {validator.MAX_ROWS} rows (limit)'
)
result = {
'results': sanitized_results,
'row_count': len(sanitized_results),
'execution_id': response.execution_id,
'execution_time_sec': (execution_time_ms / MS_IN_SEC if execution_time_ms else None),
'bytes_scanned': bytes_scanned,
'bytes_scanned_formatted': (
validator.format_scan_bytes(bytes_scanned) if bytes_scanned else None
),
'sensitive_columns': sensitive_columns,
'warnings': warnings,
}
logger.info(
'execute_query completed',
data={
'execution_id': response.execution_id,
'row_count': len(sanitized_results),
'execution_time_sec': result['execution_time_sec'],
'bytes_scanned_formatted': result['bytes_scanned_formatted'],
},
)
return result