#!/usr/bin/env python3
"""
GitHub Copilot MCP Demo Server.
Exposes SQL query tool through Model Context Protocol for Copilot integration.
Security controls implemented:
- Rate limiting: max RATE_LIMIT_MAX_CALLS calls per RATE_LIMIT_WINDOW_SECS window
- Query length cap: MAX_QUERY_LENGTH characters
- Sanitised error messages: internal details never leak to the LLM
- Audit logging via the sql_query_tool audit logger
"""
import json
import logging
import sys
import time
from collections import deque
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from mcp.server import Server
from mcp.types import Tool, TextContent, CallToolResult
from tools.sql_query_tool import SQLQueryTool
# ---------------------------------------------------------------------------
# Server-level constants
# ---------------------------------------------------------------------------
MAX_QUERY_LENGTH = 2_000 # characters — blocks prompt-injection via huge queries
RATE_LIMIT_WINDOW_SECS = 60 # sliding window
RATE_LIMIT_MAX_CALLS = 30 # max tool calls per window
# ---------------------------------------------------------------------------
# Rate limiter (in-memory, per-process; sufficient for stdio/single-session)
# ---------------------------------------------------------------------------
class _SlidingWindowRateLimiter:
"""Simple sliding-window rate limiter."""
def __init__(self, max_calls: int, window_secs: float):
self._max = max_calls
self._window = window_secs
self._calls: deque[float] = deque()
def allow(self) -> bool:
now = time.monotonic()
cutoff = now - self._window
while self._calls and self._calls[0] < cutoff:
self._calls.popleft()
if len(self._calls) >= self._max:
return False
self._calls.append(now)
return True
@property
def remaining(self) -> int:
now = time.monotonic()
cutoff = now - self._window
active = sum(1 for t in self._calls if t >= cutoff)
return max(0, self._max - active)
# ---------------------------------------------------------------------------
# Module-level singletons
# ---------------------------------------------------------------------------
server = Server("copilot-mcp-demo")
sql_tool = SQLQueryTool()
_rate_limiter = _SlidingWindowRateLimiter(RATE_LIMIT_MAX_CALLS, RATE_LIMIT_WINDOW_SECS)
log = logging.getLogger("mcp.server")
if not log.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s | SERVER | %(message)s", datefmt="%Y-%m-%dT%H:%M:%SZ"))
log.addHandler(_h)
log.setLevel(logging.INFO)
def define_tools() -> list[Tool]:
"""Define the tools exposed through this MCP server."""
return [
Tool(
name="execute_query",
description=(
"Execute SELECT, INSERT, or UPDATE queries on the demo database. "
"Use this to query users, products, orders, or modify records. "
f"Maximum query length: {MAX_QUERY_LENGTH} characters. "
"PII fields (email, phone, address, password, …) are automatically "
"redacted in results. Rate limit: "
f"{RATE_LIMIT_MAX_CALLS} calls per {RATE_LIMIT_WINDOW_SECS}s."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL query to execute (SELECT, INSERT, or UPDATE only)",
"maxLength": MAX_QUERY_LENGTH,
},
"params": {
"type": "array",
"description": "Optional positional parameters for prepared statements",
"items": {},
},
},
"required": ["query"],
},
),
Tool(
name="get_schema",
description=(
"Retrieve the database schema including table names and column definitions. "
"Columns marked pii=true will be redacted in query results. "
"Call this first to understand available tables before writing queries."
),
inputSchema={"type": "object", "properties": {}},
),
]
# ---------------------------------------------------------------------------
# Tool handlers
# ---------------------------------------------------------------------------
def _rate_limit_error() -> CallToolResult:
return CallToolResult(
content=[TextContent(
type="text",
text=(
f"Rate limit exceeded: max {RATE_LIMIT_MAX_CALLS} calls per "
f"{RATE_LIMIT_WINDOW_SECS}s. Please wait before retrying."
),
)],
isError=True,
)
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict) -> CallToolResult:
"""Handle tool calls from Copilot with rate limiting and input validation."""
# Rate limit check (applies to all tools)
if not _rate_limiter.allow():
log.warning("RATE_LIMIT_EXCEEDED | tool=%s | remaining=%d", name, _rate_limiter.remaining)
return _rate_limit_error()
log.info("CALL | tool=%s | remaining_calls=%d", name, _rate_limiter.remaining)
try:
if name == "execute_query":
query: str = (arguments.get("query") or "").strip()
params = arguments.get("params")
if not query:
return CallToolResult(
content=[TextContent(type="text", text="Error: 'query' parameter is required.")],
isError=True,
)
# Enforce query length cap BEFORE passing to the tool
if len(query) > MAX_QUERY_LENGTH:
return CallToolResult(
content=[TextContent(
type="text",
text=f"Error: query exceeds maximum length of {MAX_QUERY_LENGTH} characters.",
)],
isError=True,
)
result = sql_tool.execute_query(query, params)
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
isError=not result.get("success", False),
)
elif name == "get_schema":
result = sql_tool.get_schema()
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
isError=not result.get("success", False),
)
else:
log.warning("UNKNOWN_TOOL | tool=%s", name)
return CallToolResult(
content=[TextContent(type="text", text=f"Unknown tool: {name}")],
isError=True,
)
except Exception:
# Never leak internal stack traces to the LLM
log.exception("UNHANDLED_EXCEPTION | tool=%s", name)
return CallToolResult(
content=[TextContent(type="text", text="An internal server error occurred.")],
isError=True,
)
@server.list_tools()
async def list_tools_handler() -> list[Tool]:
"""List all available tools."""
return define_tools()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
async def main():
"""Main entry point."""
import asyncio
from mcp.server.stdio import stdio_server
log.info("Starting MCP Demo Server | rate_limit=%d/%ds | max_query_len=%d",
RATE_LIMIT_MAX_CALLS, RATE_LIMIT_WINDOW_SECS, MAX_QUERY_LENGTH)
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
if __name__ == "__main__":
import asyncio
asyncio.run(main())