main_b.py•6.22 kB
import os
import logging
import re
import uvicorn
from typing import Any, Dict, List
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dotenv import load_dotenv
# --- Load env ---
load_dotenv()
# --- Postgres setup ---
PG_URL = os.getenv("DATABASE_URL")
if not PG_URL:
DB_USER = os.getenv("POSTGRES_USER", "postgres")
DB_PASS = os.getenv("POSTGRES_PASSWORD", "")
DB_HOST = os.getenv("POSTGRES_HOST", "localhost")
DB_PORT = os.getenv("POSTGRES_PORT", "5432")
DB_NAME = os.getenv("POSTGRES_DB", "postgres")
PG_URL = f"postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
pg_engine = create_engine(PG_URL, pool_size=5, max_overflow=10)
# --- Mongo setup ---
MONGO_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
MONGO_DBNAME = os.getenv("MONGODB_DB", "test")
mongo_client = MongoClient(MONGO_URL)
mongo_db = mongo_client[MONGO_DBNAME]
# Configure logging
logging.basicConfig(
filename="database_mcp.log", # log file
level=logging.INFO, # DEBUG, INFO, WARNING, ERROR
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("database_mcp")
# --- MCP Server Setup ---
transport = os.getenv("MCP_TRANSPORT", "stdio") # default stdio
if transport == "http" or transport == "streamable-http":
host=os.getenv("MCP_HOST", "0.0.0.0")
port=int(os.getenv("MCP_PORT", "3000"))
logger.info("Starting MCP server in HTTP mode at " + f"{host}:{port}")
mcp = FastMCP(name="Database MCP", host=host,
port=port, debug=True)
else:
mcp = FastMCP(name="Database MCP")
# --- Helpers ---
_valid_name_rx = re.compile(r"^[A-Za-z0-9_]+$")
def _validate_name(name: str) -> None:
if not _valid_name_rx.match(name):
raise ValueError("Invalid name. Only letters, numbers, and underscores allowed.")
def _row_to_dict(row) -> Dict[str, Any]:
try:
return dict(row._mapping)
except Exception:
return dict(zip(row.keys(), row))
# --- Postgres tools ---
@mcp.tool(title="Postgres: DB Version", description="Get PostgreSQL server version")
def pg_version() -> str:
try:
with pg_engine.connect() as conn:
res = conn.execute(text("SELECT version()"))
return res.scalar() or "Unknown"
except SQLAlchemyError as e:
return f"Postgres error: {e}"
@mcp.tool(
title="Postgres: List tables",
description="Return the list of tables in the current Postgres database."
)
def pg_list_tables(schema: str = Field("public", description="Schema name, default is 'public'")) -> List[str]:
logger.debug(f"Listing tables in schema={schema}")
sql = text("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = :schema
ORDER BY table_name
""")
try:
with pg_engine.connect() as conn:
res = conn.execute(sql, {"schema": schema})
tables = [row[0] for row in res.fetchall()]
# tables = res.scalars().all()
logger.info(f"Found {len(tables)} tables in schema={schema}")
return tables
except SQLAlchemyError as e:
logger.error(f"Postgres error while listing tables: {e}")
raise RuntimeError(f"Postgres error: {e}")
@mcp.tool(title="Postgres: List rows", description="List up to `limit` rows from table")
def pg_list_rows(table: str = Field(...), limit: int = Field(100)) -> List[Dict[str, Any]]:
_validate_name(table)
sql = text(f"SELECT * FROM {table} LIMIT :limit")
try:
with pg_engine.connect() as conn:
res = conn.execute(sql, {"limit": limit})
return [_row_to_dict(r) for r in res.fetchall()]
except SQLAlchemyError as e:
raise RuntimeError(f"Postgres error: {e}")
@mcp.tool(title="Postgres: Insert row", description="Insert row into table and return id")
def pg_insert_row(table: str = Field(...), data: dict = Field(...)) -> str:
_validate_name(table)
if not data:
raise ValueError("Data must be non-empty")
cols = []
params = {}
for k, v in data.items():
_validate_name(k)
cols.append(k)
params[k] = v
col_list = ", ".join(cols)
bind_list = ", ".join([f":{c}" for c in cols])
sql = text(f"INSERT INTO {table} ({col_list}) VALUES ({bind_list}) RETURNING id")
try:
with pg_engine.begin() as conn:
res = conn.execute(sql, params)
inserted = res.scalar_one_or_none()
return f"Inserted id={inserted}" if inserted else "Inserted"
except SQLAlchemyError as e:
raise RuntimeError(f"Postgres error: {e}")
# --- MongoDB tools ---
@mcp.tool(title="Mongo: List collections", description="List all collections in DB")
def mongo_list_collections() -> List[str]:
try:
return mongo_db.list_collection_names()
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
@mcp.tool(title="Mongo: Find documents", description="Find documents in collection with optional filter")
def mongo_find(collection: str = Field(...), query: dict = Field(default_factory=dict), limit: int = Field(10)) -> List[Dict[str, Any]]:
_validate_name(collection)
try:
cursor = mongo_db[collection].find(query).limit(limit)
docs = []
for d in cursor:
d["_id"] = str(d["_id"]) # make ObjectId JSON-serializable
docs.append(d)
return docs
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
@mcp.tool(title="Mongo: Insert document", description="Insert a document into collection")
def mongo_insert(collection: str = Field(...), doc: dict = Field(...)) -> str:
_validate_name(collection)
try:
result = mongo_db[collection].insert_one(doc)
return f"Inserted with id={str(result.inserted_id)}"
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
if __name__ == "__main__":
logger.info(f"Starting MCP server with transport={transport}")
mcp.run(transport=transport)