action.ts•15.8 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import type { JSONSchema7 } from 'json-schema';
import type * as z from 'zod';
import { getAsyncContext } from './async-context.js';
import { lazy } from './async.js';
import { getContext, runWithContext, type ActionContext } from './context.js';
import type { ActionType, Registry } from './registry.js';
import { parseSchema } from './schema.js';
import {
SPAN_TYPE_ATTR,
runInNewSpan,
setCustomMetadataAttributes,
} from './tracing.js';
export { StatusCodes, StatusSchema, type Status } from './statusTypes.js';
export type { JSONSchema7 };
const makeNoopAbortSignal = () => new AbortController().signal;
/**
* Action metadata.
*/
export interface ActionMetadata<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
actionType?: ActionType;
name: string;
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
streamSchema?: S;
metadata?: Record<string, any>;
}
/**
* Results of an action run. Includes telemetry.
*/
export interface ActionResult<O> {
result: O;
telemetry: {
traceId: string;
spanId: string;
};
}
/**
* Options (side channel) data to pass to the model.
*/
export interface ActionRunOptions<S> {
/**
* Streaming callback (optional).
*/
onChunk?: StreamingCallback<S>;
/**
* Additional runtime context data (ex. auth context data).
*/
context?: ActionContext;
/**
* Additional span attributes to apply to OT spans.
*/
telemetryLabels?: Record<string, string>;
/**
* Abort signal for the action request.
*/
abortSignal?: AbortSignal;
}
/**
* Options (side channel) data to pass to the model.
*/
export interface ActionFnArg<S> {
/**
* Whether the caller of the action requested streaming.
*/
streamingRequested: boolean;
/**
* Streaming callback (optional).
*/
sendChunk: StreamingCallback<S>;
/**
* Additional runtime context data (ex. auth context data).
*/
context?: ActionContext;
/**
* Trace context containing trace and span IDs.
*/
trace: {
traceId: string;
spanId: string;
};
/**
* Abort signal for the action request.
*/
abortSignal: AbortSignal;
registry?: Registry;
}
/**
* Streaming response from an action.
*/
export interface StreamingResponse<
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
/** Iterator over the streaming chunks. */
stream: AsyncGenerator<z.infer<S>>;
/** Final output of the action. */
output: Promise<z.infer<O>>;
}
/**
* Self-describing, validating, observable, locally and remotely callable function.
*/
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
RunOptions extends ActionRunOptions<S> = ActionRunOptions<S>,
> = ((input?: z.infer<I>, options?: RunOptions) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, S>;
__registry?: Registry;
run(
input?: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>>;
stream(
input?: z.infer<I>,
opts?: ActionRunOptions<z.infer<S>>
): StreamingResponse<O, S>;
};
/**
* Action factory params.
*/
export type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> = {
name:
| string
| {
pluginId: string;
actionId: string;
};
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: Record<string, any>;
use?: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[];
streamSchema?: S;
actionType: ActionType;
};
export type ActionAsyncParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> = ActionParams<I, O, S> & {
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>;
};
export type SimpleMiddleware<I = any, O = any> = (
req: I,
next: (req?: I) => Promise<O>
) => Promise<O>;
export type MiddlewareWithOptions<I = any, O = any, S = any> = (
req: I,
options: ActionRunOptions<S> | undefined,
next: (req?: I, options?: ActionRunOptions<S>) => Promise<O>
) => Promise<O>;
/**
* Middleware function for actions.
*/
export type Middleware<I = any, O = any, S = any> =
| SimpleMiddleware<I, O>
| MiddlewareWithOptions<I, O, S>;
/**
* Creates an action with provided middleware.
*/
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
action: Action<I, O, S>,
middleware: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[]
): Action<I, O, S> {
const wrapped = (async (
req: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
) => {
return (await wrapped.run(req, options)).result;
}) as Action<I, O, S>;
wrapped.__action = action.__action;
wrapped.run = async (
req: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
let telemetry;
const dispatch = async (
index: number,
req: z.infer<I>,
opts?: ActionRunOptions<z.infer<S>>
) => {
if (index === middleware.length) {
// end of the chain, call the original model action
const result = await action.run(req, opts);
telemetry = result.telemetry;
return result.result;
}
const currentMiddleware = middleware[index];
if (currentMiddleware.length === 3) {
return (currentMiddleware as MiddlewareWithOptions<I, O, z.infer<S>>)(
req,
opts,
async (modifiedReq, modifiedOptions) =>
dispatch(index + 1, modifiedReq || req, modifiedOptions || opts)
);
} else if (currentMiddleware.length === 2) {
return (currentMiddleware as SimpleMiddleware<I, O>)(
req,
async (modifiedReq) => dispatch(index + 1, modifiedReq || req, opts)
);
} else {
throw new Error('unspported middleware function shape');
}
};
wrapped.stream = action.stream;
return { result: await dispatch(0, req, options), telemetry };
};
return wrapped;
}
/**
* Creates an action with the provided config.
*/
export function action<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
config: ActionParams<I, O, S>,
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O, z.infer<S>> {
const actionName =
typeof config.name === 'string'
? config.name
: `${config.name.pluginId}/${config.name.actionId}`;
const actionMetadata = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
streamSchema: config.streamSchema,
metadata: config.metadata,
actionType: config.actionType,
} as ActionMetadata<I, O, S>;
const actionFn = (async (
input?: I,
options?: ActionRunOptions<z.infer<S>>
) => {
return (await actionFn.run(input, options)).result;
}) as Action<I, O, z.infer<S>>;
actionFn.__action = { ...actionMetadata };
actionFn.run = async (
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
let traceId;
let spanId;
let output = await runInNewSpan(
{
metadata: {
name: actionName,
},
labels: {
[SPAN_TYPE_ATTR]: 'action',
'genkit:metadata:subtype': config.actionType,
...options?.telemetryLabels,
},
},
async (metadata, span) => {
setCustomMetadataAttributes({
subtype: config.actionType,
});
if (options?.context) {
setCustomMetadataAttributes({
context: JSON.stringify(options.context),
});
}
traceId = span.spanContext().traceId;
spanId = span.spanContext().spanId;
metadata.name = actionName;
metadata.input = input;
try {
const actFn = () =>
fn(input, {
...options,
// Context can either be explicitly set, or inherited from the parent action.
context: {
...actionFn.__registry?.context,
...(options?.context ?? getContext()),
},
streamingRequested:
!!options?.onChunk &&
options.onChunk !== sentinelNoopStreamingCallback,
sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback,
trace: {
traceId,
spanId,
},
registry: actionFn.__registry,
abortSignal: options?.abortSignal ?? makeNoopAbortSignal(),
});
// if context is explicitly passed in, we run action with the provided context,
// otherwise we let upstream context carry through.
const output = await runWithContext(options?.context, actFn);
metadata.output = JSON.stringify(output);
return output;
} catch (err) {
if (typeof err === 'object') {
(err as any).traceId = traceId;
}
throw err;
}
}
);
output = parseSchema(output, {
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
return {
result: output,
telemetry: {
traceId,
spanId,
},
};
};
actionFn.stream = (
input?: z.infer<I>,
opts?: ActionRunOptions<z.infer<S>>
): StreamingResponse<O, S> => {
let chunkStreamController: ReadableStreamController<z.infer<S>>;
const chunkStream = new ReadableStream<z.infer<S>>({
start(controller) {
chunkStreamController = controller;
},
pull() {},
cancel() {},
});
const invocationPromise = actionFn
.run(config.inputSchema ? config.inputSchema.parse(input) : input, {
onChunk: ((chunk: z.infer<S>) => {
chunkStreamController.enqueue(chunk);
}) as S extends z.ZodVoid ? undefined : StreamingCallback<z.infer<S>>,
context: {
...actionFn.__registry?.context,
...(opts?.context ?? getContext()),
},
abortSignal: opts?.abortSignal,
telemetryLabels: opts?.telemetryLabels,
})
.then((s) => s.result)
.finally(() => {
chunkStreamController.close();
});
return {
output: invocationPromise,
stream: (async function* () {
const reader = chunkStream.getReader();
while (true) {
const chunk = await reader.read();
if (chunk.value) {
yield chunk.value;
}
if (chunk.done) {
break;
}
}
return await invocationPromise;
})(),
};
};
if (config.use) {
return actionWithMiddleware(actionFn, config.use);
}
return actionFn;
}
export function isAction(a: unknown): a is Action {
return typeof a === 'function' && '__action' in a;
}
/**
* Defines an action with the given config and registers it in the registry.
*/
export function defineAction<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
config: ActionParams<I, O, S>,
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O, S> {
if (isInRuntimeContext()) {
throw new Error(
'Cannot define new actions at runtime.\n' +
'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md'
);
}
const act = action(config, async (i: I, options): Promise<z.infer<O>> => {
await registry.initializeAllPlugins();
return await runInActionRuntimeContext(() => fn(i, options));
});
act.__action.actionType = config.actionType;
registry.registerAction(config.actionType, act);
return act;
}
/**
* Defines an action with the given config promise and registers it in the registry.
*/
export function defineActionAsync<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
actionType: ActionType,
name:
| string
| {
pluginId: string;
actionId: string;
},
config: PromiseLike<ActionAsyncParams<I, O, S>>,
onInit?: (action: Action<I, O, S>) => void
): PromiseLike<Action<I, O, S>> {
const actionName =
typeof name === 'string' ? name : `${name.pluginId}/${name.actionId}`;
const actionPromise = lazy(() =>
config.then((resolvedConfig) => {
const act = action(
resolvedConfig,
async (i: I, options): Promise<z.infer<O>> => {
await registry.initializeAllPlugins();
return await runInActionRuntimeContext(() =>
resolvedConfig.fn(i, options)
);
}
);
act.__action.actionType = actionType;
onInit?.(act);
return act;
})
);
registry.registerActionAsync(actionType, actionName, actionPromise);
return actionPromise;
}
// Streaming callback function.
export type StreamingCallback<T> = (chunk: T) => void;
const streamingAlsKey = 'core.action.streamingCallback';
export const sentinelNoopStreamingCallback = () => null;
/**
* Executes provided function with streaming callback in async local storage which can be retrieved
* using {@link getStreamingCallback}.
*/
export function runWithStreamingCallback<S, O>(
streamingCallback: StreamingCallback<S> | undefined,
fn: () => O
): O {
return getAsyncContext().run(
streamingAlsKey,
streamingCallback || sentinelNoopStreamingCallback,
fn
);
}
/**
* Retrieves the {@link StreamingCallback} previously set by {@link runWithStreamingCallback}
*
* @hidden
*/
export function getStreamingCallback<S>(): StreamingCallback<S> | undefined {
const cb = getAsyncContext().getStore<StreamingCallback<S>>(streamingAlsKey);
if (cb === sentinelNoopStreamingCallback) {
return undefined;
}
return cb;
}
const runtimeContextAslKey = 'core.action.runtimeContext';
/**
* Checks whether the caller is currently in the runtime context of an action.
*/
export function isInRuntimeContext() {
return getAsyncContext().getStore(runtimeContextAslKey) === 'runtime';
}
/**
* Execute the provided function in the action runtime context.
*/
export function runInActionRuntimeContext<R>(fn: () => R) {
return getAsyncContext().run(runtimeContextAslKey, 'runtime', fn);
}
/**
* Execute the provided function outside the action runtime context.
*/
export function runOutsideActionRuntimeContext<R>(fn: () => R) {
return getAsyncContext().run(runtimeContextAslKey, 'outside', fn);
}