MCP Terminal Server
by dillip285
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
package ollama
import (
"testing"
"github.com/firebase/genkit/go/ai"
)
func TestConcatMessages(t *testing.T) {
tests := []struct {
name string
messages []*ai.Message
roles []ai.Role
want string
}{
{
name: "Single message with matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Hello, how are you?")},
},
},
roles: []ai.Role{ai.RoleUser},
want: "Hello, how are you?",
},
{
name: "Multiple messages with mixed roles",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Tell me a joke.")},
},
{
Role: ai.RoleModel,
Content: []*ai.Part{ai.NewTextPart("Why did the scarecrow win an award? Because he was outstanding in his field!")},
},
},
roles: []ai.Role{ai.RoleModel},
want: "Why did the scarecrow win an award? Because he was outstanding in his field!",
},
{
name: "No matching role",
messages: []*ai.Message{
{
Role: ai.RoleUser,
Content: []*ai.Part{ai.NewTextPart("Any suggestions?")},
},
},
roles: []ai.Role{ai.RoleSystem},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := &ai.ModelRequest{Messages: tt.messages}
got := concatMessages(input, tt.roles)
if got != tt.want {
t.Errorf("concatMessages() = %q, want %q", got, tt.want)
}
})
}
}
func TestTranslateGenerateChunk(t *testing.T) {
tests := []struct {
name string
input string
want *ai.ModelResponseChunk
wantErr bool
}{
{
name: "Valid JSON response",
input: `{"model": "my-model", "created_at": "2024-06-20T12:34:56Z", "response": "This is a test response."}`,
want: &ai.ModelResponseChunk{
Content: []*ai.Part{ai.NewTextPart("This is a test response.")},
},
wantErr: false,
},
{
name: "Invalid JSON",
input: `{invalid}`,
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := translateGenerateChunk(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("translateGenerateChunk() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !equalContent(got.Content, tt.want.Content) {
t.Errorf("translateGenerateChunk() got = %v, want %v", got, tt.want)
}
})
}
}
// Helper function to compare content
func equalContent(a, b []*ai.Part) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Text != b[i].Text || !a[i].IsText() || !b[i].IsText() {
return false
}
}
return true
}