Skip to main content
Glama

Simple PostgreSQL MCP Server

server.py10.3 kB
import os from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Optional, AsyncIterator import asyncpg from mcp.server.fastmcp import FastMCP, Context from pydantic import Field # Constants DEFAULT_QUERY_LIMIT = 100 DEFAULT_SCHEMA = "public" # Define our own PromptMessage class if the MCP one isn't available @dataclass class PromptMessage: content: str role: Optional[str] = "user" # Database context class @dataclass class DbContext: pool: asyncpg.Pool schema: str # Database connection lifecycle manager @asynccontextmanager async def db_lifespan(server: FastMCP) -> AsyncIterator[DbContext]: """Manage database connection lifecycle""" # Initialize DB connection from environment variables dsn = os.environ.get("DSN", "postgresql://postgres:postgres@localhost:5432/postgres") schema = os.environ.get("SCHEMA", DEFAULT_SCHEMA) pool = await asyncpg.create_pool(dsn) try: yield DbContext(pool=pool, schema=schema) finally: # Clean up await pool.close() # Create server with database lifecycle management mcp = FastMCP( "SQL Database Server", dependencies=["asyncpg", "pydantic"], lifespan=db_lifespan ) @mcp.tool() async def test_connection(ctx: Context) -> str: """Test database connection""" try: pool = ctx.request_context.lifespan_context.pool async with pool.acquire() as conn: version = await conn.fetchval("SELECT version();") return f"Connection successful. PostgreSQL version: {version}" except Exception as e: return f"Connection failed: {str(e)}" @mcp.tool() async def execute_query( query: str = Field(description="SQL query to execute (SELECT only)"), limit: Optional[int] = Field(default=DEFAULT_QUERY_LIMIT, description="Maximum number of rows to return"), ctx: Context = None ) -> str: """Execute a read-only SQL query against the database""" # Validate query - simple check for read-only query = query.strip() if not query.lower().startswith("select"): return "Error: Only SELECT queries are allowed for security reasons." try: pool = ctx.request_context.lifespan_context.pool async with pool.acquire() as conn: result = await conn.fetch(query) if not result: return "Query executed successfully. No rows returned." # Format results columns = [k for k in result[0].keys()] header = " | ".join(columns) separator = "-" * len(header) # Format rows with limit rows = [" | ".join(str(val) for val in row.values()) for row in result[:limit if limit else DEFAULT_QUERY_LIMIT]] return f"{header}\n{separator}\n" + "\n".join(rows) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" # Database helper functions async def get_all_tables(pool, schema): """Get all tables from the database""" async with pool.acquire() as conn: result = await conn.fetch(""" SELECT c.relname AS table_name FROM pg_class AS c JOIN pg_namespace AS n ON n.oid = c.relnamespace WHERE NOT EXISTS ( SELECT 1 FROM pg_inherits AS i WHERE i.inhrelid = c.oid ) AND c.relkind IN ('r', 'p') AND n.nspname = $1 AND c.relname NOT LIKE 'pg_%' ORDER BY c.relname; """, schema) return result async def get_table_schema_info(pool, schema, table_name): """Get schema information for a specific table""" async with pool.acquire() as conn: columns = await conn.fetch(""" SELECT column_name, data_type, is_nullable, column_default, character_maximum_length FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position; """, schema, table_name) return columns def format_table_schema(table_name, columns): """Format table schema into readable text""" if not columns: return f"Table '{table_name}' not found." result = [f"Table: {table_name}", "Columns:"] for col in columns: nullable = "NULL" if col['is_nullable'] == 'YES' else "NOT NULL" length = f"({col['character_maximum_length']})" if col['character_maximum_length'] else "" default = f" DEFAULT {col['column_default']}" if col['column_default'] else "" result.append(f"- {col['column_name']} ({col['data_type']}{length}) {nullable}{default}") return "\n".join(result) @mcp.resource("db://tables") async def list_tables() -> str: """List all tables in the database""" try: async with db_lifespan(mcp) as db_ctx: result = await get_all_tables(db_ctx.pool, db_ctx.schema) if not result: return f"No tables found in the {db_ctx.schema} schema." return "\n".join(row['table_name'] for row in result) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.resource("db://tables/{table_name}") async def get_table_schema(table_name: str) -> str: """Get schema information for a specific table""" try: schema = os.environ.get("SCHEMA", DEFAULT_SCHEMA) async with db_lifespan(mcp) as db_ctx: columns = await get_table_schema_info(db_ctx.pool, schema, table_name) if not columns: return f"Table '{table_name}' not found in {schema} schema." return format_table_schema(table_name, columns) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.resource("db://schema") async def get_all_schemas() -> str: """Get schema information for all tables in the database""" try: schema = os.environ.get("SCHEMA", DEFAULT_SCHEMA) async with db_lifespan(mcp) as db_ctx: tables = await get_all_tables(db_ctx.pool, db_ctx.schema) if not tables: return f"No tables found in the {db_ctx.schema} schema." all_schemas = [] for table in tables: table_name = table['table_name'] columns = await get_table_schema_info(db_ctx.pool, schema, table_name) table_schema = format_table_schema(table_name, columns) all_schemas.append(table_schema) all_schemas.append("") # Add empty line between tables return "\n".join(all_schemas) except asyncpg.exceptions.PostgresError as e: return f"SQL Error: {str(e)}" except Exception as e: return f"Error: {str(e)}" @mcp.prompt() async def generate_select_query(table_name: str) -> list[PromptMessage]: """Generate a SELECT query with best practices for a table""" try: async with db_lifespan(mcp) as db_ctx: pool = db_ctx.pool async with pool.acquire() as conn: columns = await conn.fetch(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position """, db_ctx.schema, table_name) if not columns: return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) return [ PromptMessage( f"""Please help me write a well-structured, efficient SELECT query for the '{table_name}' table. Table Schema: {columns_text} PostgreSQL SQL Best Practices: - Use explicit column names instead of * when possible - Include LIMIT clauses to restrict result sets - Consider adding WHERE clauses to filter results - Use appropriate indexing considerations - Format SQL with proper indentation and line breaks Create a basic SELECT query following these best practices:""" ) ] except Exception as e: return [PromptMessage(f"Error generating select query: {str(e)}")] @mcp.prompt() async def generate_analytical_query(table_name: str) -> list[PromptMessage]: """Generate analytical queries for a table""" try: async with db_lifespan(mcp) as db_ctx: pool = db_ctx.pool async with pool.acquire() as conn: columns = await conn.fetch(""" SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 ORDER BY ordinal_position """, db_ctx.schema, table_name) if not columns: return [PromptMessage(f"Table '{table_name}' not found in schema '{db_ctx.schema}'.")] columns_text = "\n".join([f"- {col['column_name']} ({col['data_type']})" for col in columns]) return [ PromptMessage( f"""Please help me create analytical queries for the '{table_name}' table. Table Schema: {columns_text} PostgreSQL SQL Best Practices: - Use aggregation functions (COUNT, SUM, AVG, MIN, MAX) appropriately - Group data using GROUP BY for meaningful aggregations - Filter groups with HAVING clauses when needed - Consider using window functions for advanced analytics - Format SQL with proper indentation and line breaks Create a set of analytical queries for this table:""" ) ] except Exception as e: return [PromptMessage(f"Error generating analytical query: {str(e)}")] def main(): mcp.run()

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/NetanelBollag/simple-psql-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server