Genkit MCP

Apache 2.0
/** * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; import { runInNewSpan } from '@genkit-ai/core/tracing'; import * as assert from 'assert'; import { generate } from '../generate'; import { ModelAction } from '../model'; import { defineTool } from '../tool'; const tests: Record<string, TestCase> = { 'basic hi': async (registry: Registry, model: string) => { const response = await generate(registry, { model, prompt: 'just say "Hi", literally', }); const got = response.text.trim(); assert.match(got, /Hi/i); }, multimodal: async (registry: Registry, model: string) => { const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.media) { skip(); } const response = await generate(registry, { model, prompt: [ { media: { url: '', }, }, { text: 'what math operation is this? plus, minus, multiply or divide?', }, ], }); const want = /plus/i; const got = response.text.trim(); assert.match(got, want); }, history: async (registry: Registry, model: string) => { const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.multiturn) { skip(); } const response1 = await generate(registry, { model, prompt: 'My name is Glorb', }); const response = await generate(registry, { model, prompt: "What's my name?", messages: response1.messages, }); const got = response.text.trim(); assert.match(got, /Glorb/); }, 'system prompt': async (registry: Registry, model: string) => { const { text } = await generate(registry, { model, prompt: 'Hi', messages: [ { role: 'system', content: [ { text: 'If the user says "Hi", just say "Bye" ', }, ], }, ], }); const want = 'Bye'; const got = text.trim(); assert.equal(got, want); }, 'structured output': async (registry: Registry, model: string) => { const response = await generate(registry, { model, prompt: 'extract data as json from: Jack was a Lumberjack', output: { format: 'json', schema: z.object({ name: z.string(), occupation: z.string(), }), }, }); const want = { name: 'Jack', occupation: 'Lumberjack', }; const got = response.output; assert.deepEqual(want, got); }, 'tool calling': async (registry: Registry, model: string) => { const resolvedModel = (await registry.lookupAction( `/model/${model}` )) as ModelAction; if (!resolvedModel.__action.metadata?.model.supports?.tools) { skip(); } const { text } = await generate(registry, { model, prompt: 'what is a gablorken of 2? use provided tool', tools: ['gablorkenTool'], }); const got = text.trim(); assert.match(got, /9.407/); }, }; type TestReport = { description: string; models: { name: string; passed: boolean; skipped?: boolean; error?: { message: string; stack?: string; }; }[]; }[]; type TestCase = (ai: Registry, model: string) => Promise<void>; export async function testModels( registry: Registry, models: string[] ): Promise<TestReport> { defineTool( registry, { name: 'gablorkenTool', description: 'use when need to calculate a gablorken', inputSchema: z.object({ value: z.number(), }), outputSchema: z.number(), }, async (input) => { return Math.pow(input.value, 3) + 1.407; } ); return await runInNewSpan( registry, { metadata: { name: 'testModels' } }, async () => { const report: TestReport = []; for (const test of Object.keys(tests)) { await runInNewSpan(registry, { metadata: { name: test } }, async () => { report.push({ description: test, models: [], }); const caseReport = report[report.length - 1]; for (const model of models) { caseReport.models.push({ name: model, passed: true, // optimistically }); const modelReport = caseReport.models[caseReport.models.length - 1]; try { await tests[test](registry, model); } catch (e) { modelReport.passed = false; if (e instanceof SkipTestError) { modelReport.skipped = true; } else if (e instanceof Error) { modelReport.error = { message: e.message, stack: e.stack, }; } else { modelReport.error = { message: `${e}`, }; } } } }); } return report; } ); } class SkipTestError extends Error {} function skip() { throw new SkipTestError(); }