mysqldb-mcp-server
by burakdirin
Verified
- mysqldb-mcp-server
- src
- mysqldb_mcp_server
import json
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Any, AsyncIterator, Dict, List, Optional, Union
import mysql.connector
from dotenv import load_dotenv
from mcp.server.fastmcp import Context, FastMCP
from mysql.connector import Error as MySQLError
from mysql.connector.cursor import MySQLCursor
class MySQLServerError(Exception):
"""Base exception for MySQL server errors"""
pass
class ConnectionError(MySQLServerError):
"""Raised when there's an issue with the database connection"""
pass
class QueryError(MySQLServerError):
"""Raised when there's an issue executing a query"""
pass
class QueryType(Enum):
"""Enum for different types of SQL queries"""
SELECT = "SELECT"
INSERT = "INSERT"
UPDATE = "UPDATE"
DELETE = "DELETE"
CREATE = "CREATE"
DROP = "DROP"
ALTER = "ALTER"
TRUNCATE = "TRUNCATE"
USE = "USE"
SHOW = "SHOW"
DESCRIBE = "DESCRIBE"
@classmethod
def is_write_operation(cls, query_type: str) -> bool:
"""Check if the query type is a write operation"""
write_operations = {
cls.INSERT, cls.UPDATE, cls.DELETE,
cls.CREATE, cls.DROP, cls.ALTER, cls.TRUNCATE
}
try:
return cls(query_type.upper()) in write_operations
except ValueError:
return False
@dataclass
class MySQLContext:
"""Context for MySQL connection"""
host: str
user: str
password: str
database: Optional[str]
readonly: bool
connection: Optional[mysql.connector.MySQLConnection] = None
def ensure_connected(self) -> None:
"""Ensure database connection is available, connecting lazily if needed"""
if not self.connection or not self.connection.is_connected():
config = {
"host": self.host,
"user": self.user,
"password": self.password,
}
if self.database:
config["database"] = self.database
try:
self.connection = mysql.connector.connect(**config)
except MySQLError as e:
raise ConnectionError(
f"Failed to connect to database: {str(e)}")
class QueryExecutor:
"""Handles MySQL query execution and result processing"""
def __init__(self, context: MySQLContext):
self.context = context
def _format_datetime(self, value: Any) -> Any:
"""Format datetime values to string"""
return value.strftime('%Y-%m-%d %H:%M:%S') if hasattr(value, 'strftime') else value
def _process_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single row of results"""
return {key: self._format_datetime(value) for key, value in row.items()}
def _process_results(self, cursor: MySQLCursor) -> Union[List[Dict[str, Any]], Dict[str, int]]:
"""Process query results"""
if cursor.with_rows:
results = cursor.fetchall()
return [self._process_row(row) for row in results]
return {"affected_rows": cursor.rowcount}
def execute_single_query(self, query: str) -> Dict[str, Any]:
"""Execute a single query and return results"""
self.context.ensure_connected()
cursor = None
try:
cursor = self.context.connection.cursor(dictionary=True)
query_type = QueryType(query.strip().upper().split()[0])
# Handle readonly mode
if self.context.readonly and QueryType.is_write_operation(query_type.value):
raise QueryError(
"Server is in read-only mode. Write operations are not allowed.")
# Handle USE statements
if query_type == QueryType.USE:
db_name = query.strip().split()[-1].strip('`').strip()
self.context.database = db_name
cursor.execute(query)
return {"message": f"Switched to database: {db_name}"}
# Execute query
cursor.execute(query)
results = self._process_results(cursor)
if not self.context.readonly:
self.context.connection.commit()
return results
except MySQLError as e:
raise QueryError(f"Error executing query: {str(e)}")
finally:
if cursor:
cursor.close()
def execute_multiple_queries(self, query: str) -> List[Dict[str, Any]]:
"""Execute multiple queries and return results"""
queries = [q.strip() for q in query.split(';') if q.strip()]
results = []
for single_query in queries:
try:
result = self.execute_single_query(single_query)
results.append(result)
except QueryError as e:
results.append({"error": str(e)})
return results
def get_env_vars() -> tuple[str, str, str, Optional[str], bool]:
"""Get MySQL connection settings from environment variables
Returns:
Tuple of (host, user, password, database, readonly)
"""
load_dotenv()
host = os.getenv("MYSQL_HOST", "localhost")
user = os.getenv("MYSQL_USER", "root")
password = os.getenv("MYSQL_PASSWORD", "")
database = os.getenv("MYSQL_DATABASE") # Optional
readonly = os.getenv("MYSQL_READONLY", "0") in ("1", "true", "True")
return host, user, password, database, readonly
@asynccontextmanager
async def mysql_lifespan(server: FastMCP) -> AsyncIterator[MySQLContext]:
"""MySQL connection lifecycle manager"""
# Get connection settings from environment variables
host, user, password, database, readonly = get_env_vars()
# Initialize context without connecting
ctx = MySQLContext(
host=host,
user=user,
password=password,
database=database,
readonly=readonly,
connection=None # Don't connect immediately
)
try:
yield ctx
finally:
if ctx.connection and ctx.connection.is_connected():
ctx.connection.close()
# Create MCP server instance
mcp = FastMCP("MySQL Explorer", lifespan=mysql_lifespan)
def _get_executor(ctx: Context) -> QueryExecutor:
"""Helper function to get QueryExecutor from context"""
mysql_ctx = ctx.request_context.lifespan_context
return QueryExecutor(mysql_ctx)
@mcp.tool()
def connect_database(database: str, ctx: Context) -> str:
"""Connect to a specific MySQL database"""
try:
executor = _get_executor(ctx)
result = executor.execute_single_query(f"USE `{database}`")
return json.dumps(result, indent=2)
except (ConnectionError, QueryError) as e:
return str(e)
@mcp.tool()
def execute_query(query: str, ctx: Context) -> str:
"""Execute MySQL queries"""
try:
executor = _get_executor(ctx)
results = executor.execute_multiple_queries(query)
if len(results) == 1:
return json.dumps(results[0], indent=2)
return json.dumps(results, indent=2)
except (ConnectionError, QueryError) as e:
return str(e)
if __name__ == "__main__":
mcp.run()