from typing import Any, Optional, List, Dict
import mysql.connector
from mcp.server.fastmcp import FastMCP
import sys
import logging
import os
import argparse
import time
import json
import base64
import re
from pydantic import BaseModel, Field
from typing import Literal
from abc import ABC, abstractmethod
from dotenv import load_dotenv
load_dotenv()
# -------------------------------------------------
# Logging
# -------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("multi-db-mcp-server")
mcp = FastMCP("MySQL Database Explorer", log_level="INFO")
# Add more detailed logging configuration
logger.setLevel(logging.DEBUG)
# Create file handler for detailed logs
file_handler = logging.FileHandler('mcp_server.log')
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# -------------------------------------------------
# Database abstraction
# -------------------------------------------------
class DatabaseConnection(ABC):
@abstractmethod
def connect(self): ...
@abstractmethod
def close(self): ...
@abstractmethod
def cursor(self): ...
@abstractmethod
def commit(self): ...
@abstractmethod
def rollback(self): ...
class MySQLConnection(DatabaseConnection):
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.conn = None
self.allowed_schema: Optional[str] = None
logger.debug(f"Creating MySQL connection with string: {connection_string[:50]}...")
self._parse_connection_string()
def _parse_connection_string(self):
logger.debug("Parsing connection string")
if not self.connection_string.startswith("mysql://"):
raise ValueError("Connection string must start with mysql://")
conn_str = self.connection_string[8:]
auth, rest = conn_str.rsplit("@", 1)
user, password = auth.split(":", 1) if ":" in auth else (auth, "")
host_port, database = rest.rsplit("/", 1)
self.allowed_schema = database
if ":" in host_port:
host, port = host_port.split(":", 1)
port = int(port)
else:
host, port = host_port, 3306
self.config = {
"host": host,
"port": port,
"user": user,
"password": password,
"database": database,
"charset": "utf8mb4",
"autocommit": False,
"raise_on_warnings": True,
"connection_timeout": 10,
}
logger.debug(f"Parsed config: {self.config}")
def connect(self):
logger.info("Connecting to MySQL database")
self.conn = mysql.connector.connect(**self.config)
logger.info("Successfully connected to MySQL database")
return self.conn
def close(self):
logger.debug("Closing database connection")
if self.conn:
self.conn.close()
logger.debug("Database connection closed")
def cursor(self):
logger.debug("Creating cursor")
return self.conn.cursor(dictionary=True)
def commit(self):
logger.debug("Committing transaction")
self.conn.commit()
def rollback(self):
logger.debug("Rolling back transaction")
self.conn.rollback()
def create_database_connection(conn: str) -> MySQLConnection:
logger.debug(f"Creating database connection: {conn[:50]}...")
return MySQLConnection(conn)
# -------------------------------------------------
# CLI / ENV
# -------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--conn", default=os.getenv("DATABASE_CONNECTION_STRING"))
parser.add_argument("--transport", default=os.getenv("MCP_TRANSPORT", "stdio"))
parser.add_argument("--host", default=os.getenv("MCP_HOST"))
parser.add_argument("--port", type=int, default=os.getenv("MCP_PORT"))
parser.add_argument("--mount", default=os.getenv("MCP_SSE_MOUNT"))
parser.add_argument("--readonly", action="store_true", help="Enable read-only mode (prevents INSERT, UPDATE, DELETE, etc.)")
args, _ = parser.parse_known_args()
CONNECTION_STRING = args.conn
READONLY = args.readonly or os.getenv("DATABASE_READONLY", "false").lower() in {"1", "true", "yes"}
logger.info(f"Starting server with transport: {args.transport}")
logger.info(f"Connection string provided: {'Yes' if CONNECTION_STRING else 'No'}")
logger.info(f"Read-only mode enabled: {READONLY}")
# -------------------------------------------------
# Schema enforcement
# -------------------------------------------------
def _validate_schema_access(sql: str, allowed_schema: str):
logger.debug(f"Validating schema access for SQL: {sql[:100]}...")
lowered = sql.lower()
if re.search(r"\buse\s+`?\w+`?", lowered):
logger.warning("USE statement detected - blocking access")
raise PermissionError("USE statements are not allowed")
# Only check for explicit schema.table patterns in FROM/JOIN/UPDATE/INSERT clauses
schema_patterns = [
r"\bfrom\s+`?(\w+)`?\s*\.", # FROM schema.
r"\bjoin\s+`?(\w+)`?\s*\.", # JOIN schema.
r"\bupdate\s+`?(\w+)`?\s*\.", # UPDATE schema.
r"\binsert\s+into\s+`?(\w+)`?\s*\.", # INSERT INTO schema.
]
for pattern in schema_patterns:
matches = re.findall(pattern, lowered, re.IGNORECASE)
for db in matches:
if db.lower() != allowed_schema.lower():
logger.warning(f"Cross-schema access attempt blocked: {db} vs {allowed_schema}")
raise PermissionError(
f"Cross-schema access denied: {db} (allowed: {allowed_schema})"
)
logger.debug("Schema validation passed")
# -------------------------------------------------
# Connection helper
# -------------------------------------------------
def get_connection() -> MySQLConnection:
logger.debug("Getting database connection")
if not CONNECTION_STRING:
logger.error("DATABASE_CONNECTION_STRING not set")
raise RuntimeError("DATABASE_CONNECTION_STRING not set")
db = create_database_connection(CONNECTION_STRING)
db.connect()
logger.debug(f"Setting database context to: {db.allowed_schema}")
with db.cursor() as cur:
cur.execute(f"USE `{db.allowed_schema}`")
cur.execute(
"SET SESSION sql_mode = "
"'STRICT_TRANS_TABLES,NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO'"
)
db.commit()
logger.info(f"Successfully connected to database: {db.allowed_schema}")
return db
# -------------------------------------------------
# Query execution
# -------------------------------------------------
def _is_select_like(sql: str) -> bool:
first_word = sql.lstrip().split(" ", 1)[0].lower()
result = first_word in {"select", "with", "show", "values", "explain"}
logger.debug(f"Checking if query is select-like: '{sql[:30]}...' -> {result}")
return result
def _exec_query(sql, parameters, row_limit, as_json):
logger.info(f"Executing query: {sql[:100]}...")
logger.debug(f"Parameters: {parameters}, Row limit: {row_limit}, As JSON: {as_json}")
db = None
try:
db = get_connection()
_validate_schema_access(sql, db.allowed_schema)
if READONLY and not _is_select_like(sql):
logger.warning("Read-only mode prevented non-SELECT query execution")
return [] if as_json else "Read-only mode enabled"
with db.cursor() as cur:
cur.execute(sql, parameters or None)
if cur.description is None:
db.commit()
result = [] if as_json else f"Rows affected: {cur.rowcount}"
logger.debug(f"Non-query executed, rows affected: {cur.rowcount}")
return result
rows = cur.fetchmany(row_limit)
result = [dict(r) for r in rows] if as_json else json.dumps(rows, indent=2, default=str)
logger.debug(f"Fetched {len(rows)} rows")
return result
except Exception as e:
logger.error(f"Query execution failed: {e}")
return [] if as_json else f"Query error: {e}"
finally:
if db:
db.close()
logger.debug("Database connection closed after query execution")
# -------------------------------------------------
# MCP TOOLS
# -------------------------------------------------
# All database operations are exposed as MCP tools here.
# The agent uses these tools - it does NOT connect directly to the database.
# This separation allows:
# - Reusability: Tools can be used by other agents/clients
# - Testability: Database operations can be tested independently
# - Security: Centralized access control and validation
# -------------------------------------------------
@mcp.tool()
def server_info() -> Dict[str, Any]:
"""Get server and database information. Database type: MySQL."""
logger.info("Server info requested")
return {
"name": "MySQL Database Explorer",
"database_type": "MySQL",
"readonly": READONLY,
"mysql_connector_version": getattr(mysql.connector, "__version__", None),
}
@mcp.tool()
def db_identity() -> Dict[str, Any]:
"""Get current database identity details. Database type: MySQL."""
logger.info("Database identity requested")
db = get_connection()
try:
with db.cursor() as cur:
cur.execute(
"SELECT DATABASE() AS database, USER() AS user, @@hostname AS host, @@port AS port"
)
row = cur.fetchone()
cur.execute("SELECT @@version AS server_version")
version = cur.fetchone()
result = {"database_type": "MySQL", **row, **version}
logger.debug(f"Database identity retrieved: {result}")
return result
finally:
db.close()
class QueryInput(BaseModel):
sql: str
parameters: Optional[List[Any]] = None
row_limit: int = 500
format: Literal["markdown", "json"] = "markdown"
@mcp.tool()
def run_query(input: QueryInput) -> str:
"""Execute SQL queries on MySQL database. Use MySQL syntax (e.g., SHOW TABLES, DESCRIBE, backticks for identifiers)."""
logger.info(f"Run query tool called with SQL: {input.sql[:100]}...")
as_json = input.format == "json"
res = _exec_query(input.sql, input.parameters, input.row_limit, as_json)
logger.debug(f"Query result: {str(res)[:200]}...")
return json.dumps(res, default=str) if as_json else res
class QueryJSONInput(BaseModel):
sql: str
parameters: Optional[List[Any]] = None
row_limit: int = 500
@mcp.tool()
def run_query_json(input: QueryJSONInput) -> List[Dict[str, Any]]:
"""Execute SQL queries on MySQL database and return JSON. Use MySQL syntax (e.g., SHOW TABLES, DESCRIBE, backticks for identifiers)."""
logger.info(f"Run query JSON tool called with SQL: {input.sql[:100]}...")
result = _exec_query(input.sql, input.parameters, input.row_limit, True)
logger.debug(f"Query JSON result length: {len(result) if isinstance(result, list) else 'N/A'}")
return result
@mcp.tool()
def list_table_resources(schema: Optional[str] = None) -> List[str]:
logger.info("List table resources requested")
db = get_connection()
try:
schema = schema or db.allowed_schema
logger.debug(f"Listing tables in schema: {schema}")
rows = _exec_query(f"SHOW TABLES FROM `{schema}`", None, 10000, True)
result = [f"table://{schema}/{list(r.values())[0]}" for r in rows]
logger.debug(f"Found {len(result)} tables")
return result
finally:
db.close()
@mcp.tool()
def read_table_resource(schema: str, table: str, row_limit: int = 100) -> List[Dict[str, Any]]:
logger.info(f"Reading table resource: {schema}.{table}")
result = _exec_query(f"SELECT * FROM `{schema}`.`{table}`", None, row_limit, True)
logger.debug(f"Table read completed, rows: {len(result) if isinstance(result, list) else 'N/A'}")
return result
@mcp.tool()
def list_tables(db_schema: Optional[str] = None) -> str:
logger.info("List tables requested")
db = get_connection()
schema = db_schema or db.allowed_schema
db.close()
result = _exec_query(f"SHOW TABLES FROM `{schema}`", None, 500, False)
logger.debug(f"Tables listed: {result[:200] if isinstance(result, str) else 'JSON response'}")
return result
@mcp.tool()
def describe_table(table_name: str, db_schema: Optional[str] = None) -> str:
logger.info(f"Describe table requested: {table_name}")
db = get_connection()
schema = db_schema or db.allowed_schema
db.close()
result = _exec_query(f"DESCRIBE `{schema}`.`{table_name}`", None, 500, False)
logger.debug(f"Table description completed")
return result
@mcp.tool()
def get_foreign_keys(table_name: str, db_schema: Optional[str] = None) -> str:
logger.info(f"Get foreign keys requested for table: {table_name}")
db = get_connection()
schema = db_schema or db.allowed_schema
db.close()
result = _exec_query(f"SHOW CREATE TABLE `{schema}`.`{table_name}`", None, 500, False)
logger.debug(f"Foreign keys retrieved")
return result
# -------------------------------------------------
# ❌ SCHEMA LISTING TOOLS REMOVED AS REQUESTED
# -------------------------------------------------
# -------------------------------------------------
# Main
# -------------------------------------------------
if __name__ == "__main__":
logger.info("Starting MCP server")
if args.host:
mcp.settings.host = args.host
if args.port:
mcp.settings.port = args.port
logger.info(f"Server configured to run on {args.host}:{args.port if args.port else 'default'}")
mcp.run(transport=args.transport)
logger.info("MCP server stopped")