tools.go•10.6 kB
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ai
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/api"
"github.com/firebase/genkit/go/internal/base"
)
var resumedCtxKey = base.NewContextKey[map[string]any]()
var origInputCtxKey = base.NewContextKey[any]()
// ToolFunc is the function type for tool implementations.
type ToolFunc[In, Out any] = func(ctx *ToolContext, input In) (Out, error)
// ToolRef is a reference to a tool.
type ToolRef interface {
Name() string
}
// ToolName is a distinct type for a tool name.
// It is meant to be passed where a ToolRef is expected but no Tool is had.
type ToolName string
// Name returns the name of the tool.
func (t ToolName) Name() string {
return (string)(t)
}
// tool is an action with functions specific to tools.
// It embeds [core.Action] instead of [core.ActionDef] like other primitives
// because the inputs/outputs can vary and the tool is only meant to be called
// with JSON input anyway.
type tool struct {
api.Action
}
// Tool represents a tool that can be called by a model.
type Tool interface {
// Name returns the name of the tool.
Name() string
// Definition returns the definition for this tool to be passed to models.
Definition() *ToolDefinition
// RunRaw runs this tool using the provided raw input.
RunRaw(ctx context.Context, input any) (any, error)
// Respond constructs a [Part] with a [ToolResponse] for a given interrupted tool request.
Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part
// Restart constructs a [Part] with a new [ToolRequest] to re-trigger a tool,
// potentially with new input and metadata.
Restart(toolReq *Part, opts *RestartOptions) *Part
// Register registers the tool with the given registry.
Register(r api.Registry)
}
// toolInterruptError represents an intentional interruption of tool execution.
type toolInterruptError struct {
Metadata map[string]any
}
func (e *toolInterruptError) Error() string {
return "tool execution interrupted"
}
// IsToolInterruptError determines whether the error is an interrupt error returned by the tool.
func IsToolInterruptError(err error) (bool, map[string]any) {
var tie *toolInterruptError
if errors.As(err, &tie) {
return true, tie.Metadata
}
return false, nil
}
// InterruptOptions provides configuration for tool interruption.
type InterruptOptions struct {
Metadata map[string]any
}
// RestartOptions provides configuration options for restarting a tool.
type RestartOptions struct {
// ReplaceInput allows replacing the existing input arguments to the tool with different ones,
// for example if the user revised an action before confirming. When input is replaced,
// the existing tool request will be amended in the message history.
ReplaceInput any
// ResumedMetadata is the metadata you want to provide to the tool to aide in reprocessing.
// Defaults to true if none is supplied.
ResumedMetadata any
}
// RespondOptions provides configuration options for responding to a tool request.
type RespondOptions struct {
// Metadata is additional metadata to include in the response.
Metadata map[string]any
}
// ToolContext provides context and utility functions for tool execution.
type ToolContext struct {
context.Context
// Interrupt is a function that can be used to interrupt the tool execution.
// Interrupting tool execution returns the control to the caller with the
// total model response so far.
Interrupt func(opts *InterruptOptions) error
// Resumed is optional metadata that can be used to resume the tool execution.
// Map is not nil only if the tool was interrupted.
Resumed map[string]any
// OriginalInput is the original input to the tool if the tool was interrupted, otherwise nil.
OriginalInput any
}
// DefineTool creates a new [Tool] and registers it.
func DefineTool[In, Out any](
r api.Registry,
name, description string,
fn ToolFunc[In, Out],
) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, nil, wrappedFn)
return &tool{Action: toolAction}
}
// DefineToolWithInputSchema creates a new [Tool] with a custom input schema and registers it.
func DefineToolWithInputSchema[Out any](
r api.Registry,
name, description string,
inputSchema map[string]any,
fn ToolFunc[any, Out],
) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
toolAction := core.DefineAction(r, name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
return &tool{Action: toolAction}
}
// NewTool creates a new [Tool]. It can be passed directly to [Generate].
func NewTool[In, Out any](name, description string, fn ToolFunc[In, Out]) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, nil, wrappedFn)
return &tool{Action: toolAction}
}
// NewToolWithInputSchema creates a new [Tool] with a custom input schema. It can be passed directly to [Generate].
func NewToolWithInputSchema[Out any](name, description string, inputSchema map[string]any, fn ToolFunc[any, Out]) Tool {
metadata, wrappedFn := implementTool(name, description, fn)
metadata["dynamic"] = true
toolAction := core.NewAction(name, api.ActionTypeTool, metadata, inputSchema, wrappedFn)
return &tool{Action: toolAction}
}
// implementTool creates the metadata and wrapped function common to both DefineTool and NewTool.
func implementTool[In, Out any](name, description string, fn ToolFunc[In, Out]) (map[string]any, func(context.Context, In) (Out, error)) {
metadata := map[string]any{
"type": api.ActionTypeTool,
"name": name,
"description": description,
}
wrappedFn := func(ctx context.Context, input In) (Out, error) {
toolCtx := &ToolContext{
Context: ctx,
Interrupt: func(opts *InterruptOptions) error {
return &toolInterruptError{
Metadata: opts.Metadata,
}
},
Resumed: resumedCtxKey.FromContext(ctx),
OriginalInput: origInputCtxKey.FromContext(ctx),
}
return fn(toolCtx, input)
}
return metadata, wrappedFn
}
// Name returns the name of the tool.
func (t *tool) Name() string {
return t.Action.Name()
}
// Definition returns [ToolDefinition] for for this tool.
func (t *tool) Definition() *ToolDefinition {
desc := t.Action.Desc()
return &ToolDefinition{
Name: desc.Name,
Description: desc.Description,
InputSchema: desc.InputSchema,
OutputSchema: desc.OutputSchema,
}
}
// RunRaw runs this tool using the provided raw map format data (JSON parsed
// as map[string]any).
func (t *tool) RunRaw(ctx context.Context, input any) (any, error) {
if t == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "Tool.RunRaw: tool called on a nil tool; check that all tools are defined")
}
mi, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("error marshalling tool input for %v: %v", t.Name(), err)
}
output, err := t.RunJSON(ctx, mi, nil)
if err != nil {
return nil, fmt.Errorf("error calling tool %v: %w", t.Name(), err)
}
var uo any
err = json.Unmarshal(output, &uo)
if err != nil {
return nil, fmt.Errorf("error parsing tool output for %v: %v", t.Name(), err)
}
return uo, nil
}
// LookupTool looks up the tool in the registry by provided name and returns it.
func LookupTool(r api.Registry, name string) Tool {
if name == "" {
return nil
}
provider, id := api.ParseName(name)
key := api.NewKey(api.ActionTypeTool, provider, id)
action := r.ResolveAction(key)
if action == nil {
return nil
}
return &tool{Action: action}
}
// Respond creates a tool response for an interrupted tool call to pass to the [WithToolResponses] option to [Generate].
// If the part provided is not a tool request, it returns nil.
func (t *tool) Respond(toolReq *Part, output any, opts *RespondOptions) *Part {
if toolReq == nil || !toolReq.IsToolRequest() {
return nil
}
if opts == nil {
opts = &RespondOptions{}
}
newToolResp := NewResponseForToolRequest(toolReq, output)
newToolResp.Metadata = map[string]any{
"interruptResponse": true,
}
if opts.Metadata != nil {
newToolResp.Metadata["interruptResponse"] = opts.Metadata
}
return newToolResp
}
// Restart creates a tool request for an interrupted tool call to pass to the [WithToolRestarts] option to [Generate].
// If the part provided is not a tool request, it returns nil.
func (t *tool) Restart(p *Part, opts *RestartOptions) *Part {
if p == nil || !p.IsToolRequest() {
return nil
}
if opts == nil {
opts = &RestartOptions{}
}
newInput := p.ToolRequest.Input
var originalInput any
if opts.ReplaceInput != nil {
originalInput = newInput
newInput = opts.ReplaceInput
}
newMeta := maps.Clone(p.Metadata)
if newMeta == nil {
newMeta = make(map[string]any)
}
newMeta["resumed"] = true
if opts.ResumedMetadata != nil {
newMeta["resumed"] = opts.ResumedMetadata
}
if originalInput != nil {
newMeta["replacedInput"] = originalInput
}
delete(newMeta, "interrupt")
newToolReq := NewToolRequestPart(&ToolRequest{
Name: p.ToolRequest.Name,
Ref: p.ToolRequest.Ref,
Input: newInput,
})
newToolReq.Metadata = newMeta
return newToolReq
}
// resolveUniqueTools resolves the list of tool refs to a list of all tool names and new tools that must be registered.
// Returns an error if there are tool refs with duplicate names.
func resolveUniqueTools(r api.Registry, toolRefs []ToolRef) (toolNames []string, newTools []Tool, err error) {
toolMap := make(map[string]bool)
for _, toolRef := range toolRefs {
name := toolRef.Name()
if toolMap[name] {
return nil, nil, core.NewError(core.INVALID_ARGUMENT, "duplicate tool %q", name)
}
toolMap[name] = true
toolNames = append(toolNames, name)
if LookupTool(r, name) == nil {
if tool, ok := toolRef.(Tool); ok {
newTools = append(newTools, tool)
}
}
}
return toolNames, newTools, nil
}