generate_test.go•34.1 kB
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
	"context"
	"fmt"
	"math"
	"strings"
	"testing"
	"github.com/firebase/genkit/go/internal/registry"
	test_utils "github.com/firebase/genkit/go/tests/utils"
	"github.com/google/go-cmp/cmp"
)
type StructuredResponse struct {
	Subject  string
	Location string
}
var r = registry.New()
func init() {
	// Set up default formats
	ConfigureFormats(r)
	// Register the generate action that Generate() function expects
	DefineGenerateAction(context.Background(), r)
}
// echoModel attributes
var (
	modelName = "echo"
	metadata  = ModelOptions{
		Label: modelName,
		Supports: &ModelSupports{
			Multiturn:   true,
			Tools:       true,
			SystemRole:  true,
			Media:       false,
			Constrained: ConstrainedSupportNone,
		},
		Versions: []string{"echo-001", "echo-002"},
		Stage:    ModelStageDeprecated,
	}
	echoModel = DefineModel(r, "test/"+modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
		if msc != nil {
			msc(ctx, &ModelResponseChunk{
				Content: []*Part{NewTextPart("stream!")},
			})
		}
		textResponse := ""
		for _, m := range gr.Messages {
			if m.Role == RoleUser {
				textResponse = m.Text()
			}
		}
		return &ModelResponse{
			Request: gr,
			Message: NewModelTextMessage(textResponse),
		}, nil
	})
)
// with tools
var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken",
	func(ctx *ToolContext, input struct {
		Value float64
		Over  float64
	},
	) (float64, error) {
		return math.Pow(input.Value, input.Over), nil
	},
)
func TestStreamingChunksHaveRoleAndIndex(t *testing.T) {
	t.Parallel()
	ctx := context.Background()
	convertTempTool := DefineTool(r, "convertTemp", "converts temperature",
		func(ctx *ToolContext, input struct {
			From        string
			To          string
			Temperature float64
		}) (float64, error) {
			if input.From == "celsius" && input.To == "fahrenheit" {
				return input.Temperature*9/5 + 32, nil
			}
			return input.Temperature, nil
		},
	)
	toolModel := DefineModel(r, "test/toolModel", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
		hasToolResponse := false
		for _, msg := range gr.Messages {
			if msg.Role == RoleTool {
				hasToolResponse = true
				break
			}
		}
		if hasToolResponse {
			if msc != nil {
				msc(ctx, &ModelResponseChunk{
					Content: []*Part{NewTextPart("20 degrees Celsius is 68 degrees Fahrenheit.")},
				})
			}
			return &ModelResponse{
				Request: gr,
				Message: NewModelTextMessage("20 degrees Celsius is 68 degrees Fahrenheit."),
			}, nil
		}
		if msc != nil {
			msc(ctx, &ModelResponseChunk{
				Content: []*Part{NewToolRequestPart(&ToolRequest{
					Name: "convertTemp",
					Input: map[string]any{
						"From":        "celsius",
						"To":          "fahrenheit",
						"Temperature": 20.0,
					},
					Ref: "0",
				})},
			})
		}
		return &ModelResponse{
			Request: gr,
			Message: &Message{
				Role: RoleModel,
				Content: []*Part{NewToolRequestPart(&ToolRequest{
					Name: "convertTemp",
					Input: map[string]any{
						"From":        "celsius",
						"To":          "fahrenheit",
						"Temperature": 20.0,
					},
					Ref: "0",
				})},
			},
		}, nil
	})
	var chunks []*ModelResponseChunk
	_, err := Generate(ctx, r,
		WithModel(toolModel),
		WithMessages(NewUserTextMessage("convert 20 c to f")),
		WithTools(convertTempTool),
		WithStreaming(func(ctx context.Context, chunk *ModelResponseChunk) error {
			chunks = append(chunks, chunk)
			return nil
		}),
	)
	if err != nil {
		t.Fatalf("Generate failed: %v", err)
	}
	if len(chunks) < 2 {
		t.Fatalf("Expected at least 2 chunks, got %d", len(chunks))
	}
	for i, chunk := range chunks {
		if chunk.Role == "" {
			t.Errorf("Chunk %d: Role is empty", i)
		}
		t.Logf("Chunk %d: Role=%s, Index=%d", i, chunk.Role, chunk.Index)
	}
	if chunks[0].Role != RoleModel {
		t.Errorf("Expected first chunk to have role 'model', got %s", chunks[0].Role)
	}
	if chunks[0].Index != 0 {
		t.Errorf("Expected first chunk to have index 0, got %d", chunks[0].Index)
	}
	toolChunkFound := false
	for _, chunk := range chunks {
		if chunk.Role == RoleTool {
			toolChunkFound = true
			if chunk.Index != 1 {
				t.Errorf("Expected tool chunk to have index 1, got %d", chunk.Index)
			}
		}
	}
	if !toolChunkFound {
		t.Error("Expected to find at least one tool chunk")
	}
}
func TestValidMessage(t *testing.T) {
	t.Parallel()
	t.Run("Valid message with text format", func(t *testing.T) {
		message := &Message{
			Content: []*Part{
				NewTextPart("Hello, World!"),
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatText,
		}
		_, err := validTestMessage(message, outputSchema)
		if err != nil {
			t.Fatal(err)
		}
	})
	t.Run("Valid message with JSON format and matching schema", func(t *testing.T) {
		json := `{
			"name": "John",
			"age": 30,
			"address": {
				"street": "123 Main St",
				"city": "New York",
				"country": "USA"
			}
		}`
		message := &Message{
			Content: []*Part{
				NewTextPart(JSONMarkdown(json)),
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
			Schema: map[string]any{
				"type":     "object",
				"required": []string{"name", "age", "address"},
				"properties": map[string]any{
					"name": map[string]any{"type": "string"},
					"age":  map[string]any{"type": "integer"},
					"address": map[string]any{
						"type":     "object",
						"required": []string{"street", "city", "country"},
						"properties": map[string]any{
							"street":  map[string]any{"type": "string"},
							"city":    map[string]any{"type": "string"},
							"country": map[string]any{"type": "string"},
						},
					},
					"phone": map[string]any{"type": "string"},
				},
			},
		}
		message, err := validTestMessage(message, outputSchema)
		if err != nil {
			t.Fatal(err)
		}
		text := message.Text()
		if strings.TrimSpace(text) != strings.TrimSpace(json) {
			t.Fatalf("got %q, want %q", json, text)
		}
	})
	t.Run("Invalid message with JSON format and non-matching schema", func(t *testing.T) {
		message := &Message{
			Content: []*Part{
				NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)),
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
			Schema: map[string]any{
				"type": "object",
				"properties": map[string]any{
					"name": map[string]any{"type": "string"},
					"age":  map[string]any{"type": "integer"},
				},
			},
		}
		_, err := validTestMessage(message, outputSchema)
		errorContains(t, err, "data did not match expected schema")
	})
	t.Run("Message with invalid JSON", func(t *testing.T) {
		message := &Message{
			Content: []*Part{
				NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }.
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
		}
		_, err := validTestMessage(message, outputSchema)
		t.Log(err)
		errorContains(t, err, "not a valid JSON")
	})
	t.Run("No message", func(t *testing.T) {
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
		}
		_, err := validTestMessage(nil, outputSchema)
		errorContains(t, err, "message is empty")
	})
	t.Run("Empty message", func(t *testing.T) {
		message := &Message{}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
		}
		_, err := validTestMessage(message, outputSchema)
		errorContains(t, err, "message has no content")
	})
	t.Run("Candidate contains unexpected field", func(t *testing.T) {
		message := &Message{
			Content: []*Part{
				NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)),
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
			Schema: map[string]any{
				"type": "object",
				"properties": map[string]any{
					"name": map[string]any{"type": "string"},
					"age":  map[string]any{"type": "integer"},
				},
				"additionalProperties": false,
			},
		}
		_, err := validTestMessage(message, outputSchema)
		errorContains(t, err, "data did not match expected schema")
	})
	t.Run("Invalid expected schema", func(t *testing.T) {
		message := &Message{
			Content: []*Part{
				NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)),
			},
		}
		outputSchema := &ModelOutputConfig{
			Format: OutputFormatJSON,
			Schema: map[string]any{
				"type": "invalid",
			},
		}
		_, err := validTestMessage(message, outputSchema)
		errorContains(t, err, "failed to validate data against expected schema")
	})
}
func TestGenerate(t *testing.T) {
	JSON := "{\"subject\": \"bananas\", \"location\": \"tropics\"}"
	JSONmd := "```json" + JSON + "```"
	bananaModel := DefineModel(r, "test/banana", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
		if msc != nil {
			msc(ctx, &ModelResponseChunk{
				Content: []*Part{NewTextPart("stream!")},
			})
		}
		return &ModelResponse{
			Request: gr,
			Message: NewModelTextMessage(JSONmd),
		}, nil
	})
	t.Run("constructs request", func(t *testing.T) {
		wantText := JSON
		wantStreamText := "stream!"
		wantRequest := &ModelRequest{
			Messages: []*Message{
				{
					Role: RoleSystem,
					Content: []*Part{
						NewTextPart("You are a helpful assistant."),
						{
							ContentType: "plain/text",
							Text:        "ignored (conformance message)",
							Metadata:    map[string]any{"purpose": string("output")},
						},
					},
				},
				NewUserTextMessage("How many bananas are there?"),
				NewModelTextMessage("There are at least 10 bananas."),
				{
					Role: RoleUser,
					Content: []*Part{
						NewTextPart("Where can they be found?"),
						{
							Text: "\n\nUse the following information " +
								"to complete your task:\n\n- [0]: Bananas are plentiful in the tropics.\n\n",
							Metadata: map[string]any{"purpose": "context"},
						},
					},
				},
			},
			Config: &GenerationCommonConfig{Temperature: 1},
			Docs:   []*Document{DocumentFromText("Bananas are plentiful in the tropics.", nil)},
			Output: &ModelOutputConfig{
				Format:      OutputFormatJSON,
				ContentType: "application/json",
			},
			Tools: []*ToolDefinition{
				{
					Description: "use when need to calculate a gablorken",
					InputSchema: map[string]any{
						"additionalProperties": bool(false),
						"properties": map[string]any{
							"Over":  map[string]any{"type": string("number")},
							"Value": map[string]any{"type": string("number")},
						},
						"required": []any{string("Value"), string("Over")},
						"type":     string("object"),
					},
					Name:         "gablorken",
					OutputSchema: map[string]any{"type": string("number")},
				},
			},
			ToolChoice: ToolChoiceAuto,
		}
		streamText := ""
		res, err := Generate(context.Background(), r,
			WithModel(bananaModel),
			WithSystem("You are a helpful assistant."),
			WithMessages(
				NewUserTextMessage("How many bananas are there?"),
				NewModelTextMessage("There are at least 10 bananas."),
			),
			WithPrompt("Where can they be found?"),
			WithConfig(&GenerationCommonConfig{
				Temperature: 1,
			}),
			WithDocs(DocumentFromText("Bananas are plentiful in the tropics.", nil)),
			WithOutputType(struct {
				Subject  string `json:"subject"`
				Location string `json:"location"`
			}{}),
			WithTools(gablorkenTool),
			WithToolChoice(ToolChoiceAuto),
			WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error {
				streamText += grc.Text()
				return nil
			}),
		)
		if err != nil {
			t.Fatal(err)
		}
		gotText := res.Text()
		if diff := cmp.Diff(gotText, wantText); diff != "" {
			t.Errorf("Text() diff (+got -want):\n%s", diff)
		}
		if diff := cmp.Diff(streamText, wantStreamText); diff != "" {
			t.Errorf("Text() diff (+got -want):\n%s", diff)
		}
		if diff := cmp.Diff(wantRequest, res.Request, test_utils.IgnoreNoisyParts([]string{
			"{*ai.ModelRequest}.Messages[0].Content[1].Text", "{*ai.ModelRequest}.Messages[0].Content[1].Metadata",
		})); diff != "" {
			t.Errorf("Request diff (+got -want):\n%s", diff)
		}
	})
	t.Run("handles tool interrupts", func(t *testing.T) {
		interruptTool := DefineTool(r, "interruptor", "always interrupts",
			func(ctx *ToolContext, input any) (any, error) {
				return nil, ctx.Interrupt(&InterruptOptions{
					Metadata: map[string]any{
						"reason": "test interrupt",
					},
				})
			},
		)
		info := &ModelOptions{
			Supports: &ModelSupports{
				Multiturn: true,
				Tools:     true,
			},
		}
		interruptModel := DefineModel(r, "test/interrupt", info,
			func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
				return &ModelResponse{
					Request: gr,
					Message: &Message{
						Role: RoleModel,
						Content: []*Part{
							NewToolRequestPart(&ToolRequest{
								Name:  "interruptor",
								Input: nil,
							}),
						},
					},
				}, nil
			})
		res, err := Generate(context.Background(), r,
			WithModel(interruptModel),
			WithPrompt("trigger interrupt"),
			WithTools(interruptTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		if res.FinishReason != "interrupted" {
			t.Errorf("expected finish reason 'interrupted', got %q", res.FinishReason)
		}
		if res.FinishMessage != "One or more tool calls resulted in interrupts." {
			t.Errorf("unexpected finish message: %q", res.FinishMessage)
		}
		if len(res.Message.Content) != 1 {
			t.Fatalf("expected 1 content part, got %d", len(res.Message.Content))
		}
		metadata := res.Message.Content[0].Metadata
		if metadata == nil {
			t.Fatal("expected metadata in content part")
		}
		interrupt, ok := metadata["interrupt"].(map[string]any)
		if !ok {
			t.Fatal("expected interrupt metadata")
		}
		reason, ok := interrupt["reason"].(string)
		if !ok || reason != "test interrupt" {
			t.Errorf("expected interrupt reason 'test interrupt', got %v", reason)
		}
	})
	t.Run("handles multiple parallel tool calls", func(t *testing.T) {
		roundCount := 0
		info := &ModelOptions{
			Supports: &ModelSupports{
				Multiturn: true,
				Tools:     true,
			},
		}
		parallelModel := DefineModel(r, "test/parallel", info,
			func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
				roundCount++
				if roundCount == 1 {
					return &ModelResponse{
						Request: gr,
						Message: &Message{
							Role: RoleModel,
							Content: []*Part{
								NewToolRequestPart(&ToolRequest{
									Name:  "gablorken",
									Input: map[string]any{"Value": 2, "Over": 3},
								}),
								NewToolRequestPart(&ToolRequest{
									Name:  "gablorken",
									Input: map[string]any{"Value": 3, "Over": 2},
								}),
							},
						},
					}, nil
				}
				var sum float64
				for _, msg := range gr.Messages {
					if msg.Role == RoleTool {
						for _, part := range msg.Content {
							if part.ToolResponse != nil {
								sum += part.ToolResponse.Output.(float64)
							}
						}
					}
				}
				return &ModelResponse{
					Request: gr,
					Message: &Message{
						Role: RoleModel,
						Content: []*Part{
							NewTextPart(fmt.Sprintf("Final result: %d", int(sum))),
						},
					},
				}, nil
			})
		res, err := Generate(context.Background(), r,
			WithModel(parallelModel),
			WithPrompt("trigger parallel tools"),
			WithTools(gablorkenTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		finalPart := res.Message.Content[0]
		if finalPart.Text != "Final result: 17" {
			t.Errorf("expected final result text to be 'Final result: 17', got %q", finalPart.Text)
		}
	})
	t.Run("handles multiple rounds of tool calls", func(t *testing.T) {
		roundCount := 0
		info := &ModelOptions{
			Supports: &ModelSupports{
				Multiturn: true,
				Tools:     true,
			},
		}
		multiRoundModel := DefineModel(r, "test/multiround", info,
			func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
				roundCount++
				if roundCount == 1 {
					return &ModelResponse{
						Request: gr,
						Message: &Message{
							Role: RoleModel,
							Content: []*Part{
								NewToolRequestPart(&ToolRequest{
									Name:  "gablorken",
									Input: map[string]any{"Value": 2, "Over": 3},
								}),
							},
						},
					}, nil
				}
				if roundCount == 2 {
					return &ModelResponse{
						Request: gr,
						Message: &Message{
							Role: RoleModel,
							Content: []*Part{
								NewToolRequestPart(&ToolRequest{
									Name:  "gablorken",
									Input: map[string]any{"Value": 3, "Over": 2},
								}),
							},
						},
					}, nil
				}
				return &ModelResponse{
					Request: gr,
					Message: &Message{
						Role: RoleModel,
						Content: []*Part{
							NewTextPart("Final result"),
						},
					},
				}, nil
			})
		res, err := Generate(context.Background(), r,
			WithModel(multiRoundModel),
			WithPrompt("trigger multiple rounds"),
			WithTools(gablorkenTool),
			WithMaxTurns(2),
		)
		if err != nil {
			t.Fatal(err)
		}
		if roundCount != 3 {
			t.Errorf("expected 3 rounds, got %d", roundCount)
		}
		if res.Text() != "Final result" {
			t.Errorf("expected final message 'Final result', got %q", res.Text())
		}
	})
	t.Run("exceeds maximum turns", func(t *testing.T) {
		info := &ModelOptions{
			Supports: &ModelSupports{
				Multiturn: true,
				Tools:     true,
			},
		}
		infiniteModel := DefineModel(r, "test/infinite", info,
			func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
				return &ModelResponse{
					Request: gr,
					Message: &Message{
						Role: RoleModel,
						Content: []*Part{
							NewToolRequestPart(&ToolRequest{
								Name:  "gablorken",
								Input: map[string]any{"Value": 2, "Over": 2},
							}),
						},
					},
				}, nil
			})
		_, err := Generate(context.Background(), r,
			WithModel(infiniteModel),
			WithPrompt("trigger infinite loop"),
			WithTools(gablorkenTool),
			WithMaxTurns(2),
		)
		if err == nil {
			t.Fatal("expected error for exceeding maximum turns")
		}
		if !strings.Contains(err.Error(), "exceeded maximum tool call iterations (2)") {
			t.Errorf("unexpected error message: %v", err)
		}
	})
	t.Run("applies middleware", func(t *testing.T) {
		middlewareCalled := false
		testMiddleware := func(next ModelFunc) ModelFunc {
			return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
				middlewareCalled = true
				req.Messages = append(req.Messages, NewUserTextMessage("middleware was here"))
				return next(ctx, req, cb)
			}
		}
		res, err := Generate(context.Background(), r,
			WithModel(echoModel),
			WithPrompt("test middleware"),
			WithMiddleware(testMiddleware),
		)
		if err != nil {
			t.Fatal(err)
		}
		if !middlewareCalled {
			t.Error("middleware was not called")
		}
		expectedText := "middleware was here"
		if res.Text() != expectedText {
			t.Errorf("got text %q, want %q", res.Text(), expectedText)
		}
	})
	t.Run("registers dynamic tools", func(t *testing.T) {
		// Create a tool that is NOT registered in the global registry
		dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered",
			func(ctx *ToolContext, input struct {
				Message string
			},
			) (string, error) {
				return "Dynamic: " + input.Message, nil
			},
		)
		// Verify the tool is not in the global registry
		if LookupTool(r, "dynamicTestTool") != nil {
			t.Fatal("dynamicTestTool should not be registered in global registry")
		}
		// Create a model that will call the dynamic tool then provide a final response
		roundCount := 0
		info := &ModelOptions{
			Supports: &ModelSupports{
				Multiturn: true,
				Tools:     true,
			},
		}
		toolCallModel := DefineModel(r, "test/toolcall", info,
			func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
				roundCount++
				if roundCount == 1 {
					// First response: call the dynamic tool
					return &ModelResponse{
						Request: gr,
						Message: &Message{
							Role: RoleModel,
							Content: []*Part{
								NewToolRequestPart(&ToolRequest{
									Name:  "dynamicTestTool",
									Input: map[string]any{"Message": "Hello from dynamic tool"},
								}),
							},
						},
					}, nil
				}
				// Second response: provide final answer based on tool response
				var toolResult string
				for _, msg := range gr.Messages {
					if msg.Role == RoleTool {
						for _, part := range msg.Content {
							if part.ToolResponse != nil {
								toolResult = part.ToolResponse.Output.(string)
							}
						}
					}
				}
				return &ModelResponse{
					Request: gr,
					Message: &Message{
						Role: RoleModel,
						Content: []*Part{
							NewTextPart(toolResult),
						},
					},
				}, nil
			})
		// Use Generate with the dynamic tool - this should trigger the dynamic registration
		res, err := Generate(context.Background(), r,
			WithModel(toolCallModel),
			WithPrompt("call the dynamic tool"),
			WithTools(dynamicTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		// The tool should have been called and returned a response
		expectedText := "Dynamic: Hello from dynamic tool"
		if res.Text() != expectedText {
			t.Errorf("expected text %q, got %q", expectedText, res.Text())
		}
		// Verify two rounds were executed: tool call + final response
		if roundCount != 2 {
			t.Errorf("expected 2 rounds, got %d", roundCount)
		}
		// Verify the tool is still not in the global registry (it was registered in a child)
		if LookupTool(r, "dynamicTestTool") != nil {
			t.Error("dynamicTestTool should not be registered in global registry after generation")
		}
	})
	t.Run("handles duplicate dynamic tools", func(t *testing.T) {
		// Create two tools with the same name
		dynamicTool1 := NewTool("duplicateTool", "first tool",
			func(ctx *ToolContext, input any) (string, error) {
				return "tool1", nil
			},
		)
		dynamicTool2 := NewTool("duplicateTool", "second tool",
			func(ctx *ToolContext, input any) (string, error) {
				return "tool2", nil
			},
		)
		// Using both tools should result in an error
		_, err := Generate(context.Background(), r,
			WithModel(echoModel),
			WithPrompt("test duplicate tools"),
			WithTools(dynamicTool1, dynamicTool2),
		)
		if err == nil {
			t.Fatal("expected error for duplicate tool names")
		}
		if !strings.Contains(err.Error(), "duplicate tool \"duplicateTool\"") {
			t.Errorf("unexpected error message: %v", err)
		}
	})
}
func TestModelVersion(t *testing.T) {
	t.Run("valid version", func(t *testing.T) {
		_, err := Generate(context.Background(), r,
			WithModel(echoModel),
			WithConfig(&GenerationCommonConfig{
				Temperature: 1,
				Version:     "echo-001",
			}),
			WithPrompt("tell a joke about batman"))
		if err != nil {
			t.Errorf("model version should be valid")
		}
	})
	t.Run("invalid version", func(t *testing.T) {
		_, err := Generate(context.Background(), r,
			WithModel(echoModel),
			WithConfig(&GenerationCommonConfig{
				Temperature: 1,
				Version:     "echo-im-not-a-version",
			}),
			WithPrompt("tell a joke about batman"))
		if err == nil {
			t.Errorf("model version should be invalid: %v", err)
		}
	})
}
func TestLookupModel(t *testing.T) {
	t.Run("should return model", func(t *testing.T) {
		if LookupModel(r, "test/"+modelName) == nil {
			t.Errorf("LookupModel did not return model")
		}
	})
	t.Run("should return nil", func(t *testing.T) {
		if LookupModel(r, "foo/bar") != nil {
			t.Errorf("LookupModel did not return nil")
		}
	})
}
func JSONMarkdown(text string) string {
	return "```json\n" + text + "\n```"
}
func errorContains(t *testing.T, err error, want string) {
	t.Helper()
	if err == nil {
		t.Error("got nil, want error")
	} else if !strings.Contains(err.Error(), want) {
		t.Errorf("got error message %q, want it to contain %q", err, want)
	}
}
func validTestMessage(m *Message, output *ModelOutputConfig) (*Message, error) {
	resolvedFormat, err := resolveFormat(r, output.Schema, output.Format)
	if err != nil {
		return nil, err
	}
	handler, err := resolvedFormat.Handler(output.Schema)
	if err != nil {
		return nil, err
	}
	return handler.ParseMessage(m)
}
func TestToolInterruptsAndResume(t *testing.T) {
	conditionalTool := DefineTool(r, "conditional", "tool that may interrupt based on input",
		func(ctx *ToolContext, input struct {
			Value     string
			Interrupt bool
		},
		) (string, error) {
			if input.Interrupt {
				return "", ctx.Interrupt(&InterruptOptions{
					Metadata: map[string]any{
						"reason":      "user_intervention_required",
						"value":       input.Value,
						"interrupted": true,
					},
				})
			}
			return fmt.Sprintf("processed: %s", input.Value), nil
		},
	)
	resumableTool := DefineTool(r, "resumable", "tool that can be resumed",
		func(ctx *ToolContext, input struct {
			Action string
			Data   string
		},
		) (string, error) {
			if ctx.Resumed != nil {
				resumedData, ok := ctx.Resumed["data"].(string)
				if ok {
					return fmt.Sprintf("resumed with: %s, original: %s", resumedData, input.Data), nil
				}
				return fmt.Sprintf("resumed: %s", input.Data), nil
			}
			return fmt.Sprintf("first run: %s", input.Data), nil
		},
	)
	info := &ModelOptions{
		Supports: &ModelSupports{
			Multiturn: true,
			Tools:     true,
		},
	}
	toolModel := DefineModel(r, "test/toolmodel", info,
		func(ctx context.Context, mr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
			return &ModelResponse{
				Request: mr,
				Message: &Message{
					Role: RoleModel,
					Content: []*Part{
						NewTextPart("I need to use some tools."),
						NewToolRequestPart(&ToolRequest{
							Name: "conditional",
							Ref:  "tool1",
							Input: map[string]any{
								"Value":     "test_data",
								"Interrupt": true,
							},
						}),
						NewToolRequestPart(&ToolRequest{
							Name: "resumable",
							Ref:  "tool2",
							Input: map[string]any{
								"Action": "process",
								"Data":   "initial_data",
							},
						}),
					},
				},
			}, nil
		})
	t.Run("basic interrupt flow", func(t *testing.T) {
		res, err := Generate(context.Background(), r,
			WithModel(toolModel),
			WithPrompt("use tools"),
			WithTools(conditionalTool, resumableTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		if res.FinishReason != "interrupted" {
			t.Errorf("expected finish reason 'interrupted', got %q", res.FinishReason)
		}
		if len(res.Message.Content) != 3 {
			t.Fatalf("expected 3 content parts, got %d", len(res.Message.Content))
		}
		interruptedPart := res.Message.Content[1]
		if !interruptedPart.IsToolRequest() {
			t.Fatal("expected second part to be a tool request")
		}
		interruptMeta, ok := interruptedPart.Metadata["interrupt"].(map[string]any)
		if !ok {
			t.Fatal("expected interrupt metadata in tool request")
		}
		if reason, ok := interruptMeta["reason"].(string); !ok || reason != "user_intervention_required" {
			t.Errorf("expected interrupt reason 'user_intervention_required', got %v", reason)
		}
	})
	t.Run("tool.Respond functionality", func(t *testing.T) {
		res, err := Generate(context.Background(), r,
			WithModel(toolModel),
			WithPrompt("use tools"),
			WithTools(conditionalTool, resumableTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		interruptedPart := res.Message.Content[1]
		responsePart := conditionalTool.Respond(interruptedPart, "manual_response_data", &RespondOptions{
			Metadata: map[string]any{
				"manual": true,
				"source": "user",
			},
		})
		if !responsePart.IsToolResponse() {
			t.Fatal("expected response part to be a tool response")
		}
		if responsePart.ToolResponse.Name != "conditional" {
			t.Errorf("expected tool response name 'conditional', got %q", responsePart.ToolResponse.Name)
		}
		if responsePart.ToolResponse.Ref != "tool1" {
			t.Errorf("expected tool response ref 'tool1', got %q", responsePart.ToolResponse.Ref)
		}
		if responsePart.ToolResponse.Output != "manual_response_data" {
			t.Errorf("expected output 'manual_response_data', got %v", responsePart.ToolResponse.Output)
		}
		interruptResponseMeta, ok := responsePart.Metadata["interruptResponse"].(map[string]any)
		if !ok {
			t.Fatal("expected interruptResponse metadata")
		}
		if manual, ok := interruptResponseMeta["manual"].(bool); !ok || !manual {
			t.Errorf("expected manual metadata to be true")
		}
	})
	t.Run("tool.Restart functionality", func(t *testing.T) {
		res, err := Generate(context.Background(), r,
			WithModel(toolModel),
			WithPrompt("use tools"),
			WithTools(conditionalTool, resumableTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		interruptedPart := res.Message.Content[1]
		restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
			ReplaceInput: map[string]any{
				"Value":     "new_test_data",
				"Interrupt": false,
			},
			ResumedMetadata: map[string]any{
				"data":   "resumed_data",
				"source": "restart",
			},
		})
		if !restartPart.IsToolRequest() {
			t.Fatal("expected restart part to be a tool request")
		}
		if restartPart.ToolRequest.Name != "conditional" {
			t.Errorf("expected tool request name 'conditional', got %q", restartPart.ToolRequest.Name)
		}
		newInput, ok := restartPart.ToolRequest.Input.(map[string]any)
		if !ok {
			t.Fatal("expected input to be map[string]any")
		}
		if newInput["Value"] != "new_test_data" {
			t.Errorf("expected new input value 'new_test_data', got %v", newInput["Value"])
		}
		if newInput["Interrupt"] != false {
			t.Errorf("expected interrupt to be false, got %v", newInput["Interrupt"])
		}
		if _, hasInterrupt := restartPart.Metadata["interrupt"]; hasInterrupt {
			t.Error("expected interrupt metadata to be removed")
		}
		resumedMeta, ok := restartPart.Metadata["resumed"].(map[string]any)
		if !ok {
			t.Fatal("expected resumed metadata")
		}
		if resumedMeta["data"] != "resumed_data" {
			t.Errorf("expected resumed data 'resumed_data', got %v", resumedMeta["data"])
		}
	})
	t.Run("resume with respond directive", func(t *testing.T) {
		res, err := Generate(context.Background(), r,
			WithModel(toolModel),
			WithPrompt("use tools"),
			WithTools(conditionalTool, resumableTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		interruptedPart := res.Message.Content[1]
		responsePart := conditionalTool.Respond(interruptedPart, "user_provided_response", nil)
		history := res.History()
		resumeRes, err := Generate(context.Background(), r,
			WithModel(NewModelRef("test/echo", nil)),
			WithMessages(history...),
			WithTools(conditionalTool, resumableTool),
			WithToolResponses(responsePart),
		)
		if err != nil {
			t.Fatal(err)
		}
		if resumeRes.FinishReason == "interrupted" {
			t.Error("expected generation to not be interrupted after responding")
		}
	})
	t.Run("resume with restart directive", func(t *testing.T) {
		res, err := Generate(context.Background(), r,
			WithModel(toolModel),
			WithPrompt("use tools"),
			WithTools(conditionalTool, resumableTool),
		)
		if err != nil {
			t.Fatal(err)
		}
		interruptedPart := res.Message.Content[1]
		restartPart := conditionalTool.Restart(interruptedPart, &RestartOptions{
			ReplaceInput: map[string]any{
				"Value":     "restarted_data",
				"Interrupt": false,
			},
			ResumedMetadata: map[string]any{
				"data": "restart_context",
			},
		})
		history := res.History()
		resumeRes, err := Generate(context.Background(), r,
			WithModel(NewModelRef("test/echo", nil)),
			WithMessages(history...),
			WithTools(conditionalTool, resumableTool),
			WithToolRestarts(restartPart),
		)
		if err != nil {
			t.Fatal(err)
		}
		if resumeRes.FinishReason == "interrupted" {
			t.Error("expected generation to not be interrupted after restarting")
		}
	})
}
func TestResourceProcessing(t *testing.T) {
	r := registry.New()
	// Create test resources using DefineResource
	DefineResource(r, "test-file", &ResourceOptions{
		URI:         "file:///test.txt",
		Description: "Test file resource",
	}, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) {
		return &ResourceOutput{Content: []*Part{NewTextPart("FILE CONTENT")}}, nil
	})
	DefineResource(r, "test-api", &ResourceOptions{
		URI:         "api://data/123",
		Description: "Test API resource",
	}, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) {
		return &ResourceOutput{Content: []*Part{NewTextPart("API DATA")}}, nil
	})
	// Test message with resources
	messages := []*Message{
		NewUserMessage(
			NewTextPart("Read this:"),
			NewResourcePart("file:///test.txt"),
			NewTextPart("And this:"),
			NewResourcePart("api://data/123"),
			NewTextPart("Done."),
		),
	}
	// Process resources
	processed, err := processResources(context.Background(), r, messages)
	if err != nil {
		t.Fatalf("resource processing failed: %v", err)
	}
	// Verify content
	content := processed[0].Content
	expected := []string{"Read this:", "FILE CONTENT", "And this:", "API DATA", "Done."}
	if len(content) != len(expected) {
		t.Fatalf("expected %d parts, got %d", len(expected), len(content))
	}
	for i, want := range expected {
		if content[i].Text != want {
			t.Fatalf("part %d: got %q, want %q", i, content[i].Text, want)
		}
	}
}
func TestResourceProcessingError(t *testing.T) {
	r := registry.New()
	// No resources registered
	messages := []*Message{
		NewUserMessage(NewResourcePart("missing://resource")),
	}
	_, err := processResources(context.Background(), r, messages)
	if err == nil {
		t.Fatal("expected error when no resources available")
	}
	if !strings.Contains(err.Error(), "no resource found for URI") {
		t.Fatalf("wrong error: %v", err)
	}
}