MCP Terminal Server

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ai import ( "context" "os" "testing" "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gopkg.in/yaml.v3" ) type specSuite struct { Tests []testCase `yaml:"tests"` } type testCase struct { Name string `yaml:"name"` Input *GenerateActionOptions `yaml:"input"` StreamChunks [][]*ModelResponseChunk `yaml:"streamChunks,omitempty"` ModelResponses []*ModelResponse `yaml:"modelResponses"` ExpectResponse *ModelResponse `yaml:"expectResponse,omitempty"` Stream bool `yaml:"stream,omitempty"` ExpectChunks []*ModelResponseChunk `yaml:"expectChunks,omitempty"` } type programmableModel struct { r *registry.Registry handleResp func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) lastRequest *ModelRequest } func (pm *programmableModel) Name() string { return "programmableModel" } func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) { pm.lastRequest = req return pm.handleResp(ctx, req, cb) } func defineProgrammableModel(r *registry.Registry) *programmableModel { pm := &programmableModel{r: r} DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb) }) return pm } func TestGenerateAction(t *testing.T) { data, err := os.ReadFile("../../tests/specs/generate.yaml") if err != nil { t.Fatalf("failed to read spec file: %v", err) } var suite specSuite if err := yaml.Unmarshal(data, &suite); err != nil { t.Fatalf("failed to parse spec file: %v", err) } for _, tc := range suite.Tests { t.Run(tc.Name, func(t *testing.T) { ctx := context.Background() r, err := registry.New() if err != nil { t.Fatalf("failed to create registry: %v", err) } pm := defineProgrammableModel(r) DefineTool(r, "testTool", "description", func(ctx *ToolContext, input any) (any, error) { return "tool called", nil }) if len(tc.ModelResponses) > 0 || len(tc.StreamChunks) > 0 { reqCounter := 0 pm.handleResp = func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) { if len(tc.StreamChunks) > 0 && cb != nil { for _, chunk := range tc.StreamChunks[reqCounter] { if err := cb(ctx, chunk); err != nil { return nil, err } } } resp := tc.ModelResponses[reqCounter] resp.Request = req resp.Custom = map[string]any{} resp.Request.Output = &ModelRequestOutput{} resp.Usage = &GenerationUsage{} reqCounter++ return resp, nil } } genAction := DefineGenerateAction(ctx, r) if tc.Stream { chunks := []*ModelResponseChunk{} streamCb := func(ctx context.Context, chunk *ModelResponseChunk) error { chunks = append(chunks, chunk) return nil } resp, err := genAction.Run(ctx, tc.Input, streamCb) if err != nil { t.Fatalf("action failed: %v", err) } if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" { t.Errorf("chunks mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" { t.Errorf("response mismatch (-want +got):\n%s", diff) } } else { resp, err := genAction.Run(ctx, tc.Input, nil) if err != nil { t.Fatalf("action failed: %v", err) } if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" { t.Errorf("response mismatch (-want +got):\n%s", diff) } } }) } }