evaluators.go•6.26 kB
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
// Package evaluators defines a set of Genkit Evaluators for popular use-cases
package evaluators
import (
	"context"
	"errors"
	"fmt"
	"reflect"
	"regexp"
	"sync"
	jsonata "github.com/blues/jsonata-go"
	"github.com/firebase/genkit/go/ai"
	"github.com/firebase/genkit/go/core/api"
	"github.com/firebase/genkit/go/core/logger"
)
const provider = "genkitEval"
// EvaluatorType is an enum used to indicate the type of evaluator being
// configured for use
type EvaluatorType int
const (
	EvaluatorDeepEqual EvaluatorType = iota
	EvaluatorRegex
	EvaluatorJsonata
)
var evaluatorTypeName = map[EvaluatorType]string{
	EvaluatorDeepEqual: "DEEP_EQUAL",
	EvaluatorRegex:     "REGEX",
	EvaluatorJsonata:   "JSONATA",
}
func (ss EvaluatorType) String() string {
	return evaluatorTypeName[ss]
}
// MetricConfig provides configuration options for a specific metric. More
// Params (judge LLMs, etc.) could be configured by extending this struct
type MetricConfig struct {
	MetricType EvaluatorType
}
// GenkitEval is a Genkit plugin that provides evaluators
type GenkitEval struct {
	Metrics []MetricConfig // Configs for individual metrics
	initted bool           // Whether the plugin has been initialized
	mu      sync.Mutex     // Mutex to manage locks
}
func (ge *GenkitEval) Name() string {
	return provider
}
// Init initializes the plugin.
func (ge *GenkitEval) Init(ctx context.Context) []api.Action {
	if ge == nil {
		ge = &GenkitEval{}
	}
	ge.mu.Lock()
	defer ge.mu.Unlock()
	if ge.initted {
		panic("genkitEval.Init already called")
	}
	if ge == nil || len(ge.Metrics) == 0 {
		panic("genkitEval: need to configure at least one metric")
	}
	ge.initted = true
	var actions []api.Action
	for _, metric := range ge.Metrics {
		actions = append(actions, ConfigureMetric(metric).(api.Action))
	}
	return actions
}
func ConfigureMetric(metric MetricConfig) ai.Evaluator {
	switch metric.MetricType {
	case EvaluatorDeepEqual:
		return configureDeepEqualEvaluator()
	case EvaluatorJsonata:
		return configureJsonataEvaluator()
	case EvaluatorRegex:
		return configureRegexEvaluator()
	default:
		panic(fmt.Sprintf("Unsupported genkitEval metric type: %s", metric.MetricType.String()))
	}
}
func configureRegexEvaluator() ai.Evaluator {
	evalOptions := ai.EvaluatorOptions{
		DisplayName: "RegExp",
		Definition:  "Tests output against the regexp provided as reference",
		IsBilled:    false,
	}
	return ai.NewEvaluator(api.NewName(provider, "regex"), &evalOptions, func(ctx context.Context, req *ai.EvaluatorCallbackRequest) (*ai.EvaluatorCallbackResponse, error) {
		dataPoint := req.Input
		var score ai.Score
		if dataPoint.Output == nil {
			return nil, errors.New("output was not provided")
		}
		if dataPoint.Reference == nil {
			return nil, errors.New("reference was not provided")
		}
		if reflect.TypeOf(dataPoint.Reference).String() != "string" {
			return nil, errors.New("reference must be a string (regex)")
		}
		if reflect.TypeOf(dataPoint.Output).String() == "string" {
			// Test against provided regexp
			match, _ := regexp.MatchString((dataPoint.Reference).(string), (dataPoint.Output).(string))
			status := ai.ScoreStatusUnknown
			if match {
				status = ai.ScoreStatusPass
			} else {
				status = ai.ScoreStatusFail
			}
			score = ai.Score{
				Score:  match,
				Status: status.String(),
			}
		} else {
			// Mark as failed if output is not string type
			logger.FromContext(ctx).Debug("genkitEval",
				"regex", fmt.Sprintf("Failed regex evaluation, as output is not string api. TestCaseId: %s", dataPoint.TestCaseId))
			score = ai.Score{
				Score:  false,
				Status: ai.ScoreStatusFail.String(),
			}
		}
		callbackResponse := ai.EvaluatorCallbackResponse{
			TestCaseId: req.Input.TestCaseId,
			Evaluation: []ai.Score{score},
		}
		return &callbackResponse, nil
	})
}
func configureDeepEqualEvaluator() ai.Evaluator {
	evalOptions := ai.EvaluatorOptions{
		DisplayName: "Deep Equal",
		Definition:  "Tests equality of output against the provided reference",
		IsBilled:    false,
	}
	return ai.NewEvaluator(api.NewName(provider, "deep_equal"), &evalOptions, func(ctx context.Context, req *ai.EvaluatorCallbackRequest) (*ai.EvaluatorCallbackResponse, error) {
		dataPoint := req.Input
		var score ai.Score
		if dataPoint.Output == nil {
			return nil, errors.New("output was not provided")
		}
		if dataPoint.Reference == nil {
			return nil, errors.New("reference was not provided")
		}
		deepEqual := reflect.DeepEqual(dataPoint.Reference, dataPoint.Output)
		status := ai.ScoreStatusUnknown
		if deepEqual {
			status = ai.ScoreStatusPass
		} else {
			status = ai.ScoreStatusFail
		}
		score = ai.Score{
			Score:  deepEqual,
			Status: status.String(),
		}
		callbackResponse := ai.EvaluatorCallbackResponse{
			TestCaseId: req.Input.TestCaseId,
			Evaluation: []ai.Score{score},
		}
		return &callbackResponse, nil
	})
}
func configureJsonataEvaluator() ai.Evaluator {
	evalOptions := ai.EvaluatorOptions{
		DisplayName: "JSONata",
		Definition:  "Tests JSONata expression (provided in reference) against output",
		IsBilled:    false,
	}
	return ai.NewEvaluator(api.NewName(provider, "jsonata"), &evalOptions, func(ctx context.Context, req *ai.EvaluatorCallbackRequest) (*ai.EvaluatorCallbackResponse, error) {
		dataPoint := req.Input
		var score ai.Score
		if dataPoint.Output == nil {
			return nil, errors.New("output was not provided")
		}
		if dataPoint.Reference == nil {
			return nil, errors.New("reference was not provided")
		}
		if reflect.TypeOf(dataPoint.Reference).String() != "string" {
			return nil, errors.New("reference must be a string (jsonata)")
		}
		// Test against provided jsonata
		exp := jsonata.MustCompile((dataPoint.Reference).(string))
		res, err := exp.Eval(dataPoint.Output)
		if err != nil {
			return nil, err
		}
		status := ai.ScoreStatusUnknown
		if res == false || res == "" || res == nil {
			status = ai.ScoreStatusFail
		} else {
			status = ai.ScoreStatusPass
		}
		score = ai.Score{
			Score:  res,
			Status: status.String(),
		}
		callbackResponse := ai.EvaluatorCallbackResponse{
			TestCaseId: req.Input.TestCaseId,
			Evaluation: []ai.Score{score},
		}
		return &callbackResponse, nil
	})
}