action_test.ts•7.17 kB
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
import { z } from 'zod';
import { action, defineAction } from '../src/action.js';
import { initNodeFeatures } from '../src/node.js';
import { Registry } from '../src/registry.js';
initNodeFeatures();
describe('action', () => {
var registry: Registry;
beforeEach(() => {
registry = new Registry();
});
it('applies middleware', async () => {
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
use: [
async (input, next) => (await next(input + 'middle1')) + 1,
async (input, opts, next) =>
(await next(input + 'middle2', opts)) + 2,
],
actionType: 'util',
},
async (input) => {
return input.length;
}
);
assert.strictEqual(
await act('foo'),
20 // "foomiddle1middle2".length + 1 + 2
);
});
it('returns telemetry info', async () => {
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
use: [
async (input, next) => (await next(input + 'middle1')) + 1,
async (input, opts, next) =>
(await next(input + 'middle2', opts)) + 2,
],
actionType: 'util',
},
async (input) => {
return input.length;
}
);
const result = await act.run('foo');
assert.strictEqual(
result.result,
20 // "foomiddle1middle2".length + 1 + 2
);
assert.strictEqual(result.telemetry !== null, true);
assert.strictEqual(
result.telemetry.traceId !== null && result.telemetry.traceId.length > 0,
true
);
assert.strictEqual(
result.telemetry.spanId !== null && result.telemetry.spanId.length > 0,
true
);
});
it('run the action with options', async () => {
let passedContext;
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
actionType: 'util',
},
async (input, { sendChunk, context }) => {
passedContext = context;
sendChunk(1);
sendChunk(2);
sendChunk(3);
return input.length;
}
);
const chunks: any[] = [];
await act.run('1234', {
context: { foo: 'bar' },
onChunk: (c) => chunks.push(c),
});
assert.deepStrictEqual(passedContext, {
foo: 'bar',
});
assert.deepStrictEqual(chunks, [1, 2, 3]);
});
it('runs the action with context plus registry global context', async () => {
let passedContext;
let calledWithStreamingRequestedValue;
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
actionType: 'util',
},
async (input, { sendChunk, context, streamingRequested }) => {
calledWithStreamingRequestedValue = streamingRequested;
passedContext = context;
sendChunk(1);
sendChunk(2);
sendChunk(3);
return input.length;
}
);
registry.context = { bar: 'baz' };
act.__registry = registry;
await act.run('1234', {
context: { foo: 'bar' },
});
assert.strictEqual(calledWithStreamingRequestedValue, false);
assert.deepStrictEqual(passedContext, {
foo: 'bar',
bar: 'baz', // these come from glboal registry context
});
registry.context = { bar2: 'baz2' };
const { output } = act.stream('1234', {
context: { foo2: 'bar2' },
});
await output;
assert.strictEqual(calledWithStreamingRequestedValue, true);
assert.deepStrictEqual(passedContext, {
foo2: 'bar2',
bar2: 'baz2', // these come from glboal registry context
});
});
it('should stream the response', async () => {
const action = defineAction(
registry,
{ name: 'hello', actionType: 'custom' },
async (input, { sendChunk, streamingRequested }) => {
sendChunk({ count: 1 });
sendChunk({ count: 2 });
sendChunk({ count: 3 });
return `hi ${input}`;
}
);
const response = action.stream('Pavel');
const gotChunks: any[] = [];
for await (const chunk of response.stream) {
gotChunks.push(chunk);
}
assert.equal(await response.output, 'hi Pavel');
assert.deepStrictEqual(gotChunks, [
{ count: 1 },
{ count: 2 },
{ count: 3 },
]);
});
it('should inherit context from parent action invocation', async () => {
const child = defineAction(
registry,
{ name: 'child', actionType: 'custom' },
async (_, { context }) => {
return `hi ${context?.auth?.email}`;
}
);
const parent = defineAction(
registry,
{ name: 'parent', actionType: 'custom' },
async () => {
return child();
}
);
const response = await parent(undefined, {
context: { auth: { email: 'a@b.c' } },
});
assert.strictEqual(response, 'hi a@b.c');
});
it('should include trace info in the context', async () => {
const act = defineAction(
registry,
{ name: 'child', actionType: 'custom' },
async (_, ctx) => {
return `traceId=${!!ctx.trace.traceId} spanId=${!!ctx.trace.spanId}`;
}
);
const response = await act(undefined);
assert.strictEqual(response, 'traceId=true spanId=true');
});
it('passes through the abort signal', async () => {
var gotAbortSignal;
const act = defineAction(
registry,
{ name: 'child', actionType: 'custom' },
async (_, ctx) => {
gotAbortSignal = ctx.abortSignal;
return `traceId=${!!ctx.trace.traceId} spanId=${!!ctx.trace.spanId}`;
}
);
const signal = new AbortController().signal;
await act(undefined, { abortSignal: signal });
assert.strictEqual(gotAbortSignal, signal);
});
it('passes through the abort signal with middleware', async () => {
var gotAbortSignal;
const act = defineAction(
registry,
{
name: 'child',
actionType: 'custom',
use: [async (input, next) => (await next(input + 'middle1')) + 1],
},
async (_, ctx) => {
gotAbortSignal = ctx.abortSignal;
return `traceId=${!!ctx.trace.traceId} spanId=${!!ctx.trace.spanId}`;
}
);
const signal = new AbortController().signal;
await act(undefined, { abortSignal: signal });
assert.strictEqual(gotAbortSignal, signal);
});
});