import { AsyncLocalStorage } from 'node:async_hooks';
import { Buffer } from 'node:buffer';
import { randomUUID } from 'node:crypto';
import { setInterval } from 'node:timers';
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import { type CancellableTimeout, createUnrefTimeout } from './timer-utils.js';
export type TaskStatus =
| 'working'
| 'input_required'
| 'completed'
| 'failed'
| 'cancelled';
export interface TaskError {
code: number;
message: string;
data?: unknown;
}
export interface TaskState {
taskId: string;
ownerKey: string;
status: TaskStatus;
statusMessage?: string;
createdAt: string;
lastUpdatedAt: string;
ttl: number; // in ms
pollInterval: number; // in ms
result?: unknown;
error?: TaskError;
}
interface InternalTaskState extends TaskState {
_createdAtMs: number;
}
export interface CreateTaskOptions {
ttl?: number;
}
export interface CreateTaskResult {
task: {
taskId: string;
status: TaskStatus;
statusMessage?: string;
createdAt: string;
lastUpdatedAt: string;
ttl: number;
pollInterval: number;
};
_meta?: Record<string, unknown>;
}
const DEFAULT_TTL_MS = 60_000;
const DEFAULT_POLL_INTERVAL_MS = 1_000;
const DEFAULT_OWNER_KEY = 'default';
const DEFAULT_PAGE_SIZE = 50;
const CLEANUP_INTERVAL_MS = 60_000;
const MAX_CURSOR_LENGTH = 256;
function isTerminalStatus(status: TaskStatus): boolean {
return (
status === 'completed' || status === 'failed' || status === 'cancelled'
);
}
type RunInContext = (fn: () => void) => void;
const asyncLocalStorageSnapshot = (
AsyncLocalStorage as unknown as {
snapshot?: () => RunInContext;
}
).snapshot;
const snapshotRunInContext: () => RunInContext =
typeof asyncLocalStorageSnapshot === 'function'
? asyncLocalStorageSnapshot
: () => (fn) => {
fn();
};
class TaskManager {
private tasks = new Map<string, InternalTaskState>();
private waiters = new Map<string, Set<(task: TaskState) => void>>();
constructor() {
this.startCleanupLoop();
}
private startCleanupLoop(): void {
const interval = setInterval(() => {
const now = Date.now();
for (const [id, task] of this.tasks) {
if (now - task._createdAtMs > task.ttl) {
this.tasks.delete(id);
}
}
}, CLEANUP_INTERVAL_MS);
interval.unref();
}
createTask(
options?: CreateTaskOptions,
statusMessage = 'Task started',
ownerKey: string = DEFAULT_OWNER_KEY
): TaskState {
const now = new Date();
const createdAt = now.toISOString();
const task: InternalTaskState = {
taskId: randomUUID(),
ownerKey,
status: 'working',
statusMessage,
createdAt,
lastUpdatedAt: createdAt,
ttl: options?.ttl ?? DEFAULT_TTL_MS,
pollInterval: DEFAULT_POLL_INTERVAL_MS,
_createdAtMs: now.getTime(),
};
this.tasks.set(task.taskId, task);
return task;
}
getTask(taskId: string, ownerKey?: string): TaskState | undefined {
const task = this.tasks.get(taskId);
if (!task) return undefined;
if (ownerKey && task.ownerKey !== ownerKey) return undefined;
if (this.isExpired(task)) {
this.tasks.delete(taskId);
return undefined;
}
return task;
}
updateTask(
taskId: string,
updates: Partial<Omit<TaskState, 'taskId' | 'createdAt'>>
): void {
const task = this.tasks.get(taskId);
if (!task) return;
if (updates.status && task.status !== updates.status) {
if (isTerminalStatus(task.status)) return;
}
Object.assign(task, {
...updates,
lastUpdatedAt: new Date().toISOString(),
});
this.notifyWaiters(task);
}
cancelTask(taskId: string, ownerKey?: string): TaskState | undefined {
const task = this.getTask(taskId, ownerKey);
if (!task) return undefined;
if (isTerminalStatus(task.status)) {
throw new McpError(
ErrorCode.InvalidParams,
`Cannot cancel task: already in terminal status '${task.status}'`
);
}
this.updateTask(taskId, {
status: 'cancelled',
statusMessage: 'The task was cancelled by request.',
});
return this.tasks.get(taskId);
}
private collectPage(
ownerKey: string,
startIndex: number,
pageSize: number
): TaskState[] {
const page: TaskState[] = [];
let currentIndex = 0;
const now = Date.now();
for (const task of this.tasks.values()) {
if (task.ownerKey !== ownerKey) continue;
if (now - task._createdAtMs > task.ttl) {
this.tasks.delete(task.taskId);
continue;
}
if (currentIndex >= startIndex) {
page.push(task);
if (page.length > pageSize) {
break;
}
}
currentIndex++;
}
return page;
}
listTasks(options: { ownerKey: string; cursor?: string; limit?: number }): {
tasks: TaskState[];
nextCursor?: string;
} {
const { ownerKey, cursor, limit } = options;
const pageSize = limit && limit > 0 ? limit : DEFAULT_PAGE_SIZE;
const startIndex = cursor ? this.decodeCursor(cursor) : 0;
if (startIndex === null) {
throw new McpError(ErrorCode.InvalidParams, 'Invalid cursor');
}
const page = this.collectPage(ownerKey, startIndex, pageSize);
const hasMore = page.length > pageSize;
if (hasMore) {
page.pop();
}
const nextCursor = hasMore
? this.encodeCursor(startIndex + page.length)
: undefined;
return nextCursor ? { tasks: page, nextCursor } : { tasks: page };
}
async waitForTerminalTask(
taskId: string,
ownerKey: string,
signal?: AbortSignal
): Promise<TaskState | undefined> {
const task = this.tasks.get(taskId);
if (!task) return undefined;
if (ownerKey && task.ownerKey !== ownerKey) return undefined;
if (this.isExpired(task)) {
this.tasks.delete(taskId);
return undefined;
}
if (isTerminalStatus(task.status)) return task;
const deadlineMs = task._createdAtMs + task.ttl;
const now = Date.now();
if (deadlineMs <= now) {
this.tasks.delete(taskId);
return undefined;
}
return new Promise((resolve, reject) => {
const runInContext = snapshotRunInContext();
const resolveInContext = (value: TaskState | undefined): void => {
runInContext(() => {
resolve(value);
});
};
const rejectInContext = (error: unknown): void => {
runInContext(() => {
reject(error instanceof Error ? error : new Error(String(error)));
});
};
let settled = false;
let waiter: ((updated: TaskState) => void) | null = null;
let deadlineTimeout: CancellableTimeout<{ timeout: true }> | undefined;
const cleanup = (): void => {
if (deadlineTimeout) {
deadlineTimeout.cancel();
deadlineTimeout = undefined;
}
if (signal) {
signal.removeEventListener('abort', onAbort);
}
};
const removeWaiter = (): void => {
if (waiter) {
const set = this.waiters.get(taskId);
if (set) {
set.delete(waiter);
if (set.size === 0) this.waiters.delete(taskId);
}
}
};
const settleOnce = (fn: () => void): void => {
if (settled) return;
settled = true;
fn();
};
const onAbort = (): void => {
settleOnce(() => {
cleanup();
removeWaiter();
rejectInContext(
new McpError(ErrorCode.ConnectionClosed, 'Request was cancelled')
);
});
};
waiter = (updated: TaskState): void => {
settleOnce(() => {
cleanup();
if (updated.ownerKey !== ownerKey) {
resolveInContext(undefined);
return;
}
resolveInContext(updated);
});
};
if (signal?.aborted) {
onAbort();
return;
}
let set = this.waiters.get(taskId);
if (!set) {
set = new Set();
this.waiters.set(taskId, set);
}
set.add(waiter);
if (signal) {
signal.addEventListener('abort', onAbort, { once: true });
}
const timeoutMs = Math.max(0, deadlineMs - Date.now());
deadlineTimeout = createUnrefTimeout(timeoutMs, { timeout: true });
void deadlineTimeout.promise
.then(() => {
settleOnce(() => {
cleanup();
removeWaiter();
this.tasks.delete(taskId);
resolveInContext(undefined);
});
})
.catch(rejectInContext);
});
}
private notifyWaiters(task: TaskState): void {
if (!isTerminalStatus(task.status)) return;
const waiters = this.waiters.get(task.taskId);
if (!waiters) return;
this.waiters.delete(task.taskId);
for (const waiter of waiters) waiter(task);
}
private isExpired(task: InternalTaskState): boolean {
return Date.now() - task._createdAtMs > task.ttl;
}
private encodeCursor(index: number): string {
return Buffer.from(String(index), 'utf8').toString('base64url');
}
private decodeCursor(cursor: string): number | null {
try {
if (!isValidBase64UrlCursor(cursor)) return null;
const decoded = Buffer.from(cursor, 'base64url').toString('utf8');
if (!/^\d+$/u.test(decoded)) return null;
const value = Number.parseInt(decoded, 10);
if (!Number.isFinite(value) || value < 0) return null;
return value;
} catch {
return null;
}
}
}
function isValidBase64UrlCursor(cursor: string): boolean {
if (!cursor) return false;
if (cursor.length > MAX_CURSOR_LENGTH) return false;
if (!/^[A-Za-z0-9_-]+={0,2}$/u.test(cursor)) return false;
const firstPaddingIndex = cursor.indexOf('=');
if (firstPaddingIndex !== -1) {
for (let i = firstPaddingIndex; i < cursor.length; i += 1) {
if (cursor[i] !== '=') return false;
}
return cursor.length % 4 === 0;
}
return cursor.length % 4 !== 1;
}
export const taskManager = new TaskManager();