import os
import logging
import psycopg2
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = int(os.getenv("DB_PORT", "5432"))
DB_NAME = os.getenv("DB_NAME", "postgres")
DB_USER = os.getenv("DB_USER", "postgres")
DB_PASSWORD = os.getenv("DB_PASSWORD", "")
# Create FastMCP server
app = FastMCP("postgres-mcp")
def get_connection(database=None):
"""Get a database connection."""
db_name = database or DB_NAME
logger.info(f"Connecting to database: {db_name} at {DB_HOST}:{DB_PORT}")
try:
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
database=db_name,
user=DB_USER,
password=DB_PASSWORD,
)
logger.info(f"Successfully connected to database: {db_name}")
return conn
except Exception as e:
logger.error(f"Failed to connect to database {db_name}: {e}")
raise
@app.tool()
def list_databases():
"""List all available databases in the Postgres instance."""
logger.info("Executing list_databases tool")
conn = get_connection()
try:
with conn.cursor() as cur:
cur.execute("""
SELECT datname
FROM pg_database
WHERE datistemplate = false
ORDER BY datname
""")
databases = [row[0] for row in cur.fetchall()]
logger.info(f"Found {len(databases)} databases: {', '.join(databases)}")
return f"Databases: {', '.join(databases)}"
except Exception as e:
logger.error(f"Error listing databases: {e}")
raise
finally:
conn.close()
logger.debug("Closed database connection for list_databases")
@app.tool()
def list_schemas(database: str):
"""List all schemas in the specified database."""
logger.info(f"Executing list_schemas tool for database: {database}")
conn = get_connection(database)
try:
with conn.cursor() as cur:
cur.execute("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT LIKE 'pg_%'
AND schema_name != 'information_schema'
ORDER BY schema_name
""")
schemas = [row[0] for row in cur.fetchall()]
logger.info(f"Found {len(schemas)} schemas in {database}: {', '.join(schemas)}")
return f"Schemas in {database}: {', '.join(schemas)}"
except Exception as e:
logger.error(f"Error listing schemas in {database}: {e}")
raise
finally:
conn.close()
logger.debug(f"Closed database connection for list_schemas ({database})")
@app.tool()
def list_tables(database: str, schema: str):
"""List all tables in the specified schema of the database."""
logger.info(f"Executing list_tables tool for database: {database}, schema: {schema}")
conn = get_connection(database)
try:
with conn.cursor() as cur:
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
ORDER BY table_name
""", (schema,))
tables = [row[0] for row in cur.fetchall()]
logger.info(f"Found {len(tables)} tables in {database}.{schema}: {', '.join(tables)}")
return f"Tables in {database}.{schema}: {', '.join(tables)}"
except Exception as e:
logger.error(f"Error listing tables in {database}.{schema}: {e}")
raise
finally:
conn.close()
logger.debug(f"Closed database connection for list_tables ({database}.{schema})")
@app.tool()
def get_table_schema(database: str, schema: str, table: str):
"""Get the schema (columns) of the specified table."""
logger.info(f"Executing get_table_schema tool for {database}.{schema}.{table}")
conn = get_connection(database)
try:
with conn.cursor() as cur:
cur.execute("""
SELECT column_name, data_type, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = %s
AND table_name = %s
ORDER BY ordinal_position
""", (schema, table))
columns = []
for row in cur.fetchall():
column_name, data_type, is_nullable, column_default = row
nullable = "(nullable)" if is_nullable == "YES" else "(not null)"
default = f" default: {column_default}" if column_default else ""
columns.append(f"{column_name}: {data_type} {nullable}{default}")
logger.info(f"Retrieved schema for {database}.{schema}.{table} with {len(columns)} columns")
return f"Schema for {database}.{schema}.{table}:\n" + "\n".join(columns)
except Exception as e:
logger.error(f"Error getting schema for {database}.{schema}.{table}: {e}")
raise
finally:
conn.close()
logger.debug(f"Closed database connection for get_table_schema ({database}.{schema}.{table})")
@app.tool()
def execute_query(database: str, query: str):
"""Execute a read-only SQL query and return results."""
logger.info(f"Executing execute_query tool for database: {database}")
logger.debug(f"Query: {query}")
if not query.strip().upper().startswith("SELECT"):
logger.warning(f"Rejected non-SELECT query: {query}")
raise ValueError("Only SELECT queries are allowed")
conn = get_connection(database)
try:
with conn.cursor() as cur:
cur.execute(query)
rows = cur.fetchall()
result = "\n".join([str(row) for row in rows])
logger.info(f"Query executed successfully, returned {len(rows)} rows")
return f"Query results:\n{result}"
except Exception as e:
logger.error(f"Error executing query in {database}: {e}")
raise
finally:
conn.close()
logger.debug(f"Closed database connection for execute_query ({database})")
if __name__ == "__main__":
logger.info("Starting Postgres MCP server")
try:
app.run()
except Exception as e:
logger.error(f"Server failed to start: {e}")
raise