"""
Monte Carlo Tree Search implementation for decision making in Claude Code.
This module provides an advanced MCTS implementation that can be used to select
optimal actions/tools based on simulated outcomes.
"""
import math
import numpy as np
import random
from typing import List, Dict, Any, Callable, Tuple, Optional, Union
from dataclasses import dataclass
@dataclass
class MCTSNode:
"""Represents a node in the Monte Carlo search tree."""
state: Any
parent: Optional['MCTSNode'] = None
action_taken: Any = None
visits: int = 0
value: float = 0.0
children: Dict[Any, 'MCTSNode'] = None
def __post_init__(self):
if self.children is None:
self.children = {}
def is_fully_expanded(self, possible_actions: List[Any]) -> bool:
"""Check if all possible actions have been tried from this node."""
return all(action in self.children for action in possible_actions)
def is_terminal(self) -> bool:
"""Check if this node represents a terminal state."""
# This should be customized based on your environment
return False
def best_child(self, exploration_weight: float = 1.0) -> 'MCTSNode':
"""Select the best child node according to UCB1 formula."""
if not self.children:
return None
def ucb_score(child: MCTSNode) -> float:
exploitation = child.value / child.visits if child.visits > 0 else 0
exploration = math.sqrt(2 * math.log(self.visits) / child.visits) if child.visits > 0 else float('inf')
return exploitation + exploration_weight * exploration
return max(self.children.values(), key=ucb_score)
class AdvancedMCTS:
"""
Advanced Monte Carlo Tree Search implementation with various enhancements:
- Progressive widening for large/continuous action spaces
- RAVE (Rapid Action Value Estimation)
- Parallel simulations
- Dynamic exploration weight
- Customizable simulation and backpropagation strategies
"""
def __init__(
self,
state_evaluator: Callable[[Any], float],
action_generator: Callable[[Any], List[Any]],
simulator: Callable[[Any, Any], Any],
max_iterations: int = 1000,
exploration_weight: float = 1.0,
time_limit: Optional[float] = None,
progressive_widening: bool = False,
pw_coef: float = 0.5,
pw_power: float = 0.5,
use_rave: bool = False,
rave_equiv_param: float = 1000,
):
"""
Initialize the MCTS algorithm.
Args:
state_evaluator: Function to evaluate the value of a state (terminal or not)
action_generator: Function to generate possible actions from a state
simulator: Function to simulate taking an action in a state, returning new state
max_iterations: Maximum number of search iterations
exploration_weight: Controls exploration vs exploitation balance
time_limit: Optional time limit for search in seconds
progressive_widening: Whether to use progressive widening for large action spaces
pw_coef: Coefficient for progressive widening
pw_power: Power for progressive widening
use_rave: Whether to use RAVE (Rapid Action Value Estimation)
rave_equiv_param: RAVE equivalence parameter
"""
self.state_evaluator = state_evaluator
self.action_generator = action_generator
self.simulator = simulator
self.max_iterations = max_iterations
self.exploration_weight = exploration_weight
self.time_limit = time_limit
# Progressive widening parameters
self.progressive_widening = progressive_widening
self.pw_coef = pw_coef
self.pw_power = pw_power
# RAVE parameters
self.use_rave = use_rave
self.rave_equiv_param = rave_equiv_param
self.rave_values = {} # (state, action) -> (value, visits)
def search(self, initial_state: Any, visualizer=None) -> Any:
"""
Perform MCTS search from the initial state and return the best action.
Args:
initial_state: The starting state for the search
visualizer: Optional visualizer to show progress
Returns:
The best action found by the search
"""
root = MCTSNode(state=initial_state)
# Initialize visualizer if provided
if visualizer:
visualizer.set_search_parameters(root, self.max_iterations)
# Run iterations of the MCTS algorithm
for iteration in range(self.max_iterations):
# Selection phase
selected_node = self._select(root)
# Expansion phase (if not terminal)
expanded_node = None
if not selected_node.is_terminal():
expanded_node = self._expand(selected_node)
else:
expanded_node = selected_node
# Simulation phase
simulation_path = []
if visualizer:
# Track simulation path for visualization
current = expanded_node
current_state = current.state
while current.parent:
simulation_path.insert(0, (current.parent.state, current.action_taken))
current = current.parent
simulation_result = self._simulate(expanded_node)
# Backpropagation phase
self._backpropagate(expanded_node, simulation_result)
# Update visualization
if visualizer:
# Find current best action
best_action = None
if root.children:
best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
# Update visualizer
visualizer.update_iteration(
iteration=iteration + 1,
selected_node=selected_node,
expanded_node=expanded_node,
simulation_path=simulation_path,
simulation_result=simulation_result,
best_action=best_action
)
# Return the action that leads to the child with the highest value
if not root.children:
possible_actions = self.action_generator(root.state)
if possible_actions:
best_action = random.choice(possible_actions)
if visualizer:
visualizer.update_iteration(
iteration=self.max_iterations,
best_action=best_action
)
return best_action
return None
best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
if visualizer:
visualizer.update_iteration(
iteration=self.max_iterations,
best_action=best_action
)
return best_action
def _select(self, node: MCTSNode) -> MCTSNode:
"""
Select a node to expand using UCB1 and progressive widening if enabled.
Args:
node: The current node
Returns:
The selected node for expansion
"""
while not node.is_terminal():
possible_actions = self.action_generator(node.state)
# Handle progressive widening if enabled
if self.progressive_widening:
max_children = max(1, int(self.pw_coef * (node.visits ** self.pw_power)))
if len(node.children) < min(max_children, len(possible_actions)):
return node
# If not fully expanded, select this node for expansion
if not node.is_fully_expanded(possible_actions):
return node
# Otherwise, select the best child according to UCB1
node = node.best_child(self.exploration_weight)
if node is None:
break
return node
def _expand(self, node: MCTSNode) -> MCTSNode:
"""
Expand the node by selecting an untried action and creating a new child node.
Args:
node: The node to expand
Returns:
The newly created child node
"""
possible_actions = self.action_generator(node.state)
untried_actions = [a for a in possible_actions if a not in node.children]
if not untried_actions:
return node
action = random.choice(untried_actions)
new_state = self.simulator(node.state, action)
child_node = MCTSNode(
state=new_state,
parent=node,
action_taken=action
)
node.children[action] = child_node
return child_node
def _simulate(self, node: MCTSNode, depth: int = 10) -> float:
"""
Simulate a random playout from the given node until a terminal state or max depth.
Args:
node: The node to start simulation from
depth: Maximum simulation depth
Returns:
The value of the simulated outcome
"""
state = node.state
current_depth = 0
# Continue simulation until we reach a terminal state or max depth
while current_depth < depth:
if self._is_terminal_state(state):
break
possible_actions = self.action_generator(state)
if not possible_actions:
break
action = random.choice(possible_actions)
state = self.simulator(state, action)
current_depth += 1
return self.state_evaluator(state)
def _is_terminal_state(self, state: Any) -> bool:
"""Determine if the state is terminal."""
# This should be customized based on your environment
return False
def _backpropagate(self, node: MCTSNode, value: float) -> None:
"""
Backpropagate the simulation result up the tree.
Args:
node: The leaf node where simulation started
value: The value from the simulation
"""
while node is not None:
node.visits += 1
node.value += value
# Update RAVE values if enabled
if self.use_rave and node.parent is not None:
state_hash = self._hash_state(node.parent.state)
action = node.action_taken
if (state_hash, action) not in self.rave_values:
self.rave_values[(state_hash, action)] = [0, 0] # [value, visits]
rave_value, rave_visits = self.rave_values[(state_hash, action)]
self.rave_values[(state_hash, action)] = [
rave_value + value,
rave_visits + 1
]
node = node.parent
def _hash_state(self, state: Any) -> int:
"""Create a hash of the state for RAVE table lookups."""
# This should be customized based on your state representation
if hasattr(state, "__hash__"):
return hash(state)
return hash(str(state))
class MCTSToolSelector:
"""
Specialized MCTS implementation for selecting optimal tools in Claude Code.
This class adapts the AdvancedMCTS for the specific context of tool selection.
"""
def __init__(
self,
tool_registry: Any, # Should be a reference to the tool registry
context_evaluator: Callable, # Function to evaluate quality of response given context
max_iterations: int = 200,
exploration_weight: float = 1.0,
use_learning: bool = True,
tool_history_weight: float = 0.7,
enable_plan_generation: bool = True,
use_semantic_similarity: bool = True,
adaptation_rate: float = 0.05
):
"""
Initialize the MCTS tool selector with enhanced intelligence.
Args:
tool_registry: Registry containing available tools
context_evaluator: Function to evaluate response quality
max_iterations: Maximum search iterations
exploration_weight: Controls exploration vs exploitation
use_learning: Whether to use learning from past tool selections
tool_history_weight: Weight given to historical tool performance
enable_plan_generation: Generate complete tool sequences as plans
use_semantic_similarity: Use semantic similarity for tool relevance
adaptation_rate: Rate at which the system adapts to new patterns
"""
self.tool_registry = tool_registry
self.context_evaluator = context_evaluator
self.use_learning = use_learning
self.tool_history_weight = tool_history_weight
self.enable_plan_generation = enable_plan_generation
self.use_semantic_similarity = use_semantic_similarity
self.adaptation_rate = adaptation_rate
# Tool performance history by query type
self.tool_history = {}
# Tool sequence effectiveness records
self.sequence_effectiveness = {}
# Semantic fingerprints for tools and queries
self.tool_fingerprints = {}
self.query_clusters = {}
# Cached simulation results for similar queries
self.simulation_cache = {}
# Initialize the MCTS algorithm
self.mcts = AdvancedMCTS(
state_evaluator=self._evaluate_state,
action_generator=self._generate_actions,
simulator=self._simulate_action,
max_iterations=max_iterations,
exploration_weight=exploration_weight,
progressive_widening=True
)
# Initialize tool fingerprints
self._initialize_tool_fingerprints()
def _initialize_tool_fingerprints(self):
"""Initialize semantic fingerprints for all available tools."""
if not self.use_semantic_similarity:
return
for tool_name in self.tool_registry.get_all_tool_names():
tool = self.tool_registry.get_tool(tool_name)
if tool and hasattr(tool, 'description'):
# In a real implementation, this would compute an embedding
# For now, we'll use a simple keyword extraction as a placeholder
keywords = set(word.lower() for word in tool.description.split()
if len(word) > 3)
self.tool_fingerprints[tool_name] = {
'keywords': keywords,
'description': tool.description,
'usage_contexts': set()
}
def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> Union[str, List[str]]:
"""
Select the best tool to use for a given user query and context.
Args:
user_query: The user's query
context: The current conversation context
visualizer: Optional visualizer to show the selection process
Returns:
Either a single tool name or a sequence of tool names (if plan generation is enabled)
"""
# Analyze query to determine its type/characteristics
query_type = self._analyze_query(user_query)
# Update semantic fingerprints with this query
if self.use_semantic_similarity:
self._update_query_clusters(user_query, query_type)
initial_state = {
'query': user_query,
'query_type': query_type,
'context': context,
'actions_taken': [],
'response_quality': 0.0,
'steps_remaining': 3 if self.enable_plan_generation else 1,
'step_results': {}
}
# First check if we have a high-confidence cached result for similar queries
cached_result = self._check_cache(user_query, query_type)
if cached_result and random.random() > 0.1: # 10% random exploration
if visualizer:
visualizer.add_execution(
execution_id="mcts_cache_hit",
tool_name="MCTS Tool Selection (cached)",
parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
)
visualizer.complete_execution(
execution_id="mcts_cache_hit",
result={"selected_tool": cached_result, "source": "cache"},
status="success"
)
return cached_result
# Run MCTS search
best_action = self.mcts.search(initial_state, visualizer)
# If plan generation is enabled, we might want to return a sequence
if self.enable_plan_generation:
# Extract the most promising action sequence from search
plan = self._extract_plan_from_search()
if plan and len(plan) > 1:
# Store this plan in our cache
self._cache_result(user_query, query_type, plan)
return plan
# Store single action in cache
self._cache_result(user_query, query_type, best_action)
return best_action
def _analyze_query(self, query: str) -> str:
"""
Analyze a query to determine its type and characteristics.
Args:
query: The user query
Returns:
A string identifying the query type
"""
query_lower = query.lower()
# Check for search-related queries
if any(term in query_lower for term in ['find', 'search', 'where', 'look for']):
return 'search'
# Check for explanation queries
if any(term in query_lower for term in ['explain', 'how', 'why', 'what is']):
return 'explanation'
# Check for file operation queries
if any(term in query_lower for term in ['file', 'read', 'write', 'edit', 'create']):
return 'file_operation'
# Check for execution queries
if any(term in query_lower for term in ['run', 'execute', 'start']):
return 'execution'
# Check for debugging queries
if any(term in query_lower for term in ['debug', 'fix', 'error', 'problem']):
return 'debugging'
# Default to general
return 'general'
def _update_query_clusters(self, query: str, query_type: str):
"""
Update query clusters with new query information.
Args:
query: The user query
query_type: The type of query
"""
# Extract query keywords
keywords = set(word.lower() for word in query.split() if len(word) > 3)
# Update query clusters
if query_type not in self.query_clusters:
self.query_clusters[query_type] = {
'keywords': set(),
'queries': []
}
# Add keywords to cluster
self.query_clusters[query_type]['keywords'].update(keywords)
# Add query to cluster (limit to last 50)
self.query_clusters[query_type]['queries'].append(query)
if len(self.query_clusters[query_type]['queries']) > 50:
self.query_clusters[query_type]['queries'].pop(0)
# Update tool fingerprints with these keywords
for tool_name, fingerprint in self.tool_fingerprints.items():
# If tool has been used successfully for this query type before
if tool_name in self.tool_history.get(query_type, {}) and \
self.tool_history[query_type][tool_name]['success_rate'] > 0.6:
fingerprint['usage_contexts'].add(query_type)
def _check_cache(self, query: str, query_type: str) -> Union[str, List[str], None]:
"""
Check if we have a cached result for a similar query.
Args:
query: The user query
query_type: The type of query
Returns:
A cached tool selection or None
"""
if not self.use_learning or query_type not in self.tool_history:
return None
# Find the most successful tool for this query type
type_history = self.tool_history[query_type]
best_tools = sorted(
[(tool, data['success_rate']) for tool, data in type_history.items()],
key=lambda x: x[1],
reverse=True
)
# Only use cache if we have a high confidence result
if best_tools and best_tools[0][1] > 0.75:
return best_tools[0][0]
return None
def _cache_result(self, query: str, query_type: str, action: Union[str, List[str]]):
"""
Cache a result for future similar queries.
Args:
query: The user query
query_type: The type of query
action: The selected action or plan
"""
# Store in simulation cache
query_key = self._get_query_cache_key(query)
self.simulation_cache[query_key] = {
'action': action,
'timestamp': self._get_timestamp(),
'query_type': query_type
}
# Limit cache size
if len(self.simulation_cache) > 1000:
# Remove oldest entries
oldest_key = min(self.simulation_cache.keys(),
key=lambda k: self.simulation_cache[k]['timestamp'])
del self.simulation_cache[oldest_key]
def _get_query_cache_key(self, query: str) -> str:
"""Generate a cache key for a query."""
# In a real implementation, this might use a hash of query embeddings
# For now, use a simple keyword approach
keywords = ' '.join(sorted(set(word.lower() for word in query.split() if len(word) > 3)))
return keywords[:100] # Limit key length
def _get_timestamp(self):
"""Get current timestamp."""
import time
return time.time()
def _evaluate_state(self, state: Dict[str, Any]) -> float:
"""
Evaluate the quality of a state based on response quality and steps.
Args:
state: The current state
Returns:
A quality score
"""
# Base score is the response quality
score = state['response_quality']
# If plan generation is enabled, we want to encourage complete plans
if self.enable_plan_generation:
steps_completed = len(state['actions_taken'])
total_steps = steps_completed + state['steps_remaining']
# Add bonus for completing more steps
if total_steps > 0:
step_completion_bonus = steps_completed / total_steps
score += step_completion_bonus * 0.2 # 20% bonus for step completion
return score
def _generate_actions(self, state: Dict[str, Any]) -> List[str]:
"""
Generate possible tool actions from the current state with intelligent filtering.
Args:
state: The current state
Returns:
List of possible actions
"""
# Get query type
query_type = state['query_type']
query = state['query']
# Get all available tools
all_tools = set(self.tool_registry.get_all_tool_names())
# Tools already used in this sequence
used_tools = set(state['actions_taken'])
# Remaining tools
remaining_tools = all_tools - used_tools
# If we're using learning, prioritize tools based on history
if self.use_learning and query_type in self.tool_history:
prioritized_tools = []
# First, add tools that have been successful for this query type
type_history = self.tool_history[query_type]
# Check for successful tools
for tool in remaining_tools:
if tool in type_history and type_history[tool]['success_rate'] > 0.5:
prioritized_tools.append(tool)
# If we have at least some tools, return them
if prioritized_tools and random.random() < self.tool_history_weight:
return prioritized_tools
# If using semantic similarity, filter by relevant tools
if self.use_semantic_similarity:
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
# Score tools by semantic similarity to query
scored_tools = []
for tool in remaining_tools:
if tool in self.tool_fingerprints:
fingerprint = self.tool_fingerprints[tool]
# Calculate keyword overlap
keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
# Check if tool has been used for this query type
context_match = 1.0 if query_type in fingerprint['usage_contexts'] else 0.0
# Combined score
score = keyword_overlap * 0.7 + context_match * 0.3
scored_tools.append((tool, score))
# Sort and filter tools
scored_tools.sort(key=lambda x: x[1], reverse=True)
# Take top half of tools if we have enough
if len(scored_tools) > 2:
return [t[0] for t in scored_tools[:max(2, len(scored_tools) // 2)]]
# If we reach here, use all remaining tools
return list(remaining_tools)
def _simulate_action(self, state: Dict[str, Any], action: str) -> Dict[str, Any]:
"""
Simulate taking an action (using a tool) in the given state with enhanced modeling.
Args:
state: The current state
action: The tool action to simulate
Returns:
The new state after taking the action
"""
# Create a new state with the action added
new_state = state.copy()
new_actions = state['actions_taken'].copy()
new_actions.append(action)
new_state['actions_taken'] = new_actions
# Decrement steps remaining if using plan generation
if self.enable_plan_generation and new_state['steps_remaining'] > 0:
new_state['steps_remaining'] -= 1
# Get query type and query
query_type = state['query_type']
query = state['query']
# Simulate step result
step_results = state['step_results'].copy()
step_results[action] = self._simulate_tool_result(action, query)
new_state['step_results'] = step_results
# Estimate tool relevance based on learning or semantic similarity
tool_relevance = self._estimate_tool_relevance(action, query, query_type)
# Check for sequence effects (tools that work well together)
sequence_bonus = 0.0
if len(new_actions) > 1:
prev_tool = new_actions[-2]
sequence_key = f"{prev_tool}->{action}"
if sequence_key in self.sequence_effectiveness:
sequence_bonus = self.sequence_effectiveness[sequence_key] * 0.3 # 30% weight for sequence effects
# Update quality based on relevance and sequence effects
current_quality = state['response_quality']
quality_improvement = tool_relevance + sequence_bonus
# Add diminishing returns effect for additional tools
if len(new_actions) > 1:
diminishing_factor = 1.0 / len(new_actions)
quality_improvement *= diminishing_factor
new_quality = min(1.0, current_quality + quality_improvement)
new_state['response_quality'] = new_quality
return new_state
def _simulate_tool_result(self, tool_name: str, query: str) -> Dict[str, Any]:
"""
Simulate the result of using a tool for a query.
Args:
tool_name: The name of the tool
query: The user query
Returns:
A simulated result
"""
# In a real implementation, this would be a more sophisticated simulation
return {
"tool": tool_name,
"success_probability": self._estimate_tool_relevance(tool_name, query),
"simulated": True
}
def _estimate_tool_relevance(self, tool_name: str, query: str, query_type: str = None) -> float:
"""
Estimate how relevant a tool is for a given query using history and semantics.
Args:
tool_name: The name of the tool
query: The user query
query_type: Optional query type
Returns:
A relevance score between 0.0 and 1.0
"""
relevance_score = 0.0
# If we have historical data for this query type
if self.use_learning and query_type and query_type in self.tool_history and \
tool_name in self.tool_history[query_type]:
# Get historical success rate
history_score = self.tool_history[query_type][tool_name]['success_rate']
relevance_score += history_score * self.tool_history_weight
# If we're using semantic similarity
if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
fingerprint = self.tool_fingerprints[tool_name]
# Calculate keyword overlap
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
# Normalize by query keywords
if query_keywords:
semantic_score = keyword_overlap / len(query_keywords)
relevance_score += semantic_score * (1.0 - self.tool_history_weight)
# Ensure we have a minimum score for exploration
if relevance_score < 0.1:
relevance_score = 0.1 + (random.random() * 0.1) # Random boost between 0.1-0.2
return relevance_score
def _extract_plan_from_search(self) -> List[str]:
"""
Extract a complete plan (tool sequence) from the search results.
Returns:
A list of tool names representing the plan
"""
# In a real implementation, this would extract the highest value path
# from the search tree. For now, return None to indicate no plan extraction.
return None
def update_tool_history(self, tool_name: str, query: str, success: bool,
execution_time: float, result: Any = None):
"""
Update the tool history with the results of using a tool.
Args:
tool_name: The name of the tool used
query: The query the tool was used for
success: Whether the tool was successful
execution_time: The execution time in seconds
result: Optional result of the tool execution
"""
if not self.use_learning:
return
# Get query type
query_type = self._analyze_query(query)
# Initialize history entry if needed
if query_type not in self.tool_history:
self.tool_history[query_type] = {}
if tool_name not in self.tool_history[query_type]:
self.tool_history[query_type][tool_name] = {
'success_count': 0,
'failure_count': 0,
'total_time': 0.0,
'success_rate': 0.0,
'avg_time': 0.0,
'examples': []
}
# Update history
history = self.tool_history[query_type][tool_name]
# Update counts
if success:
history['success_count'] += 1
else:
history['failure_count'] += 1
# Update time
history['total_time'] += execution_time
# Update success rate
total = history['success_count'] + history['failure_count']
history['success_rate'] = history['success_count'] / total if total > 0 else 0.0
# Update average time
history['avg_time'] = history['total_time'] / total if total > 0 else 0.0
# Add example (limit to last 5)
history['examples'].append({
'query': query,
'success': success,
'timestamp': self._get_timestamp()
})
if len(history['examples']) > 5:
history['examples'].pop(0)
# Update tool fingerprint
if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
if success:
# Add query type to usage contexts
self.tool_fingerprints[tool_name]['usage_contexts'].add(query_type)
# Add query keywords to tool fingerprint (with decay)
query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
current_keywords = self.tool_fingerprints[tool_name]['keywords']
# Add new keywords with adaptation rate
for keyword in query_keywords:
if keyword not in current_keywords:
if random.random() < self.adaptation_rate:
current_keywords.add(keyword)
def update_sequence_effectiveness(self, tool_sequence: List[str], success: bool, quality_score: float):
"""
Update the effectiveness record for a sequence of tools.
Args:
tool_sequence: The sequence of tools used
success: Whether the sequence was successful
quality_score: A quality score for the sequence
"""
if not self.use_learning or len(tool_sequence) < 2:
return
# Update pairwise effectiveness
for i in range(len(tool_sequence) - 1):
first_tool = tool_sequence[i]
second_tool = tool_sequence[i + 1]
sequence_key = f"{first_tool}->{second_tool}"
if sequence_key not in self.sequence_effectiveness:
self.sequence_effectiveness[sequence_key] = 0.5 # Initial neutral score
# Update score with decay
current_score = self.sequence_effectiveness[sequence_key]
if success:
# Increase score with quality bonus
new_score = current_score + self.adaptation_rate * quality_score
else:
# Decrease score
new_score = current_score - self.adaptation_rate
# Clamp between 0 and 1
self.sequence_effectiveness[sequence_key] = max(0.0, min(1.0, new_score))