import logging
import os
import argparse
import time
import json
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional
import mysql.connector
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from pymilvus import connections, Collection
load_dotenv()
# -------------------------------------------------
# Logging
# -------------------------------------------------
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("multi-db-mcp-server")
mcp = FastMCP("MySQL Database Explorer", log_level="INFO")
# Add more detailed logging configuration
logger.setLevel(logging.DEBUG)
# Create file handler for detailed logs
file_handler = logging.FileHandler("mcp_server.log")
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# -------------------------------------------------
# Milvus Client
# -------------------------------------------------
class MilvusClient:
def __init__(self, username, password, host, port, database_name, collection_name):
self.username = username
self.password = password
self.host = host
self.port = port
self.database_name = database_name
self.logger = logging.getLogger(__name__)
self.connect_to_milvus()
self.collection = Collection(collection_name)
def connect_to_milvus(self):
try:
connections.connect(
user=self.username,
password=self.password,
host=self.host,
port=self.port,
db_name=self.database_name,
)
self.logger.info("Successfully connected to Milvus.")
except Exception as e:
self.logger.error(f"Failed to connect to Milvus: {e}")
raise
# -------------------------------------------------
# Database abstraction
# -------------------------------------------------
class DatabaseConnection(ABC):
@abstractmethod
def connect(self): ...
@abstractmethod
def close(self): ...
@abstractmethod
def cursor(self): ...
@abstractmethod
def commit(self): ...
@abstractmethod
def rollback(self): ...
class MySQLConnection(DatabaseConnection):
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.conn = None
self.allowed_schema: Optional[str] = None
logger.debug(
f"Creating MySQL connection with string: {connection_string[:50]}..."
)
self._parse_connection_string()
def _parse_connection_string(self):
logger.debug("Parsing connection string")
if not self.connection_string.startswith("mysql://"):
raise ValueError("Connection string must start with mysql://")
conn_str = self.connection_string[8:]
auth, rest = conn_str.rsplit("@", 1)
user, password = auth.split(":", 1) if ":" in auth else (auth, "")
host_port, database = rest.rsplit("/", 1)
self.allowed_schema = database
if ":" in host_port:
host, port = host_port.split(":", 1)
port = int(port)
else:
host, port = host_port, 3306
self.config = {
"host": host,
"port": port,
"user": user,
"password": password,
"database": database,
"charset": "utf8mb4",
"autocommit": False,
"raise_on_warnings": True,
"connection_timeout": 30, # Increased timeout for slow connections
"consume_results": True, # Important for large result sets
"use_unicode": True,
"pool_reset_session": False, # Don't reset session on pool return
}
logger.debug(f"Parsed config: {self.config}")
def connect(self):
logger.info("Connecting to MySQL database")
try:
# Create connection with timeout protection
self.conn = mysql.connector.connect(**self.config)
logger.info("Successfully connected to MySQL database")
# Verify connection is actually working
if self.conn and self.conn.is_connected():
# Set connection-level settings to prevent hanging
try:
with self.conn.cursor() as test_cursor:
test_cursor.execute("SELECT 1")
test_cursor.fetchone()
except Exception as test_error:
logger.warning(f"Connection test failed: {test_error}")
else:
raise mysql.connector.Error("Connection established but not active")
return self.conn
except mysql.connector.Error as e:
logger.error(f"Failed to connect to MySQL database: {e}")
if self.conn:
try:
self.conn.close()
except Exception:
pass
self.conn = None
raise
except Exception as e:
logger.error(f"Unexpected error during connection: {e}")
if self.conn:
try:
self.conn.close()
except Exception:
pass
self.conn = None
raise
def close(self):
logger.debug("Closing database connection")
if self.conn:
self.conn.close()
logger.debug("Database connection closed")
def cursor(self):
logger.debug("Creating cursor")
if not self.conn:
logger.warning("Connection not initialized, connecting...")
self.connect()
elif not self.conn.is_connected():
logger.warning("Connection not active, reconnecting...")
try:
self.conn.reconnect()
except Exception:
self.connect()
try:
# Create unbuffered cursor for large result sets
# This prevents loading all results into memory at once
cursor = self.conn.cursor(
dictionary=True,
buffered=False
)
logger.debug("Cursor created successfully")
return cursor
except mysql.connector.Error as e:
logger.error(f"Failed to create cursor: {e}")
# Try to reconnect once
try:
logger.info("Attempting to reconnect...")
if self.conn:
try:
self.conn.reconnect()
except Exception:
self.connect()
else:
self.connect()
return self.conn.cursor(dictionary=True, buffered=False)
except Exception as reconnect_error:
logger.error(f"Reconnection failed: {reconnect_error}")
raise
except Exception as e:
logger.error(f"Unexpected error creating cursor: {e}")
raise
def commit(self):
logger.debug("Committing transaction")
self.conn.commit()
def rollback(self):
logger.debug("Rolling back transaction")
self.conn.rollback()
def create_database_connection(conn: str) -> MySQLConnection:
logger.debug(f"Creating database connection: {conn[:50]}...")
return MySQLConnection(conn)
# -------------------------------------------------
# Advanced Timeout utilities
# -------------------------------------------------
class QueryTimeoutError(Exception):
"""Custom exception for query timeouts"""
pass
# -------------------------------------------------
# CLI / ENV
# -------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--conn", default=os.getenv("DATABASE_CONNECTION_STRING"))
parser.add_argument("--transport", default=os.getenv("MCP_TRANSPORT", "stdio"))
parser.add_argument("--host", default=os.getenv("MCP_HOST"))
parser.add_argument("--port", type=int, default=os.getenv("MCP_PORT"))
parser.add_argument("--mount", default=os.getenv("MCP_SSE_MOUNT"))
parser.add_argument(
"--readonly",
action="store_true",
help="Enable read-only mode (prevents INSERT, UPDATE, DELETE, etc.)",
)
parser.add_argument(
"--timeout",
type=int,
default=int(os.getenv("QUERY_TIMEOUT", "60")),
help="Query timeout in seconds (default: 60)",
)
parser.add_argument(
"--fetch-limit",
type=int,
default=int(os.getenv("FETCH_LIMIT", "10000")),
help="Maximum rows to fetch at once (default: 10000)",
)
# Milvus arguments
parser.add_argument("--milvus-host", default=os.getenv("MILVUS_HOST"))
parser.add_argument(
"--milvus-port", type=int, default=int(os.getenv("MILVUS_PORT", "19530"))
)
parser.add_argument("--milvus-username", default=os.getenv("MILVUS_USERNAME"))
parser.add_argument("--milvus-password", default=os.getenv("MILVUS_PASSWORD"))
parser.add_argument("--milvus-database", default=os.getenv("MILVUS_DATABASE"))
parser.add_argument("--milvus-collection", default=os.getenv("MILVUS_COLLECTION"))
args, _ = parser.parse_known_args()
CONNECTION_STRING = args.conn
READONLY = args.readonly or os.getenv("DATABASE_READONLY", "false").lower() in {
"1",
"true",
"yes",
}
TIMEOUT_SECONDS = args.timeout
FETCH_LIMIT = args.fetch_limit
# Milvus configuration
MILVUS_CONFIG = {
"host": args.milvus_host,
"port": args.milvus_port,
"username": args.milvus_username,
"password": args.milvus_password,
"database_name": args.milvus_database,
"collection_name": args.milvus_collection,
}
logger.info(f"Starting server with transport: {args.transport}")
logger.info(f"Connection string provided: {'Yes' if CONNECTION_STRING else 'No'}")
logger.info(f"Read-only mode enabled: {READONLY}")
logger.info(f"Query timeout set to: {TIMEOUT_SECONDS} seconds")
logger.info(f"Fetch limit set to: {FETCH_LIMIT} rows")
logger.info(f"Milvus host: {MILVUS_CONFIG['host']}")
logger.info(f"Milvus collection: {MILVUS_CONFIG['collection_name']}")
# Initialize Milvus client if configured
milvus_client = None
if all(MILVUS_CONFIG.values()):
try:
milvus_client = MilvusClient(**MILVUS_CONFIG)
logger.info("Milvus client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Milvus client: {e}")
else:
logger.warning("Milvus configuration incomplete - Milvus features will be disabled")
# -------------------------------------------------
# Schema enforcement
# -------------------------------------------------
def _validate_schema_access(sql: str, allowed_schema: str):
logger.debug(f"Validating schema access for SQL: {sql[:100]}...")
lowered = sql.lower()
if re.search(r"\buse\s+`?\w+`?", lowered):
logger.warning("USE statement detected - blocking access")
raise PermissionError("USE statements are not allowed")
# Only check for explicit schema.table patterns in FROM/JOIN/UPDATE/INSERT clauses
schema_patterns = [
r"\bfrom\s+`?(\w+)`?\s*\.", # FROM schema.
r"\bjoin\s+`?(\w+)`?\s*\.", # JOIN schema.
r"\bupdate\s+`?(\w+)`?\s*\.", # UPDATE schema.
r"\binsert\s+into\s+`?(\w+)`?\s*\.", # INSERT INTO schema.
]
for pattern in schema_patterns:
matches = re.findall(pattern, lowered, re.IGNORECASE)
for db in matches:
if db.lower() != allowed_schema.lower():
logger.warning(
f"Cross-schema access attempt blocked: {db} vs {allowed_schema}"
)
raise PermissionError(
f"Cross-schema access denied: {db} (allowed: {allowed_schema})"
)
logger.debug("Schema validation passed")
# -------------------------------------------------
# Connection helper
# -------------------------------------------------
def get_connection() -> MySQLConnection:
logger.debug("Getting database connection")
if not CONNECTION_STRING:
logger.error("DATABASE_CONNECTION_STRING not set")
raise RuntimeError("DATABASE_CONNECTION_STRING not set")
db = create_database_connection(CONNECTION_STRING)
db.connect()
logger.debug(f"Setting database context to: {db.allowed_schema}")
# Set up session with timeout protection
try:
with db.cursor() as cur:
# Use timeout for session setup operations
start_time = time.time()
cur.execute(f"USE `{db.allowed_schema}`")
# Set session timeout to prevent hanging
cur.execute("SET SESSION wait_timeout = 300") # 5 minutes
cur.execute("SET SESSION interactive_timeout = 300")
# SQL mode is already set in connection config, but set it here too for safety
cur.execute(
"SET SESSION sql_mode = "
"'STRICT_TRANS_TABLES,NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO'"
)
db.commit()
setup_time = time.time() - start_time
if setup_time > 1.0:
logger.warning(f"Connection setup took {setup_time:.2f} seconds")
except mysql.connector.Error as e:
logger.error(f"Failed to set up database session: {e}")
db.close()
raise
except Exception as e:
logger.error(f"Unexpected error during session setup: {e}")
db.close()
raise
logger.info(f"Successfully connected to database: {db.allowed_schema}")
return db
# -------------------------------------------------
# Enhanced Query Execution with Large Data Handling
# -------------------------------------------------
def _is_select_like(sql: str) -> bool:
first_word = sql.lstrip().split(" ", 1)[0].lower()
result = first_word in {"select", "with", "show", "values", "explain"}
logger.debug(f"Checking if query is select-like: '{sql[:30]}...' -> {result}")
return result
def _exec_query(sql, parameters, row_limit, as_json):
logger.info(f"Executing query: {sql[:100]}...")
logger.debug(
f"Parameters: {parameters}, Row limit: {row_limit}, As JSON: {as_json}"
)
db = None
try:
db = get_connection()
_validate_schema_access(sql, db.allowed_schema)
if READONLY and not _is_select_like(sql):
logger.warning("Read-only mode prevented non-SELECT query execution")
return [] if as_json else "Read-only mode enabled"
# For SELECT queries, use streaming approach for large results
if _is_select_like(sql):
# Check if this is a potentially large query
if "LIMIT" not in sql.upper() and row_limit > FETCH_LIMIT:
logger.warning(
f"Large query detected without LIMIT clause. Applying fetch limit: {FETCH_LIMIT}"
)
# Add LIMIT to prevent massive data transfer
if "ORDER BY" in sql.upper():
# Find the position to insert LIMIT
order_pos = sql.upper().find("ORDER BY")
sql = sql[:order_pos] + f" LIMIT {FETCH_LIMIT}" + sql[order_pos:]
else:
# Append LIMIT at the end
sql = sql.rstrip("; ") + f" LIMIT {FETCH_LIMIT};"
logger.info(f"Applied automatic LIMIT to query: {sql[:100]}...")
# Execute with timeout protection
start_time = time.time()
# Create cursor with timeout check
try:
cursor_start = time.time()
cur = db.cursor()
cursor_time = time.time() - cursor_start
if cursor_time > 1.0:
logger.warning(f"Cursor creation took {cursor_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to create cursor: {e}")
if db:
db.close()
raise
try:
cur.execute(sql, parameters or None)
if cur.description is None:
db.commit()
result = [] if as_json else f"Rows affected: {cur.rowcount}"
logger.debug(f"Non-query executed, rows affected: {cur.rowcount}")
return result
# Handle large result sets efficiently
rows = []
fetched_count = 0
max_rows = min(row_limit, FETCH_LIMIT) # Respect both limits
while True:
batch = cur.fetchmany(max_rows)
if not batch:
break
rows.extend(batch)
fetched_count += len(batch)
# Check if we've reached our limit
if fetched_count >= max_rows:
logger.info(f"Reached fetch limit of {max_rows} rows")
break
# Check timeout periodically
if time.time() - start_time > TIMEOUT_SECONDS:
logger.warning(
f"Query timeout after {TIMEOUT_SECONDS} seconds during fetching"
)
raise QueryTimeoutError(
f"Query timeout after {TIMEOUT_SECONDS} seconds"
)
result = (
[dict(r) for r in rows]
if as_json
else json.dumps(rows, indent=2, default=str)
)
logger.debug(f"Fetched {len(rows)} rows")
return result
finally:
# Ensure cursor is closed even if there's an error
try:
if cur and not cur.closed:
cur.close()
except Exception:
pass
except QueryTimeoutError:
logger.error(f"Query timed out after {TIMEOUT_SECONDS} seconds: {sql[:100]}...")
return [] if as_json else f"Query timeout after {TIMEOUT_SECONDS} seconds"
except Exception as e:
logger.error(f"Query execution failed: {e}")
return [] if as_json else f"Query error: {e}"
finally:
if db:
db.close()
logger.debug("Database connection closed after query execution")
# -------------------------------------------------
# MCP TOOLS - Enhanced for Large Datasets
# -------------------------------------------------
@mcp.tool()
def server_info() -> Dict[str, Any]:
"""Get server and database information. Database type: MySQL."""
logger.info("Server info requested")
return {
"name": "MySQL Database Explorer",
"database_type": "MySQL",
"readonly": READONLY,
"query_timeout": TIMEOUT_SECONDS,
"fetch_limit": FETCH_LIMIT,
"mysql_connector_version": getattr(mysql.connector, "__version__", None),
}
@mcp.tool()
def db_identity() -> Dict[str, Any]:
"""Get current database identity details. Database type: MySQL."""
logger.info("Database identity requested")
db = get_connection()
try:
with db.cursor() as cur:
cur.execute(
"SELECT DATABASE() AS database, USER() AS user, @@hostname AS host, @@port AS port"
)
row = cur.fetchone()
cur.execute("SELECT @@version AS server_version")
version = cur.fetchone()
result = {"database_type": "MySQL", **row, **version}
logger.debug(f"Database identity retrieved: {result}")
return result
finally:
db.close()
class QueryInput(BaseModel):
sql: str
parameters: Optional[List[Any]] = None
row_limit: int = 500
format: Literal["markdown", "json"] = "markdown"
@mcp.tool()
def run_query(input: QueryInput) -> str:
"""Execute SQL queries on MySQL database. Use MySQL syntax (e.g., SHOW TABLES, DESCRIBE, backticks for identifiers).
For large datasets, consider using LIMIT clauses to avoid timeouts and memory issues.
"""
logger.info(f"Run query tool called with SQL: {input.sql[:100]}...")
as_json = input.format == "json"
res = _exec_query(input.sql, input.parameters, input.row_limit, as_json)
logger.debug(f"Query result: {str(res)[:200]}...")
return json.dumps(res, default=str) if as_json else res
class QueryJSONInput(BaseModel):
sql: str
parameters: Optional[List[Any]] = None
row_limit: int = 500
@mcp.tool()
def run_query_json(input: QueryJSONInput) -> List[Dict[str, Any]]:
"""Execute SQL queries on MySQL database and return JSON. Use MySQL syntax (e.g., SHOW TABLES, DESCRIBE, backticks for identifiers).
For large datasets, consider using LIMIT clauses to avoid timeouts and memory issues.
"""
logger.info(f"Run query JSON tool called with SQL: {input.sql[:100]}...")
result = _exec_query(input.sql, input.parameters, input.row_limit, True)
logger.debug(
f"Query JSON result length: {len(result) if isinstance(result, list) else 'N/A'}"
)
return result
@mcp.tool()
def list_tables(db_schema: Optional[str] = None) -> str:
logger.info("List tables requested")
db = get_connection()
schema = db_schema or db.allowed_schema
db.close()
result = _exec_query(f"SHOW TABLES FROM `{schema}`", None, 500, False)
logger.debug(
f"Tables listed: {result[:200] if isinstance(result, str) else 'JSON response'}"
)
return result
@mcp.tool()
def describe_table(table_name: str, db_schema: Optional[str] = None) -> str:
logger.info(f"Describe table requested: {table_name}")
db = get_connection()
schema = db_schema or db.allowed_schema
db.close()
result = _exec_query(f"DESCRIBE `{schema}`.`{table_name}`", None, 500, False)
logger.debug(f"Table description completed")
return result
# -------------------------------------------------
# Additional utility tools for large datasets
# -------------------------------------------------
@mcp.tool()
def get_table_stats(table_name: str, db_schema: Optional[str] = None) -> Dict[str, Any]:
"""Get statistics about a table including row count and size."""
logger.info(f"Getting table stats for: {table_name}")
db = get_connection()
try:
schema = db_schema or db.allowed_schema
query = f"""
SELECT
table_name,
table_rows,
data_length,
index_length,
ROUND((data_length + index_length) / 1024 / 1024, 2) AS total_size_mb
FROM information_schema.tables
WHERE table_schema = %s AND table_name = %s
"""
result = _exec_query(query, [schema, table_name], 1, True)
if result and isinstance(result, list) and len(result) > 0:
return result[0]
else:
return {"error": "Table not found or no stats available"}
finally:
db.close()
@mcp.tool()
def get_table_row_count(
table_name: str, db_schema: Optional[str] = None
) -> Dict[str, Any]:
"""Get approximate row count for a table."""
logger.info(f"Getting row count for table: {table_name}")
db = get_connection()
try:
schema = db_schema or db.allowed_schema
query = f"SELECT COUNT(*) as row_count FROM `{schema}`.`{table_name}`"
result = _exec_query(query, None, 1, True)
if result and isinstance(result, list) and len(result) > 0:
return {"table": table_name, "row_count": result[0]["row_count"]}
else:
return {"error": "Could not retrieve row count"}
finally:
db.close()
# -------------------------------------------------
# Milvus MCP Tools
# -------------------------------------------------
@mcp.tool()
def fetch_column_descriptions(table_name: str) -> List[Dict[str, Any]]:
"""
Fetch sql column descriptions from Milvus collection by table name.
Args:
table_name (str): Name of the table to query in Milvus
Returns:
List[Dict[str, Any]]: List of records matching the table name
"""
if not milvus_client:
logger.error("Milvus client not initialized")
return [{"error": "Milvus client not configured"}]
try:
# Fetch all data
result = milvus_client.collection.query(
expr=f"table_name == '{table_name}'",
output_fields=["sql_column_name", "column_description", "sample_output"],
)
logger.info(
f"Fetched {len(result)} records from Milvus for table: {table_name}"
)
return [{k: v for k, v in data.items() if k != "id"} for data in result]
except Exception as e:
logger.error(f"Failed to fetch data from Milvus: {e}")
return [{"error": f"Failed to fetch data: {str(e)}"}]
# -------------------------------------------------
# Main
# -------------------------------------------------
if __name__ == "__main__":
logger.info("Starting MCP server")
if args.host:
mcp.settings.host = args.host
if args.port:
mcp.settings.port = args.port
logger.info(
f"Server configured to run on {args.host}:{args.port if args.port else 'default'}"
)
logger.info(f"Query timeout: {TIMEOUT_SECONDS} seconds")
logger.info(f"Fetch limit: {FETCH_LIMIT} rows")
mcp.run(transport=args.transport)
logger.info("MCP server stopped")