Skip to main content
Glama
mcts.js23.9 kB
/** * Monte Carlo Tree Search (MCTS) Pattern Handler Implementation * * Combines tree exploration with random sampling for decision-making * under uncertainty using selection, expansion, simulation, and backpropagation. */ import { v4 as uuidv4 } from 'uuid'; export class MCTSHandler { defaultConfig = { ucbVariant: 'ucb1', useRAVE: false, raveBias: 0.5, progressiveWidening: { enabled: false, alpha: 0.5, beta: 0.0 }, virtualLoss: 0, reuseTree: false, usePriors: false, useTranspositions: false }; defaultTerminationCriteria = { maxSimulations: 1000, timeLimit: 60000, confidenceThreshold: 0.95, minVisitsThreshold: 10 }; /** * Initialize a new MCTS session */ initializeSession(explorationConstant = 1.414, config) { const sessionId = uuidv4(); const rootNodeId = uuidv4(); const timestamp = new Date().toISOString(); // Create root node const rootNode = { id: rootNodeId, content: 'Root: Initial State', timestamp, childrenIds: [], visits: 0, totalValue: 0, averageValue: 0, untriedActions: this.generateInitialActions(), isTerminal: false, metadata: { depth: 0 } }; const nodes = new Map(); nodes.set(rootNodeId, rootNode); return { sessionId, patternType: 'mcts', iteration: 0, nextStepNeeded: true, createdAt: timestamp, updatedAt: timestamp, rootNodeId, nodes, explorationConstant, simulationDepth: 10, totalSimulations: 0, rolloutPolicy: 'random', terminationCriteria: { ...this.defaultTerminationCriteria }, config: { ...this.defaultConfig, ...config }, stats: { totalNodes: 1, maxDepth: 0, averageBranchingFactor: 0, simulationsPerSecond: 0, bestActionVisits: 0, bestActionValue: 0, explorationRatio: 0, effectiveBranchingFactor: 0 } }; } /** * Generate initial actions for root node */ generateInitialActions() { // In practice, these would be domain-specific return ['action1', 'action2', 'action3', 'action4', 'action5']; } /** * Select a leaf node using UCB */ selectLeaf(session) { let currentNode = session.nodes.get(session.rootNodeId); while (!this.isLeaf(currentNode) && !currentNode.isTerminal) { currentNode = this.selectBestChild(currentNode, session); } return currentNode; } /** * Check if node is a leaf (has untried actions or no children) */ isLeaf(node) { return (node.untriedActions && node.untriedActions.length > 0) || node.childrenIds.length === 0; } /** * Select best child using UCB formula */ selectBestChild(node, session) { let bestChild = null; let bestUCB = -Infinity; node.childrenIds.forEach(childId => { const child = session.nodes.get(childId); if (child) { const ucb = this.calculateUCB(child.id, node.visits, session); if (ucb > bestUCB) { bestUCB = ucb; bestChild = child; } } }); return bestChild || session.nodes.get(node.childrenIds[0]); } /** * Calculate UCB score for a node */ calculateUCB(nodeId, parentVisits, session) { const node = session.nodes.get(nodeId); if (!node) return 0; const components = { exploitationTerm: 0, explorationTerm: 0, ucbScore: 0 }; // Handle unvisited nodes if (node.visits === 0) { components.ucbScore = Infinity; return components.ucbScore; } // Calculate components based on variant switch (session.config.ucbVariant) { case 'ucb1': components.exploitationTerm = node.averageValue; components.explorationTerm = session.explorationConstant * Math.sqrt(Math.log(parentVisits) / node.visits); break; case 'ucb1-tuned': components.exploitationTerm = node.averageValue; const variance = this.calculateVariance(node); const V = variance + Math.sqrt(2 * Math.log(parentVisits) / node.visits); components.explorationTerm = Math.sqrt(Math.log(parentVisits) / node.visits * Math.min(0.25, V)); break; case 'puct': // PUCT formula (used in AlphaGo) components.exploitationTerm = node.averageValue; const prior = node.metadata?.priorProbability || 1 / (node.parentId ? session.nodes.get(node.parentId)?.childrenIds.length || 1 : 1); components.priorBias = prior; components.explorationTerm = session.explorationConstant * prior * Math.sqrt(parentVisits) / (1 + node.visits); break; default: components.exploitationTerm = node.averageValue; components.explorationTerm = session.explorationConstant * Math.sqrt(Math.log(parentVisits) / node.visits); } // Add RAVE term if enabled if (session.config.useRAVE && node.metadata?.action) { const raveStats = this.getRAVEStatistics(node.metadata.action, session); if (raveStats.visits > 0) { const beta = Math.sqrt(session.config.raveBias / (3 * node.visits + session.config.raveBias)); components.raveTerm = beta * raveStats.averageValue; components.exploitationTerm = (1 - beta) * components.exploitationTerm + beta * raveStats.averageValue; } } components.ucbScore = components.exploitationTerm + components.explorationTerm + (components.priorBias || 0); // Store UCB score in node node.ucbScore = components.ucbScore; return components.ucbScore; } /** * Calculate variance for UCB1-Tuned */ calculateVariance(node) { if (node.visits <= 1) return 0; // Simplified variance calculation // In practice, would track individual simulation values const mean = node.averageValue; return mean * (1 - mean); // Bernoulli variance approximation } /** * Get RAVE statistics for an action */ getRAVEStatistics(action, session) { // Simplified RAVE - in practice would track action values across all simulations let totalVisits = 0; let totalValue = 0; session.nodes.forEach(node => { if (node.metadata?.action === action) { totalVisits += node.visits; totalValue += node.totalValue; } }); return { visits: totalVisits, averageValue: totalVisits > 0 ? totalValue / totalVisits : 0 }; } /** * Expand a node by adding a child */ expandNode(nodeId, session) { const node = session.nodes.get(nodeId); if (!node || !node.untriedActions || node.untriedActions.length === 0) { throw new Error('Cannot expand node: no untried actions'); } // Select random untried action const actionIndex = Math.floor(Math.random() * node.untriedActions.length); const action = node.untriedActions[actionIndex]; // Remove action from untried list node.untriedActions.splice(actionIndex, 1); // Create child node const childId = uuidv4(); const childNode = { id: childId, content: `${node.content} -> ${action}`, timestamp: new Date().toISOString(), parentId: nodeId, childrenIds: [], visits: 0, totalValue: 0, averageValue: 0, untriedActions: this.generateActionsForState(action), isTerminal: this.isTerminalState(action, session), metadata: { action, depth: (node.metadata?.depth || 0) + 1, stateRepresentation: this.getStateRepresentation(node, action) } }; // Add child to tree session.nodes.set(childId, childNode); node.childrenIds.push(childId); // Update statistics session.stats.totalNodes++; session.stats.maxDepth = Math.max(session.stats.maxDepth, childNode.metadata?.depth || 0); return childNode; } /** * Generate actions for a given state */ generateActionsForState(previousAction) { // In practice, would generate based on game/problem state const actions = ['continue', 'branch_left', 'branch_right', 'backtrack', 'terminate']; return actions.filter(a => a !== previousAction); } /** * Check if state is terminal */ isTerminalState(action, session) { // Simple heuristic - terminate after certain actions or depth return action === 'terminate' || session.stats.maxDepth >= session.simulationDepth; } /** * Get state representation for a node */ getStateRepresentation(parentNode, action) { // In practice, would encode the actual game/problem state return { path: parentNode.content + ' -> ' + action, depth: (parentNode.metadata?.depth || 0) + 1, action }; } /** * Simulate from a node to terminal state */ simulate(nodeId, session) { const startNode = session.nodes.get(nodeId); if (!startNode) return 0; const result = { value: 0, path: [nodeId], steps: 0, duration: Date.now() }; // Perform rollout based on policy switch (session.rolloutPolicy) { case 'random': result.value = this.randomRollout(startNode, session, result); break; case 'heuristic': result.value = this.heuristicRollout(startNode, session, result); break; case 'neural': result.value = this.neuralRollout(startNode, session, result); break; case 'hybrid': result.value = Math.random() < 0.5 ? this.randomRollout(startNode, session, result) : this.heuristicRollout(startNode, session, result); break; } if (result.duration !== undefined) { result.duration = Date.now() - result.duration; } return result.value; } /** * Random rollout policy */ randomRollout(startNode, session, result) { let currentDepth = startNode.metadata?.depth || 0; let value = 0; while (currentDepth < session.simulationDepth && !startNode.isTerminal) { // Random action selection const randomValue = Math.random(); value += randomValue; currentDepth++; result.steps++; // Early termination based on random chance if (Math.random() < 0.1) break; } // Normalize value return result.steps > 0 ? value / result.steps : 0; } /** * Heuristic-based rollout policy */ heuristicRollout(startNode, session, result) { let currentDepth = startNode.metadata?.depth || 0; let value = 0; while (currentDepth < session.simulationDepth && !startNode.isTerminal) { // Use heuristic to guide simulation const heuristicValue = this.evaluateHeuristic(startNode, currentDepth); value += heuristicValue; currentDepth++; result.steps++; // Terminate if heuristic indicates solution if (heuristicValue > 0.9) break; } return result.steps > 0 ? value / result.steps : 0; } /** * Neural network guided rollout (placeholder) */ neuralRollout(startNode, session, result) { // In practice, would use a neural network for value estimation // For now, use a weighted random approach return this.heuristicRollout(startNode, session, result) * 0.7 + this.randomRollout(startNode, session, result) * 0.3; } /** * Evaluate heuristic value for a state */ evaluateHeuristic(node, depth) { // Simple heuristic based on depth and node properties const depthBonus = Math.max(0, 1 - depth / 20); const explorationBonus = node.visits > 0 ? 1 / (1 + node.visits) : 0.5; const randomFactor = Math.random() * 0.2; return (depthBonus * 0.5 + explorationBonus * 0.3 + randomFactor); } /** * Backpropagate simulation results */ backpropagate(leafNodeId, value, session) { let currentNodeId = leafNodeId; while (currentNodeId) { const node = session.nodes.get(currentNodeId); if (!node) break; // Update node statistics node.visits++; node.totalValue += value; node.averageValue = node.totalValue / node.visits; // Update win/loss/draw statistics if tracked if (node.metadata && !node.metadata.outcomes) { node.metadata.outcomes = { wins: 0, losses: 0, draws: 0 }; } if (node.metadata?.outcomes) { if (value > 0.6) node.metadata.outcomes.wins++; else if (value < 0.4) node.metadata.outcomes.losses++; else node.metadata.outcomes.draws++; } // Move to parent currentNodeId = node.parentId; // Alternate value for two-player games (minimax) // value = 1 - value; } } /** * Get best action from root */ getBestAction(session) { const rootNode = session.nodes.get(session.rootNodeId); if (!rootNode || rootNode.childrenIds.length === 0) { return 'no_action'; } // Select child with most visits (robust child) let bestChild = null; let maxVisits = -1; rootNode.childrenIds.forEach(childId => { const child = session.nodes.get(childId); if (child && child.visits > maxVisits) { maxVisits = child.visits; bestChild = child; } }); // Update statistics if (bestChild) { const child = bestChild; session.stats.bestActionVisits = child.visits; session.stats.bestActionValue = child.averageValue; session.bestLeafNodeId = child.id; } return bestChild?.metadata?.action || 'no_action'; } /** * Get action probabilities from a node */ getActionProbabilities(nodeId, session) { const node = session.nodes.get(nodeId); const probabilities = new Map(); if (!node || node.childrenIds.length === 0) { return probabilities; } const totalVisits = node.childrenIds.reduce((sum, childId) => { const child = session.nodes.get(childId); return sum + (child?.visits || 0); }, 0); if (totalVisits === 0) return probabilities; node.childrenIds.forEach(childId => { const child = session.nodes.get(childId); if (child && child.metadata?.action) { probabilities.set(child.metadata.action, child.visits / totalVisits); } }); return probabilities; } /** * Run one complete MCTS iteration */ runIteration(session) { const startTime = Date.now(); // Selection session.currentPhase = 'selection'; const leafNode = this.selectLeaf(session); // Expansion session.currentPhase = 'expansion'; let selectedNode = leafNode; if (!leafNode.isTerminal && leafNode.untriedActions && leafNode.untriedActions.length > 0) { selectedNode = this.expandNode(leafNode.id, session); } // Simulation session.currentPhase = 'simulation'; if (session.currentSimulationPath) { session.currentSimulationPath = [selectedNode.id]; } else { session.currentSimulationPath = [selectedNode.id]; } const value = this.simulate(selectedNode.id, session); // Backpropagation session.currentPhase = 'backpropagation'; this.backpropagate(selectedNode.id, value, session); // Update session session.totalSimulations++; session.iteration++; session.updatedAt = new Date().toISOString(); // Update statistics this.updateStatistics(session, Date.now() - startTime); // Check termination if (this.checkTermination(session)) { session.nextStepNeeded = false; } } /** * Check termination criteria */ checkTermination(session) { const criteria = session.terminationCriteria; // Max simulations if (criteria.maxSimulations && session.totalSimulations >= criteria.maxSimulations) { return true; } // Time limit if (criteria.timeLimit) { const elapsed = Date.now() - new Date(session.createdAt).getTime(); if (elapsed >= criteria.timeLimit) { return true; } } // Confidence threshold if (criteria.confidenceThreshold) { const probabilities = this.getActionProbabilities(session.rootNodeId, session); const maxProb = Math.max(...probabilities.values()); if (maxProb >= criteria.confidenceThreshold) { return true; } } // Minimum visits for best action if (criteria.minVisitsThreshold && session.stats.bestActionVisits && session.stats.bestActionVisits >= criteria.minVisitsThreshold) { return true; } return false; } /** * Update session statistics */ updateStatistics(session, iterationTime) { // Calculate simulations per second const totalTime = Date.now() - new Date(session.createdAt).getTime(); session.stats.simulationsPerSecond = session.totalSimulations / (totalTime / 1000); // Calculate average branching factor let totalBranches = 0; let nodeCount = 0; session.nodes.forEach(node => { if (node.childrenIds.length > 0) { totalBranches += node.childrenIds.length; nodeCount++; } }); session.stats.averageBranchingFactor = nodeCount > 0 ? totalBranches / nodeCount : 0; // Calculate effective branching factor (nodes with visits > threshold) const visitThreshold = session.totalSimulations * 0.01; let effectiveBranches = 0; session.nodes.forEach(node => { if (node.visits > visitThreshold) { effectiveBranches++; } }); session.stats.effectiveBranchingFactor = effectiveBranches / session.stats.maxDepth; // Calculate exploration ratio const rootNode = session.nodes.get(session.rootNodeId); if (rootNode) { const exploredChildren = rootNode.childrenIds.filter(id => { const child = session.nodes.get(id); return child && child.visits > 0; }).length; session.stats.explorationRatio = rootNode.childrenIds.length > 0 ? exploredChildren / rootNode.childrenIds.length : 0; } } /** * Export to sequential thinking format */ exportToSequentialFormat(session) { const thoughts = []; // Find best path from root to best leaf const bestPath = this.getBestPath(session); bestPath.forEach((nodeId, index) => { const node = session.nodes.get(nodeId); if (node) { thoughts.push({ thought: node.content, thoughtNumber: index + 1, totalThoughts: bestPath.length, nextThoughtNeeded: index < bestPath.length - 1 }); } }); return thoughts; } /** * Get best path through the tree */ getBestPath(session) { const path = []; let currentNodeId = session.rootNodeId; while (currentNodeId) { path.push(currentNodeId); const node = session.nodes.get(currentNodeId); if (!node || node.childrenIds.length === 0) break; // Select child with most visits let bestChildId; let maxVisits = -1; node.childrenIds.forEach(childId => { const child = session.nodes.get(childId); if (child && child.visits > maxVisits) { maxVisits = child.visits; bestChildId = childId; } }); currentNodeId = bestChildId; } return path; } /** * Import from sequential thinking format */ importFromSequentialFormat(thoughts) { const session = this.initializeSession(); let currentNodeId = session.rootNodeId; thoughts.forEach((thought, index) => { if (index === 0) { // Update root node const rootNode = session.nodes.get(currentNodeId); if (rootNode) { rootNode.content = thought.thought; } } else { // Expand and create path const parentNode = session.nodes.get(currentNodeId); if (parentNode) { const childNode = this.expandNode(currentNodeId, session); childNode.content = thought.thought; currentNodeId = childNode.id; } } }); return session; } /** * Get action statistics for analysis */ getActionStatistics(session) { const rootNode = session.nodes.get(session.rootNodeId); if (!rootNode) return []; const stats = []; rootNode.childrenIds.forEach(childId => { const child = session.nodes.get(childId); if (child && child.metadata?.action) { stats.push({ action: child.metadata.action, visits: child.visits, averageValue: child.averageValue, winRate: child.metadata.outcomes ? child.metadata.outcomes.wins / child.visits : undefined, prior: child.metadata.priorProbability }); } }); return stats.sort((a, b) => b.visits - a.visits); } } export default MCTSHandler; //# sourceMappingURL=mcts.js.map

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/waldzellai/clearthought-onepointfive'

If you have feedback or need assistance with the MCP directory API, please join our Discord server