cli.py•2.81 kB
import json
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, Optional
import anyio
import typer
from rich import print_json
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
app = typer.Typer(help="CLI harness for the Postgres MCP server.")
ROOT = Path(__file__).resolve().parent
SERVER_PATH = ROOT / "mcp_postgres_server.py"
def _server_params() -> StdioServerParameters:
return StdioServerParameters(
command=sys.executable,
args=[str(SERVER_PATH)],
cwd=str(ROOT),
)
@asynccontextmanager
async def session_context():
async with stdio_client(_server_params()) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
async def _list_tools() -> None:
async with session_context() as session:
result = await session.list_tools()
print_json(data=result.model_dump(mode="json"))
async def _call_tool(name: str, arguments: Dict[str, Any]) -> None:
async with session_context() as session:
result = await session.call_tool(name=name, arguments=arguments)
if result.isError:
typer.echo(f"Error: {result.error}")
else:
payload = result.structuredContent or result.content or result
print_json(data=payload)
@app.command("serve")
def serve() -> None:
"""Run the MCP server over stdio (blocks)."""
from mcp_postgres_server import main
main()
@app.command("list-tools")
def list_tools() -> None:
"""List tools exposed by the server."""
anyio.run(_list_tools)
@app.command("call")
def call(tool: str, args: str = typer.Option("{}", help="JSON string of arguments to the tool")) -> None:
"""Call any tool with raw JSON arguments."""
try:
parsed_args = json.loads(args)
except json.JSONDecodeError as exc:
typer.echo(f"Invalid JSON for args: {exc}")
raise typer.Exit(code=1) from exc
anyio.run(_call_tool, tool, parsed_args)
@app.command("describe")
def describe(schema: Optional[str] = typer.Option(None, help="Schema to inspect")) -> None:
"""Describe database schemas/tables/columns."""
anyio.run(_call_tool, "describe_database", {"schema": schema})
@app.command("read")
def read(sql: str, limit: int = typer.Option(200, help="Row limit if query lacks LIMIT")) -> None:
"""Execute a read-only SQL query."""
anyio.run(_call_tool, "run_read_query", {"sql": sql, "limit": limit})
@app.command("write")
def write(sql: str) -> None:
"""Execute a write query restricted to the mcp schema."""
anyio.run(_call_tool, "run_write_query", {"sql": sql})
if __name__ == "__main__":
app()