mcp-netbird
by aantti
Verified
package mcpnetbird
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/invopop/jsonschema"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// Tool is a struct that represents a tool definition and the function used
// to handle tool calls.
//
// The simplest way to create a Tool is to use `MustTool`, or `ConvertTool`
// if you wish to create tools at runtime and need to handle errors without
// panicking.
type Tool struct {
Tool mcp.Tool
Handler server.ToolHandlerFunc
}
// Register adds the Tool to the given MCPServer.
//
// It is a convenience method that calls `server.MCPServer.Register` with the
// Tool's Tool and Handler fields, allowing you to add the tool in a single
// statement:
//
// mcpnetbird.MustTool(name, description, toolHandler).Register(server)
func (t *Tool) Register(mcp *server.MCPServer) {
mcp.AddTool(t.Tool, t.Handler)
}
// MustTool creates a new Tool from the given name, description, and toolHandler.
// It panics if the tool cannot be created.
func MustTool[T any, R any](name, description string, toolHandler ToolHandlerFunc[T, R]) Tool {
tool, handler, err := ConvertTool(name, description, toolHandler)
if err != nil {
panic(err)
}
return Tool{Tool: tool, Handler: handler}
}
// ToolHandlerFunc is the type of a handler function for a tool.
type ToolHandlerFunc[T any, R any] = func(ctx context.Context, request T) (R, error)
// ConvertTool converts a toolHandler function to a Tool and ToolHandlerFunc.
//
// The toolHandler function must have two arguments: a context.Context and a struct
// to be used as the parameters for the tool. The second argument must not be a pointer,
// should be marshalable to JSON, and the fields should have a `jsonschema` tag with the
// description of the parameter.
func ConvertTool[T any, R any](name, description string, toolHandler ToolHandlerFunc[T, R]) (mcp.Tool, server.ToolHandlerFunc, error) {
zero := mcp.Tool{}
handlerValue := reflect.ValueOf(toolHandler)
handlerType := handlerValue.Type()
if handlerType.Kind() != reflect.Func {
return zero, nil, errors.New("tool handler must be a function")
}
if handlerType.NumIn() != 2 {
return zero, nil, errors.New("tool handler must have 2 arguments")
}
if handlerType.NumOut() != 2 {
return zero, nil, errors.New("tool handler must return 2 values")
}
if handlerType.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() {
return zero, nil, errors.New("tool handler first argument must be context.Context")
}
// We no longer check the type of the first return value
if handlerType.Out(1).Kind() != reflect.Interface {
return zero, nil, errors.New("tool handler second return value must be error")
}
argType := handlerType.In(1)
if argType.Kind() != reflect.Struct {
return zero, nil, errors.New("tool handler second argument must be a struct")
}
handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
s, err := json.Marshal(request.Params.Arguments)
if err != nil {
return nil, fmt.Errorf("marshal args: %w", err)
}
unmarshaledArgs := reflect.New(argType).Interface()
if err := json.Unmarshal([]byte(s), unmarshaledArgs); err != nil {
return nil, fmt.Errorf("unmarshal args: %s", err)
}
// Need to dereference the unmarshaled arguments
of := reflect.ValueOf(unmarshaledArgs)
if of.Kind() != reflect.Ptr || !of.Elem().CanInterface() {
return nil, errors.New("arguments must be a struct")
}
args := []reflect.Value{reflect.ValueOf(ctx), of.Elem()}
output := handlerValue.Call(args)
if len(output) != 2 {
return nil, errors.New("tool handler must return 2 values")
}
if !output[0].CanInterface() {
return nil, errors.New("tool handler first return value must be interfaceable")
}
// Handle the error return value first
var handlerErr error
var ok bool
if output[1].Kind() == reflect.Interface && !output[1].IsNil() {
handlerErr, ok = output[1].Interface().(error)
if !ok {
return nil, errors.New("tool handler second return value must be error")
}
}
// If there's an error, return nil result and the error
if handlerErr != nil {
return nil, handlerErr
}
// Check if the first return value is nil (only for pointer, interface, map, etc.)
isNilable := output[0].Kind() == reflect.Ptr ||
output[0].Kind() == reflect.Interface ||
output[0].Kind() == reflect.Map ||
output[0].Kind() == reflect.Slice ||
output[0].Kind() == reflect.Chan ||
output[0].Kind() == reflect.Func
if isNilable && output[0].IsNil() {
return nil, nil
}
returnVal := output[0].Interface()
returnType := output[0].Type()
// Case 1: Already a *mcp.CallToolResult
if callResult, ok := returnVal.(*mcp.CallToolResult); ok {
return callResult, nil
}
// Case 2: An mcp.CallToolResult (not a pointer)
if returnType.ConvertibleTo(reflect.TypeOf(mcp.CallToolResult{})) {
callResult := returnVal.(mcp.CallToolResult)
return &callResult, nil
}
// Case 3: String or *string
if str, ok := returnVal.(string); ok {
if str == "" {
return nil, nil
}
return mcp.NewToolResultText(str), nil
}
if strPtr, ok := returnVal.(*string); ok {
if strPtr == nil || *strPtr == "" {
return nil, nil
}
return mcp.NewToolResultText(*strPtr), nil
}
// Case 4: Any other type - marshal to JSON
jsonBytes, err := json.Marshal(returnVal)
if err != nil {
return nil, fmt.Errorf("failed to marshal return value: %s", err)
}
return mcp.NewToolResultText(string(jsonBytes)), nil
}
jsonSchema := createJSONSchemaFromHandler(toolHandler)
properties := make(map[string]any, jsonSchema.Properties.Len())
for pair := jsonSchema.Properties.Oldest(); pair != nil; pair = pair.Next() {
properties[pair.Key] = pair.Value
}
inputSchema := mcp.ToolInputSchema{
Type: jsonSchema.Type,
Properties: properties,
Required: jsonSchema.Required,
}
return mcp.Tool{
Name: name,
Description: description,
InputSchema: inputSchema,
}, handler, nil
}
// Creates a full JSON schema from a user provided handler by introspecting the arguments
func createJSONSchemaFromHandler(handler any) *jsonschema.Schema {
handlerValue := reflect.ValueOf(handler)
handlerType := handlerValue.Type()
argumentType := handlerType.In(1)
inputSchema := jsonSchemaReflector.ReflectFromType(argumentType)
return inputSchema
}
var (
jsonSchemaReflector = jsonschema.Reflector{
BaseSchemaID: "",
Anonymous: true,
AssignAnchor: false,
AllowAdditionalProperties: true,
RequiredFromJSONSchemaTags: true,
DoNotReference: true,
ExpandedStruct: true,
FieldNameTag: "",
IgnoredTypes: nil,
Lookup: nil,
Mapper: nil,
Namer: nil,
KeyNamer: nil,
AdditionalFields: nil,
CommentMap: nil,
}
)