MCP Terminal Server

// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 package genkit import ( "bytes" "context" "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/action" "github.com/firebase/genkit/go/internal/atype" "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/invopop/jsonschema" ) func inc(_ context.Context, x int) (int, error) { return x + 1, nil } func dec(_ context.Context, x int) (int, error) { return x - 1, nil } func TestDevServer(t *testing.T) { r, err := registry.New() if err != nil { t.Fatal(err) } tc := tracing.NewTestOnlyTelemetryClient() r.TracingState().WriteTelemetryImmediate(tc) core.DefineAction(r, "devServer", "inc", atype.Custom, map[string]any{ "foo": "bar", }, inc) core.DefineAction(r, "devServer", "dec", atype.Custom, map[string]any{ "bar": "baz", }, dec) srv := httptest.NewServer(newDevServeMux(&devServer{reg: r})) defer srv.Close() t.Run("runAction", func(t *testing.T) { body := `{"key": "/custom/devServer/inc", "input": 3}` res, err := http.Post(srv.URL+"/api/runAction", "application/json", strings.NewReader(body)) if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("got status %d, wanted 200", res.StatusCode) } got, err := readJSON[runActionResponse](res.Body) if err != nil { t.Fatal(err) } if g, w := string(got.Result), "4"; g != w { t.Errorf("got %q, want %q", g, w) } tid := got.Telemetry.TraceID if len(tid) != 32 { t.Errorf("trace ID is %q, wanted 32-byte string", tid) } checkActionTrace(t, tc, tid, "inc") }) t.Run("list actions", func(t *testing.T) { res, err := http.Get(srv.URL + "/api/actions") if err != nil { t.Fatal(err) } defer res.Body.Close() if res.StatusCode != 200 { t.Fatalf("got status %d, wanted 200", res.StatusCode) } got, err := readJSON[map[string]action.Desc](res.Body) if err != nil { t.Fatal(err) } want := map[string]action.Desc{ "/custom/devServer/inc": { Key: "/custom/devServer/inc", Name: "devServer/inc", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Metadata: map[string]any{"foo": "bar"}, }, "/custom/devServer/dec": { Key: "/custom/devServer/dec", InputSchema: &jsonschema.Schema{Type: "integer"}, OutputSchema: &jsonschema.Schema{Type: "integer"}, Name: "devServer/dec", Metadata: map[string]any{"bar": "baz"}, }, } diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(jsonschema.Schema{})) if diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } }) } func TestProdServer(t *testing.T) { r, err := registry.New() if err != nil { t.Fatal(err) } tc := tracing.NewTestOnlyTelemetryClient() r.TracingState().WriteTelemetryImmediate(tc) defineFlow(r, "inc", func(_ context.Context, i int, _ noStream) (int, error) { return i + 1, nil }) srv := httptest.NewServer(newFlowServeMux(r, nil)) defer srv.Close() check := func(t *testing.T, input string, wantStatus, wantResult int) { type body struct { Data json.RawMessage `json:"data"` } payload := body{ Data: json.RawMessage([]byte(input)), } jsonPayload, err := json.Marshal(payload) if err != nil { t.Fatal(err) } res, err := http.Post(srv.URL+"/inc", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { t.Fatal(err) } defer res.Body.Close() if g, w := res.StatusCode, wantStatus; g != w { t.Fatalf("status: got %d, want %d", g, w) } if res.StatusCode != 200 { return } type resultType struct { Result int } got, err := readJSON[resultType](res.Body) if err != nil { t.Fatal(err) } if g, w := got.Result, wantResult; g != w { t.Errorf("result: got %d, want %d", g, w) } } t.Run("ok", func(t *testing.T) { check(t, "2", 200, 3) }) t.Run("bad", func(t *testing.T) { check(t, "true", 400, 0) }) } func checkActionTrace(t *testing.T, tc *tracing.TestOnlyTelemetryClient, tid, name string) { td := tc.Traces[tid] if td == nil { t.Fatalf("trace %q not found", tid) } rootSpan := findRootSpan(t, td.Spans) want := &tracing.SpanData{ TraceID: tid, DisplayName: "dev-run-action-wrapper", SpanKind: "INTERNAL", SameProcessAsParentSpan: tracing.BoolValue{Value: true}, Status: tracing.Status{Code: 0}, InstrumentationLibrary: tracing.InstrumentationLibrary{ Name: "genkit-tracer", Version: "v1", }, Attributes: map[string]any{ "genkit:name": "dev-run-action-wrapper", "genkit:input": "3", "genkit:isRoot": true, "genkit:path": "/dev-run-action-wrapper", "genkit:output": "4", "genkit:metadata:genkit-dev-internal": "true", "genkit:state": "success", }, } diff := cmp.Diff(want, rootSpan, cmpopts.IgnoreFields(tracing.SpanData{}, "SpanID", "StartTime", "EndTime")) if diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } // findRootSpan finds the root span in spans. // It also verifies that it is unique. func findRootSpan(t *testing.T, spans map[string]*tracing.SpanData) *tracing.SpanData { t.Helper() var root *tracing.SpanData for _, sd := range spans { if sd.ParentSpanID == "" { if root != nil { t.Fatal("more than one root span") } if g, w := sd.Attributes["genkit:isRoot"], true; g != w { t.Errorf("root span genkit:isRoot attr = %v, want %v", g, w) } root = sd } } if root == nil { t.Fatal("no root span") } return root } func readJSON[T any](r io.Reader) (T, error) { var x T if err := json.NewDecoder(r).Decode(&x); err != nil { return x, err } return x, nil }