"""Example: Creating and registering a custom MCP service."""
from typing import Any, Optional, Dict
from fastmcp import Context
from mcp_sql import MCPSQLServer
from mcp_sql.tools import MCPTool
# ============================================================================
# Example 1: Simple Statistics Service
# ============================================================================
class DatabaseStatsService(MCPTool):
"""Service to get comprehensive database statistics."""
@property
def name(self) -> str:
return "get_database_stats"
@property
def description(self) -> str:
return "Get comprehensive statistics about a database"
async def execute(
self,
ctx: Context,
database: Optional[str] = None,
server_name: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
driver: Optional[str] = None,
port: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""Get database statistics.
Returns detailed statistics including table counts, row counts,
and column information for all tables in the database.
"""
# Get credentials
creds = self.creds_manager.get_from_context(
ctx, user, password, server_name, database, driver, port
)
if not creds.is_valid():
return {"error": "Missing credentials"}
if not creds.database:
return {"error": "Database name is required"}
# Get tables
tables = self.inspector.get_tables(creds)
if not tables or (len(tables) == 1 and "Error" in tables[0]):
return {"error": "Could not retrieve tables"}
# Collect statistics
stats = {
"database": database,
"server": server_name,
"table_count": len(tables),
"total_rows": 0,
"total_columns": 0,
"tables": []
}
for table in tables:
table_info = self.inspector.describe_table(creds, table)
if "error" not in table_info:
row_count = table_info.get("row_count", 0)
col_count = len(table_info.get("columns", []))
stats["total_rows"] += row_count if row_count > 0 else 0
stats["total_columns"] += col_count
stats["tables"].append({
"name": table,
"rows": row_count,
"columns": col_count,
"has_primary_key": len(table_info.get("primary_key", [])) > 0,
"foreign_keys": len(table_info.get("foreign_keys", [])),
"indexes": len(table_info.get("indexes", []))
})
return stats
# ============================================================================
# Example 2: Schema Validation Service
# ============================================================================
class ValidateSchemaService(MCPTool):
"""Service to validate database schema against best practices."""
@property
def name(self) -> str:
return "validate_schema"
@property
def description(self) -> str:
return "Validate database schema against best practices and rules"
async def execute(
self,
ctx: Context,
database: Optional[str] = None,
server_name: Optional[str] = None,
check_primary_keys: bool = True,
check_foreign_keys: bool = True,
check_indexes: bool = True,
user: Optional[str] = None,
password: Optional[str] = None,
driver: Optional[str] = None,
port: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""Validate schema against rules.
Args:
check_primary_keys: Check if all tables have primary keys
check_foreign_keys: Check foreign key relationships
check_indexes: Check if tables have appropriate indexes
"""
creds = self.creds_manager.get_from_context(
ctx, user, password, server_name, database, driver, port
)
if not creds.is_valid() or not creds.database:
return {"error": "Missing credentials or database name"}
tables = self.inspector.get_tables(creds)
violations = []
warnings = []
for table in tables:
table_info = self.inspector.describe_table(creds, table)
if "error" in table_info:
continue
# Check primary keys
if check_primary_keys:
if not table_info.get("primary_key"):
violations.append({
"table": table,
"severity": "error",
"rule": "missing_primary_key",
"message": f"Table '{table}' does not have a primary key"
})
# Check foreign keys
if check_foreign_keys:
fks = table_info.get("foreign_keys", [])
for fk in fks:
if not fk.get("name"):
warnings.append({
"table": table,
"severity": "warning",
"rule": "unnamed_foreign_key",
"message": f"Foreign key in '{table}' is not explicitly named"
})
# Check indexes
if check_indexes:
columns = table_info.get("columns", [])
indexes = table_info.get("indexes", [])
if len(columns) > 5 and len(indexes) == 0:
warnings.append({
"table": table,
"severity": "warning",
"rule": "missing_indexes",
"message": f"Table '{table}' has {len(columns)} columns but no indexes"
})
return {
"database": database,
"valid": len(violations) == 0,
"total_violations": len(violations),
"total_warnings": len(warnings),
"violations": violations,
"warnings": warnings
}
# ============================================================================
# Example 3: Comparison Service
# ============================================================================
class CompareTablesService(MCPTool):
"""Service to compare structure of two tables."""
@property
def name(self) -> str:
return "compare_tables"
@property
def description(self) -> str:
return "Compare the structure of two tables (same or different databases)"
async def execute(
self,
ctx: Context,
table1: str,
table2: str,
database1: Optional[str] = None,
database2: Optional[str] = None,
server_name: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
driver: Optional[str] = None,
port: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""Compare two tables.
If database2 is not provided, assumes both tables are in database1.
"""
creds1 = self.creds_manager.get_from_context(
ctx, user, password, server_name, database1, driver, port
)
if not creds1.is_valid():
return {"error": "Missing credentials"}
# Get info for table 1
info1 = self.inspector.describe_table(creds1, table1)
# Get info for table 2 (potentially different database)
if database2 and database2 != database1:
creds2 = self.creds_manager.get_from_context(
ctx, user, password, server_name, database2, driver, port
)
info2 = self.inspector.describe_table(creds2, table2)
else:
info2 = self.inspector.describe_table(creds1, table2)
if "error" in info1 or "error" in info2:
return {
"error": "Could not retrieve table information",
"table1_error": info1.get("error"),
"table2_error": info2.get("error")
}
# Compare columns
cols1 = {col["name"]: col for col in info1.get("columns", [])}
cols2 = {col["name"]: col for col in info2.get("columns", [])}
only_in_table1 = set(cols1.keys()) - set(cols2.keys())
only_in_table2 = set(cols2.keys()) - set(cols1.keys())
common = set(cols1.keys()) & set(cols2.keys())
type_differences = []
for col_name in common:
if cols1[col_name]["type"] != cols2[col_name]["type"]:
type_differences.append({
"column": col_name,
"table1_type": cols1[col_name]["type"],
"table2_type": cols2[col_name]["type"]
})
return {
"table1": {
"name": table1,
"database": database1,
"column_count": len(cols1)
},
"table2": {
"name": table2,
"database": database2 or database1,
"column_count": len(cols2)
},
"comparison": {
"identical_structure": (
len(only_in_table1) == 0 and
len(only_in_table2) == 0 and
len(type_differences) == 0
),
"columns_only_in_table1": list(only_in_table1),
"columns_only_in_table2": list(only_in_table2),
"common_columns": list(common),
"type_differences": type_differences
}
}
# ============================================================================
# Main: Register custom services and run server
# ============================================================================
def main():
"""Run MCP server with custom services."""
# Create server instance
server = MCPSQLServer()
print("\nđź”§ Adding custom services...")
# Add custom services
stats_service = DatabaseStatsService(
connection_manager=server.connection_manager,
credentials_manager=server.credentials_manager,
inspector=server.inspector,
executor=server.executor
)
server.add_custom_tool(stats_service)
validate_service = ValidateSchemaService(
connection_manager=server.connection_manager,
credentials_manager=server.credentials_manager,
inspector=server.inspector,
executor=server.executor
)
server.add_custom_tool(validate_service)
compare_service = CompareTablesService(
connection_manager=server.connection_manager,
credentials_manager=server.credentials_manager,
inspector=server.inspector,
executor=server.executor
)
server.add_custom_tool(compare_service)
# Run server
server.run(port=3939)
if __name__ == "__main__":
main()