main.go•4.52 kB
// Package main provides a server that exposes Prometheus query capabilities via the Model Context Protocol (MCP).
package main
import (
"context"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"text/tabwriter"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/yshngg/pmcp/internal/bindingblocks"
"github.com/yshngg/pmcp/internal/prometheus/api"
"github.com/yshngg/pmcp/internal/version"
)
// Schema is the identifier for the Prometheus schema.
const Schema = "prom"
// main starts the MCP server with Prometheus query capabilities,
// selecting the transport mechanism (stdio, HTTP, or SSE) based on command-line flags.
// It initializes the Prometheus client, binds query handlers,
// and serves requests until termination.
// The function exits the program on critical errors or when printing version information.
func main() {
fs := flag.NewFlagSet("pmcp", flag.ExitOnError)
var (
// promAddr is the address of the Prometheus server to connect to.
promAddr = fs.String("prom-addr", "http://localhost:9090/", "The address of the Prometheus to connect to.")
// mcpAddr is the address for the MCP server to listen on.
mcpAddr = fs.String("mcp-addr", "localhost:8080", "The address of the MCP server to listen on.")
// transportType specifies the transport mechanism (stdio, sse, or http).
transportType = fs.String("transport", "stdio", "Transport type (stdio, sse or http).\nThe mechanisms that handle the underlying communication between clients and servers.")
// printVersion prints the version and exit.
printVersion = fs.Bool("version", false, "Print the version and exit.")
)
fs.Usage = usageFor(fs, "pmcp [flags]")
// Parse command-line flags.
if err := fs.Parse(os.Args[1:]); err != nil {
slog.Error("parse args", "err", err)
os.Exit(1)
}
if *printVersion {
fmt.Println(version.Info)
os.Exit(0)
}
server := mcp.NewServer(&mcp.Implementation{
Name: "pmcp",
Version: string(version.Info.Number),
}, nil)
promCli, err := api.New(*promAddr, http.DefaultClient, nil)
if err != nil {
slog.Error("new prometheus client", "err", err)
os.Exit(1)
}
binder := bindingblocks.NewBinder(server, promCli)
binder.Bind()
if *transportType == "http" {
http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte("pong")); err != nil {
slog.Error("write pong", "err", err)
os.Exit(1)
}
})
// Run the server over Streamable HTTP
streamableHTTPHandler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
return server
}, nil)
http.Handle("/mcp", streamableHTTPHandler)
slog.Info("Listening on http://" + *mcpAddr)
if err := http.ListenAndServe(*mcpAddr, nil); err != nil {
slog.Error("listen and serve with Streamable HTTP transport", "err", err)
os.Exit(1)
}
}
// Backwards Compatibility
if *transportType == "sse" {
slog.Warn("HTTP+SSE transport is deprecated. Please use Streamable HTTP instead.")
http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte("pong")); err != nil {
slog.Error("write pong", "err", err)
os.Exit(1)
}
})
sseHandler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { return server })
http.Handle("/mcp", sseHandler)
slog.Info("Listening on http://" + *mcpAddr)
if err := http.ListenAndServe(*mcpAddr, nil); err != nil {
slog.Error("listen and serve with HTTP+SSE transport", "err", err)
os.Exit(1)
}
}
// Run the server over stdin/stdout, until the client disconnects
slog.Info("Listening on stdio")
if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
slog.Error("run server with stdio transport", "err", err)
os.Exit(1)
}
}
func usageFor(fs *flag.FlagSet, short string) func() {
return func() {
fmt.Fprintf(os.Stderr, "Prometheus Model Context Protocol Server\n\n")
fmt.Fprintf(os.Stderr, "USAGE\n")
fmt.Fprintf(os.Stderr, " %s\n", short)
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "FLAGS\n")
w := tabwriter.NewWriter(os.Stderr, 0, 2, 2, ' ', 0)
fs.VisitAll(func(f *flag.Flag) {
def := f.DefValue
if def == "" {
def = "..."
}
_, err := fmt.Fprintf(w, "\t-%s %s\t%s\n", f.Name, def, f.Usage)
if err != nil {
panic(err)
}
})
if err := w.Flush(); err != nil {
panic(err)
}
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "VERSION\n")
fmt.Fprintf(os.Stderr, " %s\n", version.Info.Number)
fmt.Fprintf(os.Stderr, "\n")
}
}