mcp_server.py•8.16 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MCP数据查询服务器
基于FastMCP框架,提供安全的数据库查询服务
"""
import os
from typing import Dict, Any
from fastmcp import FastMCP, Context
from fastmcp.exceptions import ToolError
from fastmcp.server.dependencies import get_access_token, AccessToken
from dotenv import load_dotenv
# 导入现有模块
from database import DatabaseManager
# 导入新的认证模块
from auth_token import create_auth_components
# 加载环境变量
load_dotenv()
# 全局数据库管理器实例
db_manager = None
# 创建认证组件
auth = create_auth_components()
mcp = FastMCP(name="data-analysis-mcp", auth=auth)
def initialize_services():
"""初始化服务"""
global db_manager
if db_manager is None:
db_manager = DatabaseManager()
if not db_manager.connect():
raise Exception("数据库连接失败")
def get_validated_access_token() -> AccessToken:
"""获取并验证访问令牌"""
try:
access_token = get_access_token()
if access_token is None:
raise ToolError("未提供访问令牌或令牌无效")
return access_token
except Exception as e:
raise ToolError(f"权限验证失败: {str(e)}")
def check_permissions(access_token: AccessToken, required_scopes: list) -> None:
"""检查权限"""
if not access_token.scopes:
raise ToolError("用户没有任何权限")
missing_scopes = [scope for scope in required_scopes if scope not in access_token.scopes]
if missing_scopes:
raise ToolError(f"权限不足:需要以下权限: {', '.join(missing_scopes)}")
# 移除convert_numpy函数,不再需要
@mcp.tool
async def get_database_tables(ctx: Context) -> Dict[str, Any]:
"""
获取数据库中所有表的列表
需要 'data:read' 权限
"""
access_token = get_validated_access_token()
check_permissions(access_token, ["data:read_tables"])
try:
initialize_services()
tables = db_manager.get_all_tables()
return {
"user_id": access_token.client_id,
"tables": tables,
"total_tables": len(tables),
"message": f"成功获取 {len(tables)} 个表"
}
except Exception as e:
raise ToolError(f"获取表列表失败: {str(e)}")
@mcp.tool
async def get_table_structure(ctx: Context, table_name: str) -> Dict[str, Any]:
"""
获取指定表的结构信息
需要 'data:read' 权限
Args:
table_name: 表名
"""
access_token = get_validated_access_token()
check_permissions(access_token, ["data:read_tables"])
try:
initialize_services()
table_info = db_manager.get_table_info(table_name)
if not table_info:
raise ToolError(f"表 '{table_name}' 不存在或无法访问")
# 直接返回字典数据,无需转换
result = {
"user_id": access_token.client_id,
"table_name": table_name,
"total_rows": int(table_info.get('total_rows', 0))
}
if 'structure' in table_info and table_info['structure'] is not None:
result["structure"] = table_info['structure']
if 'sample_data' in table_info and table_info['sample_data'] is not None:
result["sample_data"] = table_info['sample_data']
return result
except Exception as e:
raise ToolError(f"获取表结构失败: {str(e)}")
@mcp.tool
async def execute_sql_query(ctx: Context, sql_query: str, limit: int = 100) -> Dict[str, Any]:
"""
执行SQL查询
需要 'data:read' 权限,查询需要 'data:read_table_data' 权限
Args:
sql_query: SQL查询语句
limit: 返回结果的最大行数,默认100
"""
access_token = get_validated_access_token()
check_permissions(access_token, ["data:read_table_data"])
# 检查是否为敏感查询(包含特定关键词)
sensitive_keywords = ['password', 'secret', 'token', 'private', 'confidential']
is_sensitive = any(keyword in sql_query.lower() for keyword in sensitive_keywords)
if is_sensitive:
check_permissions(access_token, ["data:read_table_data"])
# 安全检查:禁止危险操作
dangerous_keywords = ['drop', 'delete', 'update', 'insert', 'alter', 'create', 'truncate']
if any(keyword in sql_query.lower() for keyword in dangerous_keywords):
raise ToolError("安全限制:不允许执行修改数据的操作")
try:
initialize_services()
# 添加LIMIT限制
if 'limit' not in sql_query.lower():
sql_query = f"{sql_query.rstrip(';')} LIMIT {limit}"
result_data = db_manager.execute_query(sql_query)
if result_data is None:
raise ToolError("查询执行失败")
# 获取列名(如果有数据的话)
columns = list(result_data[0].keys()) if result_data else []
return {
"user_id": access_token.client_id,
"query": sql_query,
"row_count": len(result_data),
"columns": columns,
"data": result_data,
"message": f"查询成功,返回 {len(result_data)} 行数据"
}
except Exception as e:
raise ToolError(f"查询执行失败: {str(e)}")
@mcp.tool
async def get_user_permissions(ctx: Context) -> dict:
"""
获取当前用户的权限信息
无需特殊权限,但需要有效的访问令牌
"""
try:
print(ctx)
access_token: AccessToken = get_access_token()
print(f'access_token: {access_token}')
# 如果没有访问令牌,返回默认信息
if access_token is None:
return {
"user_id": "anonymous",
"scopes": [],
"permissions": {
"can_read_tables": False,
"can_read_table_data": False
},
"message": "未认证用户,无权限"
}
return {
"user_id": access_token.client_id or "unknown",
"scopes": access_token.scopes or [],
"permissions": {
"can_read_tables": "data:read_tables" in (access_token.scopes or []),
"can_read_table_data": "data:read_table_data" in (access_token.scopes or []),
},
"message": "权限信息获取成功"
}
except Exception as e:
# 如果获取权限时出错,返回错误信息但不抛出异常
return {
"user_id": "error",
"scopes": [],
"permissions": {
"can_read_tables": False,
"can_read_table_data": False
},
"message": f"权限检查出错: {str(e)}"
}
# 添加一个不需要权限的健康检查工具
@mcp.tool
async def health_check(ctx: Context) -> Dict[str, Any]:
"""
健康检查
无需任何权限
"""
try:
initialize_services()
return {
"status": "healthy",
"database_connected": db_manager is not None,
"message": "服务运行正常"
}
except Exception as e:
return {
"status": "unhealthy",
"database_connected": False,
"message": f"服务异常: {str(e)}"
}
if __name__ == "__main__":
# 从环境变量获取配置
host = os.getenv('MCP_HOST', '127.0.0.1')
port = int(os.getenv('MCP_PORT', 8000))
print(f"🚀 启动MCP数据查询服务器...")
print(f"📍 地址: http://{host}:{port}")
print(f"📋 可用工具:")
print(f" - health_check: 健康检查")
print(f" - get_user_permissions: 获取用户权限")
print(f" - get_database_tables: 获取数据库表列表")
print(f" - get_table_structure: 获取表结构")
print(f" - execute_sql_query: 执行SQL查询")
print(f" - generate_sql_from_question: 自然语言生成SQL")
print(f" - analyze_query_result: 查询结果分析")
mcp.run(transport="streamable-http", host=host, port=port)