model.ts•15.6 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {
ActionFnArg,
BackgroundAction,
GenkitError,
Operation,
OperationSchema,
action,
backgroundAction,
defineAction,
registerBackgroundAction,
z,
type Action,
type ActionMetadata,
type ActionParams,
type SimpleMiddleware,
type StreamingCallback,
} from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import type { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { performance } from 'node:perf_hooks';
import {
CustomPartSchema,
DataPartSchema,
MediaPartSchema,
TextPartSchema,
ToolRequestPartSchema,
ToolResponsePartSchema,
type CustomPart,
type DataPart,
type MediaPart,
type TextPart,
type ToolRequestPart,
type ToolResponsePart,
} from './document.js';
import {
CandidateData,
GenerateRequest,
GenerateRequestSchema,
GenerateResponseChunkData,
GenerateResponseChunkSchema,
GenerateResponseData,
GenerateResponseSchema,
GenerationUsage,
MessageData,
ModelInfo,
Part,
} from './model-types.js';
import {
augmentWithContext,
simulateConstrainedGeneration,
} from './model/middleware.js';
export { defineGenerateAction } from './generate/action.js';
export * from './model-types.js';
export {
CustomPartSchema,
DataPartSchema,
MediaPartSchema,
TextPartSchema,
ToolRequestPartSchema,
ToolResponsePartSchema,
simulateConstrainedGeneration,
type CustomPart,
type DataPart,
type MediaPart,
type TextPart,
type ToolRequestPart,
type ToolResponsePart,
};
export type ModelAction<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = Action<
typeof GenerateRequestSchema,
typeof GenerateResponseSchema,
typeof GenerateResponseChunkSchema
> & {
__configSchema: CustomOptionsSchema;
};
export type BackgroundModelAction<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = BackgroundAction<
typeof GenerateRequestSchema,
typeof GenerateResponseSchema
> & {
__configSchema: CustomOptionsSchema;
};
export type ModelMiddleware = SimpleMiddleware<
z.infer<typeof GenerateRequestSchema>,
z.infer<typeof GenerateResponseSchema>
>;
export type DefineModelOptions<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = {
name: string;
/** Known version names for this model, e.g. `gemini-1.0-pro-001`. */
versions?: string[];
/** Capabilities this model supports. */
supports?: ModelInfo['supports'];
/** Custom options schema for this model. */
configSchema?: CustomOptionsSchema;
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
label?: string;
/** Middleware to be used with this model. */
use?: ModelMiddleware[];
};
export function model<CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny>(
options: DefineModelOptions<CustomOptionsSchema>,
runner: (
request: GenerateRequest<CustomOptionsSchema>,
options: ActionFnArg<GenerateResponseChunkData>
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema> {
const act = action(modelActionOptions(options), (input, ctx) => {
const startTimeMs = performance.now();
return runner(input, ctx).then((response) => {
const timedResponse = {
...response,
latencyMs: performance.now() - startTimeMs,
};
return timedResponse;
});
});
Object.assign(act, {
__configSchema: options.configSchema || z.unknown(),
});
return act as ModelAction<CustomOptionsSchema>;
}
function modelActionOptions<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
options: DefineModelOptions<CustomOptionsSchema>
): ActionParams<typeof GenerateRequestSchema, typeof GenerateResponseSchema> {
const label = options.label || options.name;
const middleware = getModelMiddleware(options);
return {
actionType: 'model',
name: options.name,
description: label,
inputSchema: GenerateRequestSchema,
outputSchema: GenerateResponseSchema,
metadata: {
model: {
label,
customOptions: options.configSchema
? toJsonSchema({ schema: options.configSchema })
: undefined,
versions: options.versions,
supports: options.supports,
},
},
use: middleware,
};
}
/**
* Defines a new model and adds it to the registry.
*/
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: {
apiVersion: 'v2';
} & DefineModelOptions<CustomOptionsSchema>,
runner: (
request: GenerateRequest<CustomOptionsSchema>,
options: ActionFnArg<GenerateResponseChunkData>
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema>;
/**
* Defines a new model and adds it to the registry.
*/
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: DefineModelOptions<CustomOptionsSchema>,
runner: (
request: GenerateRequest<CustomOptionsSchema>,
streamingCallback?: StreamingCallback<GenerateResponseChunkData>
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema>;
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: any,
runner: (
request: GenerateRequest<CustomOptionsSchema>,
options: any
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema> {
const act = defineAction(
registry,
modelActionOptions(options),
(input, ctx) => {
const startTimeMs = performance.now();
const secondParam =
options.apiVersion === 'v2'
? ctx
: ctx.streamingRequested
? ctx.sendChunk
: undefined;
return runner(input, secondParam).then((response) => {
const timedResponse = {
...response,
latencyMs: performance.now() - startTimeMs,
};
return timedResponse;
});
}
);
Object.assign(act, {
__configSchema: options.configSchema || z.unknown(),
});
return act as ModelAction<CustomOptionsSchema>;
}
export type DefineBackgroundModelOptions<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
> = DefineModelOptions<CustomOptionsSchema> & {
start: (
request: GenerateRequest<CustomOptionsSchema>
) => Promise<Operation<GenerateResponseData>>;
check: (
operation: Operation<GenerateResponseData>
) => Promise<Operation<GenerateResponseData>>;
cancel?: (
operation: Operation<GenerateResponseData>
) => Promise<Operation<GenerateResponseData>>;
};
/**
* Defines a new model that runs in the background.
*/
export function defineBackgroundModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
options: DefineBackgroundModelOptions<CustomOptionsSchema>
): BackgroundModelAction<CustomOptionsSchema> {
const act = backgroundModel(options);
registerBackgroundAction(registry, act);
return act;
}
/**
* Defines a new model that runs in the background.
*/
export function backgroundModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
options: DefineBackgroundModelOptions<CustomOptionsSchema>
): BackgroundModelAction<CustomOptionsSchema> {
const label = options.label || options.name;
const middleware = getModelMiddleware(options);
const act = backgroundAction({
actionType: 'background-model',
name: options.name,
description: label,
inputSchema: GenerateRequestSchema,
outputSchema: GenerateResponseSchema,
metadata: {
model: {
label,
customOptions: options.configSchema
? toJsonSchema({ schema: options.configSchema })
: undefined,
versions: options.versions,
supports: options.supports,
},
},
use: middleware,
async start(request) {
const startTimeMs = performance.now();
const response = await options.start(request);
Object.assign(response, {
latencyMs: performance.now() - startTimeMs,
});
return response;
},
async check(op) {
return options.check(op);
},
cancel: options.cancel
? async (op) => {
if (!options.cancel) {
throw new GenkitError({
status: 'UNIMPLEMENTED',
message: 'cancel not implemented',
});
}
return options.cancel(op);
}
: undefined,
}) as BackgroundModelAction<CustomOptionsSchema>;
Object.assign(act, {
__configSchema: options.configSchema || z.unknown(),
});
return act;
}
function getModelMiddleware(options: {
use?: ModelMiddleware[];
name: string;
supports?: ModelInfo['supports'];
}) {
const middleware: ModelMiddleware[] = options.use || [];
if (!options?.supports?.context) middleware.push(augmentWithContext());
const constratedSimulator = simulateConstrainedGeneration();
middleware.push((req, next) => {
if (
!options?.supports?.constrained ||
options?.supports?.constrained === 'none' ||
(options?.supports?.constrained === 'no-tools' &&
(req.tools?.length ?? 0) > 0)
) {
return constratedSimulator(req, next);
}
return next(req);
});
return middleware;
}
export interface ModelReference<CustomOptions extends z.ZodTypeAny> {
name: string;
configSchema?: CustomOptions;
info?: ModelInfo;
version?: string;
config?: z.infer<CustomOptions>;
withConfig(cfg: z.infer<CustomOptions>): ModelReference<CustomOptions>;
withVersion(version: string): ModelReference<CustomOptions>;
}
/**
* Packages model information into ActionMetadata object.
*/
export function modelActionMetadata({
name,
info,
configSchema,
background,
}: {
name: string;
info?: ModelInfo;
configSchema?: z.ZodTypeAny;
background?: boolean;
}): ActionMetadata {
return {
actionType: background ? 'background-model' : 'model',
name: name,
inputJsonSchema: toJsonSchema({ schema: GenerateRequestSchema }),
outputJsonSchema: background
? toJsonSchema({ schema: OperationSchema })
: toJsonSchema({ schema: GenerateResponseSchema }),
metadata: {
model: {
...info,
customOptions: configSchema
? toJsonSchema({ schema: configSchema })
: undefined,
},
},
} as ActionMetadata;
}
/** Cretes a model reference. */
export function modelRef<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
>(
options: Omit<
ModelReference<CustomOptionsSchema>,
'withConfig' | 'withVersion'
> & {
namespace?: string;
}
): ModelReference<CustomOptionsSchema> {
let name = options.name;
if (options.namespace && !name.startsWith(options.namespace + '/')) {
name = `${options.namespace}/${name}`;
}
const ref: Partial<ModelReference<CustomOptionsSchema>> = {
...options,
name,
};
ref.withConfig = (
cfg: z.infer<CustomOptionsSchema>
): ModelReference<CustomOptionsSchema> => {
return modelRef({
...options,
name,
config: cfg,
});
};
ref.withVersion = (version: string): ModelReference<CustomOptionsSchema> => {
return modelRef({
...options,
name,
version,
});
};
return ref as ModelReference<CustomOptionsSchema>;
}
/** Container for counting usage stats for a single input/output {Part} */
type PartCounts = {
characters: number;
images: number;
videos: number;
audio: number;
};
/**
* Calculates basic usage statistics from the given model inputs and outputs.
*/
export function getBasicUsageStats(
input: MessageData[],
response: MessageData | CandidateData[]
): GenerationUsage {
const inputCounts = getPartCounts(input.flatMap((md) => md.content));
const outputCounts = getPartCounts(
Array.isArray(response)
? response.flatMap((c) => c.message.content)
: response.content
);
return {
inputCharacters: inputCounts.characters,
inputImages: inputCounts.images,
inputVideos: inputCounts.videos,
inputAudioFiles: inputCounts.audio,
outputCharacters: outputCounts.characters,
outputImages: outputCounts.images,
outputVideos: outputCounts.videos,
outputAudioFiles: outputCounts.audio,
};
}
function getPartCounts(parts: Part[]): PartCounts {
return parts.reduce(
(counts, part) => {
const isImage =
part.media?.contentType?.startsWith('image') ||
part.media?.url?.startsWith('data:image');
const isVideo =
part.media?.contentType?.startsWith('video') ||
part.media?.url?.startsWith('data:video');
const isAudio =
part.media?.contentType?.startsWith('audio') ||
part.media?.url?.startsWith('data:audio');
return {
characters: counts.characters + (part.text?.length || 0),
images: counts.images + (isImage ? 1 : 0),
videos: counts.videos + (isVideo ? 1 : 0),
audio: counts.audio + (isAudio ? 1 : 0),
};
},
{ characters: 0, images: 0, videos: 0, audio: 0 }
);
}
export type ModelArgument<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
| ModelAction<CustomOptions>
| ModelReference<CustomOptions>
| string;
export interface ResolvedModel<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> {
modelAction: ModelAction;
config?: z.infer<CustomOptions>;
version?: string;
}
export async function resolveModel<C extends z.ZodTypeAny = z.ZodTypeAny>(
registry: Registry,
model: ModelArgument<C> | undefined,
options?: { warnDeprecated?: boolean }
): Promise<ResolvedModel<C>> {
let out: ResolvedModel<C>;
let modelId: string;
if (!model) {
model = await registry.lookupValue('defaultModel', 'defaultModel');
}
if (!model) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'Must supply a `model` to `generate()` calls.',
});
}
if (typeof model === 'string') {
modelId = model;
out = { modelAction: await lookupModel(registry, model) };
} else if (model.hasOwnProperty('__action')) {
modelId = (model as ModelAction).__action.name;
out = { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
modelId = ref.name;
out = {
modelAction: await lookupModel(registry, ref.name),
config: {
...ref.config,
},
version: ref.version,
};
}
if (!out.modelAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model '${modelId}' not found`,
});
}
if (
options?.warnDeprecated &&
out.modelAction.__action.metadata?.model?.stage === 'deprecated'
) {
logger.warn(
`Model '${out.modelAction.__action.name}' is deprecated and may be removed in a future release.`
);
}
return out;
}
async function lookupModel(
registry: Registry,
model: string
): Promise<ModelAction> {
return (
(await registry.lookupAction(`/model/${model}`)) ||
(await registry.lookupAction(`/background-model/${model}`))
);
}