# ruff: noqa: B008
import argparse
import asyncio
import logging
import os
import signal
import sys
from enum import Enum
from typing import Any
from typing import List
from typing import Literal
from typing import Union
import mcp.types as types
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from pydantic import Field
from pydantic import validate_call
# Load environment variables from .env file if present
# This allows users to configure the server via .env file
load_dotenv()
from postgres_mcp.index.dta_calc import DatabaseTuningAdvisor # noqa: E402
from .artifacts import ErrorResult # noqa: E402
from .artifacts import ExplainPlanArtifact # noqa: E402
from .database_health import DatabaseHealthTool # noqa: E402
from .database_health import HealthType # noqa: E402
from .explain import ExplainPlanTool # noqa: E402
from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES # noqa: E402
from .index.llm_opt import LLMOptimizerTool # noqa: E402
from .index.presentation import TextPresentation # noqa: E402
from .sql import DbConnPool # noqa: E402
from .sql import SafeSqlDriver # noqa: E402
from .sql import SqlDriver # noqa: E402
from .sql import check_hypopg_installation_status # noqa: E402
from .sql import obfuscate_password # noqa: E402
from .top_queries import TopQueriesCalc # noqa: E402
# Initialize FastMCP with default settings
mcp = FastMCP("postgres-mcp")
# Constants
PG_STAT_STATEMENTS = "pg_stat_statements"
HYPOPG_EXTENSION = "hypopg"
DEFAULT_QUERY_TIMEOUT = 30 # Default timeout in seconds for restricted mode
DEFAULT_SSE_HOST = "localhost"
DEFAULT_SSE_PORT = 8000
DEFAULT_SSE_PATH = "/sse"
class HypotheticalIndex(BaseModel):
"""Schema for hypothetical index definition used in explain_query."""
table: str = Field(description="The table name to add the index to (e.g., 'users')")
columns: list[str] = Field(description="List of column names to include in the index (e.g., ['email'] or ['last_name', 'first_name'])")
using: str = Field(default="btree", description="Index method (default: 'btree', other options include 'hash', 'gist', 'gin', 'brin')")
ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource]
logger = logging.getLogger(__name__)
class AccessMode(str, Enum):
"""SQL access modes for the server."""
UNRESTRICTED = "unrestricted" # Unrestricted access
RESTRICTED = "restricted" # Read-only with safety features
# Global variables
db_connection = DbConnPool()
current_access_mode = AccessMode.UNRESTRICTED
current_query_timeout = DEFAULT_QUERY_TIMEOUT
shutdown_in_progress = False
async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:
"""Get the appropriate SQL driver based on the current access mode."""
base_driver = SqlDriver(conn=db_connection)
if current_access_mode == AccessMode.RESTRICTED:
logger.debug(f"Using SafeSqlDriver with restrictions (RESTRICTED mode, timeout={current_query_timeout}s)")
return SafeSqlDriver(sql_driver=base_driver, timeout=current_query_timeout)
else:
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
return base_driver
def format_text_response(text: Any) -> ResponseType:
"""Format a text response."""
return [types.TextContent(type="text", text=str(text))]
def format_error_response(error: str) -> ResponseType:
"""Format an error response."""
return format_text_response(f"Error: {error}")
@mcp.tool(description="List all schemas in the database")
async def list_schemas() -> ResponseType:
"""List all schemas in the database."""
try:
sql_driver = await get_sql_driver()
rows = await sql_driver.execute_query(
"""
SELECT
schema_name,
schema_owner,
CASE
WHEN schema_name LIKE 'pg_%' THEN 'System Schema'
WHEN schema_name = 'information_schema' THEN 'System Information Schema'
ELSE 'User Schema'
END as schema_type
FROM information_schema.schemata
ORDER BY schema_type, schema_name
"""
)
schemas = [row.cells for row in rows] if rows else []
return format_text_response(schemas)
except Exception as e:
logger.error(f"Error listing schemas: {e}")
return format_error_response(str(e))
@mcp.tool(description="List objects in a schema")
async def list_objects(
schema_name: str = Field(description="Schema name"),
object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"),
) -> ResponseType:
"""List objects of a given type in a schema."""
try:
sql_driver = await get_sql_driver()
if object_type in ("table", "view"):
table_type = "BASE TABLE" if object_type == "table" else "VIEW"
rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT
t.table_schema,
t.table_name,
t.table_type,
obj_description((quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::regclass, 'pg_class') AS comment
FROM information_schema.tables t
WHERE t.table_schema = {} AND t.table_type = {}
ORDER BY t.table_name
""",
[schema_name, table_type],
)
objects = (
[
{
"schema": row.cells["table_schema"],
"name": row.cells["table_name"],
"type": row.cells["table_type"],
"comment": row.cells["comment"],
}
for row in rows
]
if rows
else []
)
elif object_type == "sequence":
rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT sequence_schema, sequence_name, data_type
FROM information_schema.sequences
WHERE sequence_schema = {}
ORDER BY sequence_name
""",
[schema_name],
)
objects = (
[{"schema": row.cells["sequence_schema"], "name": row.cells["sequence_name"], "data_type": row.cells["data_type"]} for row in rows]
if rows
else []
)
elif object_type == "extension":
# Extensions are not schema-specific
rows = await sql_driver.execute_query(
"""
SELECT extname, extversion, extrelocatable
FROM pg_extension
ORDER BY extname
"""
)
objects = (
[{"name": row.cells["extname"], "version": row.cells["extversion"], "relocatable": row.cells["extrelocatable"]} for row in rows]
if rows
else []
)
else:
return format_error_response(f"Unsupported object type: {object_type}")
return format_text_response(objects)
except Exception as e:
logger.error(f"Error listing objects: {e}")
return format_error_response(str(e))
@mcp.tool(description="Show detailed information about a database object")
async def get_object_details(
schema_name: str = Field(description="Schema name"),
object_name: str = Field(description="Object name"),
object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"),
) -> ResponseType:
"""Get detailed information about a database object."""
try:
sql_driver = await get_sql_driver()
if object_type in ("table", "view"):
# Get columns with comments
col_rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
col_description(
(quote_ident(c.table_schema) || '.' || quote_ident(c.table_name))::regclass,
c.ordinal_position
) AS comment
FROM information_schema.columns c
WHERE c.table_schema = {} AND c.table_name = {}
ORDER BY c.ordinal_position
""",
[schema_name, object_name],
)
columns = (
[
{
"column": r.cells["column_name"],
"data_type": r.cells["data_type"],
"is_nullable": r.cells["is_nullable"],
"default": r.cells["column_default"],
"comment": r.cells["comment"],
}
for r in col_rows
]
if col_rows
else []
)
# Get constraints
con_rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT tc.constraint_name, tc.constraint_type, kcu.column_name
FROM information_schema.table_constraints AS tc
LEFT JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.table_schema = {} AND tc.table_name = {}
""",
[schema_name, object_name],
)
constraints = {}
if con_rows:
for row in con_rows:
cname = row.cells["constraint_name"]
ctype = row.cells["constraint_type"]
col = row.cells["column_name"]
if cname not in constraints:
constraints[cname] = {"type": ctype, "columns": []}
if col:
constraints[cname]["columns"].append(col)
constraints_list = [{"name": name, **data} for name, data in constraints.items()]
# Get indexes
idx_rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = {} AND tablename = {}
""",
[schema_name, object_name],
)
indexes = [{"name": r.cells["indexname"], "definition": r.cells["indexdef"]} for r in idx_rows] if idx_rows else []
# Get table/view comment
comment_rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT obj_description(
(quote_ident({}) || '.' || quote_ident({}))::regclass,
'pg_class'
) AS comment
""",
[schema_name, object_name],
)
table_comment = comment_rows[0].cells["comment"] if comment_rows else None
result = {
"basic": {"schema": schema_name, "name": object_name, "type": object_type, "comment": table_comment},
"columns": columns,
"constraints": constraints_list,
"indexes": indexes,
}
elif object_type == "sequence":
rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT sequence_schema, sequence_name, data_type, start_value, increment
FROM information_schema.sequences
WHERE sequence_schema = {} AND sequence_name = {}
""",
[schema_name, object_name],
)
if rows and rows[0]:
row = rows[0]
result = {
"schema": row.cells["sequence_schema"],
"name": row.cells["sequence_name"],
"data_type": row.cells["data_type"],
"start_value": row.cells["start_value"],
"increment": row.cells["increment"],
}
else:
result = {}
elif object_type == "extension":
rows = await SafeSqlDriver.execute_param_query(
sql_driver,
"""
SELECT extname, extversion, extrelocatable
FROM pg_extension
WHERE extname = {}
""",
[object_name],
)
if rows and rows[0]:
row = rows[0]
result = {"name": row.cells["extname"], "version": row.cells["extversion"], "relocatable": row.cells["extrelocatable"]}
else:
result = {}
else:
return format_error_response(f"Unsupported object type: {object_type}")
return format_text_response(result)
except Exception as e:
logger.error(f"Error getting object details: {e}")
return format_error_response(str(e))
@mcp.tool(description="Explains the execution plan for a SQL query, showing how the database will execute it and provides detailed cost estimates.")
async def explain_query(
sql: str = Field(description="SQL query to explain"),
analyze: bool = Field(
description="When True, actually runs the query to show real execution statistics instead of estimates. "
"Takes longer but provides more accurate information.",
default=False,
),
hypothetical_indexes: list[HypotheticalIndex] = Field(
description="A list of hypothetical indexes to simulate. Each index requires 'table' and 'columns' fields, "
"with optional 'using' field for index method (default: 'btree'). "
"Example: [{'table': 'users', 'columns': ['email']}, {'table': 'orders', 'columns': ['user_id', 'created_at']}]",
default=[],
),
) -> ResponseType:
"""
Explains the execution plan for a SQL query.
Args:
sql: The SQL query to explain
analyze: When True, actually runs the query for real statistics
hypothetical_indexes: Optional list of indexes to simulate
"""
try:
sql_driver = await get_sql_driver()
explain_tool = ExplainPlanTool(sql_driver=sql_driver)
result: ExplainPlanArtifact | ErrorResult | None = None
# If hypothetical indexes are specified, check for HypoPG extension
if hypothetical_indexes and len(hypothetical_indexes) > 0:
if analyze:
return format_error_response("Cannot use analyze and hypothetical indexes together")
try:
# Use the common utility function to check if hypopg is installed
(
is_hypopg_installed,
hypopg_message,
) = await check_hypopg_installation_status(sql_driver)
# If hypopg is not installed, return the message
if not is_hypopg_installed:
return format_text_response(hypopg_message)
# HypoPG is installed, proceed with explaining with hypothetical indexes
# Convert Pydantic models to dicts for the explain tool
indexes_as_dicts = [idx.model_dump() for idx in hypothetical_indexes]
result = await explain_tool.explain_with_hypothetical_indexes(sql, indexes_as_dicts)
except Exception:
raise # Re-raise the original exception
elif analyze:
try:
# Use EXPLAIN ANALYZE
result = await explain_tool.explain_analyze(sql)
except Exception:
raise # Re-raise the original exception
else:
try:
# Use basic EXPLAIN
result = await explain_tool.explain(sql)
except Exception:
raise # Re-raise the original exception
if result and isinstance(result, ExplainPlanArtifact):
return format_text_response(result.to_text())
else:
error_message = "Error processing explain plan"
if isinstance(result, ErrorResult):
error_message = result.to_text()
return format_error_response(error_message)
except Exception as e:
logger.error(f"Error explaining query: {e}")
return format_error_response(str(e))
# Query function declaration without the decorator - we'll add it dynamically based on access mode
async def execute_sql(
sql: str = Field(description="SQL to run", default="all"),
) -> ResponseType:
"""Executes a SQL query against the database."""
try:
sql_driver = await get_sql_driver()
rows = await sql_driver.execute_query(sql) # type: ignore
if rows is None:
return format_text_response("No results")
return format_text_response(list([r.cells for r in rows]))
except Exception as e:
logger.error(f"Error executing query: {e}")
return format_error_response(str(e))
@mcp.tool(description="Analyze frequently executed queries in the database and recommend optimal indexes")
@validate_call
async def analyze_workload_indexes(
max_index_size_mb: int = Field(description="Max index size in MB", default=10000),
method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"),
) -> ResponseType:
"""Analyze frequently executed queries in the database and recommend optimal indexes."""
try:
sql_driver = await get_sql_driver()
if method == "dta":
index_tuning = DatabaseTuningAdvisor(sql_driver)
else:
index_tuning = LLMOptimizerTool(sql_driver)
dta_tool = TextPresentation(sql_driver, index_tuning)
result = await dta_tool.analyze_workload(max_index_size_mb=max_index_size_mb)
return format_text_response(result)
except Exception as e:
logger.error(f"Error analyzing workload: {e}")
return format_error_response(str(e))
@mcp.tool(description="Analyze a list of (up to 10) SQL queries and recommend optimal indexes")
@validate_call
async def analyze_query_indexes(
queries: list[str] = Field(description="List of Query strings to analyze"),
max_index_size_mb: int = Field(description="Max index size in MB", default=10000),
method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"),
) -> ResponseType:
"""Analyze a list of SQL queries and recommend optimal indexes."""
if len(queries) == 0:
return format_error_response("Please provide a non-empty list of queries to analyze.")
if len(queries) > MAX_NUM_INDEX_TUNING_QUERIES:
return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.")
try:
sql_driver = await get_sql_driver()
if method == "dta":
index_tuning = DatabaseTuningAdvisor(sql_driver)
else:
index_tuning = LLMOptimizerTool(sql_driver)
dta_tool = TextPresentation(sql_driver, index_tuning)
result = await dta_tool.analyze_queries(queries=queries, max_index_size_mb=max_index_size_mb)
return format_text_response(result)
except Exception as e:
logger.error(f"Error analyzing queries: {e}")
return format_error_response(str(e))
@mcp.tool(
description="Analyzes database health. Here are the available health checks:\n"
"- index - checks for invalid, duplicate, and bloated indexes\n"
"- connection - checks the number of connection and their utilization\n"
"- vacuum - checks vacuum health for transaction id wraparound\n"
"- sequence - checks sequences at risk of exceeding their maximum value\n"
"- replication - checks replication health including lag and slots\n"
"- buffer - checks for buffer cache hit rates for indexes and tables\n"
"- constraint - checks for invalid constraints\n"
"- all - runs all checks\n"
"You can optionally specify a single health check or a comma-separated list of health checks. The default is 'all' checks."
)
async def analyze_db_health(
health_type: str = Field(
description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.",
default="all",
),
) -> ResponseType:
"""Analyze database health for specified components.
Args:
health_type: Comma-separated list of health check types to perform.
Valid values: index, connection, vacuum, sequence, replication, buffer, constraint, all
"""
health_tool = DatabaseHealthTool(await get_sql_driver())
result = await health_tool.health(health_type=health_type)
return format_text_response(result)
@mcp.tool(
name="get_top_queries",
description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.",
)
async def get_top_queries(
sort_by: str = Field(
description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' "
"for resource-intensive queries",
default="resources",
),
limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10),
) -> ResponseType:
try:
sql_driver = await get_sql_driver()
top_queries_tool = TopQueriesCalc(sql_driver=sql_driver)
if sort_by == "resources":
result = await top_queries_tool.get_top_resource_queries()
return format_text_response(result)
elif sort_by == "mean_time" or sort_by == "total_time":
# Map the sort_by values to what get_top_queries_by_time expects
result = await top_queries_tool.get_top_queries_by_time(limit=limit, sort_by="mean" if sort_by == "mean_time" else "total")
else:
return format_error_response("Invalid sort criteria. Please use 'resources' or 'mean_time' or 'total_time'.")
return format_text_response(result)
except Exception as e:
logger.error(f"Error getting slow queries: {e}")
return format_error_response(str(e))
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="PostgreSQL MCP Server")
parser.add_argument("database_url", help="Database connection URL", nargs="?")
parser.add_argument(
"--access-mode",
type=str,
choices=[mode.value for mode in AccessMode],
default=AccessMode.UNRESTRICTED.value,
help="Set SQL access mode: unrestricted (unrestricted) or restricted (read-only with protections)",
)
parser.add_argument(
"--transport",
type=str,
choices=["stdio", "sse"],
default="stdio",
help="Select MCP transport: stdio (default) or sse",
)
parser.add_argument(
"--sse-host",
type=str,
default=None,
help=f"Host to bind SSE server to (default: {DEFAULT_SSE_HOST}). Can also be set via SSE_HOST env var.",
)
parser.add_argument(
"--sse-port",
type=int,
default=None,
help=f"Port for SSE server (default: {DEFAULT_SSE_PORT}). Can also be set via SSE_PORT env var.",
)
parser.add_argument(
"--sse-path",
type=str,
default=None,
help=f"Path for SSE endpoint (default: {DEFAULT_SSE_PATH}). Can also be set via SSE_PATH env var.",
)
parser.add_argument(
"--cors-allow-origins",
type=str,
default=None,
help="Comma-separated list of allowed CORS origins for SSE transport (e.g., 'http://localhost:3000,https://example.com'). "
"Can also be set via CORS_ALLOW_ORIGINS env var. Use '*' to allow all origins (not recommended for production).",
)
parser.add_argument(
"--query-timeout",
type=int,
default=None,
help=f"Query timeout in seconds for restricted mode (default: {DEFAULT_QUERY_TIMEOUT}). Can also be set via QUERY_TIMEOUT env var.",
)
args = parser.parse_args()
# Store the access mode in the global variable
global current_access_mode
current_access_mode = AccessMode(args.access_mode)
# Set query timeout from CLI argument or environment variable
global current_query_timeout
if args.query_timeout is not None:
current_query_timeout = args.query_timeout
else:
env_timeout = os.environ.get("QUERY_TIMEOUT")
if env_timeout is not None:
try:
current_query_timeout = int(env_timeout)
except ValueError:
logger.warning(f"Invalid QUERY_TIMEOUT value '{env_timeout}', using default {DEFAULT_QUERY_TIMEOUT}")
current_query_timeout = DEFAULT_QUERY_TIMEOUT
# Add the query tool with a description appropriate to the access mode
if current_access_mode == AccessMode.UNRESTRICTED:
mcp.add_tool(execute_sql, description="Execute any SQL query")
else:
mcp.add_tool(execute_sql, description="Execute a read-only SQL query")
logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode")
# Get database URL from environment variable or command line
database_url = os.environ.get("DATABASE_URI", args.database_url)
if not database_url:
raise ValueError(
"Error: No database URL provided. Please specify via 'DATABASE_URI' environment variable or command-line argument.",
)
# Initialize database connection pool
try:
await db_connection.pool_connect(database_url)
logger.info("Successfully connected to database and initialized connection pool")
except Exception as e:
logger.warning(
f"Could not connect to database: {obfuscate_password(str(e))}",
)
logger.warning(
"The MCP server will start but database operations will fail until a valid connection is established.",
)
# Set up proper shutdown handling
try:
loop = asyncio.get_running_loop()
signals = (signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s)))
except NotImplementedError:
# Windows doesn't support signals properly
logger.warning("Signal handling not supported on Windows")
pass
# Run the server with the selected transport (always async)
if args.transport == "stdio":
await mcp.run_stdio_async()
else:
# Set SSE host from CLI argument or environment variable
sse_host = args.sse_host
if sse_host is None:
sse_host = os.environ.get("SSE_HOST", DEFAULT_SSE_HOST)
# Set SSE port from CLI argument or environment variable
sse_port = args.sse_port
if sse_port is None:
env_port = os.environ.get("SSE_PORT")
if env_port is not None:
try:
sse_port = int(env_port)
except ValueError:
logger.warning(f"Invalid SSE_PORT value '{env_port}', using default {DEFAULT_SSE_PORT}")
sse_port = DEFAULT_SSE_PORT
else:
sse_port = DEFAULT_SSE_PORT
# Set SSE path from CLI argument or environment variable
sse_path = args.sse_path
if sse_path is None:
sse_path = os.environ.get("SSE_PATH", DEFAULT_SSE_PATH)
# Get CORS allowed origins from CLI argument or environment variable
cors_origins_str = args.cors_allow_origins
if cors_origins_str is None:
cors_origins_str = os.environ.get("CORS_ALLOW_ORIGINS")
# If CORS is configured, run with custom Starlette app
if cors_origins_str:
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Mount
# Parse CORS origins
if cors_origins_str == "*":
cors_origins = ["*"]
else:
cors_origins = [origin.strip() for origin in cors_origins_str.split(",")]
logger.info(f"Enabling CORS for origins: {cors_origins}")
# Create Starlette app with CORS middleware wrapping the SSE app
middleware = [
Middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
]
# Update FastMCP settings for the SSE path
mcp.settings.sse_path = sse_path
app = Starlette(
routes=[Mount("/", app=mcp.sse_app())],
middleware=middleware,
)
logger.info(f"Starting SSE server with CORS on {sse_host}:{sse_port}{sse_path}")
config = uvicorn.Config(app, host=sse_host, port=sse_port, log_level="info")
server = uvicorn.Server(config)
await server.serve()
else:
# Update FastMCP settings based on command line arguments or env vars
mcp.settings.host = sse_host
mcp.settings.port = sse_port
mcp.settings.sse_path = sse_path
logger.info(f"Starting SSE server on {sse_host}:{sse_port}{sse_path}")
await mcp.run_sse_async()
async def shutdown(sig=None):
"""Clean shutdown of the server."""
global shutdown_in_progress
if shutdown_in_progress:
logger.warning("Forcing immediate exit")
# Use sys.exit instead of os._exit to allow for proper cleanup
sys.exit(1)
shutdown_in_progress = True
if sig:
logger.info(f"Received exit signal {sig.name}")
# Close database connections
try:
await db_connection.close()
logger.info("Closed database connections")
except Exception as e:
logger.error(f"Error closing database connections: {e}")
# Exit with appropriate status code
sys.exit(128 + sig if sig is not None else 0)