"""
SQL Query Tool for MCP Demo Server.
Provides safe execution of SELECT/INSERT/UPDATE queries on the demo database
with PII scrubbing, SQL injection prevention, audit logging, and row limits.
"""
import re
import sqlite3
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import json
# ---------------------------------------------------------------------------
# Logging / Audit trail
# ---------------------------------------------------------------------------
audit_logger = logging.getLogger("mcp.audit")
if not audit_logger.handlers:
_handler = logging.StreamHandler()
_handler.setFormatter(
logging.Formatter("%(asctime)s | AUDIT | %(message)s", datefmt="%Y-%m-%dT%H:%M:%SZ")
)
audit_logger.addHandler(_handler)
audit_logger.setLevel(logging.INFO)
# ---------------------------------------------------------------------------
# PII field registry
# Any column whose name matches one of these patterns will be redacted in
# SELECT results returned to the LLM.
# ---------------------------------------------------------------------------
_PII_FIELD_PATTERNS: list[re.Pattern] = [
re.compile(r"\bemail\b", re.IGNORECASE),
re.compile(r"\bphone\b", re.IGNORECASE),
re.compile(r"\baddress\b", re.IGNORECASE),
re.compile(r"\bpassword\b", re.IGNORECASE),
re.compile(r"\bsecret\b", re.IGNORECASE),
re.compile(r"\bssn\b", re.IGNORECASE),
re.compile(r"\bdob\b", re.IGNORECASE),
re.compile(r"\bdate_of_birth\b", re.IGNORECASE),
re.compile(r"\bcredit_card\b", re.IGNORECASE),
re.compile(r"\bip_address\b", re.IGNORECASE),
re.compile(r"\btoken\b", re.IGNORECASE),
re.compile(r"\bapi_key\b", re.IGNORECASE),
]
PII_REDACTED = "[REDACTED]"
MAX_RESULT_ROWS = 100 # Hard cap — prevents bulk extraction
def _is_pii_field(field_name: str) -> bool:
"""Return True if the field name matches any known PII pattern."""
return any(p.search(field_name) for p in _PII_FIELD_PATTERNS)
def scrub_pii(rows: list[dict]) -> tuple[list[dict], list[str]]:
"""
Redact PII fields in a list of row-dicts.
Returns:
(scrubbed_rows, list_of_redacted_field_names)
"""
if not rows:
return rows, []
redacted_fields: set[str] = set()
scrubbed: list[dict] = []
for row in rows:
clean = {}
for k, v in row.items():
if _is_pii_field(k):
clean[k] = PII_REDACTED
redacted_fields.add(k)
else:
clean[k] = v
scrubbed.append(clean)
return scrubbed, sorted(redacted_fields)
# ---------------------------------------------------------------------------
# SQL injection / dangerous-pattern detection
# ---------------------------------------------------------------------------
_BLOCKED_PATTERNS: list[re.Pattern] = [
re.compile(r"\bDROP\b", re.IGNORECASE),
re.compile(r"\bDELETE\b", re.IGNORECASE),
re.compile(r"\bTRUNCATE\b", re.IGNORECASE),
re.compile(r"\bALTER\b", re.IGNORECASE),
re.compile(r"\bCREATE\b", re.IGNORECASE),
re.compile(r"\bEXEC\b", re.IGNORECASE),
re.compile(r"\bEXECUTE\b", re.IGNORECASE),
re.compile(r"\bATTACH\b", re.IGNORECASE),
re.compile(r"\bDETACH\b", re.IGNORECASE),
re.compile(r"--",), # inline comment (SQLi vector)
re.compile(r"/\*.*?\*/", re.DOTALL), # block comment
re.compile(r";\s*\S"), # stacked queries e.g. "SELECT 1; DROP TABLE …"
re.compile(r"\bUNION\b.*\bSELECT\b", re.IGNORECASE), # UNION-based injection
re.compile(r"\bOR\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+['\"]?", re.IGNORECASE), # OR 1=1
re.compile(r"\bAND\b\s+['\"]?\d+['\"]?\s*=\s*['\"]?\d+['\"]?", re.IGNORECASE), # AND 1=1
re.compile(r"xp_cmdshell", re.IGNORECASE), # MSSQL-style RCE (defence-in-depth)
re.compile(r"load_file", re.IGNORECASE), # MySQL file read
re.compile(r"outfile", re.IGNORECASE), # MySQL file write
]
_ALLOWED_PREFIXES = ("SELECT", "INSERT", "UPDATE")
def validate_query(query: str) -> str | None:
"""
Validate a SQL query string.
Returns:
None if the query is safe, or an error message string if it is not.
"""
stripped = query.strip()
if not stripped:
return "Query cannot be empty."
upper = stripped.upper().lstrip()
if not any(upper.startswith(p) for p in _ALLOWED_PREFIXES):
return "Only SELECT, INSERT, and UPDATE statements are permitted."
for pattern in _BLOCKED_PATTERNS:
if pattern.search(stripped):
return "Query contains a blocked pattern and was rejected for security reasons."
return None # safe
# ---------------------------------------------------------------------------
# Main tool class
# ---------------------------------------------------------------------------
class SQLQueryTool:
"""Execute SQL queries on the demo database with PII scrubbing and security controls."""
def __init__(self, db_path: str | None = None):
"""
Initialize the SQL query tool.
Args:
db_path: Path to SQLite database. If None, uses default location.
"""
if db_path is None:
db_path = Path(__file__).parent.parent / "db" / "demo.db"
self.db_path = str(db_path)
# ------------------------------------------------------------------
def execute_query(self, query: str, params: list[Any] | None = None) -> dict:
"""
Execute a SQL query and return results with PII scrubbed.
Args:
query: SQL query to execute (SELECT / INSERT / UPDATE only)
params: Positional parameters for prepared-statement binding
Returns:
{success, row_count, rows, pii_redacted_fields} on SELECT
{success, rows_affected, message} on INSERT/UPDATE
{success: False, error} on failure
"""
# 1. Validate
error = validate_query(query)
if error:
audit_logger.warning("REJECTED query | reason=%s | query=%.120s", error, query)
return {"success": False, "error": error}
normalized = query.strip().upper().lstrip()
audit_logger.info(
"QUERY | type=%s | params_count=%d | query=%.200s",
"READ" if normalized.startswith("SELECT") else "WRITE",
len(params) if params else 0,
query,
)
try:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
if normalized.startswith("SELECT"):
rows = cursor.fetchmany(MAX_RESULT_ROWS)
raw_results = [dict(row) for row in rows]
conn.close()
scrubbed, redacted_fields = scrub_pii(raw_results)
response: dict = {
"success": True,
"row_count": len(scrubbed),
"rows": scrubbed,
}
if redacted_fields:
response["pii_redacted_fields"] = redacted_fields
response["pii_notice"] = (
f"The following field(s) were redacted to protect PII: "
f"{', '.join(redacted_fields)}"
)
if len(raw_results) == MAX_RESULT_ROWS:
response["warning"] = (
f"Result set was capped at {MAX_RESULT_ROWS} rows. "
"Add a LIMIT clause for more precise control."
)
return response
else: # INSERT / UPDATE
conn.commit()
row_count = cursor.rowcount
conn.close()
return {
"success": True,
"rows_affected": row_count,
"message": f"Query executed successfully. {row_count} row(s) affected.",
}
except sqlite3.Error as e:
# Sanitise DB error — never leak internal paths or schema details
audit_logger.error("DB_ERROR | %s", str(e))
return {"success": False, "error": "A database error occurred. Check your query syntax."}
except Exception as e:
audit_logger.error("UNEXPECTED_ERROR | %s", str(e))
return {"success": False, "error": "An unexpected error occurred."}
# ------------------------------------------------------------------
def get_schema(self) -> dict:
"""
Retrieve database schema (table names + column definitions).
PII column names are flagged but structure is not hidden — the LLM
needs to know which fields exist to write correct queries.
Returns:
{success, schema} where schema maps table_name → [column_info, …]
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
schema: dict = {}
for table in tables:
cursor.execute(f"PRAGMA table_info({table})") # noqa: S608 — table names are from sqlite_master
columns = cursor.fetchall()
schema[table] = [
{
"name": col[1],
"type": col[2],
"nullable": col[3] == 0,
"default": col[4],
"primary_key": col[5] == 1,
"pii": _is_pii_field(col[1]), # ← advisory flag for LLM
}
for col in columns
]
conn.close()
return {
"success": True,
"schema": schema,
"pii_policy": (
"Columns marked pii=true will be redacted ([REDACTED]) in query results. "
"Do not attempt to circumvent this by aliasing or casting PII columns."
),
}
except Exception as e:
audit_logger.error("SCHEMA_ERROR | %s", str(e))
return {"success": False, "error": "Error retrieving schema."}
# ---------------------------------------------------------------------------
# Quick smoke-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
tool = SQLQueryTool()
print("📊 Schema (with PII flags):")
schema = tool.get_schema()
print(json.dumps(schema, indent=2))
print("\n👥 Sample Users (PII scrubbed):")
result = tool.execute_query("SELECT * FROM users LIMIT 3")
print(json.dumps(result, indent=2))
print("\n🚫 Injection attempt:")
bad = tool.execute_query("SELECT * FROM users WHERE id=1 OR 1=1")
print(json.dumps(bad, indent=2))
print("\n🚫 DROP attempt:")
drop = tool.execute_query("DROP TABLE users")
print(json.dumps(drop, indent=2))