"""
FastMCP server entry point for Prometheus MCP.
This module provides the MCP server that exposes PromQL query tools
for AWS Managed Prometheus with SigV4 authentication.
Supports both stdio (for local testing) and SSE/HTTP (for Kubernetes deployment).
"""
import asyncio
import logging
import os
import sys
from contextlib import asynccontextmanager
from typing import AsyncIterator
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.server.sse import SseServerTransport
from mcp.types import Tool, TextContent
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import uvicorn
from prometheus_mcp.promql.tools import (
init_client,
get_label_values,
list_labels,
list_metrics,
query_instant,
query_range,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("prometheus-mcp")
# Create MCP server
server = Server("prometheus-mcp")
def initialize_amp_client() -> None:
"""Initialize the AMP client on startup."""
workspace_id = os.environ.get("PROMETHEUS_WORKSPACE_ID")
region = os.environ.get("AWS_REGION", "us-east-1")
if not workspace_id:
logger.warning(
"PROMETHEUS_WORKSPACE_ID not set. Set it before making queries."
)
return
init_client(workspace_id=workspace_id, region=region)
logger.info(f"Initialized AMP client for workspace: {workspace_id} in {region}")
@server.list_tools()
async def list_tools() -> list[Tool]:
"""Return the list of available tools."""
return [
Tool(
name="query_instant",
description=(
"Execute an instant PromQL query against AWS Managed Prometheus. "
"Returns current values of matching time series at a single point in time. "
"Use for current metric values or point-in-time snapshots."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"PromQL query expression (e.g., 'up', 'rate(http_requests_total[5m])')"
),
},
"time": {
"type": "string",
"description": (
"Optional evaluation timestamp in RFC3339 format "
"(e.g., '2024-01-15T10:30:00Z') or Unix timestamp. "
"Defaults to current server time."
),
},
},
"required": ["query"],
},
),
Tool(
name="query_range",
description=(
"Execute a range PromQL query to get time series data over a time period. "
"Returns data points at regular intervals. "
"Use for historical data analysis, graphing, and trend analysis."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "PromQL query expression",
},
"start": {
"type": "string",
"description": (
"Start timestamp in RFC3339 format (e.g., '2024-01-15T00:00:00Z') "
"or Unix timestamp"
),
},
"end": {
"type": "string",
"description": (
"End timestamp in RFC3339 format (e.g., '2024-01-15T12:00:00Z') "
"or Unix timestamp"
),
},
"step": {
"type": "string",
"description": (
"Query resolution step width (e.g., '15s', '1m', '5m', '1h'). "
"Defaults to '1m'. Smaller steps = more data points but slower queries."
),
"default": "1m",
},
},
"required": ["query", "start", "end"],
},
),
Tool(
name="list_labels",
description=(
"Get all label names from AWS Managed Prometheus. "
"Use for discovering what labels are available for filtering queries."
),
inputSchema={
"type": "object",
"properties": {
"match": {
"type": "array",
"items": {"type": "string"},
"description": (
"Optional list of series selectors to filter which series to consider. "
"For example: ['up', 'http_requests_total']"
),
},
},
"required": [],
},
),
Tool(
name="get_label_values",
description=(
"Get all values for a specific label from AWS Managed Prometheus. "
"Use to discover possible filter values for a specific label."
),
inputSchema={
"type": "object",
"properties": {
"label_name": {
"type": "string",
"description": (
"The label name to get values for (e.g., 'job', 'instance', 'namespace')"
),
},
"match": {
"type": "array",
"items": {"type": "string"},
"description": (
"Optional list of series selectors to filter which series to consider"
),
},
},
"required": ["label_name"],
},
),
Tool(
name="list_metrics",
description=(
"Get all metric names from AWS Managed Prometheus. "
"Optionally includes metadata like metric type, help text, and unit."
),
inputSchema={
"type": "object",
"properties": {
"with_metadata": {
"type": "boolean",
"description": (
"If True, fetches full metadata (type, help, unit) for each metric. "
"This is slower but provides more information. Defaults to False."
),
"default": False,
},
},
"required": [],
},
),
]
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle tool calls."""
try:
if name == "query_instant":
result = await query_instant(
query=arguments["query"],
time=arguments.get("time"),
)
elif name == "query_range":
result = await query_range(
query=arguments["query"],
start=arguments["start"],
end=arguments["end"],
step=arguments.get("step", "1m"),
)
elif name == "list_labels":
result = await list_labels(
match=arguments.get("match"),
)
elif name == "get_label_values":
result = await get_label_values(
label_name=arguments["label_name"],
match=arguments.get("match"),
)
elif name == "list_metrics":
result = await list_metrics(
with_metadata=arguments.get("with_metadata", False),
)
else:
return [TextContent(type="text", text=f"Unknown tool: {name}")]
# Convert Pydantic model to JSON string
return [TextContent(type="text", text=result.model_dump_json(indent=2))]
except Exception as e:
logger.exception(f"Error executing tool {name}")
return [TextContent(type="text", text=f"Error: {str(e)}")]
# SSE transport for HTTP-based MCP
sse = SseServerTransport("/messages")
async def handle_sse(request):
"""Handle SSE connections for MCP."""
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
async def handle_messages(request):
"""Handle POST messages for MCP."""
await sse.handle_post_message(request.scope, request.receive, request._send)
async def health(request):
"""Health check endpoint."""
return JSONResponse({"status": "healthy", "service": "prometheus-mcp"})
# Create Starlette app for HTTP server with CORS support
app = Starlette(
debug=False,
routes=[
Route("/health", health, methods=["GET"]),
Route("/sse", handle_sse, methods=["GET"]),
Route("/messages", handle_messages, methods=["POST"]),
],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for MCP Inspector
allow_credentials=True,
allow_methods=["*"], # Allow all methods including OPTIONS
allow_headers=["*"], # Allow all headers
),
],
on_startup=[initialize_amp_client],
)
async def run_stdio_server() -> None:
"""Run the MCP server with stdio transport (for local testing)."""
initialize_amp_client()
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
)
def main() -> None:
"""Main entry point."""
# Check if we should run in HTTP mode (for Kubernetes) or stdio mode (for local)
mode = os.environ.get("MCP_MODE", "http").lower()
host = os.environ.get("MCP_HOST", "0.0.0.0")
port = int(os.environ.get("MCP_PORT", "8080"))
if mode == "stdio":
logger.info("Starting MCP server in stdio mode...")
try:
asyncio.run(run_stdio_server())
except KeyboardInterrupt:
logger.info("Server stopped by user")
sys.exit(0)
else:
logger.info(f"Starting MCP server in HTTP mode on {host}:{port}...")
uvicorn.run(app, host=host, port=port, log_level="info")
if __name__ == "__main__":
main()