/**
* @file Unified Tool Registry
* @description Central registry for tools that works for both MCP and API
*/
import {
UnifiedToolDefinition,
IToolRegistry,
ToolContext,
ToolResult,
ToolCategory,
ToolMiddleware,
} from './base.js';
import { logger } from '../../logging/logger.js';
import { metrics } from '../../logging/metrics.js';
interface RegistryEntry {
tool: UnifiedToolDefinition;
registeredAt: Date;
callCount: number;
errorCount: number;
totalDuration: number;
lastCalled?: Date;
}
/**
* Unified Tool Registry implementation
*/
export class UnifiedToolRegistry implements IToolRegistry {
private tools: Map<string, RegistryEntry> = new Map();
private middlewares: ToolMiddleware[] = [];
private categories: Map<ToolCategory, Set<string>> = new Map();
constructor() {
// Initialize category sets
const categories: ToolCategory[] = [
'ai', 'network', 'ops', 'security', 'database',
'file', 'code', 'chat', 'search', 'system'
];
categories.forEach(cat => this.categories.set(cat, new Set()));
}
/**
* Register a tool
*/
register<TInput, TOutput>(tool: UnifiedToolDefinition<TInput, TOutput>): void {
if (this.tools.has(tool.name)) {
logger.warn('Tool already registered, replacing', { toolName: tool.name });
}
const entry: RegistryEntry = {
tool: tool as UnifiedToolDefinition,
registeredAt: new Date(),
callCount: 0,
errorCount: 0,
totalDuration: 0,
};
this.tools.set(tool.name, entry);
this.categories.get(tool.category)?.add(tool.name);
logger.info('Tool registered', {
toolName: tool.name,
category: tool.category,
hasAuth: tool.metadata?.requiresAuth,
});
}
/**
* Unregister a tool
*/
unregister(name: string): boolean {
const entry = this.tools.get(name);
if (!entry) return false;
this.categories.get(entry.tool.category)?.delete(name);
this.tools.delete(name);
logger.info('Tool unregistered', { toolName: name });
return true;
}
/**
* Get a tool by name
*/
get(name: string): UnifiedToolDefinition | undefined {
return this.tools.get(name)?.tool;
}
/**
* List all tools with optional filtering
*/
list(filter?: { category?: ToolCategory; tags?: string[] }): UnifiedToolDefinition[] {
let tools = Array.from(this.tools.values()).map(e => e.tool);
if (filter?.category) {
tools = tools.filter(t => t.category === filter.category);
}
if (filter?.tags && filter.tags.length > 0) {
tools = tools.filter(t => {
const toolTags = t.metadata?.tags || [];
return filter.tags!.some(tag => toolTags.includes(tag));
});
}
return tools;
}
/**
* Execute a tool with full middleware support
*/
async execute<TInput, TOutput>(
name: string,
input: TInput,
context: ToolContext
): Promise<ToolResult<TOutput>> {
const entry = this.tools.get(name);
if (!entry) {
return {
success: false,
error: {
code: 'TOOL_NOT_FOUND',
message: `Tool not found: ${name}`,
},
};
}
const { tool } = entry;
const startTime = Date.now();
try {
// Validate input
const validatedInput = tool.inputSchema.parse(input);
// Run before middlewares
for (const mw of this.middlewares) {
if (mw.before) {
await mw.before(tool, validatedInput, context);
}
}
// Execute tool
logger.info('Executing tool', {
toolName: name,
executionId: context.executionId,
userId: context.userId,
});
const result = await tool.handler(validatedInput, context);
let duration = Date.now() - startTime;
if (duration <= 0) duration = 1; // ensure measurable duration for tests
// Update stats
entry.callCount++;
entry.totalDuration += duration;
entry.lastCalled = new Date();
// Add duration to result metadata
if (!result.metadata) result.metadata = {};
result.metadata.duration = duration;
// Validate output if schema provided
if (tool.outputSchema && result.success && result.data) {
try {
result.data = tool.outputSchema.parse(result.data);
} catch (validationError) {
logger.warn('Output validation failed', {
toolName: name,
error: validationError,
});
}
}
// Run after middlewares
for (const mw of this.middlewares) {
if (mw.after) {
await mw.after(tool, validatedInput, result, context);
}
}
// Record metrics
metrics.recordToolExecution(name, duration, result.success);
logger.info('Tool execution completed', {
toolName: name,
executionId: context.executionId,
success: result.success,
duration,
});
return result;
} catch (error) {
let duration = Date.now() - startTime;
if (duration <= 0) duration = 1; // ensure measurable duration for tests
entry.errorCount++;
// Run error middlewares
for (const mw of this.middlewares) {
if (mw.onError) {
await mw.onError(tool, input, error as Error, context);
}
}
logger.error('Tool execution failed', {
toolName: name,
executionId: context.executionId,
error: error instanceof Error ? error.message : String(error),
duration,
});
metrics.recordToolExecution(name, duration, false);
return {
success: false,
error: {
code: error instanceof Error && error.name ? error.name : 'EXECUTION_ERROR',
message: error instanceof Error ? error.message : String(error),
details: error instanceof Error ? error.stack : undefined,
},
metadata: { duration },
};
}
}
/**
* Get tool statistics
*/
getStats(name: string) {
const entry = this.tools.get(name);
if (!entry) return undefined;
return {
callCount: entry.callCount,
errorCount: entry.errorCount,
avgDuration: entry.callCount > 0 ? entry.totalDuration / entry.callCount : 0,
lastCalled: entry.lastCalled,
};
}
/**
* Add middleware
*/
use(middleware: ToolMiddleware): void {
this.middlewares.push(middleware);
logger.info('Middleware added to tool registry');
}
/**
* Get all statistics
*/
getAllStats() {
const stats: Record<string, any> = {};
for (const [name, entry] of this.tools.entries()) {
stats[name] = {
callCount: entry.callCount,
errorCount: entry.errorCount,
avgDuration: entry.callCount > 0 ? entry.totalDuration / entry.callCount : 0,
lastCalled: entry.lastCalled,
successRate: entry.callCount > 0
? ((entry.callCount - entry.errorCount) / entry.callCount * 100).toFixed(2) + '%'
: 'N/A',
};
}
return stats;
}
/**
* Clear all tools (useful for testing)
*/
clear(): void {
this.tools.clear();
this.categories.forEach(set => set.clear());
logger.info('Tool registry cleared');
}
}
/**
* Global unified tool registry instance
*/
export const unifiedRegistry = new UnifiedToolRegistry();