import asyncio
import time
from typing import Any, Dict, List, Optional
import boto3
import structlog
from botocore.exceptions import ClientError, NoCredentialsError
from ..config.settings import Settings
logger = structlog.get_logger(__name__)
class AthenaClient:
def __init__(self, settings: Settings):
self.settings = settings
self._client: Optional[boto3.client] = None
self._session: Optional[boto3.Session] = None
def _get_session(self) -> boto3.Session:
if self._session is None:
if self.settings.aws_profile:
self._session = boto3.Session(
profile_name=self.settings.aws_profile,
region_name=self.settings.aws_region
)
else:
self._session = boto3.Session(region_name=self.settings.aws_region)
return self._session
def _get_client(self) -> boto3.client:
if self._client is None:
session = self._get_session()
self._client = session.client("athena")
return self._client
async def execute_query(
self,
query: str,
parameters: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
try:
client = self._get_client()
execution_params = {
"QueryString": query,
"QueryExecutionContext": {
"Database": self.settings.security_lake_database
},
"ResultConfiguration": {
"OutputLocation": self.settings.athena_output_location
},
"WorkGroup": self.settings.athena_workgroup,
}
if parameters:
execution_params["ExecutionParameters"] = parameters
logger.info(
"Starting Athena query execution",
query=query[:200] + "..." if len(query) > 200 else query,
database=self.settings.security_lake_database,
workgroup=self.settings.athena_workgroup
)
response = client.start_query_execution(**execution_params)
execution_id = response["QueryExecutionId"]
logger.info("Query execution started", execution_id=execution_id)
execution_status = await self._wait_for_completion(execution_id)
if execution_status != "SUCCEEDED":
error_msg = await self._get_execution_error(execution_id)
raise RuntimeError(f"Query failed: {error_msg}")
results = await self._get_query_results(execution_id)
logger.info(
"Query completed successfully",
execution_id=execution_id,
result_count=len(results)
)
return results
except NoCredentialsError:
logger.error("AWS credentials not found")
raise RuntimeError(
"AWS credentials not configured. Please set up AWS credentials "
"or specify an AWS profile."
)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_message = e.response.get("Error", {}).get("Message", str(e))
logger.error(
"AWS client error",
error_code=error_code,
error_message=error_message
)
raise RuntimeError(f"AWS error ({error_code}): {error_message}")
except Exception as e:
logger.error("Unexpected error during query execution", error=str(e))
raise
async def _wait_for_completion(self, execution_id: str) -> str:
client = self._get_client()
start_time = time.time()
while True:
if time.time() - start_time > self.settings.query_timeout_seconds:
raise RuntimeError(f"Query timeout after {self.settings.query_timeout_seconds} seconds")
response = client.get_query_execution(QueryExecutionId=execution_id)
status = response["QueryExecution"]["Status"]["State"]
if status in ["SUCCEEDED", "FAILED", "CANCELLED"]:
return status
await asyncio.sleep(1)
async def _get_execution_error(self, execution_id: str) -> str:
client = self._get_client()
response = client.get_query_execution(QueryExecutionId=execution_id)
status_info = response["QueryExecution"]["Status"]
return status_info.get("StateChangeReason", "Unknown error")
async def _get_query_results(self, execution_id: str) -> List[Dict[str, Any]]:
client = self._get_client()
results = []
next_token = None
while True:
params = {"QueryExecutionId": execution_id}
if next_token:
params["NextToken"] = next_token
response = client.get_query_results(**params)
# Extract column names from the first row (header)
if not results and response["ResultSet"]["Rows"]:
columns = [
col["VarCharValue"]
for col in response["ResultSet"]["Rows"][0]["Data"]
]
# Skip the header row for data processing
data_rows = response["ResultSet"]["Rows"][1:]
else:
data_rows = response["ResultSet"]["Rows"]
columns = []
# Process data rows
for row in data_rows:
if len(results) >= self.settings.max_query_results:
logger.warning(
"Query result limit reached",
limit=self.settings.max_query_results
)
return results
row_data = {}
for i, cell in enumerate(row["Data"]):
if i < len(columns):
value = cell.get("VarCharValue")
row_data[columns[i]] = value
if row_data: # Only add non-empty rows
results.append(row_data)
next_token = response.get("NextToken")
if not next_token:
break
return results
async def list_databases(self) -> List[str]:
try:
client = self._get_client()
response = client.list_databases(CatalogName="AwsDataCatalog")
return [db["Name"] for db in response["DatabaseList"]]
except Exception as e:
logger.error("Failed to list databases", error=str(e))
raise
async def list_tables(self, database: str) -> List[Dict[str, Any]]:
try:
client = self._get_client()
response = client.list_table_metadata(
CatalogName="AwsDataCatalog",
DatabaseName=database
)
return [
{
"name": table["Name"],
"type": table.get("TableType", "UNKNOWN"),
"columns": len(table.get("Columns", [])),
"last_accessed": table.get("LastAccessTime"),
"location": table.get("Parameters", {}).get("location")
}
for table in response["TableMetadataList"]
]
except Exception as e:
logger.error("Failed to list tables", database=database, error=str(e))
raise
async def test_connection(self) -> bool:
try:
await self.list_databases()
logger.info("Athena connection test successful")
return True
except Exception as e:
logger.error("Athena connection test failed", error=str(e))
return False