aws-athena-mcp

by lishenxydlgzs
Verified
import { AthenaClient, StartQueryExecutionCommand, GetQueryExecutionCommand, GetQueryResultsCommand, QueryExecutionState, GetQueryExecutionCommandOutput, InvalidRequestException, } from "@aws-sdk/client-athena"; import { defaultProvider } from "@aws-sdk/credential-provider-node"; import { QueryInput, QueryResult, QueryStatus, AthenaError } from "./types.js"; export class AthenaService { private client: AthenaClient; private outputLocation: string; constructor() { if (!process.env.OUTPUT_S3_PATH) { throw new Error("OUTPUT_S3_PATH environment variable is required"); } this.outputLocation = process.env.OUTPUT_S3_PATH; const profile = process.env.AWS_PROFILE; this.client = new AthenaClient({ credentials: defaultProvider({ profile: profile }), region: process.env.AWS_REGION, }); } async executeQuery(input: QueryInput): Promise<QueryResult | { queryExecutionId: string }> { try { // Start query execution const startResponse = await this.client.send( new StartQueryExecutionCommand({ QueryString: input.query, QueryExecutionContext: { Database: input.database, }, ResultConfiguration: { OutputLocation: this.outputLocation, }, }) ); if (!startResponse.QueryExecutionId) { throw new Error("Failed to start query execution"); } const timeoutMs = input.timeoutMs || 60000; // Default 60 second timeout const startTime = Date.now(); try { // Wait for query completion or timeout const queryExecution = await this.waitForQueryCompletion( startResponse.QueryExecutionId, 100, timeoutMs ); // If we got here, query completed before timeout return await this.getQueryResults(startResponse.QueryExecutionId, input.maxRows); } catch (error) { if (error && typeof error === "object" && "code" in error) { const athenaError = error as AthenaError; if (athenaError.code === "TIMEOUT") { // Return just the execution ID on timeout return { queryExecutionId: startResponse.QueryExecutionId }; } } throw error; } } catch (error) { if (error instanceof InvalidRequestException) { throw { message: error.message, code: "INVALID_REQUEST", }; } throw error; } } async getQueryStatus(queryExecutionId: string): Promise<QueryStatus> { try { const response = await this.client.send( new GetQueryExecutionCommand({ QueryExecutionId: queryExecutionId, }) ); if (!response.QueryExecution) { throw { message: "Query execution not found", code: "QUERY_NOT_FOUND", }; } return { state: response.QueryExecution.Status?.State || "UNKNOWN", stateChangeReason: response.QueryExecution.Status?.StateChangeReason, statistics: { dataScannedInBytes: response.QueryExecution.Statistics?.DataScannedInBytes || 0, engineExecutionTimeInMillis: response.QueryExecution.Statistics?.EngineExecutionTimeInMillis || 0, }, }; } catch (error) { if (error instanceof InvalidRequestException) { throw { message: "Query execution not found", code: "QUERY_NOT_FOUND", }; } throw error; } } async getQueryResults(queryExecutionId: string, maxRows?: number): Promise<QueryResult> { try { // Check query state first const status = await this.getQueryStatus(queryExecutionId); if (status.state === QueryExecutionState.RUNNING || status.state === QueryExecutionState.QUEUED) { throw { message: "Query is still running", code: "QUERY_STILL_RUNNING", queryExecutionId, }; } if (status.state === QueryExecutionState.FAILED) { throw { message: status.stateChangeReason || "Query failed", code: "QUERY_FAILED", queryExecutionId, }; } if (status.state !== QueryExecutionState.SUCCEEDED) { throw { message: `Unexpected query state: ${status.state}`, code: "UNEXPECTED_STATE", queryExecutionId, }; } const results = await this.client.send( new GetQueryResultsCommand({ QueryExecutionId: queryExecutionId, MaxResults: maxRows || 1000, }) ); if (!results.ResultSet) { throw new Error("No results returned from query"); } const columns = results.ResultSet.ResultSetMetadata?.ColumnInfo?.map( (col) => col.Name || "" ) || []; const rows = (results.ResultSet.Rows || []) .slice(1) // Skip header row .map((row) => { const obj: Record<string, unknown> = {}; row.Data?.forEach((data, index) => { if (columns[index]) { obj[columns[index]] = data.VarCharValue; } }); return obj; }); return { columns, rows, queryExecutionId, bytesScanned: status.statistics?.dataScannedInBytes || 0, executionTime: status.statistics?.engineExecutionTimeInMillis || 0, }; } catch (error) { if (error instanceof InvalidRequestException) { throw { message: "Query execution not found", code: "QUERY_NOT_FOUND", }; } throw error; } } private async waitForQueryCompletion( queryExecutionId: string, maxAttempts = 100, timeoutMs?: number ): Promise<GetQueryExecutionCommandOutput> { let attempts = 0; const startTime = Date.now(); while (attempts < maxAttempts) { if (timeoutMs && Date.now() - startTime >= timeoutMs) { throw { message: "Query timed out", code: "TIMEOUT", queryExecutionId, }; } const response = await this.client.send( new GetQueryExecutionCommand({ QueryExecutionId: queryExecutionId, }) ); const state = response.QueryExecution?.Status?.State; if (state === QueryExecutionState.SUCCEEDED) { return response; } if ( state === QueryExecutionState.FAILED || state === QueryExecutionState.CANCELLED ) { throw { message: response.QueryExecution?.Status?.StateChangeReason || "Query failed", code: "QUERY_FAILED", queryExecutionId, }; } // Wait before checking again await new Promise((resolve) => setTimeout(resolve, 1000)); attempts++; } throw { message: "Query timed out", code: "TIMEOUT", queryExecutionId, }; } }