import asyncio
import json
import logging
import os
import re
import sys
import time
from decimal import Decimal
from logging.handlers import RotatingFileHandler
from typing import Any, Dict
from dotenv import load_dotenv
from db import run_query
load_dotenv("config.env", override=True)
# Configure logging: stream to stderr + rotating file handler in logs/
LOG_DIR = os.getenv("LOG_DIR", "logs")
os.makedirs(LOG_DIR, exist_ok=True)
log_file = os.path.join(LOG_DIR, "pagila.log")
log_level_name = os.getenv("LOG_LEVEL", "INFO").upper()
log_level = getattr(logging, log_level_name, logging.INFO)
# Root logger config: ensure both stderr and file handlers
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s")
# Stream handler -> stderr
sh = logging.StreamHandler(sys.stderr)
sh.setFormatter(formatter)
# Rotating file handler (5 MB per file, keep 5 backups)
fh = RotatingFileHandler(
log_file, maxBytes=5 * 1024 * 1024, backupCount=5, encoding="utf-8"
)
fh.setFormatter(formatter)
# Replace existing handlers to avoid duplicated logs when reloading
if root_logger.handlers:
root_logger.handlers.clear()
root_logger.addHandler(sh)
root_logger.addHandler(fh)
logger = logging.getLogger(__name__)
# JSON default serializer for non-JSON types (Decimal etc.)
def _json_default(o):
if isinstance(o, Decimal):
return float(o)
return str(o)
async def handle_list_films(params: Dict[str, Any]) -> Dict[str, Any]:
limit = int(params.get("limit", 10))
sql = """
SELECT film_id, title, release_year, rental_rate
FROM film
ORDER BY film_id
LIMIT %s
"""
# run blocking DB call off the event loop
start = time.monotonic()
rows = await asyncio.to_thread(run_query, sql, (limit,))
duration = time.monotonic() - start
logger.debug(
"handle_list_films finished limit=%s rows=%d duration=%.3fs",
limit,
len(rows) if isinstance(rows, list) else -1,
duration,
)
return {"rows": rows}
async def handle_list_tables(params: Dict[str, Any]) -> Dict[str, Any]:
"""List all public tables in the database."""
sql = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
rows = await asyncio.to_thread(run_query, sql, None)
return {"tables": [r["table_name"] for r in rows]}
async def handle_get_table_schema(params: Dict[str, Any]) -> Dict[str, Any]:
"""Get column definitions for a specific list of tables."""
table_names = params.get("table_names", [])
if isinstance(table_names, str):
table_names = [table_names]
if not table_names:
return {"schema": "No tables specified."}
# Parameterized IN clause
placeholders = ",".join(["%s"] * len(table_names))
sql = f"""
SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name IN ({placeholders})
ORDER BY table_name, ordinal_position
"""
rows = await asyncio.to_thread(run_query, sql, tuple(table_names))
return {"schema_rows": rows}
async def _get_schema_info() -> Dict[str, list]:
"""Return a mapping table_name -> list of column names from information_schema."""
sql = """
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position
"""
rows = await asyncio.to_thread(run_query, sql, None)
schema: Dict[str, list] = {}
for r in rows:
t = r.get("table_name")
c = r.get("column_name")
schema.setdefault(t, []).append(c)
return schema
def _Text_to_sql_local(text: str, schema: Dict[str, list]) -> Dict[str, str]:
"""Very small heuristic text->SQL generator. Returns dict with keys: sql and note.
This is a fallback for quick testing. It's NOT a replacement for an LLM-based
generator, but it handles simple requests about the `film` table (titles, year,
rental_rate, limit) and a few common patterns.
"""
t = text.strip().lower()
note = "generated by local heuristic"
# choose target table
target = None
confident = True
# explicit category request
if "category" in t or "categories" in t:
target = "category"
# explicit actor/cast request
elif "actor" in t or "actors" in t or "cast" in t:
target = "actor"
# explicit film request
elif "film" in t or "films" in t or "movie" in t or "movies" in t:
target = "film"
else:
# fallback to film but mark as not confident
target = "film"
confident = False
cols = schema.get(target, [])
# select columns
select_cols = "*"
# special-case category requests to return the category name
if target == "category":
select_cols = "name" if "name" in cols else "*"
wants_title = False
wants_year = False
wants_rate = False
elif target == "actor":
# prefer first_name + last_name for actors
if "first_name" in cols and "last_name" in cols:
select_cols = "first_name, last_name"
else:
select_cols = "*"
wants_title = False
wants_year = False
wants_rate = False
else:
wants_title = (
any(k in t for k in ("title", "titles", "name")) and "title" in cols
)
wants_year = (
any(k in t for k in ("year", "released", "release"))
and "release_year" in cols
)
wants_rate = "rental" in t and "rental_rate" in cols
if wants_title or wants_year or wants_rate:
parts = []
if wants_title:
parts.append("title")
if wants_year:
parts.append("release_year")
if wants_rate:
parts.append("rental_rate")
select_cols = ", ".join(parts)
where_clauses = []
params: list = []
# capture 4-digit year (parameterized)
m = re.search(r"(19|20)\d{2}", t)
if m and wants_year:
year = int(m.group(0))
where_clauses.append("release_year = %s")
params.append(year)
# quoted phrase -> search in title (parameterized)
m2 = re.search(r"['\"]([\w \-]+)['\"]", text)
if m2 and "title" in cols:
phrase = m2.group(1)
where_clauses.append("title ILIKE %s")
params.append(f"%{phrase}%")
# limit (parameterized)
limit = None
m3 = re.search(r"limit\s+(\d+)", t)
if m3:
limit = int(m3.group(1))
sql = f"SELECT {select_cols} FROM {target}"
if where_clauses:
sql += " WHERE " + " AND ".join(where_clauses)
# apply safe limit (parameterized)
if limit:
sql += " LIMIT %s"
params.append(limit)
else:
sql += " LIMIT %s"
params.append(50)
result = {"sql": sql, "params": params, "note": note, "confident": confident}
return result
async def handle_text_to_sql(params: Dict[str, Any]) -> Dict[str, Any]:
"""Generate SQL from natural language text. Params:
- text: natural language query (required)
- execute: bool (optional) -> if true, run the generated SQL and return rows
- provider: 'local'|'openai' (optional) -- currently only 'local' implemented
"""
text = params.get("text")
if not text or not isinstance(text, str):
raise ValueError("'text' parameter is required and must be a string")
provider = params.get("provider", "local")
execute = bool(params.get("execute", False))
# fetch schema
schema = await _get_schema_info()
if provider != "local":
# future: implement LLM provider here (OpenAI/GPT). For now, reject.
raise ValueError("Only 'local' provider is supported in this deployment")
generated = _Text_to_sql_local(text, schema)
sql = generated.get("sql")
params_for_sql = generated.get("params") or []
result: Dict[str, Any] = {
"sql": sql,
"params": params_for_sql,
"note": generated.get("note"),
"confident": generated.get("confident", True),
}
if execute:
start = time.monotonic()
# pass params to run_query so the driver can handle escaping
rows = await asyncio.to_thread(
run_query, sql, params_for_sql if params_for_sql else None
)
duration = time.monotonic() - start
result["rows"] = rows
result["duration"] = duration
logger.info(
"text_to_sql generated sql=%r execute=%s",
sql if len(sql) < 200 else sql[:200] + "...",
execute,
)
return result
async def handle_run_pagila_query(params: Dict[str, Any]) -> Dict[str, Any]:
query = params.get("query", "")
if not isinstance(query, str):
raise ValueError("Query must be a string")
qstr = query.strip()
# disallow multiple statements or trailing/pipelined commands
if ";" in qstr.replace("\n", " ") and not qstr.rstrip().endswith(";"):
# presence of semicolon anywhere (other than possibly at end) is suspicious
raise ValueError("Multiple statements are not allowed")
# normalize for check
qnorm = qstr.lstrip().lower()
if not qnorm.startswith("select"):
raise ValueError("Only SELECT queries are allowed")
# SECURITY NOTE: This blacklist is a secondary defense.
# The primary defense MUST be a database user with READ-ONLY permissions
# (GRANT SELECT only).
# basic blacklist to avoid destructive patterns
forbidden = ["drop ", "delete ", "update ", "insert ", ";--", "--", "/*"]
for bad in forbidden:
if bad in qnorm:
raise ValueError("Query contains disallowed patterns")
# execute safely (no params since this is a raw SQL path)
# then cap rows to avoid huge responses
start = time.monotonic()
rows = await asyncio.to_thread(run_query, query, None)
duration = time.monotonic() - start
MAX_ROWS = int(os.getenv("MCP_MAX_ROWS", "1000"))
note = None
if isinstance(rows, list) and len(rows) > MAX_ROWS:
note = f"Truncated results to first {MAX_ROWS} rows"
rows = rows[:MAX_ROWS]
# log with a truncated query to avoid overly long log lines
if len(query) < 200:
truncated_query = query
else:
truncated_query = query[:200] + "..."
logger.debug(
"handle_run_pagila_query finished query=%r rows=%d duration=%.3fs",
truncated_query,
len(rows) if isinstance(rows, list) else -1,
duration,
)
result = {"rows": rows}
if note:
result["note"] = note
return result
async def handle_execute_sql(params: Dict[str, Any]) -> Dict[str, Any]:
"""Execute parameterized SQL provided as {'sql': str, 'params': [..]}.
This is safer for client-side generated SQL that includes placeholders.
"""
sql = params.get("sql")
sql_params = params.get("params")
if not isinstance(sql, str):
raise ValueError("'sql' must be a string")
qstr = sql.strip()
qnorm = qstr.lstrip().lower()
if not qnorm.startswith("select"):
raise ValueError("Only SELECT queries are allowed")
# SECURITY NOTE: This blacklist is a secondary defense.
# The primary defense MUST be a database user with READ-ONLY permissions
# (GRANT SELECT only).
# basic blacklist
forbidden = ["drop ", "delete ", "update ", "insert ", ";--", "--", "/*"]
for bad in forbidden:
if bad in qnorm:
raise ValueError("Query contains disallowed patterns")
# execute with params (may be None)
start = time.monotonic()
rows = await asyncio.to_thread(
run_query, sql, tuple(sql_params) if sql_params else None
)
duration = time.monotonic() - start
MAX_ROWS = int(os.getenv("MCP_MAX_ROWS", "1000"))
note = None
if isinstance(rows, list) and len(rows) > MAX_ROWS:
note = f"Truncated results to first {MAX_ROWS} rows"
rows = rows[:MAX_ROWS]
# log timing for execute_sql for diagnostics
truncated_sql = sql if len(sql) < 200 else sql[:200] + "..."
logger.debug(
"handle_execute_sql finished sql=%r rows=%d duration=%.3fs",
truncated_sql,
len(rows) if isinstance(rows, list) else -1,
duration,
)
result = {"rows": rows}
if note:
result["note"] = note
return result
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
method = request.get("method")
params = request.get("params") or {}
req_id = request.get("id")
logger.info("Request start id=%s method=%s params=%s", req_id, method, params)
start = time.monotonic()
try:
if method == "list_films":
result = await handle_list_films(params)
elif method == "list_tables":
result = await handle_list_tables(params)
elif method == "get_table_schema":
result = await handle_get_table_schema(params)
elif method == "run_pagila_query":
result = await handle_run_pagila_query(params)
elif method == "execute_sql":
result = await handle_execute_sql(params)
elif method == "text_to_sql":
result = await handle_text_to_sql(params)
else:
raise ValueError(f"Unknown method: {method}")
duration = time.monotonic() - start
rows = result.get("rows") if isinstance(result, dict) else None
row_count = len(rows) if isinstance(rows, list) else None
logger.info(
"Request done id=%s method=%s rows=%s duration=%.3fs",
req_id,
method,
row_count,
duration,
)
except Exception:
duration = time.monotonic() - start
logger.exception(
"Request error id=%s method=%s duration=%.3fs", req_id, method, duration
)
raise
return {"id": req_id, "result": result}
async def server_loop() -> None:
logger.info("Pagila MCP server started")
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, os.fdopen(0))
# No asyncio StreamWriter; write responses to stdout.buffer.
while True:
line = await reader.readline()
if not line:
break
try:
request = json.loads(line.decode("utf-8"))
logger.debug("Received request: %s", request)
response = await handle_request(request)
except Exception as exc:
logger.exception("Error handling request: %s", exc)
response = {
"id": request.get("id") if isinstance(request, dict) else None,
"error": {"message": str(exc)},
}
# error response prepared
# write response directly to stdout (works with pipes)
sys.stdout.buffer.write(
(json.dumps(response, default=_json_default) + "\n").encode("utf-8")
)
sys.stdout.buffer.flush()
logger.info("Pagila MCP server stopped")
def main() -> None:
asyncio.run(server_loop())
if __name__ == "__main__":
main()