mcp-reasoner
by Jacck
- mcp-reasoner
- dist
- strategies
import { v4 as uuidv4 } from 'uuid';
import { CONFIG } from '../types.js';
import { BaseStrategy } from './base.js';
export class MonteCarloTreeSearchStrategy extends BaseStrategy {
constructor(stateManager, numSimulations = CONFIG.numSimulations) {
super(stateManager);
this.explorationConstant = Math.sqrt(2);
this.simulationDepth = CONFIG.maxDepth;
this.root = null;
// Ensure numSimulations is within reasonable bounds
this.numSimulations = Math.max(1, Math.min(150, numSimulations));
}
async processThought(request) {
const nodeId = uuidv4();
const parentNode = request.parentId ?
await this.getNode(request.parentId) : undefined;
const node = {
id: nodeId,
thought: request.thought,
depth: request.thoughtNumber - 1,
score: 0,
children: [],
parentId: request.parentId,
isComplete: !request.nextThoughtNeeded,
visits: 0,
totalReward: 0,
untriedActions: []
};
// Initialize node
node.score = this.evaluateThought(node, parentNode);
node.visits = 1;
node.totalReward = node.score;
await this.saveNode(node);
// Update parent if exists
if (parentNode) {
parentNode.children.push(node.id);
await this.saveNode(parentNode);
}
// If this is a root node, store it
if (!parentNode) {
this.root = node;
}
// Run MCTS simulations
if (!node.isComplete) {
await this.runSimulations(node);
}
// Calculate path statistics
const currentPath = await this.stateManager.getPath(nodeId);
const pathScore = this.calculatePathScore(currentPath);
return {
nodeId: node.id,
thought: node.thought,
score: node.score,
depth: node.depth,
isComplete: node.isComplete,
nextThoughtNeeded: request.nextThoughtNeeded,
possiblePaths: this.calculatePossiblePaths(node),
bestScore: pathScore
};
}
async runSimulations(node) {
for (let i = 0; i < this.numSimulations; i++) {
const selectedNode = await this.select(node);
const expandedNode = await this.expand(selectedNode);
const reward = await this.simulate(expandedNode);
await this.backpropagate(expandedNode, reward);
}
}
async select(node) {
let current = node;
while (current.children.length > 0 && !current.untriedActions?.length) {
const children = await Promise.all(current.children.map(id => this.getNode(id)));
current = this.selectBestUCT(children);
}
return current;
}
async expand(node) {
if (node.isComplete || node.depth >= this.simulationDepth) {
return node;
}
// Create a new thought node as expansion
const newNode = {
id: uuidv4(),
thought: `Simulated thought at depth ${node.depth + 1}`,
depth: node.depth + 1,
score: 0,
children: [],
parentId: node.id,
isComplete: false,
visits: 1,
totalReward: 0
};
newNode.score = this.evaluateThought(newNode, node);
await this.saveNode(newNode);
node.children.push(newNode.id);
await this.saveNode(node);
return newNode;
}
async simulate(node) {
let current = node;
let totalScore = current.score;
let depth = current.depth;
while (depth < this.simulationDepth && !current.isComplete) {
const simulatedNode = {
id: uuidv4(),
thought: `Random simulation at depth ${depth + 1}`,
depth: depth + 1,
score: 0,
children: [],
parentId: current.id,
isComplete: depth + 1 >= this.simulationDepth,
visits: 1,
totalReward: 0
};
simulatedNode.score = this.evaluateThought(simulatedNode, current);
totalScore += simulatedNode.score;
current = simulatedNode;
depth++;
}
return totalScore / (depth - node.depth + 1);
}
async backpropagate(node, reward) {
let current = node;
while (current) {
current.visits++;
current.totalReward += reward;
await this.saveNode(current);
current = current.parentId ?
await this.getNode(current.parentId) :
undefined;
}
}
selectBestUCT(nodes) {
const totalVisits = nodes.reduce((sum, node) => sum + node.visits, 0);
return nodes.reduce((best, node) => {
const exploitation = node.totalReward / node.visits;
const exploration = Math.sqrt(Math.log(totalVisits) / node.visits);
const uct = exploitation + this.explorationConstant * exploration;
return uct > (best.totalReward / best.visits +
this.explorationConstant * Math.sqrt(Math.log(totalVisits) / best.visits))
? node : best;
});
}
calculatePathScore(path) {
if (path.length === 0)
return 0;
return path.reduce((acc, node) => acc + node.score, 0) / path.length;
}
calculatePossiblePaths(node) {
return Math.pow(2, this.simulationDepth - node.depth);
}
async getBestPath() {
if (!this.root)
return [];
const bestChild = (await Promise.all(this.root.children.map(id => this.getNode(id))))
.reduce((best, node) => node.visits > best.visits ? node : best);
return this.stateManager.getPath(bestChild.id);
}
async getMetrics() {
const baseMetrics = await super.getMetrics();
return {
...baseMetrics,
simulationDepth: this.simulationDepth,
numSimulations: this.numSimulations,
explorationConstant: this.explorationConstant,
totalSimulations: this.root?.visits || 0
};
}
async clear() {
await super.clear();
this.root = null;
}
}