Skip to main content
Glama
mcts.ts5.37 kB
import { v4 as uuidv4 } from 'uuid'; import { AttackNode, ReasoningRequest, ReasoningResponse, CONFIG } from '../types.js'; import { BaseStrategy } from './base.js'; import { StateManager } from '../state.js'; interface MCTSNode extends AttackNode { visits: number; ucb1Score?: number; } export class MCTSStrategy extends BaseStrategy { private readonly explorationConstant = Math.sqrt(2); private readonly simulationDepth = 3; constructor(stateManager: StateManager) { super(stateManager); } public async processAttackStep(request: ReasoningRequest): Promise<ReasoningResponse> { // Create or get root node const rootNode: MCTSNode = { id: request.parentId || 'root', attackStep: request.attackStep, depth: request.attackStepNumber - 1, visits: 0, score: 0, isComplete: !request.nextAttackStepNeeded, }; await this.saveNode(rootNode); // Run MCTS iterations for (let i = 0; i < CONFIG.mctsIterations; i++) { const selectedNode = await this.select(rootNode); const expandedNode = await this.expand(selectedNode); const simulationScore = await this.simulate(expandedNode); await this.backpropagate(expandedNode, simulationScore); } // Get best child of root const bestChild = await this.getBestChild(rootNode); return { nodeId: bestChild.id, attackStep: bestChild.attackStep, score: bestChild.score || 0, strategyUsed: 'mcts', nextAttackStepNeeded: request.nextAttackStepNeeded, }; } private async select(node: MCTSNode): Promise<MCTSNode> { let current = node; while (Array.isArray(current.children) && current.children.length > 0) { current = await this.selectBestUCB1(current); } return current; } private async expand(node: MCTSNode): Promise<MCTSNode> { // Create a new attack step node as expansion const newNode: MCTSNode = { id: `${node.id}-${Date.now()}`, attackStep: `Simulated attack step at depth ${node.depth + 1}`, depth: (node.depth || 0) + 1, visits: 0, score: 0, isComplete: false }; // Score and save newNode.score = this.evaluateAttackStep(newNode, node); await this.saveNode(newNode); // Update parent-child relationship if (!node.children) node.children = []; node.children.push(newNode); newNode.parent = node; return newNode; } private async simulate(node: MCTSNode): Promise<number> { let current = node; let totalScore = current.score || 0; // Random playout for (let depth = 0; depth < this.simulationDepth; depth++) { const simulatedNode: MCTSNode = { id: `sim-${Date.now()}-${depth}`, attackStep: `Random attack simulation at depth ${depth + 1}`, depth: (current.depth || 0) + 1, visits: 1, score: 0, isComplete: depth === this.simulationDepth - 1 }; simulatedNode.score = this.evaluateAttackStep(simulatedNode, current); totalScore += simulatedNode.score || 0; current = simulatedNode; } return totalScore / (this.simulationDepth + 1); } private async backpropagate(node: MCTSNode, score: number) { let current: MCTSNode | undefined = node; while (current) { current.visits++; if (current.score !== undefined) { current.score = ((current.score * (current.visits - 1)) + score) / current.visits; } current = current.parent as MCTSNode; } } private async selectBestUCB1(node: MCTSNode): Promise<MCTSNode> { const children = (node.children || []).filter((c): c is MCTSNode => typeof (c as MCTSNode).visits === 'number'); const totalVisits = node.visits; for (const child of children) { const exploitation = (child.score || 0); const exploration = Math.sqrt(Math.log(totalVisits) / (child.visits || 1)); child.ucb1Score = exploitation + this.explorationConstant * exploration; } return children.reduce((a, b) => (a.ucb1Score || 0) > (b.ucb1Score || 0) ? a : b); } private async getBestChild(node: MCTSNode): Promise<MCTSNode> { const children = (node.children || []).filter((c): c is MCTSNode => typeof (c as MCTSNode).visits === 'number'); return children.reduce((a, b) => (a.visits > b.visits) ? a : b); } private calculatePathScore(path: AttackNode[]): number { if (path.length === 0) return 0; return path.reduce((sum, node) => sum + (node.score || 0), 0) / path.length; } public async getBestPath(): Promise<AttackNode[]> { const nodes = Array.from(this.nodes.values()); if (nodes.length === 0) return []; const completePaths = this.findCompletePaths(nodes); return completePaths.reduce((bestPath, currentPath) => this.calculatePathScore(currentPath) > this.calculatePathScore(bestPath) ? currentPath : bestPath ); } private findCompletePaths(nodes: AttackNode[]): AttackNode[][] { const endNodes = nodes.filter(n => n.isComplete); return endNodes.map(end => this.constructPath(end)); } private constructPath(endNode: AttackNode): AttackNode[] { const path: AttackNode[] = []; let current: AttackNode | undefined = endNode; while (current) { path.unshift(current); current = current.parent; } return path; } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/LT7T/SecMCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server