// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package prompts_test
import (
"context"
"errors"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/prompts"
_ "github.com/googleapis/genai-toolbox/internal/prompts/custom"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
type mockPromptConfig struct {
name string
kind string
}
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
var errMockFactory = errors.New("mock factory error")
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
}
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
return nil, errMockFactory
}
func TestRegistry(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
kind := "testKindSuccess"
if !prompts.Register(kind, mockFactory) {
t.Fatal("expected registration to succeed")
}
// This should fail because we are registering a duplicate
if prompts.Register(kind, mockFactory) {
t.Fatal("expected duplicate registration to fail")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
config, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig to succeed, but got error: %v", err)
}
if config == nil {
t.Fatal("expected a non-nil config")
}
})
t.Run("DecodeUnknownKind", func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error for unknown kind, but got nil")
}
if !strings.Contains(err.Error(), "unknown prompt kind") {
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
}
})
t.Run("FactoryReturnsError", func(t *testing.T) {
kind := "testKindError"
if !prompts.Register(kind, mockErrorFactory) {
t.Fatal("expected registration to succeed")
}
decoder := yaml.NewDecoder(strings.NewReader(""))
_, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
if err == nil {
t.Fatal("expected an error from the factory, but got nil")
}
if !errors.Is(err, errMockFactory) {
t.Errorf("expected error to wrap mock factory error, but it didn't")
}
})
t.Run("DecodeDefaultsToCustom", func(t *testing.T) {
decoder := yaml.NewDecoder(strings.NewReader("description: A test prompt"))
config, err := prompts.DecodeConfig(ctx, "", "testDefaultPrompt", decoder)
if err != nil {
t.Fatalf("expected DecodeConfig with empty kind to succeed, but got error: %v", err)
}
if config == nil {
t.Fatal("expected a non-nil config for default kind")
}
if config.PromptConfigKind() != "custom" {
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
}
})
}
func TestGetMcpManifest(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
promptName string
description string
args prompts.Arguments
want prompts.McpManifest
}{
{
name: "No arguments",
promptName: "test-prompt",
description: "A test prompt.",
args: prompts.Arguments{},
want: prompts.McpManifest{
Name: "test-prompt",
Description: "A test prompt.",
Arguments: []prompts.ArgMcpManifest{},
},
},
{
name: "With arguments",
promptName: "arg-prompt",
description: "Prompt with args.",
args: prompts.Arguments{
{Parameter: parameters.NewStringParameter("param1", "First param")},
{Parameter: parameters.NewIntParameterWithRequired("param2", "Second param", false)},
},
want: prompts.McpManifest{
Name: "arg-prompt",
Description: "Prompt with args.",
Arguments: []prompts.ArgMcpManifest{
{Name: "param1", Description: "First param", Required: true},
{Name: "param2", Description: "Second param", Required: false},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := prompts.GetMcpManifest(tc.promptName, tc.description, tc.args)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("GetMcpManifest() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestGetManifest(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
description string
args prompts.Arguments
want prompts.Manifest
}{
{
name: "No arguments",
description: "A simple prompt.",
args: prompts.Arguments{},
want: prompts.Manifest{
Description: "A simple prompt.",
Arguments: []parameters.ParameterManifest{},
},
},
{
name: "With arguments",
description: "Prompt with arguments.",
args: prompts.Arguments{
{Parameter: parameters.NewStringParameter("param1", "First param")},
{Parameter: parameters.NewBooleanParameterWithRequired("param2", "Second param", false)},
},
want: prompts.Manifest{
Description: "Prompt with arguments.",
Arguments: []parameters.ParameterManifest{
{Name: "param1", Type: "string", Required: true, Description: "First param", AuthServices: []string{}},
{Name: "param2", Type: "boolean", Required: false, Description: "Second param", AuthServices: []string{}},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := prompts.GetManifest(tc.description, tc.args)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("GetManifest() mismatch (-want +got):\n%s", diff)
}
})
}
}