generate_live_test.go•7.61 kB
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package compat_oai_test
import (
"context"
"fmt"
"os"
"slices"
"strings"
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/compat_oai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"github.com/stretchr/testify/assert"
)
const defaultModel = "gpt-4o-mini"
func setupTestClient(t *testing.T) *compat_oai.ModelGenerator {
t.Helper()
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
t.Skip("Skipping test: OPENAI_API_KEY environment variable not set")
}
client := openai.NewClient(option.WithAPIKey(apiKey))
return compat_oai.NewModelGenerator(&client, defaultModel)
}
func TestGenerator_Complete(t *testing.T) {
g := setupTestClient(t)
messages := []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewTextPart("Tell me a joke"),
},
},
{
Role: ai.RoleModel,
Content: []*ai.Part{
ai.NewTextPart("Why did the scarecrow win an award?"),
},
},
{
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewTextPart("Why?"),
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}
resp, err := g.WithMessages(messages).Generate(context.Background(), req, nil)
if err != nil {
t.Error(err)
}
if len(resp.Message.Content) == 0 {
t.Error("empty messages content, got 0")
}
if resp.Message.Role != ai.RoleModel {
t.Errorf("unexpected role, got: %q, want: %q", resp.Message.Role, ai.RoleModel)
}
}
func TestGenerator_Stream(t *testing.T) {
g := setupTestClient(t)
messages := []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewTextPart("Count from 1 to 3"),
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}
var chunks []string
handleChunk := func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
for _, part := range chunk.Content {
chunks = append(chunks, part.Text)
}
return nil
}
_, err := g.WithMessages(messages).Generate(context.Background(), req, handleChunk)
if err != nil {
t.Error(err)
}
if len(chunks) == 0 {
t.Error("expecting stream chunks, got: 0")
}
// Verify we got the full response
fullText := strings.Join(chunks, "")
if !strings.Contains(fullText, "1") {
t.Errorf("expecting chunk to contain: \"1\", got: %q", fullText)
}
if !strings.Contains(fullText, "2") {
t.Errorf("expecting chunk to contain: \"2\", got: %q", fullText)
}
if !strings.Contains(fullText, "3") {
t.Errorf("expecting chunk to contain: \"3\", got: %q", fullText)
}
}
func TestWithConfig(t *testing.T) {
tests := []struct {
name string
config any
err error
validate func(*testing.T, *openai.ChatCompletionNewParams)
}{
{
name: "nil config",
config: nil,
validate: func(t *testing.T, request *openai.ChatCompletionNewParams) {
// For nil config, we expect config fields to be unset (not nil, but with its zero value)
if request.Temperature.Value != 0 {
t.Errorf("expecting empty in temperature, got: %v", request.Temperature.Value)
}
if request.MaxCompletionTokens.Value != 0 {
t.Errorf("expecting empty max completion tokens, got: %v", request.MaxCompletionTokens.Value)
}
if request.TopP.Value != 0 {
t.Errorf("expecting empty in topP, got: %v", request.TopP.Value)
}
if len(request.Stop.OfStringArray) != 0 {
t.Errorf("expecting empty stop reasons, got: %v", request.Stop)
}
},
},
{
name: "empty openai config",
config: openai.ChatCompletionNewParams{},
validate: func(t *testing.T, request *openai.ChatCompletionNewParams) {
if request.Temperature.Value != 0 {
t.Errorf("expecting empty in temperature, got: %v", request.Temperature.Value)
}
if request.MaxCompletionTokens.Value != 0 {
t.Errorf("expecting empty max completion tokens, got: %v", request.MaxCompletionTokens.Value)
}
if request.TopP.Value != 0 {
t.Errorf("expecting empty in topP, got: %v", request.TopP.Value)
}
if len(request.Stop.OfStringArray) != 0 {
t.Errorf("expecting empty stop reasons, got: %v", request.Stop)
}
},
},
{
name: "valid config with all supported fields",
config: openai.ChatCompletionNewParams{
Temperature: openai.Float(0.7),
MaxCompletionTokens: openai.Int(100),
TopP: openai.Float(0.9),
Stop: openai.ChatCompletionNewParamsStopUnion{
OfStringArray: []string{"stop1", "stop2"},
},
},
validate: func(t *testing.T, request *openai.ChatCompletionNewParams) {
// Check that fields are present and have correct values
stopReasons := []string{"stop1, stop2"}
if request.Temperature.Value != 0.7 {
t.Errorf("expecting empty in temperature, got: %v", request.Temperature.Value)
}
if request.MaxCompletionTokens.Value != 100 {
t.Errorf("expecting empty max completion tokens, got: %v", request.MaxCompletionTokens.Value)
}
if request.TopP.Value != 0.9 {
t.Errorf("expecting empty in topP, got: %v", request.TopP.Value)
}
if slices.Equal(request.Stop.OfStringArray, stopReasons) {
t.Errorf("diff in stop reasons, got: %v, want: %v", request.Stop.OfStringArray, stopReasons)
}
},
},
{
name: "valid config as map",
config: map[string]any{
"temperature": 0.7,
"max_completion_tokens": 100,
"top_p": 0.9,
"stop": []string{"stop1", "stop2"},
},
validate: func(t *testing.T, request *openai.ChatCompletionNewParams) {
stopReasons := []string{"stop1, stop2"}
if request.Temperature.Value != 0.7 {
t.Errorf("expecting empty in temperature, got: %v", request.Temperature.Value)
}
if request.MaxCompletionTokens.Value != 100 {
t.Errorf("expecting empty max completion tokens, got: %v", request.MaxCompletionTokens.Value)
}
if request.TopP.Value != 0.9 {
t.Errorf("expecting empty in topP, got: %v", request.TopP.Value)
}
if slices.Equal(request.Stop.OfStringArray, stopReasons) {
t.Errorf("diff in stop reasons, got: %v, want: %v", request.Stop.OfStringArray, stopReasons)
}
},
},
{
name: "invalid config type",
config: "not a config",
err: fmt.Errorf("unexpected config type: string"),
},
}
// define simple messages for testing
messages := []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{
ai.NewTextPart("Tell me a joke"),
},
},
}
req := &ai.ModelRequest{
Messages: messages,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
generator := setupTestClient(t)
result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), req, nil)
if tt.err != nil {
assert.Error(t, err)
assert.Equal(t, tt.err.Error(), err.Error())
return
}
// validate that the response was successful
assert.NoError(t, err)
assert.NotNil(t, result)
// validate the input request was transformed correctly
if tt.validate != nil {
tt.validate(t, generator.GetRequest())
}
})
}
}