main.pyā¢10.2 kB
import os
import logging
import re
import json
import socket
import uuid
import threading
import time
from typing import Any, Dict, List
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dotenv import load_dotenv
# --- Load env ---
load_dotenv()
# --- Postgres setup ---
PG_URL = os.getenv("DATABASE_URL")
if not PG_URL:
DB_USER = os.getenv("POSTGRES_USER", "postgres")
DB_PASS = os.getenv("POSTGRES_PASSWORD", "")
DB_HOST = os.getenv("POSTGRES_HOST", "localhost")
DB_PORT = os.getenv("POSTGRES_PORT", "5432")
DB_NAME = os.getenv("POSTGRES_DB", "postgres")
PG_URL = f"postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
pg_engine = create_engine(PG_URL, pool_size=5, max_overflow=10)
# --- Mongo setup ---
MONGO_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
MONGO_DBNAME = os.getenv("MONGODB_DB", "test")
mongo_client = MongoClient(MONGO_URL)
mongo_db = mongo_client[MONGO_DBNAME]
# Configure logging
logging.basicConfig(
filename="database_mcp.log",
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("database_mcp")
# --- Configuration ---
transport = os.getenv("MCP_TRANSPORT", "stdio")
server_name = os.getenv("MCP_SERVER_NAME", "Database MCP Server")
enable_broadcast = os.getenv("MCP_ENABLE_BROADCAST", "true").lower() == "true"
broadcast_interval = int(os.getenv("MCP_BROADCAST_INTERVAL", "30"))
# Multicast settings
SSDP_ADDR = "239.255.255.250"
MCP_DISCOVERY_PORT = 5353
# --- Multicast Broadcaster ---
class MCPBroadcaster:
"""Broadcasts MCP server presence on multicast"""
def __init__(self, server_name: str, port: int, transport: str):
self.server_name = server_name
self.port = port
self.transport = transport
self.uuid = str(uuid.uuid4())
self.running = False
self.sock = None
def get_local_ip(self):
"""Get local IP address"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
s.close()
return ip
except:
return '127.0.0.1'
def create_announcement(self):
"""Create MCP discovery announcement"""
local_ip = self.get_local_ip()
announcement = {
"type": "mcp-announcement",
"protocol": "MCP-DISCOVERY-v1",
"uuid": self.uuid,
"name": self.server_name,
"host": local_ip,
"port": self.port,
"endpoint": "/mcp",
"protocol_type": "MCP-HTTP",
"transport": self.transport,
"version": "1.0.0"
}
return json.dumps(announcement).encode('utf-8')
def start_broadcasting(self):
"""Start broadcasting server presence"""
if not enable_broadcast:
logger.info("Broadcasting disabled")
return
self.running = True
# Create UDP socket
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
self.sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
logger.info(f"Starting multicast broadcaster on {SSDP_ADDR}:{MCP_DISCOVERY_PORT}")
logger.info(f"Broadcasting every {broadcast_interval} seconds")
print(f"š Broadcasting {self.server_name} on {SSDP_ADDR}:{MCP_DISCOVERY_PORT}")
while self.running:
try:
message = self.create_announcement()
self.sock.sendto(message, (SSDP_ADDR, MCP_DISCOVERY_PORT))
logger.debug(f"Broadcast sent: {message.decode('utf-8')}")
except Exception as e:
logger.error(f"Broadcast error: {e}")
# Sleep in small intervals so shutdown is responsive
for _ in range(broadcast_interval * 10):
if not self.running:
break
time.sleep(0.1)
def stop_broadcasting(self):
"""Stop broadcasting"""
self.running = False
if self.sock:
self.sock.close()
logger.info("Broadcasting stopped")
# --- MCP Server Setup ---
if transport == "http" or transport == "streamable-http":
host = os.getenv("MCP_HOST", "0.0.0.0")
port = int(os.getenv("MCP_PORT", "3000"))
logger.info(f"Starting MCP server in {transport} mode at {host}:{port}")
mcp = FastMCP(name=server_name, host=host, port=port, debug=True)
else:
mcp = FastMCP(name=server_name)
# --- Helpers ---
_valid_name_rx = re.compile(r"^[A-Za-z0-9_]+$")
def _validate_name(name: str) -> None:
if not _valid_name_rx.match(name):
raise ValueError("Invalid name. Only letters, numbers, and underscores allowed.")
def _row_to_dict(row) -> Dict[str, Any]:
try:
return dict(row._mapping)
except Exception:
return dict(zip(row.keys(), row))
# --- Postgres tools ---
@mcp.tool(title="Postgres: DB Version", description="Get PostgreSQL server version")
def pg_version() -> str:
try:
with pg_engine.connect() as conn:
res = conn.execute(text("SELECT version()"))
return res.scalar() or "Unknown"
except SQLAlchemyError as e:
return f"Postgres error: {e}"
@mcp.tool(
title="Postgres: List tables",
description="Return the list of tables in the current Postgres database."
)
def pg_list_tables(schema: str = Field("public", description="Schema name, default is 'public'")) -> List[str]:
logger.debug(f"Listing tables in schema={schema}")
sql = text("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = :schema
ORDER BY table_name
""")
try:
with pg_engine.connect() as conn:
res = conn.execute(sql, {"schema": schema})
tables = [row[0] for row in res.fetchall()]
logger.info(f"Found {len(tables)} tables in schema={schema}")
return tables
except SQLAlchemyError as e:
logger.error(f"Postgres error while listing tables: {e}")
raise RuntimeError(f"Postgres error: {e}")
@mcp.tool(title="Postgres: List rows", description="List up to `limit` rows from table")
def pg_list_rows(table: str = Field(...), limit: int = Field(100)) -> List[Dict[str, Any]]:
_validate_name(table)
sql = text(f"SELECT * FROM {table} LIMIT :limit")
try:
with pg_engine.connect() as conn:
res = conn.execute(sql, {"limit": limit})
return [_row_to_dict(r) for r in res.fetchall()]
except SQLAlchemyError as e:
raise RuntimeError(f"Postgres error: {e}")
@mcp.tool(title="Postgres: Insert row", description="Insert row into table and return id")
def pg_insert_row(table: str = Field(...), data: dict = Field(...)) -> str:
_validate_name(table)
if not data:
raise ValueError("Data must be non-empty")
cols = []
params = {}
for k, v in data.items():
_validate_name(k)
cols.append(k)
params[k] = v
col_list = ", ".join(cols)
bind_list = ", ".join([f":{c}" for c in cols])
sql = text(f"INSERT INTO {table} ({col_list}) VALUES ({bind_list}) RETURNING id")
try:
with pg_engine.begin() as conn:
res = conn.execute(sql, params)
inserted = res.scalar_one_or_none()
return f"Inserted id={inserted}" if inserted else "Inserted"
except SQLAlchemyError as e:
raise RuntimeError(f"Postgres error: {e}")
# --- MongoDB tools ---
@mcp.tool(title="Mongo: List collections", description="List all collections in DB")
def mongo_list_collections() -> List[str]:
try:
return mongo_db.list_collection_names()
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
@mcp.tool(title="Mongo: Find documents", description="Find documents in collection with optional filter")
def mongo_find(collection: str = Field(...), query: dict = Field(default_factory=dict), limit: int = Field(10)) -> List[Dict[str, Any]]:
_validate_name(collection)
try:
cursor = mongo_db[collection].find(query).limit(limit)
docs = []
for d in cursor:
d["_id"] = str(d["_id"]) # make ObjectId JSON-serializable
docs.append(d)
return docs
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
@mcp.tool(title="Mongo: Insert document", description="Insert a document into collection")
def mongo_insert(collection: str = Field(...), doc: dict = Field(...)) -> str:
_validate_name(collection)
try:
result = mongo_db[collection].insert_one(doc)
return f"Inserted with id={str(result.inserted_id)}"
except PyMongoError as e:
raise RuntimeError(f"Mongo error: {e}")
if __name__ == "__main__":
logger.info(f"Starting MCP server with transport={transport}")
print("=" * 60)
print(f"Starting {server_name}")
print("=" * 60)
print(f"Transport: {transport}")
if transport in ["http", "streamable-http"]:
print(f"Host: 0.0.0.0")
print(f"Port: {os.getenv('MCP_PORT', '3000')}")
print(f"Endpoint: http://localhost:{os.getenv('MCP_PORT', '3000')}/mcp")
print(f"Broadcasting: {'ENABLED' if enable_broadcast else 'DISABLED'}")
if enable_broadcast:
print(f"Broadcast interval: {broadcast_interval}s")
print(f"Multicast: {SSDP_ADDR}:{MCP_DISCOVERY_PORT}")
print("=" * 60)
# Start broadcaster if in HTTP mode
broadcaster = None
if transport in ["http", "streamable-http"] and enable_broadcast:
broadcaster = MCPBroadcaster(server_name, int(os.getenv("MCP_PORT", "3000")), transport)
broadcaster_thread = threading.Thread(target=broadcaster.start_broadcasting, daemon=True)
broadcaster_thread.start()
try:
mcp.run(transport=transport)
finally:
if broadcaster:
broadcaster.stop_broadcasting()