/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { GenkitError, StatusName } from '@genkit-ai/core';
import { HasRegistry } from '@genkit-ai/core/registry';
import { Document } from '../document.js';
import { injectInstructions } from '../formats/index.js';
import { ModelArgument } from '../index.js';
import type {
MediaPart,
MessageData,
ModelInfo,
ModelMiddleware,
Part,
} from '../model.js';
import { resolveModel } from '../model.js';
/**
* Preprocess a GenerateRequest to download referenced http(s) media URLs and
* inline them as data URIs.
*/
export function downloadRequestMedia(options?: {
maxBytes?: number;
filter?: (part: MediaPart) => boolean;
}): ModelMiddleware {
return async (req, next) => {
const { default: fetch } = await import('node-fetch');
const newReq = {
...req,
messages: await Promise.all(
req.messages.map(async (message) => {
const content: Part[] = await Promise.all(
message.content.map(async (part) => {
// skip non-media parts and non-http urls, or parts that have been
// filtered out by user config
if (
!part.media ||
!part.media.url.startsWith('http') ||
(options?.filter && !options?.filter(part))
) {
return part;
}
const response = await fetch(part.media.url, {
size: options?.maxBytes,
});
if (response.status !== 200)
throw new Error(
`HTTP error downloading media '${
part.media.url
}': ${await response.text()}`
);
// use provided contentType or sniff from response
const contentType =
part.media.contentType ||
response.headers.get('content-type') ||
'';
return {
media: {
contentType,
url: `data:${contentType};base64,${Buffer.from(
await response.arrayBuffer()
).toString('base64')}`,
},
};
})
);
return {
...message,
content,
};
})
),
};
return next(newReq);
};
}
/**
* Validates that a GenerateRequest does not include unsupported features.
*/
export function validateSupport(options: {
name: string;
supports?: ModelInfo['supports'];
}): ModelMiddleware {
const supports = options.supports || {};
return async (req, next) => {
function invalid(message: string): never {
throw new Error(
`Model '${
options.name
}' does not support ${message}. Request: ${JSON.stringify(
req,
null,
2
)}`
);
}
if (
supports.media === false &&
req.messages.some((message) => message.content.some((part) => part.media))
)
invalid('media, but media was provided');
if (supports.tools === false && req.tools?.length)
invalid('tool use, but tools were provided');
if (supports.multiturn === false && req.messages.length > 1)
invalid(`multiple messages, but ${req.messages.length} were provided`);
// if (
// typeof supports.output !== 'undefined' &&
// req.output?.format &&
// !supports.output.includes(req.output?.format)
// )
// invalid(`requested output format '${req.output?.format}'`);
return next();
};
}
// N.B. Figure out why array.findLast isn't available despite setting target
// to ES2022 (Node 16.14.0)
function lastUserMessage(messages: MessageData[]) {
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === 'user') {
return messages[i];
}
}
return undefined;
}
/**
* Provide a simulated system prompt for models that don't support it natively.
*/
export function simulateSystemPrompt(options?: {
preface: string;
acknowledgement: string;
}): ModelMiddleware {
const preface = options?.preface || 'SYSTEM INSTRUCTIONS:\n';
const acknowledgement = options?.acknowledgement || 'Understood.';
return (req, next) => {
const messages = [...req.messages];
for (let i = 0; i < messages.length; i++) {
if (req.messages[i].role === 'system') {
const systemPrompt = messages[i].content;
messages.splice(
i,
1,
{ role: 'user', content: [{ text: preface }, ...systemPrompt] },
{ role: 'model', content: [{ text: acknowledgement }] }
);
break;
}
}
return next({ ...req, messages });
};
}
export interface AugmentWithContextOptions {
/** Preceding text to place before the rendered context documents. */
preface?: string | null;
/** A function to render a document into a text part to be included in the message. */
itemTemplate?: (d: Document, options?: AugmentWithContextOptions) => string;
/** The metadata key to use for citation reference. Pass `null` to provide no citations. */
citationKey?: string | null;
}
export const CONTEXT_PREFACE =
'\n\nUse the following information to complete your task:\n\n';
const CONTEXT_ITEM_TEMPLATE = (
d: Document,
index: number,
options?: AugmentWithContextOptions
) => {
let out = '- ';
if (options?.citationKey) {
out += `[${d.metadata![options.citationKey]}]: `;
} else if (options?.citationKey === undefined) {
out += `[${d.metadata?.['ref'] || d.metadata?.['id'] || index}]: `;
}
out += d.text + '\n';
return out;
};
export function augmentWithContext(
options?: AugmentWithContextOptions
): ModelMiddleware {
const preface =
typeof options?.preface === 'undefined' ? CONTEXT_PREFACE : options.preface;
const itemTemplate = options?.itemTemplate || CONTEXT_ITEM_TEMPLATE;
return (req, next) => {
// if there is no context in the request, no-op
if (!req.docs?.length) return next(req);
const userMessage = lastUserMessage(req.messages);
// if there are no messages, no-op
if (!userMessage) return next(req);
// if there is already a context part, no-op
const contextPartIndex = userMessage?.content.findIndex(
(p) => p.metadata?.purpose === 'context'
);
const contextPart =
contextPartIndex >= 0 && userMessage.content[contextPartIndex];
if (contextPart && !contextPart.metadata?.pending) {
return next(req);
}
let out = `${preface || ''}`;
req.docs?.forEach((d, i) => {
out += itemTemplate(new Document(d), i, options);
});
out += '\n';
if (contextPartIndex >= 0) {
userMessage.content[contextPartIndex] = {
...contextPart,
text: out,
metadata: { purpose: 'context' },
} as Part;
} else {
userMessage.content.push({ text: out, metadata: { purpose: 'context' } });
}
return next(req);
};
}
/**
* Options for the `retry` middleware.
*/
export interface RetryOptions {
/**
* The maximum number of times to retry a failed request.
* @default 3
*/
maxRetries?: number;
/**
* An array of `StatusName` values that should trigger a retry.
* @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL']
*/
statuses?: StatusName[];
/**
* The initial delay between retries, in milliseconds.
* @default 1000
*/
initialDelayMs?: number;
/**
* The maximum delay between retries, in milliseconds.
* @default 60000
*/
maxDelayMs?: number;
/**
* The factor by which the delay increases after each retry (exponential backoff).
* @default 2
*/
backoffFactor?: number;
/**
* Whether to disable jitter on the delay. Jitter adds a random factor to the
* delay to help prevent a "thundering herd" of clients all retrying at the
* same time.
* @default false
*/
noJitter?: boolean;
/**
* A callback to be executed on each retry attempt.
*/
onError?: (error: Error, attempt: number) => void;
}
let __setTimeout: (
callback: (...args: any[]) => void,
ms?: number
) => NodeJS.Timeout = setTimeout;
/**
* FOR TESTING ONLY.
* @internal
*/
export const TEST_ONLY = {
setRetryTimeout(
impl: (callback: (...args: any[]) => void, ms?: number) => NodeJS.Timeout
) {
__setTimeout = impl;
},
};
const DEFAULT_RETRY_STATUSES: StatusName[] = [
'UNAVAILABLE',
'DEADLINE_EXCEEDED',
'RESOURCE_EXHAUSTED',
'ABORTED',
'INTERNAL',
];
const DEFAULT_FALLBACK_STATUSES: StatusName[] = [
'UNAVAILABLE',
'DEADLINE_EXCEEDED',
'RESOURCE_EXHAUSTED',
'ABORTED',
'INTERNAL',
'NOT_FOUND',
'UNIMPLEMENTED',
];
/**
* Creates a middleware that retries requests on specific error statuses.
*
* ```ts
* const { text } = await ai.generate({
* model: googleAI.model('gemini-2.5-pro'),
* prompt: 'You are a helpful AI assistant named Walt, say hello',
* use: [
* retry({
* maxRetries: 2,
* initialDelayMs: 1000,
* backoffFactor: 2,
* }),
* ],
* });
* ```
*/
export function retry(options: RetryOptions = {}): ModelMiddleware {
const {
maxRetries = 3,
statuses = DEFAULT_RETRY_STATUSES,
initialDelayMs = 1000,
maxDelayMs = 60000,
backoffFactor = 2,
noJitter = false,
onError,
} = options;
return async (req, next) => {
let lastError: any;
let currentDelay = initialDelayMs;
for (let i = 0; i <= maxRetries; i++) {
try {
return await next(req);
} catch (e) {
lastError = e;
const error = e as Error;
if (i < maxRetries) {
let shouldRetry = false;
if (error instanceof GenkitError) {
if (statuses.includes(error.status)) {
shouldRetry = true;
}
} else {
shouldRetry = true;
}
if (shouldRetry) {
onError?.(error, i + 1);
let delay = currentDelay;
if (!noJitter) {
delay = delay + 1000 * Math.pow(2, i) * Math.random();
}
await new Promise((resolve) => __setTimeout(resolve, delay));
currentDelay = Math.min(currentDelay * backoffFactor, maxDelayMs);
continue;
}
}
throw error;
}
}
throw lastError;
};
}
/**
* Options for the `fallback` middleware.
*/
export interface FallbackOptions {
/**
* An array of models to try in order.
*/
models: ModelArgument[];
/**
* An array of `StatusName` values that should trigger a fallback.
* @default ['UNAVAILABLE', 'DEADLINE_EXCEEDED', 'RESOURCE_EXHAUSTED', 'ABORTED', 'INTERNAL', 'NOT_FOUND', 'UNIMPLEMENTED']
*/
statuses?: StatusName[];
/**
* A callback to be executed on each fallback attempt.
*/
onError?: (error: Error) => void;
}
/**
* Creates a middleware that falls back to a different model on specific error statuses.
*
* ```ts
* const { text } = await ai.generate({
* model: googleAI.model('gemini-2.5-pro'),
* prompt: 'You are a helpful AI assistant named Walt, say hello',
* use: [
* fallback(ai, {
* models: [googleAI.model('gemini-2.5-flash')],
* statuses: ['RESOURCE_EXHAUSTED'],
* }),
* ],
* });
* ```
*/
export function fallback(
ai: HasRegistry,
options: FallbackOptions
): ModelMiddleware {
const { models, statuses = DEFAULT_FALLBACK_STATUSES, onError } = options;
return async (req, next) => {
try {
return await next(req);
} catch (e) {
if (e instanceof GenkitError && statuses.includes(e.status)) {
onError?.(e);
let lastError: any = e;
for (const model of models) {
try {
const resolved = await resolveModel(ai.registry, model);
return await resolved.modelAction(req);
} catch (e2) {
lastError = e2;
if (e2 instanceof GenkitError && statuses.includes(e2.status)) {
onError?.(e2);
continue;
}
throw e2;
}
}
throw lastError;
}
throw e;
}
};
}
export interface SimulatedConstrainedGenerationOptions {
instructionsRenderer?: (schema: Record<string, any>) => string;
}
const DEFAULT_CONSTRAINED_GENERATION_INSTRUCTIONS = (
schema: Record<string, any>
) => `Output should be in JSON format and conform to the following schema:
\`\`\`
${JSON.stringify(schema)}
\`\`\`
`;
/**
* Model middleware that simulates constrained generation by injecting generation
* instructions into the user message.
*/
export function simulateConstrainedGeneration(
options?: SimulatedConstrainedGenerationOptions
): ModelMiddleware {
return (req, next) => {
let instructions: string | undefined;
if (req.output?.constrained && req.output?.schema) {
instructions = (
options?.instructionsRenderer ??
DEFAULT_CONSTRAINED_GENERATION_INSTRUCTIONS
)(req.output?.schema);
req = {
...req,
messages: injectInstructions(req.messages, instructions),
output: {
...req.output,
// we're simulating it, so to the underlying model it's unconstrained.
constrained: false,
format: undefined,
contentType: undefined,
schema: undefined,
},
};
}
return next(req);
};
}