MCP Terminal Server
by dillip285
- go
- ai
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"context"
"fmt"
"math"
"strings"
"testing"
"github.com/firebase/genkit/go/internal/registry"
test_utils "github.com/firebase/genkit/go/tests/utils"
"github.com/google/go-cmp/cmp"
)
// structured output
type GameCharacter struct {
Name string
Backstory string
}
var r, _ = registry.New()
// echoModel attributes
var (
modelName = "echo"
metadata = ModelInfo{
Label: modelName,
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
},
Versions: []string{"echo-001", "echo-002"},
}
echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
})
}
textResponse := ""
for _, m := range gr.Messages {
if m.Role == RoleUser {
textResponse += m.Content[0].Text
}
}
return &ModelResponse{
Request: gr,
Message: NewUserTextMessage(textResponse),
}, nil
})
)
// with tools
var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken",
func(ctx *ToolContext, input struct {
Value float64
Over float64
},
) (float64, error) {
return math.Pow(input.Value, input.Over), nil
},
)
func TestValidMessage(t *testing.T) {
t.Parallel()
t.Run("Valid message with text format", func(t *testing.T) {
message := &Message{
Content: []*Part{
NewTextPart("Hello, World!"),
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatText,
}
_, err := validMessage(message, outputSchema)
if err != nil {
t.Fatal(err)
}
})
t.Run("Valid message with JSON format and matching schema", func(t *testing.T) {
json := `{
"name": "John",
"age": 30,
"address": {
"street": "123 Main St",
"city": "New York",
"country": "USA"
}
}`
message := &Message{
Content: []*Part{
NewTextPart(JSONMarkdown(json)),
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
Schema: map[string]any{
"type": "object",
"required": []string{"name", "age", "address"},
"properties": map[string]any{
"name": map[string]any{"type": "string"},
"age": map[string]any{"type": "integer"},
"address": map[string]any{
"type": "object",
"required": []string{"street", "city", "country"},
"properties": map[string]any{
"street": map[string]any{"type": "string"},
"city": map[string]any{"type": "string"},
"country": map[string]any{"type": "string"},
},
},
"phone": map[string]any{"type": "string"},
},
},
}
message, err := validMessage(message, outputSchema)
if err != nil {
t.Fatal(err)
}
text := message.Text()
if strings.TrimSpace(text) != strings.TrimSpace(json) {
t.Fatalf("got %q, want %q", json, text)
}
})
t.Run("Invalid message with JSON format and non-matching schema", func(t *testing.T) {
message := &Message{
Content: []*Part{
NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)),
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
Schema: map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
"age": map[string]any{"type": "integer"},
},
},
}
_, err := validMessage(message, outputSchema)
errorContains(t, err, "data did not match expected schema")
})
t.Run("Message with invalid JSON", func(t *testing.T) {
message := &Message{
Content: []*Part{
NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }.
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
}
_, err := validMessage(message, outputSchema)
errorContains(t, err, "data is not valid JSON")
})
t.Run("No message", func(t *testing.T) {
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
}
_, err := validMessage(nil, outputSchema)
errorContains(t, err, "message is empty")
})
t.Run("Empty message", func(t *testing.T) {
message := &Message{}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
}
_, err := validMessage(message, outputSchema)
errorContains(t, err, "message has no content")
})
t.Run("Candidate contains unexpected field", func(t *testing.T) {
message := &Message{
Content: []*Part{
NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)),
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
Schema: map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
"age": map[string]any{"type": "integer"},
},
"additionalProperties": false,
},
}
_, err := validMessage(message, outputSchema)
errorContains(t, err, "data did not match expected schema")
})
t.Run("Invalid expected schema", func(t *testing.T) {
message := &Message{
Content: []*Part{
NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)),
},
}
outputSchema := &ModelRequestOutput{
Format: OutputFormatJSON,
Schema: map[string]any{
"type": "invalid",
},
}
_, err := validMessage(message, outputSchema)
errorContains(t, err, "failed to validate data against expected schema")
})
}
func TestGenerate(t *testing.T) {
t.Run("constructs request", func(t *testing.T) {
charJSON := "{\"Name\": \"foo\", \"Backstory\": \"bar\"}"
charJSONmd := "```json" + charJSON + "```"
wantText := charJSON
wantRequest := &ModelRequest{
Messages: []*Message{
// system prompt -- always first
{
Role: RoleSystem,
Content: []*Part{{ContentType: "plain/text", Text: "you are"}},
},
// then history
{
Role: "user",
Content: []*Part{
{ContentType: "plain/text", Text: "banana"},
},
},
{
Role: "model",
Content: []*Part{
{ContentType: "plain/text", Text: "yes, banana"},
},
},
// then messages in order specified
{
Role: "user",
Content: []*Part{
{ContentType: "plain/text", Text: charJSONmd},
},
},
{
Role: "model",
Content: []*Part{
{ContentType: "plain/text", Text: "banana again"},
// structured output prompt
{
ContentType: "plain/text",
Text: "!!Ignored!!", // structured output prompt, noisy, ignored
},
},
},
},
Config: GenerationCommonConfig{Temperature: 1},
Context: []any{[]any{string("Banana")}},
Output: &ModelRequestOutput{
Format: "json",
Schema: map[string]any{
"$id": string("https://github.com/firebase/genkit/go/ai/game-character"),
"additionalProperties": bool(false),
"properties": map[string]any{
"Backstory": map[string]any{"type": string("string")},
"Name": map[string]any{"type": string("string")},
},
"required": []any{string("Name"), string("Backstory")},
"type": string("object"),
},
},
Tools: []*ToolDefinition{
{
Description: "use when need to calculate a gablorken",
InputSchema: map[string]any{
"additionalProperties": bool(false),
"properties": map[string]any{
"Over": map[string]any{"type": string("number")},
"Value": map[string]any{"type": string("number")},
},
"required": []any{
string("Value"),
string("Over"),
},
"type": string("object"),
},
Name: "gablorken",
OutputSchema: map[string]any{"type": string("number")},
},
},
}
wantStreamText := "stream!"
streamText := ""
res, err := Generate(context.Background(), r,
WithModel(echoModel),
WithTextPrompt(charJSONmd),
WithMessages(NewModelTextMessage("banana again")),
WithSystemPrompt("you are"),
WithConfig(GenerationCommonConfig{
Temperature: 1,
}),
WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")),
WithContext([]any{"Banana"}),
WithOutputSchema(&GameCharacter{}),
WithTools(gablorkenTool),
WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error {
streamText += grc.Text()
return nil
}),
)
if err != nil {
t.Error(err)
}
gotText := res.Text()
if diff := cmp.Diff(gotText, wantText); diff != "" {
t.Errorf("Text() diff (+got -want):\n%s", diff)
}
if diff := cmp.Diff(streamText, wantStreamText); diff != "" {
t.Errorf("Text() diff (+got -want):\n%s", diff)
}
if diff := cmp.Diff(res.Request, wantRequest, test_utils.IgnoreNoisyParts([]string{
"{*ai.ModelRequest}.Messages[4].Content[1].Text",
})); diff != "" {
t.Errorf("Request diff (+got -want):\n%s", diff)
}
})
t.Run("handles tool interrupts", func(t *testing.T) {
interruptTool := DefineTool(r, "interruptor", "always interrupts",
func(ctx *ToolContext, input any) (any, error) {
return nil, ctx.Interrupt(&InterruptOptions{
Metadata: map[string]any{
"reason": "test interrupt",
},
})
},
)
interruptModel := DefineModel(r, "test", "interrupt", nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "interruptor",
Input: nil,
}),
},
},
}, nil
})
res, err := Generate(context.Background(), r,
WithModel(interruptModel),
WithTextPrompt("trigger interrupt"),
WithTools(interruptTool),
)
if err != nil {
t.Fatal(err)
}
if res.FinishReason != "interrupted" {
t.Errorf("expected finish reason 'interrupted', got %q", res.FinishReason)
}
if res.FinishMessage != "One or more tool calls resulted in interrupts." {
t.Errorf("unexpected finish message: %q", res.FinishMessage)
}
if len(res.Message.Content) != 1 {
t.Fatalf("expected 1 content part, got %d", len(res.Message.Content))
}
metadata := res.Message.Content[0].Metadata
if metadata == nil {
t.Fatal("expected metadata in content part")
}
interrupt, ok := metadata["interrupt"].(map[string]any)
if !ok {
t.Fatal("expected interrupt metadata")
}
reason, ok := interrupt["reason"].(string)
if !ok || reason != "test interrupt" {
t.Errorf("expected interrupt reason 'test interrupt', got %v", reason)
}
})
t.Run("handles multiple parallel tool calls", func(t *testing.T) {
roundCount := 0
parallelModel := DefineModel(r, "test", "parallel", nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 2, "Over": 3},
}),
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 3, "Over": 2},
}),
},
},
}, nil
}
var sum float64
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.ToolResponse != nil {
sum += part.ToolResponse.Output.(float64)
}
}
}
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart(fmt.Sprintf("Final result: %d", int(sum))),
},
},
}, nil
})
res, err := Generate(context.Background(), r,
WithModel(parallelModel),
WithTextPrompt("trigger parallel tools"),
WithTools(gablorkenTool),
)
if err != nil {
t.Fatal(err)
}
finalPart := res.Message.Content[0]
if finalPart.Text != "Final result: 17" {
t.Errorf("expected final result text to be 'Final result: 17', got %q", finalPart.Text)
}
})
t.Run("handles multiple rounds of tool calls", func(t *testing.T) {
roundCount := 0
multiRoundModel := DefineModel(r, "test", "multiround", nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 2, "Over": 3},
}),
},
},
}, nil
}
if roundCount == 2 {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 3, "Over": 2},
}),
},
},
}, nil
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart("Final result"),
},
},
}, nil
})
res, err := Generate(context.Background(), r,
WithModel(multiRoundModel),
WithTextPrompt("trigger multiple rounds"),
WithTools(gablorkenTool),
WithMaxTurns(2),
)
if err != nil {
t.Fatal(err)
}
if roundCount != 3 {
t.Errorf("expected 3 rounds, got %d", roundCount)
}
if res.Text() != "Final result" {
t.Errorf("expected final message 'Final result', got %q", res.Text())
}
})
t.Run("exceeds maximum turns", func(t *testing.T) {
infiniteModel := DefineModel(r, "test", "infinite", nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 2, "Over": 2},
}),
},
},
}, nil
})
_, err := Generate(context.Background(), r,
WithModel(infiniteModel),
WithTextPrompt("trigger infinite loop"),
WithTools(gablorkenTool),
WithMaxTurns(2),
)
if err == nil {
t.Fatal("expected error for exceeding maximum turns")
}
if !strings.Contains(err.Error(), "exceeded maximum tool call iterations (2)") {
t.Errorf("unexpected error message: %v", err)
}
})
}
func TestModelVersion(t *testing.T) {
t.Run("valid version", func(t *testing.T) {
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithConfig(&GenerationCommonConfig{
Temperature: 1,
Version: "echo-001",
}),
WithTextPrompt("tell a joke about batman"))
if err != nil {
t.Errorf("model version should be valid")
}
})
t.Run("invalid version", func(t *testing.T) {
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithConfig(&GenerationCommonConfig{
Temperature: 1,
Version: "echo-im-not-a-version",
}),
WithTextPrompt("tell a joke about batman"))
if err == nil {
t.Errorf("model version should be invalid: %v", err)
}
})
}
func TestIsDefinedModel(t *testing.T) {
t.Run("should return true", func(t *testing.T) {
if IsDefinedModel(r, "test", modelName) != true {
t.Errorf("IsDefinedModel did not return true")
}
})
t.Run("should return false", func(t *testing.T) {
if IsDefinedModel(r, "foo", "bar") != false {
t.Errorf("IsDefinedModel did not return false")
}
})
}
func TestLookupModel(t *testing.T) {
t.Run("should return model", func(t *testing.T) {
if LookupModel(r, "test", modelName) == nil {
t.Errorf("LookupModel did not return model")
}
})
t.Run("should return nil", func(t *testing.T) {
if LookupModel(r, "foo", "bar") != nil {
t.Errorf("LookupModel did not return nil")
}
})
}
func JSONMarkdown(text string) string {
return "```json\n" + text + "\n```"
}
func errorContains(t *testing.T, err error, want string) {
t.Helper()
if err == nil {
t.Error("got nil, want error")
} else if !strings.Contains(err.Error(), want) {
t.Errorf("got error message %q, want it to contain %q", err, want)
}
}