#!/bin/bash
# 使用 OAuth 认证执行 Athena 查询的测试脚本
# 使用方法: ./test-oauth-query.sh "SELECT * FROM my_table LIMIT 5" "my_database"
set -e
QUERY="${1:-SELECT 1 as test}"
DATABASE="${2:-quickdemo}"
echo "=== AWS Athena MCP - OAuth 查询测试 ==="
echo ""
# 从 CloudFormation 获取配置
STACK_NAME="aws-athena-mcp-stack"
echo "1. 获取配置..."
CLIENT_ID=$(aws cloudformation describe-stacks --stack-name $STACK_NAME --query "Stacks[0].Outputs[?OutputKey=='CognitoAppClientId'].OutputValue" --output text)
TOKEN_URL=$(aws cloudformation describe-stacks --stack-name $STACK_NAME --query "Stacks[0].Outputs[?OutputKey=='CognitoTokenUrl'].OutputValue" --output text)
API_ENDPOINT=$(aws cloudformation describe-stacks --stack-name $STACK_NAME --query "Stacks[0].Outputs[?OutputKey=='ApiEndpoint'].OutputValue" --output text)
USER_POOL_ID=$(aws cloudformation describe-stacks --stack-name $STACK_NAME --query "Stacks[0].Outputs[?OutputKey=='CognitoUserPoolId'].OutputValue" --output text)
# 获取 Client Secret
CLIENT_SECRET=$(aws cognito-idp describe-user-pool-client --user-pool-id $USER_POOL_ID --client-id $CLIENT_ID --query "UserPoolClient.ClientSecret" --output text)
echo "✓ 配置获取成功"
echo ""
# 获取 Access Token
echo "2. 获取 Access Token..."
AUTH_HEADER=$(echo -n "$CLIENT_ID:$CLIENT_SECRET" | base64)
TOKEN_RESPONSE=$(curl -s -X POST "$TOKEN_URL" \
-H "Content-Type: application/x-www-form-urlencoded" \
-H "Authorization: Basic $AUTH_HEADER" \
-d "grant_type=client_credentials&scope=athena-mcp-api/read+athena-mcp-api/write")
ACCESS_TOKEN=$(echo $TOKEN_RESPONSE | jq -r '.access_token')
if [ "$ACCESS_TOKEN" == "null" ] || [ -z "$ACCESS_TOKEN" ]; then
echo "错误: 无法获取 Access Token"
exit 1
fi
echo "✓ Token 获取成功"
echo ""
# 执行查询
echo "3. 执行 Athena 查询..."
echo "数据库: $DATABASE"
echo "查询: $QUERY"
echo ""
QUERY_RESPONSE=$(curl -s -X POST "$API_ENDPOINT" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $ACCESS_TOKEN" \
-d "{
\"jsonrpc\": \"2.0\",
\"id\": $(date +%s),
\"method\": \"tools/call\",
\"params\": {
\"name\": \"run_query\",
\"arguments\": {
\"database\": \"$DATABASE\",
\"query\": \"$QUERY\",
\"maxRows\": 10
}
}
}")
# 检查是否有错误
ERROR=$(echo $QUERY_RESPONSE | jq -r '.error // empty')
if [ ! -z "$ERROR" ]; then
echo "查询错误:"
echo $QUERY_RESPONSE | jq '.error'
exit 1
fi
# 解析结果
RESULT_TEXT=$(echo $QUERY_RESPONSE | jq -r '.result.content[0].text')
RESULT=$(echo $RESULT_TEXT | jq '.')
# 检查是否返回了 queryExecutionId(查询超时)
QUERY_ID=$(echo $RESULT | jq -r '.queryExecutionId // empty')
if [ ! -z "$QUERY_ID" ]; then
echo "查询已提交,ID: $QUERY_ID"
echo "查询仍在运行,使用以下命令获取结果:"
echo ""
echo " curl -X POST \"$API_ENDPOINT\" \\"
echo " -H \"Content-Type: application/json\" \\"
echo " -H \"Authorization: Bearer \$ACCESS_TOKEN\" \\"
echo " -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/call\",\"params\":{\"name\":\"get_result\",\"arguments\":{\"queryExecutionId\":\"$QUERY_ID\"}}}'"
else
echo "查询结果:"
echo $RESULT | jq '.'
# 显示统计信息
BYTES_SCANNED=$(echo $RESULT | jq -r '.bytesScanned // 0')
EXECUTION_TIME=$(echo $RESULT | jq -r '.executionTime // 0')
ROW_COUNT=$(echo $RESULT | jq -r '.rows | length')
echo ""
echo "统计信息:"
echo " - 扫描数据: $(numfmt --to=iec-i --suffix=B $BYTES_SCANNED 2>/dev/null || echo $BYTES_SCANNED bytes)"
echo " - 执行时间: ${EXECUTION_TIME}ms"
echo " - 返回行数: $ROW_COUNT"
fi
echo ""
echo "=== 测试完成 ==="