#!/usr/bin/env python3
"""
MariaDB MCP Server - Model Context Protocol server for MariaDB database access
This MCP server provides comprehensive MariaDB database access with:
- Schema exploration and metadata inspection
- Read-only and read-write query execution modes
- Pagination support for large result sets
- Multiple response formats (JSON and Markdown)
- Connection pooling and error handling
- Security features including query validation
"""
import os
import json
import asyncio
import logging
from typing import Optional, List, Dict, Any, Union
from enum import Enum
from datetime import datetime, date
from decimal import Decimal
import mariadb
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel, Field, field_validator, ConfigDict
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
CHARACTER_LIMIT = 25000 # Maximum response size in characters
DEFAULT_ROW_LIMIT = 100 # Default maximum rows to return
MAX_ROW_LIMIT = 1000 # Maximum allowed row limit
DEFAULT_TIMEOUT = 30 # Query timeout in seconds
# Initialize MCP server
mcp = FastMCP("mariadb_mcp")
# Database connection pool
connection_pool = None
# ============================================================================
# Input Models
# ============================================================================
class ResponseFormat(str, Enum):
"""Output format for query results"""
JSON = "json"
MARKDOWN = "markdown"
class ListDatabasesInput(BaseModel):
"""Input for listing databases"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
extra='forbid'
)
response_format: ResponseFormat = Field(
default=ResponseFormat.MARKDOWN,
description="Output format: 'markdown' for human-readable, 'json' for structured data"
)
class ListTablesInput(BaseModel):
"""Input for listing tables in a database"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
extra='forbid'
)
database_name: str = Field(
...,
description="Name of the database to list tables from (e.g., 'myapp_db', 'analytics')",
min_length=1,
max_length=64
)
response_format: ResponseFormat = Field(
default=ResponseFormat.MARKDOWN,
description="Output format: 'markdown' for human-readable, 'json' for structured data"
)
class GetTableSchemaInput(BaseModel):
"""Input for getting table schema information"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
extra='forbid'
)
database_name: str = Field(
...,
description="Name of the database containing the table",
min_length=1,
max_length=64
)
table_name: str = Field(
...,
description="Name of the table to get schema for",
min_length=1,
max_length=64
)
include_indexes: bool = Field(
default=True,
description="Whether to include index information"
)
include_foreign_keys: bool = Field(
default=True,
description="Whether to include foreign key relationships"
)
response_format: ResponseFormat = Field(
default=ResponseFormat.MARKDOWN,
description="Output format: 'markdown' for human-readable, 'json' for structured data"
)
class ExecuteQueryInput(BaseModel):
"""Input for executing SQL queries"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
extra='forbid'
)
database_name: str = Field(
...,
description="Database to execute query against",
min_length=1,
max_length=64
)
query: str = Field(
...,
description="SQL query to execute. For safety, only SELECT, SHOW, DESCRIBE, and EXPLAIN are allowed in read-only mode",
min_length=1,
max_length=50000
)
parameters: Optional[List[Union[str, int, float, bool, None]]] = Field(
default=None,
description="Query parameters for parameterized queries (prevents SQL injection)"
)
limit: Optional[int] = Field(
default=DEFAULT_ROW_LIMIT,
description=f"Maximum number of rows to return (1-{MAX_ROW_LIMIT})",
ge=1,
le=MAX_ROW_LIMIT
)
offset: Optional[int] = Field(
default=0,
description="Number of rows to skip for pagination",
ge=0
)
response_format: ResponseFormat = Field(
default=ResponseFormat.MARKDOWN,
description="Output format: 'markdown' for human-readable, 'json' for structured data"
)
@field_validator('query')
def validate_query(cls, v: str) -> str:
"""Basic SQL injection prevention"""
# Remove comments
lines = v.split('\n')
cleaned_lines = []
for line in lines:
if '--' in line:
line = line[:line.index('--')]
cleaned_lines.append(line)
return '\n'.join(cleaned_lines).strip()
class GetDatabaseStatsInput(BaseModel):
"""Input for getting database statistics"""
model_config = ConfigDict(
str_strip_whitespace=True,
validate_assignment=True,
extra='forbid'
)
database_name: str = Field(
...,
description="Name of the database to get statistics for",
min_length=1,
max_length=64
)
response_format: ResponseFormat = Field(
default=ResponseFormat.MARKDOWN,
description="Output format: 'markdown' for human-readable, 'json' for structured data"
)
# ============================================================================
# Database Connection Management
# ============================================================================
def get_connection_config() -> Dict[str, Any]:
"""Get database connection configuration from environment variables"""
return {
'user': os.getenv('MARIADB_USER', 'root'),
'password': os.getenv('MARIADB_PASSWORD', ''),
'host': os.getenv('MARIADB_HOST', 'localhost'),
'port': int(os.getenv('MARIADB_PORT', '3306')),
'database': os.getenv('MARIADB_DATABASE', ''),
'autocommit': False,
'pool_name': 'mariadb_mcp_pool',
'pool_size': int(os.getenv('MARIADB_POOL_SIZE', '5')),
'pool_reset_connection': True
}
def initialize_connection_pool():
"""Initialize the MariaDB connection pool"""
global connection_pool
try:
config = get_connection_config()
# Remove database from config if empty (connect without default database)
if not config.get('database'):
config.pop('database', None)
connection_pool = mariadb.ConnectionPool(**config)
logger.info("MariaDB connection pool initialized successfully")
except mariadb.Error as e:
logger.error(f"Failed to initialize connection pool: {e}")
raise
def get_connection():
"""Get a connection from the pool"""
if not connection_pool:
initialize_connection_pool()
return connection_pool.get_connection()
def is_read_only_query(query: str) -> bool:
"""Check if a query is read-only"""
# Get first word of query (after stripping whitespace)
first_word = query.strip().split()[0].upper() if query.strip() else ""
read_only_keywords = {'SELECT', 'SHOW', 'DESCRIBE', 'DESC', 'EXPLAIN', 'WITH'}
return first_word in read_only_keywords
# ============================================================================
# Response Formatting Utilities
# ============================================================================
def format_value(value: Any) -> str:
"""Format a database value for display"""
if value is None:
return "NULL"
elif isinstance(value, (datetime, date)):
return value.isoformat()
elif isinstance(value, Decimal):
return str(value)
elif isinstance(value, bytes):
return f"<binary:{len(value)} bytes>"
else:
return str(value)
def format_results_as_markdown(columns: List[str], rows: List[tuple],
truncated: bool = False, total_count: Optional[int] = None) -> str:
"""Format query results as a Markdown table"""
if not rows:
return "*No results found*"
# Build markdown table
lines = []
# Add result count if available
if total_count is not None:
lines.append(f"**Results:** Showing {len(rows)} of {total_count} total rows\n")
elif len(rows) > 0:
lines.append(f"**Results:** {len(rows)} row(s)\n")
# Table header
lines.append("| " + " | ".join(columns) + " |")
lines.append("|" + "|".join(["---" for _ in columns]) + "|")
# Table rows
for row in rows:
formatted_values = [format_value(val) for val in row]
lines.append("| " + " | ".join(formatted_values) + " |")
if truncated:
lines.append("\n*Results truncated due to size limits. Use pagination or add filters to see more.*")
result = "\n".join(lines)
# Check character limit
if len(result) > CHARACTER_LIMIT:
# Truncate rows and rebuild
truncate_at = max(1, len(rows) // 2)
return format_results_as_markdown(columns, rows[:truncate_at], truncated=True, total_count=total_count)
return result
def format_results_as_json(columns: List[str], rows: List[tuple],
total_count: Optional[int] = None,
has_more: bool = False) -> str:
"""Format query results as JSON"""
results = []
for row in rows:
row_dict = {}
for i, col in enumerate(columns):
value = row[i]
if isinstance(value, (datetime, date)):
row_dict[col] = value.isoformat()
elif isinstance(value, Decimal):
row_dict[col] = float(value)
elif isinstance(value, bytes):
row_dict[col] = f"<binary:{len(value)} bytes>"
else:
row_dict[col] = value
results.append(row_dict)
output = {
"results": results,
"count": len(results),
"has_more": has_more
}
if total_count is not None:
output["total_count"] = total_count
json_str = json.dumps(output, indent=2, default=str)
# Check character limit
if len(json_str) > CHARACTER_LIMIT:
# Truncate results
truncate_at = max(1, len(results) // 2)
output["results"] = results[:truncate_at]
output["truncated"] = True
output["original_count"] = len(results)
return json.dumps(output, indent=2, default=str)
return json_str
# ============================================================================
# MCP Tool Implementations
# ============================================================================
@mcp.tool(
name="list_databases",
annotations={
"title": "List All Databases",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": True
}
)
async def list_databases(params: ListDatabasesInput) -> str:
"""
List all databases available in the MariaDB server.
Returns a list of database names that you have access to.
Use this to discover what databases are available before querying specific tables.
Example: When asked "What databases do I have?", use this tool to list them all.
Returns:
List of database names in the requested format (markdown or json)
"""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("SHOW DATABASES")
databases = [row[0] for row in cursor.fetchall()]
if params.response_format == ResponseFormat.JSON:
return json.dumps({
"databases": databases,
"count": len(databases)
}, indent=2)
else: # Markdown
lines = [f"## Available Databases ({len(databases)})\n"]
for db in databases:
lines.append(f"- **{db}**")
return "\n".join(lines)
except mariadb.Error as e:
error_msg = f"Database error: {str(e)}"
logger.error(error_msg)
return f"❌ {error_msg}"
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@mcp.tool(
name="list_tables",
annotations={
"title": "List Tables in Database",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": True
}
)
async def list_tables(params: ListTablesInput) -> str:
"""
List all tables in a specific database.
Use this to explore the structure of a database and find available tables.
The response includes table names and basic metadata like row counts when available.
Example: When asked "What tables are in the sales database?", use this tool.
Returns:
List of tables with metadata in the requested format
"""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
# Get tables
cursor.execute(
"SELECT TABLE_NAME, TABLE_TYPE, TABLE_ROWS, TABLE_COMMENT "
"FROM INFORMATION_SCHEMA.TABLES "
"WHERE TABLE_SCHEMA = %s "
"ORDER BY TABLE_NAME",
(params.database_name,)
)
tables = cursor.fetchall()
if not tables:
return f"No tables found in database '{params.database_name}'"
if params.response_format == ResponseFormat.JSON:
table_list = []
for name, table_type, rows, comment in tables:
table_list.append({
"name": name,
"type": table_type,
"estimated_rows": rows,
"comment": comment or ""
})
return json.dumps({
"database": params.database_name,
"tables": table_list,
"count": len(table_list)
}, indent=2)
else: # Markdown
lines = [f"## Tables in `{params.database_name}` ({len(tables)} total)\n"]
for name, table_type, rows, comment in tables:
lines.append(f"### `{name}`")
lines.append(f"- **Type:** {table_type}")
if rows is not None:
lines.append(f"- **Estimated Rows:** {rows:,}")
if comment:
lines.append(f"- **Comment:** {comment}")
lines.append("")
return "\n".join(lines)
except mariadb.Error as e:
error_msg = f"Database error: {str(e)}"
logger.error(error_msg)
return f"❌ {error_msg}"
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@mcp.tool(
name="get_table_schema",
annotations={
"title": "Get Table Schema Details",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": True
}
)
async def get_table_schema(params: GetTableSchemaInput) -> str:
"""
Get detailed schema information for a specific table.
Returns comprehensive information about table structure including:
- Column names, types, and constraints
- Primary keys and unique constraints
- Indexes (optional)
- Foreign key relationships (optional)
Use this to understand table structure before writing queries.
Example: When asked "What columns does the users table have?", use this tool.
Returns:
Detailed schema information in the requested format
"""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
# Get column information
cursor.execute(
"""
SELECT
COLUMN_NAME,
COLUMN_TYPE,
IS_NULLABLE,
COLUMN_KEY,
COLUMN_DEFAULT,
EXTRA,
COLUMN_COMMENT
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION
""",
(params.database_name, params.table_name)
)
columns = cursor.fetchall()
if not columns:
return f"Table '{params.table_name}' not found in database '{params.database_name}'"
# Get indexes if requested
indexes = []
if params.include_indexes:
cursor.execute(
"""
SELECT
INDEX_NAME,
NON_UNIQUE,
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) as COLUMNS
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
GROUP BY INDEX_NAME, NON_UNIQUE
""",
(params.database_name, params.table_name)
)
indexes = cursor.fetchall()
# Get foreign keys if requested
foreign_keys = []
if params.include_foreign_keys:
cursor.execute(
"""
SELECT
CONSTRAINT_NAME,
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND REFERENCED_TABLE_NAME IS NOT NULL
""",
(params.database_name, params.table_name)
)
foreign_keys = cursor.fetchall()
if params.response_format == ResponseFormat.JSON:
schema = {
"database": params.database_name,
"table": params.table_name,
"columns": [],
"indexes": [],
"foreign_keys": []
}
for col in columns:
schema["columns"].append({
"name": col[0],
"type": col[1],
"nullable": col[2] == "YES",
"key": col[3] or "",
"default": col[4],
"extra": col[5] or "",
"comment": col[6] or ""
})
if params.include_indexes:
for idx in indexes:
schema["indexes"].append({
"name": idx[0],
"unique": not idx[1],
"columns": idx[2].split(',')
})
if params.include_foreign_keys:
for fk in foreign_keys:
schema["foreign_keys"].append({
"constraint": fk[0],
"column": fk[1],
"references_table": fk[2],
"references_column": fk[3]
})
return json.dumps(schema, indent=2, default=str)
else: # Markdown
lines = [f"## Schema for `{params.database_name}`.`{params.table_name}`\n"]
# Columns section
lines.append("### Columns\n")
lines.append("| Column | Type | Nullable | Key | Default | Extra |")
lines.append("|--------|------|----------|-----|---------|-------|")
for col in columns:
nullable = "✓" if col[2] == "YES" else "✗"
key = col[3] or "-"
default = col[4] if col[4] is not None else "NULL"
extra = col[5] or "-"
lines.append(f"| `{col[0]}` | {col[1]} | {nullable} | {key} | {default} | {extra} |")
# Indexes section
if params.include_indexes and indexes:
lines.append("\n### Indexes\n")
for idx in indexes:
unique = "UNIQUE" if not idx[1] else "INDEX"
lines.append(f"- **{idx[0]}** ({unique}): {idx[2]}")
# Foreign keys section
if params.include_foreign_keys and foreign_keys:
lines.append("\n### Foreign Keys\n")
for fk in foreign_keys:
lines.append(f"- **{fk[0]}**: `{fk[1]}` → `{fk[2]}`.`{fk[3]}`")
return "\n".join(lines)
except mariadb.Error as e:
error_msg = f"Database error: {str(e)}"
logger.error(error_msg)
return f"❌ {error_msg}"
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@mcp.tool(
name="execute_query",
annotations={
"title": "Execute SQL Query",
"readOnlyHint": False,
"destructiveHint": False,
"idempotentHint": False,
"openWorldHint": True
}
)
async def execute_query(params: ExecuteQueryInput) -> str:
"""
Execute a SQL query against the database.
Supports:
- SELECT queries for data retrieval
- SHOW, DESCRIBE, EXPLAIN for metadata
- INSERT, UPDATE, DELETE (when write mode is enabled)
- Parameterized queries for security
- Pagination with limit and offset
Security: Use parameterized queries to prevent SQL injection.
Examples:
- "Execute SELECT * FROM users WHERE age > 18 LIMIT 10"
- "Run query to find top selling products"
Returns:
Query results in the requested format with pagination info
"""
conn = None
cursor = None
try:
# Check if query is allowed
read_only_mode = os.getenv('MARIADB_READ_ONLY', 'true').lower() == 'true'
if read_only_mode and not is_read_only_query(params.query):
return "❌ Error: Only read-only queries (SELECT, SHOW, DESCRIBE, EXPLAIN) are allowed in read-only mode"
conn = get_connection()
cursor = conn.cursor()
# Switch to the specified database
cursor.execute(f"USE `{params.database_name}`")
# Add LIMIT clause if not present and it's a SELECT query
query = params.query
if query.strip().upper().startswith('SELECT') and 'LIMIT' not in query.upper():
query += f" LIMIT {params.limit}"
if params.offset:
query += f" OFFSET {params.offset}"
# Execute query with timeout
conn.timeout = DEFAULT_TIMEOUT
if params.parameters:
cursor.execute(query, params.parameters)
else:
cursor.execute(query)
# Handle different query types
if query.strip().upper().startswith('SELECT') or query.strip().upper().startswith('SHOW'):
# Fetch results
results = cursor.fetchall()
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Try to get total count (for pagination)
total_count = None
if query.strip().upper().startswith('SELECT'):
try:
count_query = f"SELECT COUNT(*) FROM ({params.query}) as subquery"
cursor.execute(count_query)
total_count = cursor.fetchone()[0]
except:
pass # Count query failed, continue without it
has_more = len(results) == params.limit if params.limit else False
if params.response_format == ResponseFormat.JSON:
return format_results_as_json(columns, results, total_count, has_more)
else:
return format_results_as_markdown(columns, results, False, total_count)
else:
# For non-SELECT queries, return affected rows
affected = cursor.rowcount
conn.commit()
if params.response_format == ResponseFormat.JSON:
return json.dumps({
"success": True,
"affected_rows": affected,
"message": f"Query executed successfully. {affected} row(s) affected."
}, indent=2)
else:
return f"✅ Query executed successfully. **{affected}** row(s) affected."
except mariadb.Error as e:
error_msg = f"Database error: {str(e)}"
logger.error(error_msg)
if conn:
conn.rollback()
return f"❌ {error_msg}"
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logger.error(error_msg)
return f"❌ {error_msg}"
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@mcp.tool(
name="get_database_stats",
annotations={
"title": "Get Database Statistics",
"readOnlyHint": True,
"destructiveHint": False,
"idempotentHint": True,
"openWorldHint": True
}
)
async def get_database_stats(params: GetDatabaseStatsInput) -> str:
"""
Get statistics and metadata about a database.
Returns information about:
- Total number of tables
- Database size
- Character set and collation
- Table statistics (rows, size)
Use this to understand database usage and performance characteristics.
Example: When asked "How large is the analytics database?", use this tool.
Returns:
Database statistics in the requested format
"""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
# Get database size
cursor.execute(
"""
SELECT
SUM(DATA_LENGTH + INDEX_LENGTH) as SIZE_BYTES,
COUNT(DISTINCT TABLE_NAME) as TABLE_COUNT
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = %s
""",
(params.database_name,)
)
size_info = cursor.fetchone()
# Get character set and collation
cursor.execute(
"""
SELECT
DEFAULT_CHARACTER_SET_NAME,
DEFAULT_COLLATION_NAME
FROM INFORMATION_SCHEMA.SCHEMATA
WHERE SCHEMA_NAME = %s
""",
(params.database_name,)
)
charset_info = cursor.fetchone()
# Get top tables by size
cursor.execute(
"""
SELECT
TABLE_NAME,
TABLE_ROWS,
DATA_LENGTH + INDEX_LENGTH as SIZE_BYTES,
CREATE_TIME,
UPDATE_TIME
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = %s
AND TABLE_TYPE = 'BASE TABLE'
ORDER BY DATA_LENGTH + INDEX_LENGTH DESC
LIMIT 10
""",
(params.database_name,)
)
top_tables = cursor.fetchall()
if params.response_format == ResponseFormat.JSON:
stats = {
"database": params.database_name,
"total_size_bytes": size_info[0] or 0,
"total_size_mb": round((size_info[0] or 0) / 1024 / 1024, 2),
"table_count": size_info[1] or 0,
"character_set": charset_info[0] if charset_info else None,
"collation": charset_info[1] if charset_info else None,
"top_tables": []
}
for table in top_tables:
stats["top_tables"].append({
"name": table[0],
"rows": table[1] or 0,
"size_bytes": table[2] or 0,
"size_mb": round((table[2] or 0) / 1024 / 1024, 2),
"created": table[3].isoformat() if table[3] else None,
"updated": table[4].isoformat() if table[4] else None
})
return json.dumps(stats, indent=2, default=str)
else: # Markdown
lines = [f"## Database Statistics: `{params.database_name}`\n"]
if size_info:
size_mb = (size_info[0] or 0) / 1024 / 1024
lines.append(f"**Total Size:** {size_mb:.2f} MB")
lines.append(f"**Total Tables:** {size_info[1] or 0}")
if charset_info:
lines.append(f"**Character Set:** {charset_info[0]}")
lines.append(f"**Collation:** {charset_info[1]}")
if top_tables:
lines.append("\n### Top Tables by Size\n")
lines.append("| Table | Rows | Size (MB) | Created | Updated |")
lines.append("|-------|------|-----------|---------|---------|")
for table in top_tables:
name = table[0]
rows = f"{table[1]:,}" if table[1] else "0"
size_mb = (table[2] or 0) / 1024 / 1024
created = table[3].strftime("%Y-%m-%d") if table[3] else "-"
updated = table[4].strftime("%Y-%m-%d") if table[4] else "-"
lines.append(f"| `{name}` | {rows} | {size_mb:.2f} | {created} | {updated} |")
return "\n".join(lines)
except mariadb.Error as e:
error_msg = f"Database error: {str(e)}"
logger.error(error_msg)
return f"❌ {error_msg}"
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
# Initialize connection pool and run the MCP server
import sys
logger.info("Starting MariaDB MCP Server...")
try:
initialize_connection_pool()
logger.info("MariaDB MCP Server started successfully")
except Exception as e:
logger.error(f"Failed to start server: {e}")
sys.exit(1)
try:
mcp.run()
except KeyboardInterrupt:
logger.info("Server interrupted by user")
except Exception as e:
logger.error(f"Server error: {e}")
finally:
# Clean up resources on server shutdown
logger.info("Shutting down MariaDB MCP Server...")
if connection_pool:
try:
connection_pool.close()
logger.info("Connection pool closed")
except Exception as e:
logger.error(f"Error closing connection pool: {e}")
sys.exit(0)