tokenEstimationService.ts•9.7 kB
import type { Prompt, PromptArgument, Resource, Tool } from '@modelcontextprotocol/sdk/types.js';
import logger from '@src/logger/logger.js';
import { encoding_for_model, type TiktokenModel } from 'tiktoken';
/**
 * Token breakdown by capability type
 */
export interface TokenBreakdown {
  tools: ToolTokenInfo[];
  resources: ResourceTokenInfo[];
  prompts: PromptTokenInfo[];
  serverOverhead: number;
  totalTokens: number;
}
/**
 * Token information for a specific tool
 */
export interface ToolTokenInfo {
  name: string;
  tokens: number;
  description?: string;
}
/**
 * Token information for a specific resource
 */
export interface ResourceTokenInfo {
  uri: string;
  name?: string;
  tokens: number;
  mimeType?: string;
}
/**
 * Token information for a specific prompt
 */
export interface PromptTokenInfo {
  name: string;
  tokens: number;
  description?: string;
  argumentsTokens: number;
}
/**
 * Server token estimation summary
 */
export interface ServerTokenEstimate {
  serverName: string;
  connected: boolean;
  breakdown: TokenBreakdown;
  error?: string;
}
/**
 * Service for estimating MCP token usage using tiktoken
 */
export class TokenEstimationService {
  private encoder: any;
  private model: TiktokenModel;
  private static readonly BASE_SERVER_OVERHEAD = 75; // Base overhead for server connection
  private static readonly FALLBACK_CHARS_PER_TOKEN = 3.5; // Character-based fallback estimation
  constructor(model: string = 'gpt-4o') {
    try {
      // Validate and cast the model to TiktokenModel
      this.model = model as TiktokenModel;
      // Initialize encoder for the specified model
      this.encoder = encoding_for_model(this.model);
      logger.debug(`TokenEstimationService initialized with tiktoken ${this.model} encoding`);
    } catch (error) {
      logger.error(`Failed to initialize tiktoken encoder for model ${model}:`, error);
      logger.warn(`Falling back to gpt-4o encoding`);
      // Fallback to gpt-4o if the provided model fails
      try {
        this.model = 'gpt-4o';
        this.encoder = encoding_for_model(this.model);
      } catch (fallbackError) {
        logger.error('Failed to initialize fallback encoder:', fallbackError);
        this.encoder = null;
        this.model = 'gpt-4o'; // Keep default for logging purposes
      }
    }
  }
  /**
   * Estimate tokens for a single tool definition
   */
  private estimateToolTokens(tool: Tool): number {
    try {
      // Create a representative JSON structure for the tool
      const toolDefinition = {
        name: tool.name,
        description: tool.description,
        inputSchema: tool.inputSchema,
      };
      const toolJson = JSON.stringify(toolDefinition);
      if (this.encoder) {
        const tokens = this.encoder.encode(toolJson);
        return tokens.length;
      } else {
        // Fallback to character-based estimation
        return Math.ceil(toolJson.length / TokenEstimationService.FALLBACK_CHARS_PER_TOKEN);
      }
    } catch (error) {
      logger.warn(`Error estimating tokens for tool ${tool.name}:`, error);
      // Fallback estimation based on typical tool size
      return 150;
    }
  }
  /**
   * Estimate tokens for a single resource
   */
  private estimateResourceTokens(resource: Resource): number {
    try {
      // Create a representative structure for the resource metadata
      const resourceMeta = {
        uri: resource.uri,
        name: resource.name,
        description: resource.description,
        mimeType: resource.mimeType,
      };
      const resourceJson = JSON.stringify(resourceMeta);
      if (this.encoder) {
        const tokens = this.encoder.encode(resourceJson);
        return tokens.length;
      } else {
        // Fallback to character-based estimation
        return Math.ceil(resourceJson.length / TokenEstimationService.FALLBACK_CHARS_PER_TOKEN);
      }
    } catch (error) {
      logger.warn(`Error estimating tokens for resource ${resource.uri}:`, error);
      // Fallback estimation based on typical resource size
      return 50;
    }
  }
  /**
   * Estimate tokens for a single prompt template
   */
  private estimatePromptTokens(prompt: Prompt): number {
    try {
      // Estimate tokens for the prompt template and arguments
      let totalTokens = 0;
      // Count tokens for prompt name and description
      const promptMeta = {
        name: prompt.name,
        description: prompt.description,
      };
      const promptMetaJson = JSON.stringify(promptMeta);
      if (this.encoder) {
        totalTokens += this.encoder.encode(promptMetaJson).length;
      } else {
        totalTokens += Math.ceil(promptMetaJson.length / TokenEstimationService.FALLBACK_CHARS_PER_TOKEN);
      }
      // Count tokens for arguments
      if (prompt.arguments) {
        const argumentsJson = JSON.stringify(prompt.arguments);
        if (this.encoder) {
          totalTokens += this.encoder.encode(argumentsJson).length;
        } else {
          totalTokens += Math.ceil(argumentsJson.length / TokenEstimationService.FALLBACK_CHARS_PER_TOKEN);
        }
      }
      return totalTokens;
    } catch (error) {
      logger.warn(`Error estimating tokens for prompt ${prompt.name}:`, error);
      // Fallback estimation based on typical prompt size
      return 100;
    }
  }
  /**
   * Estimate tokens for all capabilities of a server
   */
  public estimateServerTokens(
    serverName: string,
    tools: Tool[] = [],
    resources: Resource[] = [],
    prompts: Prompt[] = [],
    connected: boolean = true,
  ): ServerTokenEstimate {
    try {
      logger.debug(`Estimating tokens for server: ${serverName}`);
      // Calculate token breakdown by capability type
      const toolTokens: ToolTokenInfo[] = tools.map((tool) => ({
        name: tool.name,
        description: tool.description,
        tokens: this.estimateToolTokens(tool),
      }));
      const resourceTokens: ResourceTokenInfo[] = resources.map((resource) => ({
        uri: resource.uri,
        name: resource.name,
        mimeType: resource.mimeType,
        tokens: this.estimateResourceTokens(resource),
      }));
      const promptTokens: PromptTokenInfo[] = prompts.map((prompt) => ({
        name: prompt.name,
        description: prompt.description,
        tokens: this.estimatePromptTokens(prompt),
        argumentsTokens: prompt.arguments ? this.estimateArgumentsTokens(prompt.arguments) : 0,
      }));
      // Calculate totals
      const totalToolTokens = toolTokens.reduce((sum, tool) => sum + tool.tokens, 0);
      const totalResourceTokens = resourceTokens.reduce((sum, resource) => sum + resource.tokens, 0);
      const totalPromptTokens = promptTokens.reduce((sum, prompt) => sum + prompt.tokens, 0);
      const serverOverhead = TokenEstimationService.BASE_SERVER_OVERHEAD;
      const breakdown: TokenBreakdown = {
        tools: toolTokens,
        resources: resourceTokens,
        prompts: promptTokens,
        serverOverhead,
        totalTokens: totalToolTokens + totalResourceTokens + totalPromptTokens + serverOverhead,
      };
      return {
        serverName,
        connected,
        breakdown,
      };
    } catch (error) {
      logger.error(`Error estimating tokens for server ${serverName}:`, error);
      return {
        serverName,
        connected,
        breakdown: {
          tools: [],
          resources: [],
          prompts: [],
          serverOverhead: TokenEstimationService.BASE_SERVER_OVERHEAD,
          totalTokens: TokenEstimationService.BASE_SERVER_OVERHEAD,
        },
        error: error instanceof Error ? error.message : 'Unknown error',
      };
    }
  }
  /**
   * Estimate tokens for prompt arguments
   */
  private estimateArgumentsTokens(arguments_: PromptArgument[]): number {
    try {
      const argumentsJson = JSON.stringify(arguments_);
      if (this.encoder) {
        const tokens = this.encoder.encode(argumentsJson);
        return tokens.length;
      } else {
        return Math.ceil(argumentsJson.length / TokenEstimationService.FALLBACK_CHARS_PER_TOKEN);
      }
    } catch (error) {
      logger.warn('Error estimating tokens for prompt arguments:', error);
      return 25; // Conservative fallback
    }
  }
  /**
   * Calculate aggregate statistics across multiple servers
   */
  public calculateAggregateStats(estimates: ServerTokenEstimate[]): {
    totalServers: number;
    connectedServers: number;
    totalTools: number;
    totalResources: number;
    totalPrompts: number;
    overallTokens: number;
    serverBreakdown: { [serverName: string]: number };
  } {
    const connectedEstimates = estimates.filter((est) => est.connected && !est.error);
    return {
      totalServers: estimates.length,
      connectedServers: connectedEstimates.length,
      totalTools: connectedEstimates.reduce((sum, est) => sum + est.breakdown.tools.length, 0),
      totalResources: connectedEstimates.reduce((sum, est) => sum + est.breakdown.resources.length, 0),
      totalPrompts: connectedEstimates.reduce((sum, est) => sum + est.breakdown.prompts.length, 0),
      overallTokens: connectedEstimates.reduce((sum, est) => sum + est.breakdown.totalTokens, 0),
      serverBreakdown: Object.fromEntries(connectedEstimates.map((est) => [est.serverName, est.breakdown.totalTokens])),
    };
  }
  /**
   * Clean up resources when done
   */
  public dispose(): void {
    if (this.encoder && typeof this.encoder.free === 'function') {
      try {
        this.encoder.free();
        logger.debug('TokenEstimationService encoder disposed');
      } catch (error) {
        logger.warn('Error disposing tiktoken encoder:', error);
      }
    }
  }
}