server.py•7.94 kB
import asyncio
import logging
import os
import sys
import traceback
from mysql.connector import connect, Error
from mcp.server import Server
from mcp.types import Resource, Tool, TextContent
from pydantic import AnyUrl
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stderr
)
logger = logging.getLogger("mysql_mcp_server")
def get_db_config():
"""Get database configuration from environment variables."""
config = {
"host": os.getenv("MYSQL_HOST", "localhost"),
"port": int(os.getenv("MYSQL_PORT", "3306")),
"user": os.getenv("MYSQL_USER"),
"password": os.getenv("MYSQL_PASSWORD"),
"database": os.getenv("MYSQL_DATABASE"),
"connect_timeout": 5,
"use_pure": True
}
if not all([config["user"], config["password"], config["database"]]):
logger.error("Missing required database configuration. MYSQL_USER, MYSQL_PASSWORD, and MYSQL_DATABASE are required")
raise ValueError("Missing required database configuration")
return config
def test_db_connection():
"""Test the database connection and return True if successful."""
config = get_db_config()
try:
conn = connect(**config)
conn.close()
logger.info("Database connection test successful")
return True
except Error as e:
logger.error(f"Database connection test failed: {str(e)}")
return False
# Initialize server
app = Server("mysql_mcp_server")
@app.list_resources()
async def list_resources() -> list[Resource]:
"""List MySQL tables as resources."""
try:
config = get_db_config()
with connect(**config) as conn:
with conn.cursor() as cursor:
cursor.execute("SHOW TABLES")
tables = cursor.fetchall()
logger.info(f"Found tables: {tables}")
resources = []
for table in tables:
resources.append(
Resource(
uri=f"mysql://{table[0]}/data",
name=f"Table: {table[0]}",
mimeType="text/plain",
description=f"Data in table: {table[0]}"
)
)
return resources
except Error as e:
logger.error(f"Failed to list resources: {str(e)}")
return []
except Exception as e:
logger.error(f"Unexpected error in list_resources: {str(e)}")
logger.error(traceback.format_exc())
return []
@app.read_resource()
async def read_resource(uri: AnyUrl) -> str:
"""Read table contents."""
uri_str = str(uri)
logger.info(f"Reading resource: {uri_str}")
if not uri_str.startswith("mysql://"):
logger.error(f"Invalid URI scheme: {uri_str}")
return f"Invalid URI scheme: {uri_str}"
try:
parts = uri_str[8:].split('/')
table = parts[0]
config = get_db_config()
with connect(**config) as conn:
with conn.cursor() as cursor:
cursor.execute(f"SELECT * FROM {table} LIMIT 100")
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
result = [",".join(map(str, row)) for row in rows]
return "\n".join([",".join(columns)] + result)
except Error as e:
logger.error(f"Database error reading resource {uri}: {str(e)}")
return f"Database error: {str(e)}"
except Exception as e:
logger.error(f"Unexpected error in read_resource: {str(e)}")
logger.error(traceback.format_exc())
return f"Unexpected error: {str(e)}"
@app.list_tools()
async def list_tools() -> list[Tool]:
"""List available MySQL tools."""
return [
Tool(
name="execute_sql",
description="Execute an SQL query on the MySQL server",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The SQL query to execute"
}
},
"required": ["query"]
}
)
]
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Execute tools."""
# SQL execution
if name == "execute_sql":
query = arguments.get("query")
if not query:
logger.error("Query is required but not provided")
return [TextContent(type="text", text="Error: Query is required")]
logger.info(f"Executing SQL query: {query}")
try:
config = get_db_config()
with connect(**config) as conn:
with conn.cursor() as cursor:
cursor.execute(query)
# Special handling for SHOW TABLES
if query.strip().upper().startswith("SHOW TABLES"):
tables = cursor.fetchall()
result = ["Tables_in_" + config["database"]] # Header
result.extend([table[0] for table in tables])
return [TextContent(type="text", text="\n".join(result))]
# Regular SELECT queries
elif query.strip().upper().startswith("SELECT"):
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
result = [",".join(map(str, row)) for row in rows]
return [TextContent(type="text", text="\n".join([",".join(columns)] + result))]
# Non-SELECT queries
else:
conn.commit()
return [TextContent(type="text", text=f"Query executed successfully. Rows affected: {cursor.rowcount}")]
except Error as e:
logger.error(f"Error executing SQL '{query}': {e}")
return [TextContent(type="text", text=f"Error executing query: {str(e)}")]
except Exception as e:
logger.error(f"Unexpected error executing SQL '{query}': {e}")
logger.error(traceback.format_exc())
return [TextContent(type="text", text=f"Unexpected error: {str(e)}")]
# Return error for any other tool
logger.error(f"Unknown tool: {name}")
return [TextContent(type="text", text=f"Unknown tool: {name}")]
async def main():
"""Main entry point to run the MCP server."""
from mcp.server.stdio import stdio_server
logger.info("Starting MySQL MCP server...")
try:
# Test database connection first
db_connected = test_db_connection()
if not db_connected:
logger.warning("Database connection test failed. Server will start but may have limited functionality.")
async with stdio_server() as (read_stream, write_stream):
try:
await app.run(
read_stream,
write_stream,
app.create_initialization_options()
)
except Exception as e:
logger.error(f"Server error in main loop: {str(e)}")
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Startup error: {str(e)}")
logger.error(traceback.format_exc())
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as e:
logger.error(f"Fatal error: {str(e)}")
logger.error(traceback.format_exc())
print(f"Fatal error: {str(e)}", file=sys.stderr)
sys.exit(1)