mysql_mcp_zh.py•12.7 kB
from fastmcp import FastMCP
import mysql.connector
from mysql.connector import Error
import os
import logging
import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import argparse
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("mysql_mcp_zh.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("mysql_mcp_zh")
# 加载环境变量
load_dotenv()
# 初始化FastMCP应用
mcp = FastMCP(
"MySQL MCP",
description="MySQL数据库Claude连接器",
dependencies=["mysql-connector-python", "python-dotenv"]
)
# 数据库连接配置
class DBConfig(BaseModel):
host: str = Field(default=os.getenv("MYSQL_HOST", "localhost"))
port: int = Field(default=int(os.getenv("MYSQL_PORT", "3306")))
user: str = Field(default=os.getenv("MYSQL_USER", "root"))
password: str = Field(default=os.getenv("MYSQL_PASSWORD", ""))
database: Optional[str] = Field(default=os.getenv("MYSQL_DATABASE"))
# 全局连接状态
current_db = os.getenv("MYSQL_DATABASE", "")
config = DBConfig()
def get_connection():
"""创建MySQL连接,使用当前配置"""
try:
logger.info(f"正在连接到MySQL服务器 {config.host}:{config.port} 用户 {config.user}")
conn = mysql.connector.connect(
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database if config.database else None
)
logger.info(f"连接成功建立. 使用数据库: {config.database or '未选择'}")
return conn
except Error as e:
logger.error(f"数据库连接错误: {e}")
raise Exception(f"数据库连接错误: {e}")
@mcp.tool()
def query_sql(query: str) -> Dict[str, Any]:
"""执行SELECT查询并返回结果"""
start_time = datetime.datetime.now()
logger.info(f"执行SELECT查询: {query}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query)
results = cursor.fetchall()
execution_time = (datetime.datetime.now() - start_time).total_seconds()
logger.info(f"查询执行成功. 获取 {len(results[:100])} 行数据. 执行时间: {execution_time:.3f}秒")
return {
"rows": results[:100], # 为安全起见限制100行
"row_count": cursor.rowcount,
"column_names": [desc[0] for desc in cursor.description] if cursor.description else [],
"execution_time_seconds": execution_time
}
except Error as e:
logger.error(f"查询错误: {e}")
raise Exception(f"查询错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def execute_sql(query: str) -> Dict[str, Any]:
"""执行非SELECT查询(INSERT, UPDATE, DELETE等)"""
start_time = datetime.datetime.now()
logger.info(f"执行非SELECT查询: {query}")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute(query)
conn.commit()
execution_time = (datetime.datetime.now() - start_time).total_seconds()
logger.info(f"查询执行成功. 受影响行数: {cursor.rowcount}. 执行时间: {execution_time:.3f}秒")
return {
"affected_rows": cursor.rowcount,
"last_insert_id": cursor.lastrowid if cursor.lastrowid else None,
"execution_time_seconds": execution_time
}
except Error as e:
conn.rollback()
logger.error(f"查询错误: {e}")
logger.info("事务已回滚")
raise Exception(f"查询错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def explain_sql(query: str) -> Dict[str, Any]:
"""获取查询的执行计划"""
logger.info(f"分析查询: {query}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
explain_query = f"EXPLAIN {query}"
cursor.execute(explain_query)
results = cursor.fetchall()
logger.info(f"EXPLAIN执行成功. 计划步骤: {len(results)}")
return {
"plan": results
}
except Error as e:
logger.error(f"EXPLAIN错误: {e}")
raise Exception(f"EXPLAIN错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def show_databases() -> Dict[str, Any]:
"""列出所有可用的数据库"""
logger.info("正在列出所有数据库")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW DATABASES")
results = cursor.fetchall()
databases = [db[0] for db in results]
logger.info(f"找到 {len(databases)} 个数据库")
return {
"databases": databases
}
except Error as e:
logger.error(f"列出数据库错误: {e}")
raise Exception(f"列出数据库错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def use_database(database: str) -> Dict[str, Any]:
"""切换到不同的数据库"""
global config, current_db
logger.info(f"切换到数据库: {database}")
# 验证数据库是否存在
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW DATABASES")
dbs = [db[0] for db in cursor.fetchall()]
if database not in dbs:
logger.error(f"数据库'{database}'不存在")
raise ValueError(f"数据库'{database}'不存在")
# 更新配置
config.database = database
current_db = database
logger.info(f"成功切换到数据库: {database}")
return {
"current_database": database,
"status": "success"
}
except Error as e:
logger.error(f"切换数据库错误: {e}")
raise Exception(f"切换数据库错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def show_tables() -> Dict[str, Any]:
"""列出当前数据库中的所有表"""
if not config.database:
logger.error("未选择数据库")
raise ValueError("未选择数据库。请先使用'use_database'。")
logger.info(f"正在列出数据库 {config.database} 中的表")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW TABLES")
results = cursor.fetchall()
tables = [table[0] for table in results]
logger.info(f"在数据库 {config.database} 中找到 {len(tables)} 个表")
return {
"database": config.database,
"tables": tables
}
except Error as e:
logger.error(f"列出表错误: {e}")
raise Exception(f"列出表错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def describe_table(table: str) -> Dict[str, Any]:
"""获取表的列定义"""
if not config.database:
logger.error("未选择数据库")
raise ValueError("未选择数据库。请先使用'use_database'。")
logger.info(f"描述数据库 {config.database} 中的表 {table}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(f"DESCRIBE {table}")
columns = cursor.fetchall()
# 获取索引信息
cursor.execute(f"SHOW INDEX FROM {table}")
indexes = cursor.fetchall()
logger.info(f"表 {table} 有 {len(columns)} 列和 {len(indexes)} 个索引条目")
return {
"table": table,
"columns": columns,
"indexes": indexes
}
except Error as e:
logger.error(f"描述表错误: {e}")
raise Exception(f"描述表错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.tool()
def show_table_sample(table: str, limit: int = 5) -> Dict[str, Any]:
"""显示表中的样本行"""
if not config.database:
logger.error("未选择数据库")
raise ValueError("未选择数据库。请先使用'use_database'。")
logger.info(f"从数据库 {config.database} 的表 {table} 中抽样 {limit} 行")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
# 安全限制行数
safe_limit = min(limit, 100)
cursor.execute(f"SELECT * FROM {table} LIMIT {safe_limit}")
rows = cursor.fetchall()
logger.info(f"从表 {table} 中检索到 {len(rows)} 个样本行")
return {
"table": table,
"sample_rows": rows,
"sample_size": len(rows)
}
except Error as e:
logger.error(f"抽样表错误: {e}")
raise Exception(f"抽样表错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.resource(f"schema://{'{database}'}")
def get_database_schema(database: Optional[str] = None) -> str:
"""获取数据库的完整模式作为资源"""
db_to_use = database or config.database
if not db_to_use:
logger.error("未指定或选择数据库")
raise ValueError("未指定或选择数据库")
logger.info(f"获取数据库架构: {db_to_use}")
conn = get_connection()
cursor = conn.cursor()
schema = []
try:
# 切换到指定数据库
cursor.execute(f"USE {db_to_use}")
# 获取所有表
cursor.execute("SHOW TABLES")
tables = [table[0] for table in cursor.fetchall()]
# 获取每个表的CREATE TABLE语句
for table in tables:
cursor.execute(f"SHOW CREATE TABLE {table}")
create_stmt = cursor.fetchone()[1]
schema.append(create_stmt)
logger.info(f"获取了数据库 {db_to_use} 中 {len(tables)} 个表的架构")
return "\n\n".join(schema)
except Error as e:
logger.error(f"获取架构错误: {e}")
raise Exception(f"获取架构错误: {e}")
finally:
cursor.close()
conn.close()
logger.debug("连接已关闭")
@mcp.prompt()
def write_query_for_task(task: str) -> str:
"""帮助Claude为给定任务编写最优SQL查询"""
logger.info(f"为任务提供编写SQL查询提示: {task}")
return f"""任务: {task}
请编写一个能够高效完成此任务的SQL查询。
一些指导原则:
1. 根据数据关系使用适当的连接(INNER, LEFT, RIGHT)
2. 在WHERE子句中过滤数据以减少数据处理
3. 考虑使用索引以提高性能
4. 在需要时使用适当的聚合函数
5. 使用清晰的缩进格式化查询,提高可读性
如果您需要先查看数据库模式,可以使用schema://资源访问。
"""
@mcp.prompt()
def analyze_query_performance(query: str) -> str:
"""帮助Claude分析查询性能"""
logger.info(f"为查询提供性能分析提示: {query}")
return f"""查询: {query}
请分析此查询的性能问题:
1. 首先使用explain_sql工具获取执行计划
2. 查找表扫描而非索引使用的情况
3. 检查连接是否高效
4. 确定查询是否可以通过更好的索引进行优化
5. 提出具体的改进建议,使查询更高效
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", help="MySQL host")
parser.add_argument("--port", type=int, help="MySQL port")
parser.add_argument("--user", help="MySQL user")
parser.add_argument("--password", help="MySQL password")
parser.add_argument("--database", help="MySQL database")
args = parser.parse_args()
# 如果命令行提供了参数,则覆盖环境变量
if args.host:
os.environ["MYSQL_HOST"] = args.host
if args.port:
os.environ["MYSQL_PORT"] = str(args.port)
if args.user:
os.environ["MYSQL_USER"] = args.user
if args.password:
os.environ["MYSQL_PASSWORD"] = args.password
if args.database:
os.environ["MYSQL_DATABASE"] = args.database
# 直接运行服务器
logger.info("启动MySQL MCP服务器 参数: host={args.host}, port={args.port}, user={args.user}, password={args.password}, database={args.database}")
mcp.run()