Genkit MCP

Official
Apache 2.0
127
1,173
// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 package ollama import ( "bufio" "bytes" "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "slices" "strings" "sync" "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/internal/uri" ) const provider = "ollama" var ( mediaSupportedModels = []string{"llava"} roleMapping = map[ai.Role]string{ ai.RoleUser: "user", ai.RoleModel: "assistant", ai.RoleSystem: "system", } ) var state struct { serverAddress string initted bool mu sync.Mutex } func DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic("ollama.Init not called") } var mi ai.ModelInfo if info != nil { mi = *info } else { mi = ai.ModelInfo{ Label: model.Name, Supports: &ai.ModelInfoSupports{ Multiturn: true, SystemRole: true, Media: slices.Contains(mediaSupportedModels, model.Name), }, Versions: []string{}, } } meta := &ai.ModelInfo{ Label: "Ollama - " + model.Name, Supports: mi.Supports, Versions: []string{}, } gen := &generator{model: model, serverAddress: state.serverAddress} return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) } // IsDefinedModel reports whether a model is defined. func IsDefinedModel(g *genkit.Genkit, name string) bool { return genkit.IsDefinedModel(g, provider, name) } // Model returns the [ai.Model] with the given name. // It returns nil if the model was not configured. func Model(g *genkit.Genkit, name string) ai.Model { return genkit.LookupModel(g, provider, name) } // ModelDefinition represents a model with its name and type. type ModelDefinition struct { Name string Type string } type generator struct { model ModelDefinition serverAddress string } type ollamaMessage struct { Role string `json:"role"` Content string `json:"content"` Images []string `json:"images,omitempty"` } // Ollama has two API endpoints, one with a chat interface and another with a generate response interface. // That's why have multiple request interfaces for the Ollama API below. /* TODO: Support optional, advanced parameters: format: the format to return a response in. Currently the only accepted value is json options: additional model parameters listed in the documentation for the Modelfile such as temperature system: system message to (overrides what is defined in the Modelfile) template: the prompt template to use (overrides what is defined in the Modelfile) context: the context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory stream: if false the response will be returned as a single response object, rather than a stream of objects raw: if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m) */ type ollamaChatRequest struct { Messages []*ollamaMessage `json:"messages"` Model string `json:"model"` Stream bool `json:"stream"` } type ollamaModelRequest struct { System string `json:"system,omitempty"` Images []string `json:"images,omitempty"` Model string `json:"model"` Prompt string `json:"prompt"` Stream bool `json:"stream"` } // TODO: Add optional parameters (images, format, options, etc.) based on your use case type ollamaChatResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Message struct { Role string `json:"role"` Content string `json:"content"` } `json:"message"` } type ollamaModelResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Response string `json:"response"` } // Config provides configuration options for the Init function. type Config struct { // Server Address of oLLama. ServerAddress string } // Init initializes the plugin. // Since Ollama models are locally hosted, the plugin doesn't initialize any default models. // After downloading a model, call [DefineModel] to use it. func Init(ctx context.Context, cfg *Config) (err error) { state.mu.Lock() defer state.mu.Unlock() if state.initted { panic("ollama.Init already called") } if cfg == nil || cfg.ServerAddress == "" { return errors.New("ollama: need ServerAddress") } state.serverAddress = cfg.ServerAddress state.initted = true return nil } // Generate makes a request to the Ollama API and processes the response. func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { stream := cb != nil var payload any isChatModel := g.model.Type == "chat" if !isChatModel { images, err := concatImages(input, []ai.Role{ai.RoleUser, ai.RoleModel}) if err != nil { return nil, fmt.Errorf("failed to grab image parts: %v", err) } payload = ollamaModelRequest{ Model: g.model.Name, Prompt: concatMessages(input, []ai.Role{ai.RoleUser, ai.RoleModel, ai.RoleTool}), System: concatMessages(input, []ai.Role{ai.RoleSystem}), Images: images, Stream: stream, } } else { var messages []*ollamaMessage // Translate all messages to ollama message format. for _, m := range input.Messages { message, err := convertParts(m.Role, m.Content) if err != nil { return nil, fmt.Errorf("failed to convert message parts: %v", err) } messages = append(messages, message) } payload = ollamaChatRequest{ Messages: messages, Model: g.model.Name, Stream: stream, } } client := &http.Client{Timeout: 30 * time.Second} payloadBytes, err := json.Marshal(payload) if err != nil { return nil, err } // Determine the correct endpoint endpoint := g.serverAddress + "/api/chat" if !isChatModel { endpoint = g.serverAddress + "/api/generate" } req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") req = req.WithContext(ctx) resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } defer resp.Body.Close() if cb == nil { // Existing behavior for non-streaming responses var err error body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %v", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body) } var response *ai.ModelResponse if isChatModel { response, err = translateChatResponse(body) } else { response, err = translateModelResponse(body) } response.Request = input if err != nil { return nil, fmt.Errorf("failed to parse response: %v", err) } return response, nil } else { var chunks []*ai.ModelResponseChunk scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() var chunk *ai.ModelResponseChunk if isChatModel { chunk, err = translateChatChunk(line) } else { chunk, err = translateGenerateChunk(line) } if err != nil { return nil, fmt.Errorf("failed to translate chunk: %v", err) } chunks = append(chunks, chunk) cb(ctx, chunk) } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("reading response stream: %v", err) } // Create a final response with the merged chunks finalResponse := &ai.ModelResponse{ Request: input, FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ Role: ai.RoleModel, }, } // Add all the merged content to the final response's candidate for _, chunk := range chunks { finalResponse.Message.Content = append(finalResponse.Message.Content, chunk.Content...) } return finalResponse, nil // Return the final merged response } } func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { message := &ollamaMessage{ Role: roleMapping[role], } var contentBuilder strings.Builder for _, part := range parts { if part.IsText() { contentBuilder.WriteString(part.Text) } else if part.IsMedia() { _, data, err := uri.Data(part) if err != nil { return nil, err } base64Encoded := base64.StdEncoding.EncodeToString(data) message.Images = append(message.Images, base64Encoded) } else { return nil, errors.New("unknown content type") } } message.Content = contentBuilder.String() return message, nil } // translateChatResponse translates Ollama chat response into a genkit response. func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { var response ollamaChatResponse if err := json.Unmarshal(responseData, &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ Role: ai.Role(response.Message.Role), }, } aiPart := ai.NewTextPart(response.Message.Content) modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) return modelResponse, nil } // translateResponse translates Ollama generate response into a genkit response. func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) { var response ollamaModelResponse if err := json.Unmarshal(responseData, &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ Role: ai.RoleModel, }, } aiPart := ai.NewTextPart(response.Response) modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) modelResponse.Usage = &ai.GenerationUsage{} // TODO: can we get any of this info? return modelResponse, nil } func translateChatChunk(input string) (*ai.ModelResponseChunk, error) { var response ollamaChatResponse if err := json.Unmarshal([]byte(input), &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } chunk := &ai.ModelResponseChunk{} aiPart := ai.NewTextPart(response.Message.Content) chunk.Content = append(chunk.Content, aiPart) return chunk, nil } func translateGenerateChunk(input string) (*ai.ModelResponseChunk, error) { var response ollamaModelResponse if err := json.Unmarshal([]byte(input), &response); err != nil { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } chunk := &ai.ModelResponseChunk{} aiPart := ai.NewTextPart(response.Response) chunk.Content = append(chunk.Content, aiPart) return chunk, nil } // concatMessages translates a list of messages into a prompt-style format func concatMessages(input *ai.ModelRequest, roles []ai.Role) string { roleSet := make(map[ai.Role]bool) for _, role := range roles { roleSet[role] = true // Create a set for faster lookup } var sb strings.Builder for _, message := range input.Messages { // Check if the message role is in the allowed set if !roleSet[message.Role] { continue } for _, part := range message.Content { if !part.IsText() { continue } sb.WriteString(part.Text) } } return sb.String() } // concatImages grabs the images from genkit message parts func concatImages(input *ai.ModelRequest, roleFilter []ai.Role) ([]string, error) { roleSet := make(map[ai.Role]bool) for _, role := range roleFilter { roleSet[role] = true } var images []string for _, message := range input.Messages { // Check if the message role is in the allowed set if roleSet[message.Role] { for _, part := range message.Content { if !part.IsMedia() { continue } _, data, err := uri.Data(part) if err != nil { return nil, err } base64Encoded := base64.StdEncoding.EncodeToString(data) images = append(images, base64Encoded) } } } return images, nil }