MCP Terminal Server

// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // This file implements production and development servers. // // The genkit CLI sends requests to the development server. // See js/common/src/reflectionApi.ts. // // The production server has a route for each flow. It // is intended for production deployments. package genkit import ( "context" "encoding/json" "errors" "fmt" "log/slog" "net/http" "os" "path/filepath" "strconv" "sync" "sync/atomic" "time" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/action" "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" "go.opentelemetry.io/otel/trace" ) type runtimeFileData struct { ID string `json:"id"` PID int `json:"pid"` ReflectionServerURL string `json:"reflectionServerUrl"` Timestamp string `json:"timestamp"` GenkitVersion string `json:"genkitVersion"` ReflectionApiSpecVersion int `json:"reflectionApiSpecVersion"` } type devServer struct { reg *registry.Registry runtimeFilePath string } // startReflectionServer starts the Reflection API server listening at the // value of the environment variable GENKIT_REFLECTION_PORT for the port, // or ":3100" if it is empty. func startReflectionServer(ctx context.Context, r *registry.Registry, errCh chan<- error) *http.Server { slog.Debug("starting reflection server") addr := serverAddress("", "GENKIT_REFLECTION_PORT", "127.0.0.1:3100") s := &devServer{reg: r} if err := s.writeRuntimeFile(addr); err != nil { slog.Error("failed to write runtime file", "error", err) } mux := newDevServeMux(s) server := startServer(addr, mux, errCh) go func() { <-ctx.Done() if err := s.cleanupRuntimeFile(); err != nil { slog.Error("failed to cleanup runtime file", "error", err) } }() return server } // writeRuntimeFile writes a file describing the runtime to the project root. func (s *devServer) writeRuntimeFile(url string) error { projectRoot, err := findProjectRoot() if err != nil { return fmt.Errorf("failed to find project root: %w", err) } runtimesDir := filepath.Join(projectRoot, ".genkit", "runtimes") if err := os.MkdirAll(runtimesDir, 0755); err != nil { return fmt.Errorf("failed to create runtimes directory: %w", err) } runtimeID := os.Getenv("GENKIT_RUNTIME_ID") if runtimeID == "" { runtimeID = strconv.Itoa(os.Getpid()) } timestamp := time.Now().UTC().Format(time.RFC3339) s.runtimeFilePath = filepath.Join(runtimesDir, fmt.Sprintf("%d-%s.json", os.Getpid(), timestamp)) data := runtimeFileData{ ID: runtimeID, PID: os.Getpid(), ReflectionServerURL: fmt.Sprintf("http://%s", url), Timestamp: timestamp, GenkitVersion: "go/" + internal.Version, ReflectionApiSpecVersion: internal.GENKIT_REFLECTION_API_SPEC_VERSION, } fileContent, err := json.MarshalIndent(data, "", " ") if err != nil { return fmt.Errorf("failed to marshal runtime data: %w", err) } if err := os.WriteFile(s.runtimeFilePath, fileContent, 0644); err != nil { return fmt.Errorf("failed to write runtime file: %w", err) } slog.Debug("runtime file written", "path", s.runtimeFilePath) return nil } // cleanupRuntimeFile removes the runtime file associated with the dev server. func (s *devServer) cleanupRuntimeFile() error { if s.runtimeFilePath == "" { return nil } content, err := os.ReadFile(s.runtimeFilePath) if err != nil { return fmt.Errorf("failed to read runtime file: %w", err) } var data runtimeFileData if err := json.Unmarshal(content, &data); err != nil { return fmt.Errorf("failed to unmarshal runtime data: %w", err) } if data.PID == os.Getpid() { if err := os.Remove(s.runtimeFilePath); err != nil { return fmt.Errorf("failed to remove runtime file: %w", err) } slog.Debug("runtime file cleaned up", "path", s.runtimeFilePath) } return nil } // findProjectRoot finds the project root by looking for a go.mod file. func findProjectRoot() (string, error) { dir, err := os.Getwd() if err != nil { return "", err } for { if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { return dir, nil } parent := filepath.Dir(dir) if parent == dir { return "", fmt.Errorf("could not find project root (go.mod not found)") } dir = parent } } // startFlowServer starts a production server listening at the given address. // The Server has a route for each defined flow. // If addr is "", it uses the value of the environment variable PORT // for the port, and if that is empty it uses ":3400". // // To construct a server with additional routes, use [NewFlowServeMux]. func startFlowServer(g *Genkit, addr string, flows []string, errCh chan<- error) *http.Server { slog.Debug("starting flow server") addr = serverAddress(addr, "PORT", "127.0.0.1:3400") mux := NewFlowServeMux(g, flows) return startServer(addr, mux, errCh) } // flow is the type that all Flow[In, Out, Stream] have in common. type flow interface { Name() string // runJSON uses encoding/json to unmarshal the input, // calls Flow.start, then returns the marshaled result. runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) } // startServer starts an HTTP server listening on the address. // It returns the server an func startServer(addr string, handler http.Handler, errCh chan<- error) *http.Server { server := &http.Server{ Addr: addr, Handler: handler, } go func() { slog.Debug("server listening", "addr", addr) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { errCh <- fmt.Errorf("server error on %s: %w", addr, err) } }() return server } // shutdownServers initiates shutdown of the servers and waits for the shutdown to complete. // After 5 seconds, it will timeout. func shutdownServers(servers []*http.Server) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var wg sync.WaitGroup for _, server := range servers { wg.Add(1) go func(srv *http.Server) { defer wg.Done() if err := srv.Shutdown(ctx); err != nil { slog.Error("server shutdown failed", "addr", srv.Addr, "err", err) } else { slog.Debug("server shutdown successfully", "addr", srv.Addr) } }(server) } done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: slog.Info("all servers shut down successfully") case <-ctx.Done(): return errors.New("server shutdown timed out") } return nil } func newDevServeMux(s *devServer) *http.ServeMux { mux := http.NewServeMux() handle(mux, "GET /api/__health", func(w http.ResponseWriter, _ *http.Request) error { return nil }) handle(mux, "POST /api/runAction", s.handleRunAction) handle(mux, "GET /api/actions", s.handleListActions) handle(mux, "POST /api/notify", s.handleNotify) return mux } // handleRunAction looks up an action by name in the registry, runs it with the // provided JSON input, and writes back the JSON-marshaled request. func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var body struct { Key string `json:"key"` Input json.RawMessage `json:"input"` Context json.RawMessage `json:"context"` } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { return &base.HTTPError{Code: http.StatusBadRequest, Err: err} } stream, err := parseBoolQueryParam(r, "stream") if err != nil { return err } logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream) var callback streamingCallback[json.RawMessage] if stream { w.Header().Set("Content-Type", "text/plain") w.Header().Set("Transfer-Encoding", "chunked") // Stream results are newline-separated JSON. callback = func(ctx context.Context, msg json.RawMessage) error { _, err := fmt.Fprintf(w, "%s\n", msg) if err != nil { return err } if f, ok := w.(http.Flusher); ok { f.Flush() } return nil } } var contextMap map[string]any = nil if body.Context != nil { json.Unmarshal(body.Context, &contextMap) } resp, err := runAction(ctx, s.reg, body.Key, body.Input, callback, contextMap) if err != nil { return err } return writeJSON(ctx, w, resp) } // handleNotify configures the telemetry server URL from the request. func (s *devServer) handleNotify(w http.ResponseWriter, r *http.Request) error { var body struct { TelemetryServerURL string `json:"telemetryServerUrl"` ReflectionApiSpecVersion int `json:"reflectionApiSpecVersion"` } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { return &base.HTTPError{Code: http.StatusBadRequest, Err: err} } if body.TelemetryServerURL != "" { s.reg.TracingState().WriteTelemetryImmediate(tracing.NewHTTPTelemetryClient(body.TelemetryServerURL)) slog.Debug("connected to telemetry server", "url", body.TelemetryServerURL) } if body.ReflectionApiSpecVersion != internal.GENKIT_REFLECTION_API_SPEC_VERSION { slog.Error("Genkit CLI version is not compatible with runtime library. Please use `genkit-cli` version compatible with runtime library version.") } w.WriteHeader(http.StatusOK) _, err := w.Write([]byte("OK")) return err } type runActionResponse struct { Result json.RawMessage `json:"result"` Telemetry telemetry `json:"telemetry"` } type telemetry struct { TraceID string `json:"traceId"` } func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := reg.LookupAction(key) if action == nil { return nil, &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no action with key %q", key)} } if runtimeContext != nil { ctx = core.WithActionContext(ctx, runtimeContext) } var traceID string output, err := tracing.RunInNewSpan(ctx, reg.TracingState(), "dev-run-action-wrapper", "", true, input, func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { tracing.SetCustomMetadataAttr(ctx, "genkit-dev-internal", "true") traceID = trace.SpanContextFromContext(ctx).TraceID().String() return action.RunJSON(ctx, input, cb) }) if err != nil { return nil, err } return &runActionResponse{ Result: output, Telemetry: telemetry{TraceID: traceID}, }, nil } // handleListActions lists all the registered actions. func (s *devServer) handleListActions(w http.ResponseWriter, r *http.Request) error { descs := s.reg.ListActions() descMap := map[string]action.Desc{} for _, d := range descs { descMap[d.Key] = d } return writeJSON(r.Context(), w, descMap) } // NewFlowServeMux constructs a [net/http.ServeMux]. // If flows is non-empty, the each of the named flows is registered as a route. // Otherwise, all defined flows are registered. // // All routes take a single query parameter, "stream", which if true will stream the // flow's results back to the client. (Not all flows support streaming, however.) // // To use the returned ServeMux as part of a server with other routes, either add routes // to it, or install it as part of another ServeMux, like so: // // mainMux := http.NewServeMux() // mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) func NewFlowServeMux(g *Genkit, flows []string) *http.ServeMux { return newFlowServeMux(g.reg, flows) } func newFlowServeMux(r *registry.Registry, flows []string) *http.ServeMux { mux := http.NewServeMux() m := map[string]bool{} for _, f := range flows { m[f] = true } for _, f := range r.ListFlows() { f := f.(flow) if len(flows) == 0 || m[f.Name()] { handle(mux, "POST /"+f.Name(), nonDurableFlowHandler(f)) } } return mux } func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { var body struct { Data json.RawMessage `json:"data"` } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { return &base.HTTPError{Code: http.StatusBadRequest, Err: err} } stream, err := parseBoolQueryParam(r, "stream") if err != nil { return err } var callback streamingCallback[json.RawMessage] if r.Header.Get("Accept") == "text/event-stream" || stream { w.Header().Set("Content-Type", "text/plain") w.Header().Set("Transfer-Encoding", "chunked") // Event Stream results are in JSON format separated by two newline escape sequences // including the `data` and `message` labels callback = func(ctx context.Context, msg json.RawMessage) error { _, err := fmt.Fprintf(w, "data: {\"message\": %s}\n\n", msg) if err != nil { return err } if f, ok := w.(http.Flusher); ok { f.Flush() } return nil } } // TODO: telemetry out, err := f.runJSON(r.Context(), r.Header.Get("Authorization"), body.Data, callback) if err != nil { if r.Header.Get("Accept") == "text/event-stream" || stream { _, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err) return err } return err } // Responses for streaming, non-durable flows should be prefixed // with "data" if r.Header.Get("Accept") == "text/event-stream" || stream { _, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out) return err } // Responses for non-streaming, non-durable flows are passed back // with the flow result stored in a field called "result." _, err = fmt.Fprintf(w, `{"result": %s}\n`, out) return err } } // serverAddress determines a server address. func serverAddress(arg, envVar, defaultValue string) string { if arg != "" { return arg } if port := os.Getenv(envVar); port != "" { return "127.0.0.1:" + port } return defaultValue } // requestID is a unique ID for each request. var requestID atomic.Int64 // handle registers pattern on mux with an http.Handler that calls f. // If f returns a non-nil error, the handler calls http.Error. // If the error is an httpError, the code it contains is used as the status code; // otherwise a 500 status is used. func handle(mux *http.ServeMux, pattern string, f func(w http.ResponseWriter, r *http.Request) error) { mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { id := requestID.Add(1) // Create a logger that always outputs the requestID, and store it in the request context. log := slog.Default().With("reqID", id) log.Info("request start", "method", r.Method, "path", r.URL.Path) var err error defer func() { if err != nil { log.Error("request end", "err", err) } else { log.Info("request end") } }() err = f(w, r) if err != nil { // If the error is an httpError, serve the status code it contains. // Otherwise, assume this is an unexpected error and serve a 500. var herr *base.HTTPError if errors.As(err, &herr) { http.Error(w, herr.Error(), herr.Code) } else { http.Error(w, err.Error(), http.StatusInternalServerError) } } }) } func parseBoolQueryParam(r *http.Request, name string) (bool, error) { b := false if s := r.FormValue(name); s != "" { var err error b, err = strconv.ParseBool(s) if err != nil { return false, &base.HTTPError{Code: http.StatusBadRequest, Err: err} } } return b, nil } func writeJSON(ctx context.Context, w http.ResponseWriter, value any) error { data, err := json.Marshal(value) if err != nil { return err } _, err = w.Write(data) if err != nil { logger.FromContext(ctx).Error("writing output", "err", err) } if f, ok := w.(http.Flusher); ok { f.Flush() } return nil }