import dataclasses
import gzip
import itertools
import json
import re
import time
from datetime import date, datetime
from decimal import Decimal
from io import BytesIO
from typing import Any, Callable, Iterable, Iterator
from urllib.parse import urlsplit
import pandas as pd
from botocore.client import BaseClient
from botocore.exceptions import ClientError
from lib import loggerutils
from lib.aws.athena.athena_config import AwsAthenaQuery, get_athena_query
from lib.aws.athena.exceptions import (
AthenaAWSDisabledError,
AthenaQueryExecutionError,
HIVECannotOpenSplit,
HIVEMalformedData,
HIVEPartitionExist,
S3ObjectNotFoundError,
)
from lib.aws.aws_auth_helper import (
create_boto3_session,
invalidate_credentials_cache,
)
from lib.aws.aws_profile_config import AWSProfileConfig, get_aws_profile_config
from lib.enums import AthenaQueryNames, AthenaQueryStatuses, AwsProfilesList
from settings import (
AWS_ATHENA_DB_MAIN,
AWS_ATHENA_DEFAULT_CACHE_TIMEOUT_MIN,
AWS_ATHENA_DEFAULT_QUERY_TIMEOUT_SEC,
AWS_ATHENA_LONG_QUERY_SEC,
AWS_ATHENA_S3_TMP_LOCATION,
AWS_ATHENA_WORKGROUP,
AWS_ENABLED,
)
logger = loggerutils.getLogger('analytics')
_DEFAULT_SLEEP_TIME_SEC = 1
def _extract_s3_path_from_error(error_message: str) -> str | None:
pattern = r's3://[^\s\]]+\.parquet'
match = re.search(pattern, error_message)
return match.group(0) if match else None
@dataclasses.dataclass
class AthenaQueryStatus:
execution_id: str
status: AthenaQueryStatuses
raw_details: dict[str, Any]
execution_time: float = 0
error_message: str | None = None
def raise_for_status(self) -> None:
if self.status in (
AthenaQueryStatuses.FAILED,
AthenaQueryStatuses.CANCELLED,
):
if self.error_message and 'HIVE_PATH_ALREADY_EXISTS' in self.error_message:
raise HIVEPartitionExist(
execution_id=self.execution_id,
athena_error_message=self.error_message,
)
if self.error_message and 'HIVE_CANNOT_OPEN_SPLIT' in self.error_message:
raise HIVECannotOpenSplit(
execution_id=self.execution_id,
athena_error_message=self.error_message,
)
if (
self.error_message
and 'HIVE_BAD_DATA' in self.error_message
and 'Malformed Parquet file' in self.error_message
and 'incompatible with type' in self.error_message
):
s3_path = _extract_s3_path_from_error(self.error_message)
if s3_path:
raise HIVEMalformedData(
execution_id=self.execution_id,
athena_error_message=self.error_message,
s3_file_path=s3_path,
)
raise AthenaQueryExecutionError(
execution_id=self.execution_id,
athena_error_message=self.error_message or 'Query failed with no error message',
)
@dataclasses.dataclass
class AthenaResponse:
execution_id: str
raw_response: dict[str, Any]
status: AthenaQueryStatus
@dataclasses.dataclass
class AthenaSelectResponse:
"""
Response object for AWS Athena SELECT query execution.
Provides efficient access to query results with automatic caching of the first result.
When results are assigned, the first row is extracted and stored for quick access
without consuming the entire iterator.
Attributes:
execution_id: Query execution identifier
raw_response: Raw AWS Athena API response
status: Query execution status
results: Iterator over all result rows (includes first result)
first_result: Cached first result row for efficient access
"""
execution_id: str
raw_response: dict[str, Any]
status: AthenaQueryStatus
_results: Iterator[dict[str, Any]] | None = dataclasses.field(default=None, repr=False)
_first_result: dict[str, Any] | None = dataclasses.field(default=None, repr=False)
@property
def results(self) -> Iterable[dict[str, Any]] | None:
if self._first_result is not None:
if self._results is not None:
return itertools.chain([self._first_result], self._results)
return iter([self._first_result])
return self._results
@results.setter
def results(self, value: Iterable[dict[str, Any]]) -> None:
it: Iterator[dict[str, Any]] = iter(value)
self._first_result = next(it, None)
self._results = it
@property
def first_result(self) -> dict[str, Any] | None:
return self._first_result
class AthenaResultDataConverter:
_athena_to_python_types_mapper: dict[str, Callable[[Any], Any]] = {
'integer': int,
'bigint': int,
'smallint': int,
'tinyint': int,
'varchar': str,
'string': str,
'boolean': bool,
'double': float,
'decimal': Decimal,
'decimal(30,10)': Decimal,
'decimal(20,0)': Decimal,
'timestamp': datetime.fromisoformat,
'timestamp with time zone': datetime.fromisoformat,
'date': date.fromisoformat,
}
def __init__(self, result_set_metadata: dict[str, Any]):
self.column_info = result_set_metadata['ColumnInfo']
self.column_names = [c['Name'] for c in self.column_info]
self.column_types_mapper = {c['Name']: c['Type'] for c in self.column_info}
def process_row(self, row: dict[str, Any]) -> dict[str, Any]:
column_values: list[Any] = []
for val in row['Data']:
if not val:
column_values.append(None)
else:
column_values.append(val['VarCharValue'])
raw_record = dict(zip(self.column_names, column_values))
record = self._convert_to_python_type(raw_record)
return record
def _convert_to_python_type(self, record: dict[str, Any]) -> dict[str, Any]:
converted_record: dict[str, Any] = {}
for col_name, col_value in record.items():
col_type = self.column_types_mapper[col_name]
converter: Callable[[Any], Any] = self._athena_to_python_types_mapper[col_type]
converted_record[col_name] = converter(col_value) if col_value else col_value
return converted_record
class AthenaHelper:
_aws_profile: AWSProfileConfig
_aws_athena_database: str | None
_aws_athena_workgroup: str
_aws_athena_cache_timeout_min: int
_aws_athena_tmp_dir: str | None
_athena_client: BaseClient | None
def __init__(
self,
aws_profile_name: AwsProfilesList = AwsProfilesList.ATHENA_ANALYTICS_INTERNAL,
aws_athena_database: str | None = AWS_ATHENA_DB_MAIN,
aws_athena_workgroup: str = AWS_ATHENA_WORKGROUP,
aws_athena_cache_timeout_min: int = AWS_ATHENA_DEFAULT_CACHE_TIMEOUT_MIN,
aws_athena_tmp_dir: str | None = AWS_ATHENA_S3_TMP_LOCATION,
):
if not AWS_ENABLED:
raise AthenaAWSDisabledError()
self._athena_client = None
self._aws_profile = get_aws_profile_config(aws_profile_name)
self._aws_athena_database = aws_athena_database
self._aws_athena_cache_timeout_min = aws_athena_cache_timeout_min
self._aws_athena_workgroup = aws_athena_workgroup
self._aws_athena_tmp_dir = aws_athena_tmp_dir
self._s3_client: BaseClient | None = None
def _create_athena_client(self) -> BaseClient:
"""Create new Athena client with current credentials."""
session = create_boto3_session(self._aws_profile)
return session.client('athena')
@property
def athena_client(self) -> BaseClient:
if not self._athena_client:
self._athena_client = self._create_athena_client()
return self._athena_client
def _refresh_athena_client(self) -> None:
"""Force recreation of athena client with fresh credentials."""
invalidate_credentials_cache(self._aws_profile.name.value)
self._athena_client = self._create_athena_client()
def _athena_call(self, method: str, **kwargs: Any) -> Any:
"""Execute Athena client method with automatic credential refresh on token expiration."""
client_method = getattr(self.athena_client, method)
if not callable(client_method):
raise AttributeError(f'Athena client has no callable method "{method}"')
try:
return client_method(**kwargs)
except ClientError as e:
error_code = e.response.get('Error', {}).get('Code', '')
if error_code in {'ExpiredToken', 'ExpiredTokenException'}:
self._refresh_athena_client()
return getattr(self.athena_client, method)(**kwargs)
raise
def _paginated_athena_call(
self, operation_name: str, **paginate_kwargs: Any
) -> Iterable[dict[str, Any]]:
"""Execute paginated Athena call with credential refresh during iteration.
Handles ExpiredTokenException during page iteration by refreshing
credentials and resuming from last successful page using NextToken.
"""
next_token = None
while True:
try:
paginator = self.athena_client.get_paginator(operation_name)
# Resume from last successful page if we had to refresh
if next_token:
paginate_kwargs['NextToken'] = next_token
page_iterator = paginator.paginate(**paginate_kwargs)
for page in page_iterator:
# Save token before yielding in case iteration fails
next_token = page.get('NextToken')
yield page
# Successfully iterated all pages
return
except ClientError as e:
error_code = e.response.get('Error', {}).get('Code', '')
if error_code in {'ExpiredToken', 'ExpiredTokenException'}:
self._refresh_athena_client()
# Loop will retry with next_token to resume from last page
continue
raise
def _start_query_execution(self, **kwargs: Any) -> dict[str, Any]:
"""Start Athena query execution with automatic credential refresh on token expiration."""
return self._athena_call('start_query_execution', **kwargs)
def select(
self,
query: AwsAthenaQuery,
params: dict[str, Any] | None = None,
enable_cache: bool = False,
timeout: int = AWS_ATHENA_DEFAULT_QUERY_TIMEOUT_SEC,
) -> AthenaSelectResponse:
params = params or {}
rendered_query = query.render(params)
self._log_query_start(name=query.name, params=params, method='select')
aws_params = dict(
QueryString=rendered_query,
QueryExecutionContext={'Database': self._aws_athena_database},
ResultReuseConfiguration={
'ResultReuseByAgeConfiguration': {
'Enabled': enable_cache,
'MaxAgeInMinutes': self._aws_athena_cache_timeout_min,
}
},
WorkGroup=self._aws_athena_workgroup,
)
if self._aws_athena_tmp_dir:
aws_params['ResultConfiguration'] = {
'OutputLocation': self._aws_athena_tmp_dir,
}
response = self._start_query_execution(**aws_params)
execution_id = response['QueryExecutionId']
status = self.wait_query_execution(
execution_id=execution_id,
timeout=timeout,
)
self._log_query_finish(name=query.name, params=params, method='select', status=status)
status.raise_for_status()
results = self._get_query_results(execution_id, has_headers=True)
response_obj = AthenaSelectResponse(
execution_id=execution_id,
raw_response=response,
status=status,
)
response_obj.results = results
return response_obj
def execute(
self,
query: AwsAthenaQuery,
params: dict[str, Any] | None = None,
timeout: int = AWS_ATHENA_DEFAULT_QUERY_TIMEOUT_SEC,
enable_cache: bool = True,
wait_execution_finish: bool = True,
) -> AthenaResponse:
params = params or {}
rendered_query = query.render(params)
self._log_query_start(name=query.name, params=params, method='execute')
aws_params = dict(
QueryString=rendered_query,
QueryExecutionContext={'Database': self._aws_athena_database},
ResultReuseConfiguration={
'ResultReuseByAgeConfiguration': {
'Enabled': enable_cache,
'MaxAgeInMinutes': self._aws_athena_cache_timeout_min,
}
},
WorkGroup=self._aws_athena_workgroup,
)
if self._aws_athena_tmp_dir:
aws_params['ResultConfiguration'] = {
'OutputLocation': self._aws_athena_tmp_dir,
}
response = self._start_query_execution(**aws_params)
execution_id = response['QueryExecutionId']
if wait_execution_finish:
status = self.wait_query_execution(execution_id, timeout=timeout)
else:
status = self.get_query_status(execution_id)
self._log_query_finish(name=query.name, params=params, method='execute', status=status)
status.raise_for_status()
return AthenaResponse(
execution_id=execution_id,
raw_response=response,
status=status,
)
def wait_query_execution(
self,
execution_id: str,
timeout: int = AWS_ATHENA_DEFAULT_QUERY_TIMEOUT_SEC,
) -> AthenaQueryStatus:
while True:
status = self.get_query_status(execution_id)
if status.status in (
AthenaQueryStatuses.FAILED,
AthenaQueryStatuses.CANCELLED,
AthenaQueryStatuses.SUCCEEDED,
):
return status
if status.execution_time >= timeout and timeout != -1:
self.stop_query_execution(execution_id)
status.error_message = f'Terminated by athena helper, reason - timeout {timeout}'
status.status = AthenaQueryStatuses.CANCELLED
return status
time.sleep(_DEFAULT_SLEEP_TIME_SEC)
def stop_query_execution(self, query_execution_id: str) -> None:
logger.info('Stop query execution', data=query_execution_id)
self._athena_call('stop_query_execution', QueryExecutionId=query_execution_id)
def get_query_results(self, query_execution_id: str) -> Iterable[dict[str, Any]]:
return self._get_query_results(query_execution_id, has_headers=True)
def _get_query_results(
self, query_execution_id: str, has_headers: bool
) -> Iterable[dict[str, Any]]:
page_iterator = self._paginated_athena_call(
'get_query_results', QueryExecutionId=query_execution_id
)
for idx, page in enumerate(page_iterator):
rows = page['ResultSet']['Rows']
if has_headers and idx == 0:
rows = rows[1:]
athena_result_converter = AthenaResultDataConverter(
page['ResultSet']['ResultSetMetadata']
)
for row in rows:
record = athena_result_converter.process_row(row)
yield record
def get_query_status(self, query_execution_id: str) -> AthenaQueryStatus:
response = self._athena_call('get_query_execution', QueryExecutionId=query_execution_id)
query_execution = response['QueryExecution']
execution_time = query_execution['Statistics']['TotalExecutionTimeInMillis']
execution_time = float(execution_time / 1000)
status = AthenaQueryStatus(
execution_id=query_execution_id,
status=query_execution['Status']['State'],
raw_details=query_execution,
execution_time=execution_time,
)
error_details = query_execution['Status'].get('AthenaError')
if error_details:
status.error_message = error_details['ErrorMessage']
return status
def load_csv_query_result_files(
self,
execution_id: str,
s3_output_bucket: str,
s3_output_path: str,
) -> pd.DataFrame:
"""Only for not partitioned csv files,
for result of a simple Athena select query,
for result of unload use PARQUET or JSON format instead"""
key = f"{s3_output_path}/{execution_id}.csv"
response = self._get_object_from_s3(s3_output_bucket, key)
df = pd.read_csv(response['Body'])
return df
def list_tables(self, name_filter_regex: str = '') -> list[dict[str, Any]]:
response_iterator = self._paginated_athena_call(
'list_table_metadata',
CatalogName='AwsDataCatalog',
DatabaseName=self._aws_athena_database,
Expression=name_filter_regex,
)
tables = []
for page in response_iterator:
tables.extend(page['TableMetadataList'])
return tables
def list_table_names(self, name_filter_regex: str = '') -> list[str]:
return [x['Name'] for x in self.list_tables(name_filter_regex)]
def get_table_partitions(self, table_name: str) -> dict[str, str]:
response = self._athena_call(
'get_table_metadata',
CatalogName='AwsDataCatalog',
DatabaseName=self._aws_athena_database,
TableName=table_name,
)
table_location = response['TableMetadata']['Parameters']['location']
s3_url_split = urlsplit(table_location, allow_fragments=False)
path = s3_url_split.path
partitions = {}
for folder in path.split('/'):
if '=' not in folder:
continue
key, value = folder.split('=')
partitions[key] = value
return partitions
def get_table_ddl(self, table_name: str) -> str:
response = self.execute(
get_athena_query(AthenaQueryNames.SHOW_CREATE_TABLE),
{'table': table_name},
)
rows = self._get_query_results(response.execution_id, has_headers=False)
result = []
for row in rows:
result.append(row['createtab_stmt'])
return '\n'.join(result)
def load_json_query_result_files(
self,
execution_id: str,
s3_output_bucket: str,
s3_output_path: str,
) -> list[dict[str, Any]]:
s3_obj_list = self._get_manifest_csv_file(s3_output_bucket, s3_output_path, execution_id)
if not s3_obj_list:
return list()
list_of_records = []
for s3_obj_url in s3_obj_list:
s3_path_split = urlsplit(s3_obj_url)
bucket, key = s3_path_split.netloc, s3_path_split.path[1:]
response = self._get_object_from_s3(bucket, key)
buffer = BytesIO(response['Body'].read())
gz_file = gzip.GzipFile(fileobj=buffer)
for line in gz_file:
json_object = json.loads(line)
list_of_records.append(json_object)
return list_of_records
def _get_manifest_csv_file(self, bucket: str, path: str, query_execution_id: str) -> list[str]:
key = f"{path}/{query_execution_id}-manifest.csv"
response = self._get_object_from_s3(bucket, key)
s3_obj_list_bytes = response['Body'].read().splitlines()
s3_obj_list = [b.decode('utf-8') for b in s3_obj_list_bytes]
return s3_obj_list
def _create_s3_client(self) -> BaseClient:
"""Create new S3 client with current credentials."""
session = create_boto3_session(self._aws_profile)
return session.client('s3')
def _get_s3_client(self) -> BaseClient:
if self._s3_client is None:
self._s3_client = self._create_s3_client()
return self._s3_client
def _refresh_s3_client(self) -> None:
"""Force recreation of S3 client with fresh credentials."""
invalidate_credentials_cache(self._aws_profile.name.value)
self._s3_client = self._create_s3_client()
def _get_object_from_s3(self, bucket: str, key: str) -> dict[str, Any]:
"""Get object from S3 with automatic credential refresh on token expiration."""
try:
return self._get_s3_client().get_object(Bucket=bucket, Key=key)
except ClientError as e:
error_code = e.response.get('Error', {}).get('Code', '')
if error_code in {'ExpiredToken', 'ExpiredTokenException'}:
self._refresh_s3_client()
return self._get_s3_client().get_object(Bucket=bucket, Key=key)
if error_code == 'NoSuchKey':
raise S3ObjectNotFoundError(bucket, key)
raise
def _log_query_start(self, name: AthenaQueryNames, params: dict[str, Any], method: str) -> None:
logger.info(
'Query started ',
data=name,
extra_data={'params': params, 'method': method},
)
def _log_query_finish(
self, name: AthenaQueryNames, status: AthenaQueryStatus, method: str, params: dict[str, Any]
) -> None:
log_data = {
'method': method,
'params': params,
'execution_time': status.execution_time,
'execution_id': status.execution_id,
}
if status.status in (
AthenaQueryStatuses.FAILED,
AthenaQueryStatuses.CANCELLED,
):
log_data['raw'] = status.raw_details
logger.error(
f'Query {status.status.lower()}',
data=name,
extra_data=log_data,
sentry_skip=True,
)
else:
logger.info('Query finished', data=name, extra_data=log_data)
if status.execution_time >= AWS_ATHENA_LONG_QUERY_SEC:
log_data['raw'] = status.raw_details
logger.warning(
'Long query execution',
data=name,
extra_data=log_data,
)