googlegenai.go•15 kB
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
package googlegenai
import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"os"
	"strings"
	"sync"
	"cloud.google.com/go/auth/credentials"
	"cloud.google.com/go/auth/httptransport"
	"github.com/firebase/genkit/go/ai"
	"github.com/firebase/genkit/go/core/api"
	"github.com/firebase/genkit/go/genkit"
	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
	"google.golang.org/genai"
)
const (
	googleAIProvider = "googleai"
	vertexAIProvider = "vertexai"
	googleAILabelPrefix = "Google AI"
	vertexAILabelPrefix = "Vertex AI"
)
var (
	defaultGeminiOpts = ai.ModelOptions{
		Supports: &Multimodal,
		Versions: []string{},
		Stage:    ai.ModelStageUnstable,
	}
	defaultImagenOpts = ai.ModelOptions{
		Supports: &Media,
		Versions: []string{},
		Stage:    ai.ModelStageUnstable,
	}
	defaultEmbedOpts = ai.EmbedderOptions{
		Supports: &ai.EmbedderSupports{
			Input: []string{"text"},
		},
		Dimensions: 768,
	}
)
// GoogleAI is a Genkit plugin for interacting with the Google AI service.
type GoogleAI struct {
	APIKey string // API key to access the service. If empty, the values of the environment variables GEMINI_API_KEY or GOOGLE_API_KEY will be consulted, in that order.
	gclient *genai.Client // Client for the Google AI service.
	mu      sync.Mutex    // Mutex to control access.
	initted bool          // Whether the plugin has been initialized.
}
// VertexAI is a Genkit plugin for interacting with the Google Vertex AI service.
type VertexAI struct {
	ProjectID string // Google Cloud project to use for Vertex AI. If empty, the value of the environment variable GOOGLE_CLOUD_PROJECT will be consulted.
	Location  string // Location of the Vertex AI service. If empty, GOOGLE_CLOUD_LOCATION and GOOGLE_CLOUD_REGION environment variables will be consulted, in that order.
	gclient *genai.Client // Client for the Vertex AI service.
	mu      sync.Mutex    // Mutex to control access.
	initted bool          // Whether the plugin has been initialized.
}
// Name returns the name of the plugin.
func (ga *GoogleAI) Name() string {
	return googleAIProvider
}
// Name returns the name of the plugin.
func (v *VertexAI) Name() string {
	return vertexAIProvider
}
// Init initializes the Google AI plugin and all known models and embedders.
// After calling Init, you may call [DefineModel] and [DefineEmbedder] to create
// and register any additional generative models and embedders
func (ga *GoogleAI) Init(ctx context.Context) []api.Action {
	if ga == nil {
		ga = &GoogleAI{}
	}
	ga.mu.Lock()
	defer ga.mu.Unlock()
	if ga.initted {
		panic("plugin already initialized")
	}
	apiKey := ga.APIKey
	if apiKey == "" {
		apiKey = os.Getenv("GEMINI_API_KEY")
		if apiKey == "" {
			apiKey = os.Getenv("GOOGLE_API_KEY")
		}
		if apiKey == "" {
			panic("Google AI requires setting GEMINI_API_KEY or GOOGLE_API_KEY in the environment. You can get an API key at https://ai.google.dev")
		}
	}
	gc := genai.ClientConfig{
		Backend: genai.BackendGeminiAPI,
		APIKey:  apiKey,
		HTTPClient: &http.Client{
			Transport: otelhttp.NewTransport(http.DefaultTransport),
		},
		HTTPOptions: genai.HTTPOptions{
			Headers: genkitClientHeader,
		},
	}
	client, err := genai.NewClient(ctx, &gc)
	if err != nil {
		panic(fmt.Errorf("GoogleAI.Init: %w", err))
	}
	ga.gclient = client
	ga.initted = true
	return []api.Action{}
}
// Init initializes the VertexAI plugin and all known models and embedders.
// After calling Init, you may call [DefineModel] and [DefineEmbedder] to create
// and register any additional generative models and embedders
func (v *VertexAI) Init(ctx context.Context) []api.Action {
	if v == nil {
		v = &VertexAI{}
	}
	v.mu.Lock()
	defer v.mu.Unlock()
	if v.initted {
		panic("plugin already initialized")
	}
	projectID := v.ProjectID
	if projectID == "" {
		projectID = os.Getenv("GOOGLE_CLOUD_PROJECT")
		if projectID == "" {
			panic("Vertex AI requires setting GOOGLE_CLOUD_PROJECT in the environment. You can get a project ID at https://console.cloud.google.com/home/dashboard")
		}
	}
	location := v.Location
	if location == "" {
		location = os.Getenv("GOOGLE_CLOUD_LOCATION")
		if location == "" {
			location = os.Getenv("GOOGLE_CLOUD_REGION")
		}
		if location == "" {
			panic("Vertex AI requires setting GOOGLE_CLOUD_LOCATION or GOOGLE_CLOUD_REGION in the environment. You can get a location at https://cloud.google.com/vertex-ai/docs/general/locations")
		}
	}
	cred, err := credentials.DetectDefault(&credentials.DetectOptions{
		Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"},
	})
	if err != nil {
		panic(fmt.Errorf("failed to find default credentials: %w", err))
	}
	quotaProjectID, err := cred.QuotaProjectID(ctx)
	if err != nil {
		panic(fmt.Errorf("failed to get quota project ID: %v", quotaProjectID))
	}
	httpClient, err := httptransport.NewClient(&httptransport.Options{
		Credentials:      cred,
		BaseRoundTripper: otelhttp.NewTransport(http.DefaultTransport),
		Headers: http.Header{
			"X-Goog-User-Project": []string{quotaProjectID},
		},
	})
	if err != nil {
		panic(fmt.Errorf("failed to create http client: %w", err))
	}
	// Project and Region values gets validated by genai SDK upon client creation
	gc := genai.ClientConfig{
		Backend:    genai.BackendVertexAI,
		Project:    v.ProjectID,
		Location:   v.Location,
		HTTPClient: httpClient,
		HTTPOptions: genai.HTTPOptions{
			Headers: genkitClientHeader,
		},
	}
	client, err := genai.NewClient(ctx, &gc)
	if err != nil {
		panic(fmt.Errorf("VertexAI.Init: %w", err))
	}
	v.gclient = client
	v.initted = true
	return []api.Action{}
}
// DefineModel defines an unknown model with the given name.
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func (ga *GoogleAI) DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOptions) (ai.Model, error) {
	ga.mu.Lock()
	defer ga.mu.Unlock()
	if !ga.initted {
		return nil, errors.New("GoogleAI plugin not initialized")
	}
	models, err := listModels(googleAIProvider)
	if err != nil {
		return nil, err
	}
	if opts == nil {
		var ok bool
		modelOpts, ok := models[name]
		if !ok {
			return nil, fmt.Errorf("GoogleAI.DefineModel: called with unknown model %q and nil ModelOptions", name)
		}
		opts = &modelOpts
	}
	return newModel(ga.gclient, name, *opts), nil
}
// DefineModel defines an unknown model with the given name.
// The second argument describes the capability of the model.
// Use [IsDefinedModel] to determine if a model is already defined.
// After [Init] is called, only the known models are defined.
func (v *VertexAI) DefineModel(g *genkit.Genkit, name string, opts *ai.ModelOptions) (ai.Model, error) {
	v.mu.Lock()
	defer v.mu.Unlock()
	if !v.initted {
		return nil, errors.New("VertexAI plugin not initialized")
	}
	models, err := listModels(vertexAIProvider)
	if err != nil {
		return nil, err
	}
	if opts == nil {
		var ok bool
		modelOpts, ok := models[name]
		if !ok {
			return nil, fmt.Errorf("VertexAI.DefineModel: called with unknown model %q and nil ModelOptions", name)
		}
		opts = &modelOpts
	}
	return newModel(v.gclient, name, *opts), nil
}
// DefineEmbedder defines an embedder with a given name.
func (ga *GoogleAI) DefineEmbedder(g *genkit.Genkit, name string, embedOpts *ai.EmbedderOptions) (ai.Embedder, error) {
	ga.mu.Lock()
	defer ga.mu.Unlock()
	if !ga.initted {
		return nil, errors.New("GoogleAI plugin not initialized")
	}
	return newEmbedder(ga.gclient, name, embedOpts), nil
}
// DefineEmbedder defines an embedder with a given name.
func (v *VertexAI) DefineEmbedder(g *genkit.Genkit, name string, embedOpts *ai.EmbedderOptions) (ai.Embedder, error) {
	v.mu.Lock()
	defer v.mu.Unlock()
	if !v.initted {
		return nil, errors.New("VertexAI plugin not initialized")
	}
	return newEmbedder(v.gclient, name, embedOpts), nil
}
// IsDefinedEmbedder reports whether the named [Embedder] is defined by this plugin.
func (ga *GoogleAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool {
	return genkit.LookupEmbedder(g, api.NewName(googleAIProvider, name)) != nil
}
// IsDefinedEmbedder reports whether the named [Embedder] is defined by this plugin.
func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool {
	return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil
}
// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given name and configuration.
func GoogleAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef {
	return ai.NewModelRef(googleAIProvider+"/"+name, config)
}
// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given name and configuration.
func VertexAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef {
	return ai.NewModelRef(vertexAIProvider+"/"+name, config)
}
// GoogleAIModel returns the [ai.Model] with the given name.
// It returns nil if the model was not defined.
func GoogleAIModel(g *genkit.Genkit, name string) ai.Model {
	return genkit.LookupModel(g, api.NewName(googleAIProvider, name))
}
// VertexAIModel returns the [ai.Model] with the given name.
// It returns nil if the model was not defined.
func VertexAIModel(g *genkit.Genkit, name string) ai.Model {
	return genkit.LookupModel(g, api.NewName(vertexAIProvider, name))
}
// GoogleAIEmbedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not defined.
func GoogleAIEmbedder(g *genkit.Genkit, name string) ai.Embedder {
	return genkit.LookupEmbedder(g, api.NewName(googleAIProvider, name))
}
// VertexAIEmbedder returns the [ai.Embedder] with the given name.
// It returns nil if the embedder was not defined.
func VertexAIEmbedder(g *genkit.Genkit, name string) ai.Embedder {
	return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name))
}
// ListActions lists all the actions supported by the Google AI plugin.
func (ga *GoogleAI) ListActions(ctx context.Context) []api.ActionDesc {
	models, err := listGenaiModels(ctx, ga.gclient)
	if err != nil {
		return nil
	}
	actions := []api.ActionDesc{}
	// Generative models.
	for _, name := range models.gemini {
		var opts ai.ModelOptions
		if knownOpts, ok := supportedGeminiModels[name]; ok {
			opts = knownOpts
			opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label)
		} else {
			opts = defaultGeminiOpts
			opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)
		}
		model := newModel(ga.gclient, name, opts)
		if actionDef, ok := model.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	// Imagen models.
	for _, name := range models.imagen {
		var opts ai.ModelOptions
		if knownOpts, ok := supportedImagenModels[name]; ok {
			opts = knownOpts
			opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, opts.Label)
		} else {
			opts = defaultImagenOpts
			opts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, name)
		}
		model := newModel(ga.gclient, name, opts)
		if actionDef, ok := model.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	// Embedders.
	for _, e := range models.embedders {
		var embedOpts ai.EmbedderOptions
		if knownOpts, ok := googleAIEmbedderConfig[e]; ok {
			embedOpts = knownOpts
		} else {
			embedOpts = defaultEmbedOpts
			embedOpts.Label = fmt.Sprintf("%s - %s", googleAILabelPrefix, e)
		}
		embedder := newEmbedder(ga.gclient, e, &embedOpts)
		if actionDef, ok := embedder.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	return actions
}
// ResolveAction resolves an action with the given name.
func (ga *GoogleAI) ResolveAction(atype api.ActionType, name string) api.Action {
	switch atype {
	case api.ActionTypeEmbedder:
		return newEmbedder(ga.gclient, name, &ai.EmbedderOptions{}).(api.Action)
	case api.ActionTypeModel:
		var supports *ai.ModelSupports
		var config any
		// TODO: Add veo case.
		switch {
		case strings.Contains(name, "imagen"):
			supports = &Media
			config = &genai.GenerateImagesConfig{}
		default:
			supports = &Multimodal
			config = &genai.GenerateContentConfig{}
		}
		return newModel(ga.gclient, name, ai.ModelOptions{
			Label:        fmt.Sprintf("%s - %s", googleAILabelPrefix, name),
			Stage:        ai.ModelStageStable,
			Versions:     []string{},
			Supports:     supports,
			ConfigSchema: configToMap(config),
		}).(api.Action)
	}
	return nil
}
// ListActions lists all the actions supported by the Vertex AI plugin.
func (v *VertexAI) ListActions(ctx context.Context) []api.ActionDesc {
	models, err := listGenaiModels(ctx, v.gclient)
	if err != nil {
		return nil
	}
	actions := []api.ActionDesc{}
	// Gemini generative models.
	for _, name := range models.gemini {
		var opts ai.ModelOptions
		if knownOpts, ok := supportedGeminiModels[name]; ok {
			opts = knownOpts
			opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label)
		} else {
			opts = defaultGeminiOpts
			opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
		}
		model := newModel(v.gclient, name, opts)
		if actionDef, ok := model.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	// Imagen models.
	for _, name := range models.imagen {
		var opts ai.ModelOptions
		if knownOpts, ok := supportedImagenModels[name]; ok {
			opts = knownOpts
			opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, opts.Label)
		} else {
			opts = defaultImagenOpts
			opts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, name)
		}
		model := newModel(v.gclient, name, opts)
		if actionDef, ok := model.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	// Embedders.
	for _, e := range models.embedders {
		var embedOpts ai.EmbedderOptions
		if knownOpts, ok := googleAIEmbedderConfig[e]; ok {
			embedOpts = knownOpts
		} else {
			embedOpts = defaultEmbedOpts
			embedOpts.Label = fmt.Sprintf("%s - %s", vertexAILabelPrefix, e)
		}
		embedder := newEmbedder(v.gclient, e, &embedOpts)
		if actionDef, ok := embedder.(api.Action); ok {
			actions = append(actions, actionDef.Desc())
		}
	}
	return actions
}
// ResolveAction resolves an action with the given name.
func (v *VertexAI) ResolveAction(atype api.ActionType, id string) api.Action {
	switch atype {
	case api.ActionTypeEmbedder:
		return newEmbedder(v.gclient, id, &ai.EmbedderOptions{}).(api.Action)
	case api.ActionTypeModel:
		var supports *ai.ModelSupports
		var config any
		// TODO: Add veo case.
		switch {
		case strings.Contains(id, "imagen"):
			supports = &Media
			config = &genai.GenerateImagesConfig{}
		default:
			supports = &Multimodal
			config = &genai.GenerateContentConfig{}
		}
		return newModel(v.gclient, id, ai.ModelOptions{
			Label:        fmt.Sprintf("%s - %s", vertexAILabelPrefix, id),
			Stage:        ai.ModelStageStable,
			Versions:     []string{},
			Supports:     supports,
			ConfigSchema: configToMap(config),
		}).(api.Action)
	}
	return nil
}