mysql_mcp.py•13.2 kB
import argparse
from fastmcp import FastMCP
import mysql.connector
from mysql.connector import Error
import os
import logging
import datetime
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("mysql_mcp.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("mysql_mcp")
# Load environment variables
load_dotenv()
# Initialize FastMCP app
mcp = FastMCP(
"MySQL MCP",
description="MySQL database connector for Claude",
dependencies=["mysql-connector-python", "python-dotenv"]
)
# Database connection configuration
class DBConfig(BaseModel):
host: str = Field(default=os.getenv("MYSQL_HOST", "localhost"))
port: int = Field(default=int(os.getenv("MYSQL_PORT", "3306")))
user: str = Field(default=os.getenv("MYSQL_USER", "root"))
password: str = Field(default=os.getenv("MYSQL_PASSWORD", ""))
database: Optional[str] = Field(default=os.getenv("MYSQL_DATABASE"))
# Global connection state
current_db = os.getenv("MYSQL_DATABASE", "")
config = DBConfig()
def get_connection():
"""Create a MySQL connection using the current configuration"""
try:
logger.info(f"Connecting to MySQL server at {config.host}:{config.port} with user {config.user}")
conn = mysql.connector.connect(
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database if config.database else None
)
logger.info(f"Connection established successfully. Using database: {config.database or 'None'}")
return conn
except Error as e:
logger.error(f"Database connection error: {e}")
raise Exception(f"Database connection error: {e}")
@mcp.tool()
def query_sql(query: str) -> Dict[str, Any]:
"""Execute a SELECT query and return the results"""
start_time = datetime.datetime.now()
logger.info(f"Executing SELECT query: {query}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(query)
results = cursor.fetchall()
execution_time = (datetime.datetime.now() - start_time).total_seconds()
logger.info(f"Query executed successfully. Fetched {len(results[:100])} rows. Execution time: {execution_time:.3f}s")
return {
"rows": results[:100], # Limit to 100 rows for safety
"row_count": cursor.rowcount,
"column_names": [desc[0] for desc in cursor.description] if cursor.description else [],
"execution_time_seconds": execution_time
}
except Error as e:
logger.error(f"Query error: {e}")
raise Exception(f"Query error: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def execute_sql(query: str) -> Dict[str, Any]:
"""Execute a non-SELECT query (INSERT, UPDATE, DELETE, etc.)"""
start_time = datetime.datetime.now()
logger.info(f"Executing non-SELECT query: {query}")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute(query)
conn.commit()
execution_time = (datetime.datetime.now() - start_time).total_seconds()
logger.info(f"Query executed successfully. Affected rows: {cursor.rowcount}. Execution time: {execution_time:.3f}s")
return {
"affected_rows": cursor.rowcount,
"last_insert_id": cursor.lastrowid if cursor.lastrowid else None,
"execution_time_seconds": execution_time
}
except Error as e:
conn.rollback()
logger.error(f"Query error: {e}")
logger.info("Transaction rolled back")
raise Exception(f"Query error: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def explain_sql(query: str) -> Dict[str, Any]:
"""Get the execution plan for a query"""
logger.info(f"Explaining query: {query}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
explain_query = f"EXPLAIN {query}"
cursor.execute(explain_query)
results = cursor.fetchall()
logger.info(f"EXPLAIN executed successfully. Plan steps: {len(results)}")
return {
"plan": results
}
except Error as e:
logger.error(f"EXPLAIN error: {e}")
raise Exception(f"EXPLAIN error: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def show_databases() -> Dict[str, Any]:
"""List all available databases"""
logger.info("Listing all databases")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW DATABASES")
results = cursor.fetchall()
databases = [db[0] for db in results]
logger.info(f"Found {len(databases)} databases")
return {
"databases": databases
}
except Error as e:
logger.error(f"Error listing databases: {e}")
raise Exception(f"Error listing databases: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def use_database(database: str) -> Dict[str, Any]:
"""Switch to a different database"""
global config, current_db
logger.info(f"Changing database to: {database}")
# Verify database exists
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW DATABASES")
dbs = [db[0] for db in cursor.fetchall()]
if database not in dbs:
logger.error(f"Database '{database}' does not exist")
raise ValueError(f"Database '{database}' does not exist")
# Update configuration
config.database = database
current_db = database
logger.info(f"Successfully changed to database: {database}")
return {
"current_database": database,
"status": "success"
}
except Error as e:
logger.error(f"Error changing database: {e}")
raise Exception(f"Error changing database: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def show_tables() -> Dict[str, Any]:
"""List all tables in the current database"""
if not config.database:
logger.error("No database selected")
raise ValueError("No database selected. Use 'use_database' first.")
logger.info(f"Listing tables in database: {config.database}")
conn = get_connection()
cursor = conn.cursor()
try:
cursor.execute("SHOW TABLES")
results = cursor.fetchall()
tables = [table[0] for table in results]
logger.info(f"Found {len(tables)} tables in database {config.database}")
return {
"database": config.database,
"tables": tables
}
except Error as e:
logger.error(f"Error listing tables: {e}")
raise Exception(f"Error listing tables: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def describe_table(table: str) -> Dict[str, Any]:
"""Get column definitions for a table"""
if not config.database:
logger.error("No database selected")
raise ValueError("No database selected. Use 'use_database' first.")
logger.info(f"Describing table {table} in database {config.database}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
cursor.execute(f"DESCRIBE {table}")
columns = cursor.fetchall()
# Get index information
cursor.execute(f"SHOW INDEX FROM {table}")
indexes = cursor.fetchall()
logger.info(f"Table {table} has {len(columns)} columns and {len(indexes)} index entries")
return {
"table": table,
"columns": columns,
"indexes": indexes
}
except Error as e:
logger.error(f"Error describing table: {e}")
raise Exception(f"Error describing table: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.tool()
def show_table_sample(table: str, limit: int = 5) -> Dict[str, Any]:
"""Show a sample of rows from a table"""
if not config.database:
logger.error("No database selected")
raise ValueError("No database selected. Use 'use_database' first.")
logger.info(f"Sampling {limit} rows from table {table} in database {config.database}")
conn = get_connection()
cursor = conn.cursor(dictionary=True)
try:
# Safely limit the number of rows
safe_limit = min(limit, 100)
cursor.execute(f"SELECT * FROM {table} LIMIT {safe_limit}")
rows = cursor.fetchall()
logger.info(f"Retrieved {len(rows)} sample rows from table {table}")
return {
"table": table,
"sample_rows": rows,
"sample_size": len(rows)
}
except Error as e:
logger.error(f"Error sampling table: {e}")
raise Exception(f"Error sampling table: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.resource(f"schema://{'{database}'}")
def get_database_schema(database: Optional[str] = None) -> str:
"""Get the full schema of a database as a resource"""
db_to_use = database or config.database
if not db_to_use:
logger.error("No database specified or selected")
raise ValueError("No database specified or selected")
logger.info(f"Getting schema for database: {db_to_use}")
conn = get_connection()
cursor = conn.cursor()
schema = []
try:
# Switch to the specified database
cursor.execute(f"USE {db_to_use}")
# Get all tables
cursor.execute("SHOW TABLES")
tables = [table[0] for table in cursor.fetchall()]
# Get CREATE TABLE statements for each table
for table in tables:
cursor.execute(f"SHOW CREATE TABLE {table}")
create_stmt = cursor.fetchone()[1]
schema.append(create_stmt)
logger.info(f"Retrieved schema for {len(tables)} tables in database {db_to_use}")
return "\n\n".join(schema)
except Error as e:
logger.error(f"Error getting schema: {e}")
raise Exception(f"Error getting schema: {e}")
finally:
cursor.close()
conn.close()
logger.debug("Connection closed")
@mcp.prompt()
def write_query_for_task(task: str) -> str:
"""Help Claude write an optimal SQL query for a given task"""
logger.info(f"Providing write_query_for_task prompt for task: {task}")
return f"""Task: {task}
Please write an SQL query that accomplishes this task efficiently.
Some guidelines:
1. Use appropriate JOINs (INNER, LEFT, RIGHT) based on the data relationships
2. Filter data in the WHERE clause to minimize data processing
3. Consider using indexes for better performance
4. Use appropriate aggregation functions when needed
5. Format the query with clear indentation for readability
If you need to see the database schema first, you can access it using the schema:// resource.
"""
@mcp.prompt()
def analyze_query_performance(query: str) -> str:
"""Help Claude analyze the performance of a query"""
logger.info(f"Providing analyze_query_performance prompt for query: {query}")
return f"""Query: {query}
Please analyze this query for performance issues:
1. First, use the explain_sql tool to get the execution plan
2. Look for table scans instead of index usage
3. Check if the joins are efficient
4. Identify if the query can be optimized with better indexes
5. Suggest concrete improvements to make the query more efficient
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", help="MySQL host")
parser.add_argument("--port", type=int, help="MySQL port")
parser.add_argument("--user", help="MySQL user")
parser.add_argument("--password", help="MySQL password")
parser.add_argument("--database", help="MySQL database")
args = parser.parse_args()
# 如果命令行提供了参数,则覆盖环境变量
if args.host:
os.environ["MYSQL_HOST"] = args.host
if args.port:
os.environ["MYSQL_PORT"] = str(args.port)
if args.user:
os.environ["MYSQL_USER"] = args.user
if args.password:
os.environ["MYSQL_PASSWORD"] = args.password
if args.database:
os.environ["MYSQL_DATABASE"] = args.database
# Run the server directly
logger.info(f"Starting MySQL MCP server 参数: host={args.host}, port={args.port}, user={args.user}, password={args.password}, database={args.database}")
mcp.run()