mcp_postgres_server.py•3.72 kB
import argparse
from typing import Any
import uvicorn
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from mcp.server.fastmcp import FastMCP
from mcp.types import TextContent
from config import load_settings
from db import DatabaseError, PostgresConnector
settings = load_settings()
db = PostgresConnector(settings)
server = FastMCP(
name="postgres-mcp",
instructions=(
"Postgres-backed MCP server. Use run_read_query for SELECT-only tasks "
"and run_write_query for writes that must stay inside the mcp schema."
),
)
@server.tool(
name="describe_database",
description="List schemas, tables, and columns. Optionally narrow to a schema name.",
)
def describe_database(schema: str | None = None) -> dict[str, dict[str, list[dict[str, Any]]]]:
return db.describe(schema)
@server.tool(
name="run_read_query",
description="Execute a SELECT query with an automatic LIMIT if none provided.",
)
def run_read_query(sql: str, limit: int = 200) -> dict[str, Any]:
result = db.run_read_query(sql, limit=limit)
return {"rowcount": result.rowcount, "rows": result.rows}
@server.tool(
name="run_write_query",
description=(
"Execute a single write statement confined to the mcp schema. "
"Schema-qualified targets other than mcp are blocked."
),
)
def run_write_query(sql: str) -> dict[str, Any]:
result = db.run_write_query(sql)
return {"rowcount": result.rowcount, "rows": result.rows}
@server.resource("resource://postgres/status", name="status", description="Current server configuration.")
def server_status() -> str:
return (
f"Connected host={settings.db_address}:{settings.db_port} "
f"dbname={settings.db_name} user={settings.db_user} schema=mcp"
)
@server.tool(name="health_check", description="Validate the connection to Postgres.")
def health_check() -> TextContent:
try:
_ = db.describe(schema="mcp")
return TextContent(type="text", text="ok")
except DatabaseError as exc:
return TextContent(type="text", text=f"error: {exc}")
class BearerAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, app, token: str):
super().__init__(app)
self.token = token
async def dispatch(self, request, call_next):
auth = request.headers.get("authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse({"error": "missing bearer token"}, status_code=401)
supplied = auth.split(" ", 1)[1]
if supplied != self.token:
return JSONResponse({"error": "invalid token"}, status_code=401)
return await call_next(request)
def main() -> None:
parser = argparse.ArgumentParser(description="Postgres MCP server")
parser.add_argument(
"--transport",
choices=["stdio", "streamable-http"],
default="stdio",
help="Transport to use for MCP (default: stdio).",
)
parser.add_argument("--host", default="127.0.0.1", help="Host for HTTP transport (default: 127.0.0.1).")
parser.add_argument("--port", type=int, default=8000, help="Port for HTTP transport (default: 8000).")
args = parser.parse_args()
# Apply host/port for HTTP/SSE transports.
server.settings.host = args.host
server.settings.port = args.port
if args.transport == "stdio":
server.run(args.transport)
else:
app = server.streamable_http_app()
if settings.mcp_api_key:
app.add_middleware(BearerAuthMiddleware, token=settings.mcp_api_key)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()