import asyncio
import json
import logging
import os
import re
import sys
import time
from decimal import Decimal
from logging.handlers import RotatingFileHandler
from typing import Any, Dict
from dotenv import load_dotenv
from db import run_query
load_dotenv("config.env")
# Configure logging: stream to stderr + rotating file handler in logs/
LOG_DIR = os.getenv("LOG_DIR", "logs")
os.makedirs(LOG_DIR, exist_ok=True)
log_file = os.path.join(LOG_DIR, "pagila.log")
log_level_name = os.getenv("LOG_LEVEL", "INFO").upper()
log_level = getattr(logging, log_level_name, logging.INFO)
# Root logger config: ensure both stderr and file handlers
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s - %(message)s")
# Stream handler -> stderr
sh = logging.StreamHandler(sys.stderr)
sh.setFormatter(formatter)
# Rotating file handler (5 MB per file, keep 5 backups)
fh = RotatingFileHandler(
log_file, maxBytes=5 * 1024 * 1024, backupCount=5, encoding="utf-8"
)
fh.setFormatter(formatter)
# Replace existing handlers to avoid duplicated logs when reloading
if root_logger.handlers:
root_logger.handlers.clear()
root_logger.addHandler(sh)
root_logger.addHandler(fh)
logger = logging.getLogger(__name__)
# JSON default serializer for non-JSON types (Decimal etc.)
def _json_default(o):
if isinstance(o, Decimal):
return float(o)
return str(o)
async def handle_list_films(params: Dict[str, Any]) -> Dict[str, Any]:
limit = int(params.get("limit", 10))
sql = """
SELECT film_id, title, release_year, rental_rate
FROM film
ORDER BY film_id
LIMIT %s
"""
# run blocking DB call off the event loop
start = time.monotonic()
rows = await asyncio.to_thread(run_query, sql, (limit,))
duration = time.monotonic() - start
logger.debug(
"handle_list_films finished limit=%s rows=%d duration=%.3fs",
limit,
len(rows) if isinstance(rows, list) else -1,
duration,
)
return {"rows": rows}
async def _get_schema_info() -> Dict[str, list]:
"""Return a mapping table_name -> list of column names from information_schema."""
sql = """
SELECT table_name, column_name
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position
"""
rows = await asyncio.to_thread(run_query, sql, None)
schema: Dict[str, list] = {}
for r in rows:
t = r.get("table_name")
c = r.get("column_name")
schema.setdefault(t, []).append(c)
return schema
def _Text_to_sql_local(text: str, schema: Dict[str, list]) -> Dict[str, str]:
"""Very small heuristic text->SQL generator. Returns dict with keys: sql and note.
This is a fallback for quick testing. It's NOT a replacement for an LLM-based
generator, but it handles simple requests about the `film` table (titles, year,
rental_rate, limit) and a few common patterns.
"""
t = text.strip().lower()
note = "generated by local heuristic"
# choose target table
target = None
if "film" in t or "films" in t or "movie" in t or "movies" in t:
target = "film"
else:
# fallback to film
target = "film"
cols = schema.get(target, [])
# select columns
select_cols = "*"
wants_title = any(k in t for k in ("title", "titles", "name")) and "title" in cols
wants_year = (
any(k in t for k in ("year", "released", "release")) and "release_year" in cols
)
wants_rate = "rental" in t and "rental_rate" in cols
if wants_title or wants_year or wants_rate:
parts = []
if wants_title:
parts.append("title")
if wants_year:
parts.append("release_year")
if wants_rate:
parts.append("rental_rate")
select_cols = ", ".join(parts)
where_clauses = []
# capture 4-digit year
m = re.search(r"(19|20)\d{2}", t)
if m and wants_year:
year = m.group(0)
where_clauses.append(f"release_year = {year}")
# quoted phrase -> search in title
m2 = re.search(r"['\"]([\w \-]+)['\"]", text)
if m2 and "title" in cols:
phrase = m2.group(1).replace("'", "''")
where_clauses.append(f"title ILIKE '%{phrase}%'")
# limit
limit = None
m3 = re.search(r"limit\s+(\d+)", t)
if m3:
limit = int(m3.group(1))
sql = f"SELECT {select_cols} FROM {target}"
if where_clauses:
sql += " WHERE " + " AND ".join(where_clauses)
if limit:
sql += f" LIMIT {limit}"
else:
# default small limit for safety
sql += " LIMIT 50"
return {"sql": sql, "note": note}
async def handle_text_to_sql(params: Dict[str, Any]) -> Dict[str, Any]:
"""Generate SQL from natural language text. Params:
- text: natural language query (required)
- execute: bool (optional) -> if true, run the generated SQL and return rows
- provider: 'local'|'openai' (optional) -- currently only 'local' implemented
"""
text = params.get("text")
if not text or not isinstance(text, str):
raise ValueError("'text' parameter is required and must be a string")
provider = params.get("provider", "local")
execute = bool(params.get("execute", False))
# fetch schema
schema = await _get_schema_info()
if provider != "local":
# future: implement LLM provider here (OpenAI/GPT). For now, reject.
raise ValueError("Only 'local' provider is supported in this deployment")
generated = _Text_to_sql_local(text, schema)
sql = generated.get("sql")
result: Dict[str, Any] = {"sql": sql, "note": generated.get("note")}
if execute:
start = time.monotonic()
rows = await asyncio.to_thread(run_query, sql, None)
duration = time.monotonic() - start
result["rows"] = rows
result["duration"] = duration
logger.info(
"text_to_sql generated sql=%r execute=%s",
sql if len(sql) < 200 else sql[:200] + "...",
execute,
)
return result
async def handle_run_pagila_query(params: Dict[str, Any]) -> Dict[str, Any]:
query = params.get("query", "")
if not isinstance(query, str) or not query.strip().lower().startswith("select"):
raise ValueError("Only SELECT queries are allowed")
start = time.monotonic()
rows = await asyncio.to_thread(run_query, query, None)
duration = time.monotonic() - start
logger.debug(
"handle_run_pagila_query finished query=%r rows=%d duration=%.3fs",
query if len(query) < 200 else query[:200] + "...",
len(rows) if isinstance(rows, list) else -1,
duration,
)
return {"rows": rows}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
method = request.get("method")
params = request.get("params") or {}
req_id = request.get("id")
logger.info("Request start id=%s method=%s params=%s", req_id, method, params)
start = time.monotonic()
try:
if method == "list_films":
result = await handle_list_films(params)
elif method == "run_pagila_query":
result = await handle_run_pagila_query(params)
elif method == "text_to_sql":
result = await handle_text_to_sql(params)
else:
raise ValueError(f"Unknown method: {method}")
duration = time.monotonic() - start
rows = result.get("rows") if isinstance(result, dict) else None
row_count = len(rows) if isinstance(rows, list) else None
logger.info(
"Request done id=%s method=%s rows=%s duration=%.3fs",
req_id,
method,
row_count,
duration,
)
except Exception:
duration = time.monotonic() - start
logger.exception(
"Request error id=%s method=%s duration=%.3fs", req_id, method, duration
)
raise
return {"id": req_id, "result": result}
async def server_loop() -> None:
logger.info("Pagila MCP server started")
loop = asyncio.get_running_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, os.fdopen(0))
# No asyncio StreamWriter; write responses to stdout.buffer.
while True:
line = await reader.readline()
if not line:
break
try:
request = json.loads(line.decode("utf-8"))
logger.debug("Received request: %s", request)
response = await handle_request(request)
except Exception as exc:
logger.exception("Error handling request: %s", exc)
response = {
"id": request.get("id") if isinstance(request, dict) else None,
"error": {"message": str(exc)},
}
# error response prepared
# write response directly to stdout (works with pipes)
sys.stdout.buffer.write(
(json.dumps(response, default=_json_default) + "\n").encode("utf-8")
)
sys.stdout.buffer.flush()
logger.info("Pagila MCP server stopped")
def main() -> None:
asyncio.run(server_loop())
if __name__ == "__main__":
main()