import os
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from jinja2 import Environment, StrictUndefined, Template
from lib import timeutils
from lib.aws.athena.exceptions import (
TableS3LocationNotFound,
TemplateError,
TemplateFileNotFoundError,
)
from lib.enums import AthenaQueryNames
from lib.queryutils import strip_sql_comments
from settings import (
AWS_ATHENA_BRAND_NAME,
AWS_ATHENA_QUERY_TEMPLATES_DIR,
AWS_ATHENA_QUERY_TEMPLATES_STRIP_COMMENTS,
AWS_ATHENA_S3_DATA_SOURCE_LOCATION,
AWS_ATHENA_S3_UNLOAD_LOCATION,
)
def load_query_template(query_template_path: str) -> str:
if not AWS_ATHENA_QUERY_TEMPLATES_DIR:
raise ValueError(
'AWS_ATHENA_QUERY_TEMPLATES_DIR not configured. '
'Set environment variable or check settings.py'
)
query_file_path = os.path.join(AWS_ATHENA_QUERY_TEMPLATES_DIR, query_template_path)
try:
with open(query_file_path, 'r') as sql_file:
query_template = sql_file.read()
if AWS_ATHENA_QUERY_TEMPLATES_STRIP_COMMENTS:
query_template = strip_sql_comments(query_template)
except FileNotFoundError:
raise TemplateFileNotFoundError(query_file_path)
return query_template
def load_jinja_query_template(query_template_path: str) -> Template:
if not AWS_ATHENA_QUERY_TEMPLATES_DIR:
raise ValueError(
'AWS_ATHENA_QUERY_TEMPLATES_DIR not configured. '
'Set environment variable or check settings.py'
)
query_file_path = os.path.join(AWS_ATHENA_QUERY_TEMPLATES_DIR, query_template_path)
try:
with open(query_file_path, 'r') as sql_file:
query_template = sql_file.read()
if AWS_ATHENA_QUERY_TEMPLATES_STRIP_COMMENTS:
query_template = strip_sql_comments(query_template)
except FileNotFoundError:
raise TemplateFileNotFoundError(query_file_path)
# SQL templates, not HTML - autoescape not applicable
env = Environment(undefined=StrictUndefined) # nosec B701
return env.from_string(query_template)
@dataclass
class AwsAthenaQuery:
name: AthenaQueryNames
query_template: str | Template
source_s3_location: str | None = AWS_ATHENA_S3_DATA_SOURCE_LOCATION
aggregation_s3_location: str | None = AWS_ATHENA_S3_UNLOAD_LOCATION
brand: str | None = AWS_ATHENA_BRAND_NAME
version: str = 'latest'
comment: str | None = None
available_up_to_date: datetime | None = None
def __post_init__(self) -> None:
if self.available_up_to_date is not None:
self.available_up_to_date = timeutils.to_utc(self.available_up_to_date)
def gen_uuid(self, params: dict[str, Any]) -> str:
query_conf = {
'name': self.name,
'brand': self.brand,
'version': self.version,
}
query_conf.update(params)
uid = ','.join([f'{k}={v}' for k, v in query_conf.items()])
return uid
def render(self, params: dict[str, Any] | None = None) -> str:
data = {
'aggregation_s3_location': self.aggregation_s3_location,
'source_s3_location': self.source_s3_location,
'brand': self.brand,
}
if params:
data.update(params)
if isinstance(self.query_template, Template):
return self.query_template.render(**data)
else:
return self.query_template.format(**data)
def get_s3_location(self, params: dict[str, Any] | None) -> str:
rendered_query = self.render(params)
pattern = r"\b(?:TO|LOCATION)\s*'(?P<url>s3://[^']+)'"
match = re.search(pattern, rendered_query, re.IGNORECASE | re.DOTALL)
if not match:
raise TableS3LocationNotFound(self.name)
return match.group(1)
# Validate required settings at module load
if not AWS_ATHENA_BRAND_NAME:
raise ValueError('AWS_ATHENA_BRAND_NAME is required. Set environment variable.')
_BRAND: str = AWS_ATHENA_BRAND_NAME
# System queries for MCP server (minimal set)
SYS_QUERIES: list[AwsAthenaQuery] = [
AwsAthenaQuery(
name=AthenaQueryNames.MSCK_REPAIR_TABLE,
query_template='MSCK REPAIR TABLE {table}',
brand=_BRAND,
),
AwsAthenaQuery(
name=AthenaQueryNames.SHOW_CREATE_TABLE,
query_template='SHOW CREATE TABLE {table}',
brand=_BRAND,
),
AwsAthenaQuery(
name=AthenaQueryNames.DROP_TABLE_IF_EXISTS,
query_template='DROP TABLE IF EXISTS {table}',
brand=_BRAND,
),
AwsAthenaQuery(
name=AthenaQueryNames.DROP_TABLE,
query_template='DROP TABLE {table}',
brand=_BRAND,
),
]
QUERIES: list[AwsAthenaQuery] = []
QUERIES.extend(SYS_QUERIES)
def get_athena_query(
query_name: AthenaQueryNames, query_version: str | None = None
) -> AwsAthenaQuery:
"""Get Athena query configuration by name and optional version.
Args:
query_name: Query name from AthenaQueryNames enum
query_version: Query version (default: 'latest')
Returns:
AwsAthenaQuery configuration object
Raises:
TemplateError: If query not found in QUERIES list
"""
if query_version is None:
query_version = 'latest'
for query_config in QUERIES:
if query_config.name == query_name and query_config.version == query_version:
return query_config
raise TemplateError(query_name)