db.py•4.12 kB
import re
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Iterable
import psycopg
from psycopg.rows import dict_row
from config import Settings
WRITE_KEYWORDS = re.compile(
r"\b(insert|update|delete|create|alter|drop|truncate|grant|revoke|comment|rename)\b",
re.IGNORECASE,
)
SCHEMA_PATTERN = re.compile(r"\b([a-zA-Z_][\w]*)\.")
class DatabaseError(Exception):
pass
@dataclass
class QueryResult:
rowcount: int
rows: list[dict[str, Any]]
class PostgresConnector:
def __init__(self, settings: Settings):
self.settings = settings
@contextmanager
def _connection(self):
try:
with psycopg.connect(self.settings.conninfo, autocommit=False) as conn:
yield conn
except Exception as exc: # pragma: no cover - translated error
raise DatabaseError(str(exc)) from exc
def describe(self, schema: str | None = None) -> dict[str, Any]:
filter_schema = "WHERE table_schema = %(schema)s" if schema else ""
sql = f"""
SELECT table_schema, table_name, column_name, data_type, is_nullable
FROM information_schema.columns
{filter_schema}
ORDER BY table_schema, table_name, ordinal_position;
"""
params = {"schema": schema} if schema else None
with self._connection() as conn, conn.cursor(row_factory=dict_row) as cur:
cur.execute(sql, params)
columns = cur.fetchall()
catalog: dict[str, dict[str, list[dict[str, str]]]] = {}
for col in columns:
schema_name = col["table_schema"]
table_name = col["table_name"]
table_entry = catalog.setdefault(schema_name, {})
table_entry.setdefault(table_name, []).append(
{
"name": col["column_name"],
"type": col["data_type"],
"nullable": col["is_nullable"] == "YES",
}
)
return catalog
def run_read_query(self, sql: str, params: dict[str, Any] | None = None, limit: int = 200) -> QueryResult:
normalized = sql.strip()
if WRITE_KEYWORDS.search(normalized):
raise DatabaseError("Refusing to run write-like statement via read endpoint")
needs_limit = " limit " not in normalized.lower()
limited = f"{normalized} LIMIT {limit}" if needs_limit else normalized
with self._connection() as conn, conn.cursor(row_factory=dict_row) as cur:
cur.execute("SET LOCAL default_transaction_read_only = on;")
cur.execute(limited, params)
rows = cur.fetchall()
return QueryResult(rowcount=len(rows), rows=rows)
def _validate_write_query(self, sql: str) -> None:
if not WRITE_KEYWORDS.search(sql):
raise DatabaseError("Expected a write statement (insert/update/delete/ddl) for write endpoint")
schemas = {schema.lower() for schema in SCHEMA_PATTERN.findall(sql)}
forbidden = {schema for schema in schemas if schema != "mcp"}
if forbidden:
raise DatabaseError(f"Write queries may only target the mcp schema (found: {', '.join(sorted(forbidden))})")
statements = [stmt for stmt in (part.strip() for part in sql.split(";")) if stmt]
if len(statements) > 1:
raise DatabaseError("Multiple statements per call are blocked; submit one statement at a time")
def run_write_query(self, sql: str, params: dict[str, Any] | None = None) -> QueryResult:
normalized = sql.strip()
self._validate_write_query(normalized)
with self._connection() as conn, conn.cursor(row_factory=dict_row) as cur:
cur.execute("SET LOCAL search_path = mcp;")
cur.execute(normalized, params)
rowcount = cur.rowcount
try:
rows: Iterable[dict[str, Any]] = cur.fetchall()
except psycopg.ProgrammingError:
rows = []
conn.commit()
return QueryResult(rowcount=rowcount or 0, rows=list(rows))