option_test.go•15.9 kB
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestCommonOptions(t *testing.T) {
tests := []struct {
name string
opts []CommonGenOption
wantErr bool
}{
{
name: "valid options",
opts: []CommonGenOption{
WithMessages(NewUserTextMessage("test")),
WithConfig(&GenerationCommonConfig{Temperature: 0.7}),
WithModel(&mockModel{name: "test/model"}),
WithTools(&mockTool{name: "test/tool"}),
WithToolChoice(ToolChoiceAuto),
WithMaxTurns(3),
WithReturnToolRequests(true),
WithMiddleware(func(next ModelFunc) ModelFunc { return next }),
},
wantErr: false,
},
{
name: "mutually exclusive - messages",
opts: []CommonGenOption{
WithMessages(NewUserTextMessage("test")),
WithMessagesFn(func(context.Context, any) ([]*Message, error) { return nil, nil }),
},
wantErr: true,
},
{
name: "mutually exclusive - model",
opts: []CommonGenOption{
WithModel(&mockModel{name: "test/model"}),
WithModelName("test/model"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
genOpts := &generateOptions{}
promptOpts := &promptOptions{}
pgOpts := &promptExecutionOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyGenerate(genOpts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyGenerate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
for _, opt := range tt.opts {
if err = opt.applyPrompt(promptOpts); err != nil {
t.Errorf("applyPrompt() unexpected error = %v", err)
return
}
}
for _, opt := range tt.opts {
if err = opt.applyPromptExecute(pgOpts); err != nil {
t.Errorf("applyPromptExecute() unexpected error = %v", err)
return
}
}
})
}
}
func TestPromptOptions(t *testing.T) {
tests := []struct {
name string
opts []PromptOption
wantErr bool
}{
{
name: "valid options",
opts: []PromptOption{
WithDescription("test description"),
WithMetadata(map[string]any{"key": "value"}),
WithInputType(struct {
Test string `json:"test"`
}{}),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &promptOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyPrompt(opts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestPromptingOptions(t *testing.T) {
tests := []struct {
name string
opts []PromptingOption
wantErr bool
}{
{
name: "valid options",
opts: []PromptingOption{
WithSystem("system instruction"),
WithPrompt("user prompt"),
},
wantErr: false,
},
{
name: "mutually exclusive - system",
opts: []PromptingOption{
WithSystem("system instruction"),
WithSystemFn(func(context.Context, any) (string, error) { return "system", nil }),
},
wantErr: true,
},
{
name: "mutually exclusive - prompt",
opts: []PromptingOption{
WithPrompt("user prompt"),
WithPromptFn(func(context.Context, any) (string, error) { return "prompt", nil }),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
genOpts := &generateOptions{}
promptOpts := &promptOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyGenerate(genOpts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyGenerate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
for _, opt := range tt.opts {
if err = opt.applyPrompt(promptOpts); err != nil {
t.Errorf("applyPrompt() unexpected error = %v", err)
return
}
}
})
}
}
func TestOutputOptions(t *testing.T) {
tests := []struct {
name string
opts []OutputOption
wantErr bool
}{
{
name: "valid - output type",
opts: []OutputOption{
WithOutputType(struct {
Test string `json:"test"`
}{}),
},
wantErr: false,
},
{
name: "valid - output format",
opts: []OutputOption{
WithOutputFormat(OutputFormatText),
},
wantErr: false,
},
{
name: "valid - output instruction",
opts: []OutputOption{
WithOutputInstructions(""),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
genOpts := &generateOptions{}
promptOpts := &promptOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyGenerate(genOpts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyGenerate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
for _, opt := range tt.opts {
if err = opt.applyPrompt(promptOpts); err != nil {
t.Errorf("applyPrompt() unexpected error = %v", err)
return
}
}
})
}
}
func TestExecutionOptions(t *testing.T) {
tests := []struct {
name string
opts []ExecutionOption
wantErr bool
}{
{
name: "valid options",
opts: []ExecutionOption{
WithStreaming(func(context.Context, *ModelResponseChunk) error { return nil }),
},
wantErr: false,
},
{
name: "duplicate - streaming",
opts: []ExecutionOption{
WithStreaming(func(context.Context, *ModelResponseChunk) error { return nil }),
WithStreaming(func(context.Context, *ModelResponseChunk) error { return nil }),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
genOpts := &generateOptions{}
pgOpts := &promptExecutionOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyGenerate(genOpts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyGenerate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
for _, opt := range tt.opts {
if err = opt.applyPromptExecute(pgOpts); err != nil {
t.Errorf("applyPromptExecute() unexpected error = %v", err)
return
}
}
})
}
}
func TestPromptGenerateOptions(t *testing.T) {
tests := []struct {
name string
opts []PromptExecuteOption
wantErr bool
}{
{
name: "valid options",
opts: []PromptExecuteOption{
WithInput(map[string]string{"key": "value"}),
},
wantErr: false,
},
{
name: "duplicate - input",
opts: []PromptExecuteOption{
WithInput("input1"),
WithInput("input2"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &promptExecutionOptions{}
var err error
for _, opt := range tt.opts {
err = opt.applyPromptExecute(opts)
if err != nil {
break
}
}
if (err != nil) != tt.wantErr {
t.Errorf("applyPromptExecute() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestGenerateOptionsComplete(t *testing.T) {
opts := &generateOptions{}
mw := func(next ModelFunc) ModelFunc { return next }
model := &mockModel{name: "test/model"}
tool := &mockTool{name: "test/tool"}
streamFunc := func(context.Context, *ModelResponseChunk) error { return nil }
doc := DocumentFromText("doc", nil)
options := []GenerateOption{
WithModel(model),
WithMessages(NewUserTextMessage("message")),
WithConfig(&GenerationCommonConfig{Temperature: 0.7}),
WithTools(tool),
WithToolChoice(ToolChoiceAuto),
WithMaxTurns(3),
WithReturnToolRequests(true),
WithMiddleware(mw),
WithSystem("system prompt"),
WithPrompt("user prompt"),
WithDocs(doc),
WithOutputType(map[string]string{"key": "value"}),
WithOutputInstructions(""),
WithCustomConstrainedOutput(),
WithStreaming(streamFunc),
}
for _, opt := range options {
if err := opt.applyGenerate(opts); err != nil {
t.Fatalf("Failed to apply option: %v", err)
}
}
returnToolRequests := true
expected := &generateOptions{
commonGenOptions: commonGenOptions{
configOptions: configOptions{
Config: &GenerationCommonConfig{Temperature: 0.7},
},
Model: model,
Tools: []ToolRef{tool},
ToolChoice: ToolChoiceAuto,
MaxTurns: 3,
ReturnToolRequests: &returnToolRequests,
Middleware: []ModelMiddleware{mw},
},
promptingOptions: promptingOptions{
SystemFn: opts.SystemFn,
PromptFn: opts.PromptFn,
},
outputOptions: outputOptions{
OutputFormat: OutputFormatJSON,
OutputSchema: opts.OutputSchema,
OutputInstructions: func() *string {
s := ""
return &s
}(),
CustomConstrained: true,
},
executionOptions: executionOptions{
Stream: streamFunc,
},
documentOptions: documentOptions{
Documents: []*Document{doc},
},
}
if diff := cmp.Diff(expected, opts,
cmpopts.IgnoreFields(commonGenOptions{}, "MessagesFn", "Middleware"),
cmpopts.IgnoreFields(promptingOptions{}, "SystemFn", "PromptFn"),
cmpopts.IgnoreFields(executionOptions{}, "Stream"),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}),
cmp.AllowUnexported(generateOptions{}, commonGenOptions{}, promptingOptions{},
outputOptions{}, executionOptions{}, documentOptions{})); diff != "" {
t.Errorf("Options not applied correctly, diff (-want +got):\n%s", diff)
}
if opts.MessagesFn == nil {
t.Errorf("MessagesFn should not be nil")
}
if len(opts.Middleware) == 0 {
t.Errorf("Middleware should not be empty")
}
if opts.SystemFn == nil {
t.Errorf("SystemFn should not be nil")
}
if opts.PromptFn == nil {
t.Errorf("PromptFn should not be nil")
}
if opts.Stream == nil {
t.Errorf("Stream should not be nil")
}
}
func TestPromptOptionsComplete(t *testing.T) {
opts := &promptOptions{}
mw := func(next ModelFunc) ModelFunc { return next }
model := &mockModel{name: "test/model"}
tool := &mockTool{name: "test/tool"}
input := struct {
Test string `json:"test"`
}{
Test: "value",
}
options := []PromptOption{
WithModel(model),
WithMessages(NewUserTextMessage("message")),
WithConfig(&GenerationCommonConfig{Temperature: 0.7}),
WithTools(tool),
WithToolChoice(ToolChoiceAuto),
WithMaxTurns(3),
WithReturnToolRequests(true),
WithMiddleware(mw),
WithSystem("system prompt"),
WithPrompt("user prompt"),
WithDescription("test description"),
WithMetadata(map[string]any{"key": "value"}),
WithOutputType(map[string]string{"key": "value"}),
WithOutputInstructions(""),
WithCustomConstrainedOutput(),
WithInputType(input),
}
for _, opt := range options {
if err := opt.applyPrompt(opts); err != nil {
t.Fatalf("Failed to apply option: %v", err)
}
}
returnToolRequests := true
expected := &promptOptions{
commonGenOptions: commonGenOptions{
configOptions: configOptions{
Config: &GenerationCommonConfig{Temperature: 0.7},
},
Model: model,
Tools: []ToolRef{tool},
ToolChoice: ToolChoiceAuto,
MaxTurns: 3,
ReturnToolRequests: &returnToolRequests,
Middleware: []ModelMiddleware{mw},
},
promptingOptions: promptingOptions{
SystemFn: opts.SystemFn,
PromptFn: opts.PromptFn,
},
outputOptions: outputOptions{
OutputFormat: OutputFormatJSON,
OutputSchema: opts.OutputSchema,
OutputInstructions: func() *string {
s := ""
return &s
}(),
CustomConstrained: true,
},
Description: "test description",
Metadata: map[string]any{"key": "value"},
InputSchema: opts.InputSchema,
DefaultInput: map[string]any{"test": "value"},
}
if diff := cmp.Diff(expected, opts,
cmpopts.IgnoreFields(commonGenOptions{}, "MessagesFn", "Middleware"),
cmpopts.IgnoreFields(promptingOptions{}, "SystemFn", "PromptFn"),
cmpopts.IgnoreFields(outputOptions{}, "OutputSchema"),
cmpopts.IgnoreFields(promptOptions{}, "InputSchema"),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}),
cmp.AllowUnexported(promptOptions{}, commonGenOptions{}, promptingOptions{},
outputOptions{})); diff != "" {
t.Errorf("Options not applied correctly, diff (-want +got):\n%s", diff)
}
if opts.MessagesFn == nil {
t.Errorf("MessagesFn should not be nil")
}
if len(opts.Middleware) == 0 {
t.Errorf("Middleware should not be empty")
}
if opts.SystemFn == nil {
t.Errorf("SystemFn should not be nil")
}
if opts.PromptFn == nil {
t.Errorf("PromptFn should not be nil")
}
if opts.OutputSchema == nil {
t.Errorf("OutputSchema should not be nil")
}
if opts.InputSchema == nil {
t.Errorf("InputSchema should not be nil")
}
}
func TestPromptExecuteOptionsComplete(t *testing.T) {
opts := &promptExecutionOptions{}
mw := func(next ModelFunc) ModelFunc { return next }
model := &mockModel{name: "test/model"}
tool := &mockTool{name: "test/tool"}
streamFunc := func(context.Context, *ModelResponseChunk) error { return nil }
input := map[string]string{"key": "value"}
doc := DocumentFromText("doc", nil)
options := []PromptExecuteOption{
WithModel(model),
WithMessages(NewUserTextMessage("message")),
WithConfig(&GenerationCommonConfig{Temperature: 0.7}),
WithTools(tool),
WithToolChoice(ToolChoiceAuto),
WithMaxTurns(3),
WithReturnToolRequests(true),
WithMiddleware(mw),
WithDocs(doc),
WithStreaming(streamFunc),
WithInput(input),
}
for _, opt := range options {
if err := opt.applyPromptExecute(opts); err != nil {
t.Fatalf("Failed to apply option: %v", err)
}
}
returnToolRequests := true
expected := &promptExecutionOptions{
commonGenOptions: commonGenOptions{
configOptions: configOptions{
Config: &GenerationCommonConfig{Temperature: 0.7},
},
Model: model,
Tools: []ToolRef{tool},
ToolChoice: ToolChoiceAuto,
MaxTurns: 3,
ReturnToolRequests: &returnToolRequests,
Middleware: []ModelMiddleware{mw},
},
executionOptions: executionOptions{
Stream: streamFunc,
},
documentOptions: documentOptions{
Documents: []*Document{doc},
},
Input: input,
}
if diff := cmp.Diff(expected, opts,
cmpopts.IgnoreFields(commonGenOptions{}, "MessagesFn", "Middleware"),
cmpopts.IgnoreFields(executionOptions{}, "Stream"),
cmpopts.IgnoreUnexported(mockModel{}, mockTool{}),
cmp.AllowUnexported(promptExecutionOptions{}, commonGenOptions{},
executionOptions{})); diff != "" {
t.Errorf("Options not applied correctly, diff (-want +got):\n%s", diff)
}
if opts.MessagesFn == nil {
t.Errorf("MessagesFn should not be nil")
}
if opts.Middleware == nil {
t.Errorf("Middleware should not be nil")
}
if opts.Stream == nil {
t.Errorf("Stream should not be nil")
}
}
type mockModel struct {
name string
}
func (m *mockModel) Name() string {
return m.name
}
func (m *mockModel) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
return nil, nil
}
type mockTool struct {
name string
}
func (t *mockTool) Name() string {
return t.name
}
func (t *mockTool) Definition() *ToolDefinition {
return &ToolDefinition{Name: t.name}
}
func (t *mockTool) RunRaw(ctx context.Context, input any) (any, error) {
return nil, nil
}