import { APIGatewayProxyEvent, APIGatewayProxyResult } from 'aws-lambda';
import { AthenaService } from "./athena.js";
import { QueryInput, AthenaError } from "./types.js";
// 全局 Athena 服务实例(Lambda 容器复用)
let athenaService: AthenaService | null = null;
function getAthenaService(): AthenaService {
if (!athenaService) {
athenaService = new AthenaService();
}
return athenaService;
}
interface JsonRpcRequest {
jsonrpc: string;
id: number | string | null;
method: string;
params?: any;
}
interface JsonRpcResponse {
jsonrpc: string;
id: number | string | null;
result?: any;
error?: {
code: number;
message: string;
};
}
async function handleMcpRequest(request: JsonRpcRequest): Promise<JsonRpcResponse> {
const { method, params, id } = request;
const service = getAthenaService();
try {
// Initialize
if (method === 'initialize') {
return {
jsonrpc: '2.0',
id,
result: {
protocolVersion: '2024-11-05',
capabilities: {
tools: {},
},
serverInfo: {
name: 'aws-athena-mcp',
version: '1.0.0',
},
},
};
}
// List tools
if (method === 'tools/list') {
const tools = [
{
name: 'run_query',
description: 'Execute a SQL query using AWS Athena. Returns full results if query completes before timeout, otherwise returns queryExecutionId.',
inputSchema: {
type: 'object',
properties: {
database: { type: 'string', description: 'The Athena database to query' },
query: { type: 'string', description: 'SQL query to execute' },
maxRows: { type: 'number', description: 'Maximum number of rows to return (default: 1000)', minimum: 1, maximum: 10000 },
timeoutMs: { type: 'number', description: 'Timeout in milliseconds (default: 60000)', minimum: 1000 },
},
required: ['database', 'query'],
},
},
{
name: 'get_result',
description: 'Get results for a completed query. Returns error if query is still running.',
inputSchema: {
type: 'object',
properties: {
queryExecutionId: { type: 'string', description: 'The query execution ID' },
maxRows: { type: 'number', description: 'Maximum number of rows to return (default: 1000)', minimum: 1, maximum: 10000 },
},
required: ['queryExecutionId'],
},
},
{
name: 'get_status',
description: 'Get the current status of a query execution',
inputSchema: {
type: 'object',
properties: {
queryExecutionId: { type: 'string', description: 'The query execution ID' },
},
required: ['queryExecutionId'],
},
},
{
name: 'run_saved_query',
description: 'Execute a saved (named) Athena query by its query ID.',
inputSchema: {
type: 'object',
properties: {
namedQueryId: { type: 'string', description: 'Athena NamedQueryId' },
databaseOverride: { type: 'string', description: 'Optional database override' },
maxRows: { type: 'number', description: 'Maximum number of rows to return (default: 1000)', minimum: 1, maximum: 10000 },
timeoutMs: { type: 'number', description: 'Timeout in milliseconds (default: 60000)', minimum: 1000 },
},
required: ['namedQueryId'],
},
},
{
name: 'list_saved_queries',
description: 'List all saved (named) Athena queries available in your AWS account.',
inputSchema: {
type: 'object',
properties: {},
},
},
];
return {
jsonrpc: '2.0',
id,
result: { tools },
};
}
// Call tool
if (method === 'tools/call') {
const toolName = params?.name;
const args = params?.arguments || {};
if (toolName === 'run_query') {
if (!args.database || !args.query) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32602,
message: 'Invalid params: database and query are required',
},
};
}
try {
const queryInput: QueryInput = {
database: args.database,
query: args.query,
maxRows: args.maxRows,
timeoutMs: args.timeoutMs,
};
const result = await service.executeQuery(queryInput);
return {
jsonrpc: '2.0',
id,
result: {
content: [
{
type: 'text',
text: JSON.stringify(result, null, 2),
},
],
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
if (toolName === 'get_result') {
if (!args.queryExecutionId) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32602,
message: 'Invalid params: queryExecutionId is required',
},
};
}
try {
const result = await service.getQueryResults(args.queryExecutionId, args.maxRows);
return {
jsonrpc: '2.0',
id,
result: {
content: [
{
type: 'text',
text: JSON.stringify(result, null, 2),
},
],
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
if (toolName === 'get_status') {
if (!args.queryExecutionId) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32602,
message: 'Invalid params: queryExecutionId is required',
},
};
}
try {
const status = await service.getQueryStatus(args.queryExecutionId);
return {
jsonrpc: '2.0',
id,
result: {
content: [
{
type: 'text',
text: JSON.stringify(status, null, 2),
},
],
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
if (toolName === 'run_saved_query') {
if (!args.namedQueryId) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32602,
message: 'Invalid params: namedQueryId is required',
},
};
}
try {
const result = await service.executeNamedQuery(
args.namedQueryId,
args.databaseOverride,
args.maxRows,
args.timeoutMs
);
return {
jsonrpc: '2.0',
id,
result: {
content: [
{
type: 'text',
text: JSON.stringify(result, null, 2),
},
],
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
if (toolName === 'list_saved_queries') {
try {
const result = await service.listNamedQueries();
return {
jsonrpc: '2.0',
id,
result: {
content: [
{
type: 'text',
text: JSON.stringify(result, null, 2),
},
],
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
return {
jsonrpc: '2.0',
id,
error: {
code: -32601,
message: `Method not found: ${toolName}`,
},
};
}
return {
jsonrpc: '2.0',
id,
error: {
code: -32601,
message: `Method not found: ${method}`,
},
};
} catch (error: any) {
return {
jsonrpc: '2.0',
id,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
};
}
}
export const handler = async (
event: APIGatewayProxyEvent
): Promise<APIGatewayProxyResult> => {
const httpMethod = event.httpMethod;
const path = event.path || '/';
const corsHeaders = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, OPTIONS, GET',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Content-Type': 'application/json',
};
try {
// OPTIONS - CORS preflight
if (httpMethod === 'OPTIONS') {
return {
statusCode: 200,
headers: corsHeaders,
body: '',
};
}
// GET - Health check
if (httpMethod === 'GET' && (path === '/health' || path === '/' || path.endsWith('/health'))) {
return {
statusCode: 200,
headers: corsHeaders,
body: JSON.stringify({
status: 'healthy',
service: 'aws-athena-mcp',
protocol: 'MCP JSON-RPC 2.0',
athena_configured: !!process.env.OUTPUT_S3_PATH,
}),
};
}
// POST - MCP JSON-RPC requests
if (httpMethod === 'POST') {
try {
let bodyStr = event.body || '{}';
if (event.isBase64Encoded) {
bodyStr = Buffer.from(bodyStr, 'base64').toString('utf-8');
}
const requestData: JsonRpcRequest = JSON.parse(bodyStr);
const responseData = await handleMcpRequest(requestData);
return {
statusCode: 200,
headers: corsHeaders,
body: JSON.stringify(responseData),
};
} catch (error) {
return {
statusCode: 400,
headers: corsHeaders,
body: JSON.stringify({
jsonrpc: '2.0',
id: null,
error: {
code: -32700,
message: 'Parse error: Invalid JSON',
},
}),
};
}
}
// Other methods
return {
statusCode: 404,
headers: corsHeaders,
body: JSON.stringify({
error: 'Not found',
info: 'This is an MCP JSON-RPC 2.0 server. Send POST requests with JSON-RPC format.',
}),
};
} catch (error: any) {
console.error('Lambda handler error:', error);
return {
statusCode: 500,
headers: corsHeaders,
body: JSON.stringify({
jsonrpc: '2.0',
id: null,
error: {
code: -32603,
message: `Internal error: ${error.message || String(error)}`,
},
}),
};
}
};