"""PostgreSQL tools for MCP server."""
import csv
import io
import logging
import os
from typing import Any, Dict, List, Optional
import psycopg2
import psycopg2.extras
logger = logging.getLogger(__name__)
def get_db_connection():
"""Create and return a PostgreSQL database connection."""
try:
conn = psycopg2.connect(
host=os.getenv("POSTGRES_HOST", "localhost"),
port=int(os.getenv("POSTGRES_PORT", 5432)),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
database=os.getenv("POSTGRES_DATABASE"),
sslmode=os.getenv("POSTGRES_SSL_MODE", "prefer")
)
return conn
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
async def execute_query(query: str, limit: int = 1000) -> str:
"""Execute a SQL query and return results."""
try:
conn = get_db_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# Add LIMIT clause if it's a SELECT query and no LIMIT is specified
query_upper = query.upper().strip()
if query_upper.startswith("SELECT") and "LIMIT" not in query_upper:
query = f"{query.rstrip(';')} LIMIT {limit}"
cur.execute(query)
if query_upper.startswith("SELECT"):
results = cur.fetchall()
if not results:
return "Query executed successfully. No results returned."
# Format as table
rows = [dict(row) for row in results]
headers = list(rows[0].keys())
table_output = []
header_row = " | ".join(f"{h:<20}" for h in headers)
table_output.append(header_row)
table_output.append("-" * len(header_row))
for row in rows:
data_row = " | ".join(f"{str(row[h]):<20}" for h in headers)
table_output.append(data_row)
result = f"Query executed successfully. {len(rows)} rows returned:\n\n"
result += "\n".join(table_output)
return result
else:
conn.commit()
return f"Query executed successfully. {cur.rowcount} rows affected."
except Exception as e:
if 'conn' in locals():
conn.rollback()
raise Exception(f"Query execution failed: {str(e)}")
finally:
if 'cur' in locals():
cur.close()
if 'conn' in locals():
conn.close()
async def list_tables(schema: str = "public") -> str:
"""List all tables in the specified schema."""
try:
query = """
SELECT table_name, table_type
FROM information_schema.tables
WHERE table_schema = %s
ORDER BY table_name
"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cur.execute(query, (schema,))
results = cur.fetchall()
cur.close()
conn.close()
if not results:
return f"No tables found in schema '{schema}'"
output = [f"Tables in schema '{schema}':\n"]
for row in results:
output.append(f"• {row['table_name']} ({row['table_type']})")
return "\n".join(output)
except Exception as e:
raise Exception(f"Failed to list tables: {str(e)}")
async def describe_table(table_name: str, schema: str = "public") -> str:
"""Get detailed information about a table structure."""
try:
query = """
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_name = %s AND table_schema = %s
ORDER BY ordinal_position
"""
conn = get_db_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cur.execute(query, (table_name, schema))
columns = cur.fetchall()
cur.close()
conn.close()
if not columns:
return f"Table '{schema}.{table_name}' not found"
output = [f"Table: {schema}.{table_name}\n"]
output.append("Columns:")
output.append("-" * 60)
for col in columns:
nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL"
default = f"DEFAULT {col['column_default']}" if col['column_default'] else ""
col_line = f" {col['column_name']:<25} {col['data_type']:<15} {nullable:<10} {default}"
output.append(col_line)
return "\n".join(output)
except Exception as e:
raise Exception(f"Failed to describe table: {str(e)}")
async def export_table_to_csv(table_name: str, limit: int = 10000, where_clause: Optional[str] = None) -> str:
"""Export table data to CSV format."""
try:
query = f"SELECT * FROM {table_name}"
if where_clause:
query += f" WHERE {where_clause}"
query += f" LIMIT {limit}"
conn = get_db_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cur.execute(query)
results = cur.fetchall()
cur.close()
conn.close()
if not results:
return "No data found to export"
# Convert to CSV
output = io.StringIO()
fieldnames = list(results[0].keys())
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()
for row in results:
clean_row = {k: (v if v is not None else '') for k, v in row.items()}
writer.writerow(clean_row)
csv_content = output.getvalue()
output.close()
return csv_content
except Exception as e:
raise Exception(f"Failed to export table to CSV: {str(e)}")