OpenAI MCP Server

""" Monte Carlo Tree Search implementation for decision making in Claude Code. This module provides an advanced MCTS implementation that can be used to select optimal actions/tools based on simulated outcomes. """ import math import numpy as np import random from typing import List, Dict, Any, Callable, Tuple, Optional, Union from dataclasses import dataclass @dataclass class MCTSNode: """Represents a node in the Monte Carlo search tree.""" state: Any parent: Optional['MCTSNode'] = None action_taken: Any = None visits: int = 0 value: float = 0.0 children: Dict[Any, 'MCTSNode'] = None def __post_init__(self): if self.children is None: self.children = {} def is_fully_expanded(self, possible_actions: List[Any]) -> bool: """Check if all possible actions have been tried from this node.""" return all(action in self.children for action in possible_actions) def is_terminal(self) -> bool: """Check if this node represents a terminal state.""" # This should be customized based on your environment return False def best_child(self, exploration_weight: float = 1.0) -> 'MCTSNode': """Select the best child node according to UCB1 formula.""" if not self.children: return None def ucb_score(child: MCTSNode) -> float: exploitation = child.value / child.visits if child.visits > 0 else 0 exploration = math.sqrt(2 * math.log(self.visits) / child.visits) if child.visits > 0 else float('inf') return exploitation + exploration_weight * exploration return max(self.children.values(), key=ucb_score) class AdvancedMCTS: """ Advanced Monte Carlo Tree Search implementation with various enhancements: - Progressive widening for large/continuous action spaces - RAVE (Rapid Action Value Estimation) - Parallel simulations - Dynamic exploration weight - Customizable simulation and backpropagation strategies """ def __init__( self, state_evaluator: Callable[[Any], float], action_generator: Callable[[Any], List[Any]], simulator: Callable[[Any, Any], Any], max_iterations: int = 1000, exploration_weight: float = 1.0, time_limit: Optional[float] = None, progressive_widening: bool = False, pw_coef: float = 0.5, pw_power: float = 0.5, use_rave: bool = False, rave_equiv_param: float = 1000, ): """ Initialize the MCTS algorithm. Args: state_evaluator: Function to evaluate the value of a state (terminal or not) action_generator: Function to generate possible actions from a state simulator: Function to simulate taking an action in a state, returning new state max_iterations: Maximum number of search iterations exploration_weight: Controls exploration vs exploitation balance time_limit: Optional time limit for search in seconds progressive_widening: Whether to use progressive widening for large action spaces pw_coef: Coefficient for progressive widening pw_power: Power for progressive widening use_rave: Whether to use RAVE (Rapid Action Value Estimation) rave_equiv_param: RAVE equivalence parameter """ self.state_evaluator = state_evaluator self.action_generator = action_generator self.simulator = simulator self.max_iterations = max_iterations self.exploration_weight = exploration_weight self.time_limit = time_limit # Progressive widening parameters self.progressive_widening = progressive_widening self.pw_coef = pw_coef self.pw_power = pw_power # RAVE parameters self.use_rave = use_rave self.rave_equiv_param = rave_equiv_param self.rave_values = {} # (state, action) -> (value, visits) def search(self, initial_state: Any, visualizer=None) -> Any: """ Perform MCTS search from the initial state and return the best action. Args: initial_state: The starting state for the search visualizer: Optional visualizer to show progress Returns: The best action found by the search """ root = MCTSNode(state=initial_state) # Initialize visualizer if provided if visualizer: visualizer.set_search_parameters(root, self.max_iterations) # Run iterations of the MCTS algorithm for iteration in range(self.max_iterations): # Selection phase selected_node = self._select(root) # Expansion phase (if not terminal) expanded_node = None if not selected_node.is_terminal(): expanded_node = self._expand(selected_node) else: expanded_node = selected_node # Simulation phase simulation_path = [] if visualizer: # Track simulation path for visualization current = expanded_node current_state = current.state while current.parent: simulation_path.insert(0, (current.parent.state, current.action_taken)) current = current.parent simulation_result = self._simulate(expanded_node) # Backpropagation phase self._backpropagate(expanded_node, simulation_result) # Update visualization if visualizer: # Find current best action best_action = None if root.children: best_action = max(root.children.items(), key=lambda x: x[1].visits)[0] # Update visualizer visualizer.update_iteration( iteration=iteration + 1, selected_node=selected_node, expanded_node=expanded_node, simulation_path=simulation_path, simulation_result=simulation_result, best_action=best_action ) # Return the action that leads to the child with the highest value if not root.children: possible_actions = self.action_generator(root.state) if possible_actions: best_action = random.choice(possible_actions) if visualizer: visualizer.update_iteration( iteration=self.max_iterations, best_action=best_action ) return best_action return None best_action = max(root.children.items(), key=lambda x: x[1].visits)[0] if visualizer: visualizer.update_iteration( iteration=self.max_iterations, best_action=best_action ) return best_action def _select(self, node: MCTSNode) -> MCTSNode: """ Select a node to expand using UCB1 and progressive widening if enabled. Args: node: The current node Returns: The selected node for expansion """ while not node.is_terminal(): possible_actions = self.action_generator(node.state) # Handle progressive widening if enabled if self.progressive_widening: max_children = max(1, int(self.pw_coef * (node.visits ** self.pw_power))) if len(node.children) < min(max_children, len(possible_actions)): return node # If not fully expanded, select this node for expansion if not node.is_fully_expanded(possible_actions): return node # Otherwise, select the best child according to UCB1 node = node.best_child(self.exploration_weight) if node is None: break return node def _expand(self, node: MCTSNode) -> MCTSNode: """ Expand the node by selecting an untried action and creating a new child node. Args: node: The node to expand Returns: The newly created child node """ possible_actions = self.action_generator(node.state) untried_actions = [a for a in possible_actions if a not in node.children] if not untried_actions: return node action = random.choice(untried_actions) new_state = self.simulator(node.state, action) child_node = MCTSNode( state=new_state, parent=node, action_taken=action ) node.children[action] = child_node return child_node def _simulate(self, node: MCTSNode, depth: int = 10) -> float: """ Simulate a random playout from the given node until a terminal state or max depth. Args: node: The node to start simulation from depth: Maximum simulation depth Returns: The value of the simulated outcome """ state = node.state current_depth = 0 # Continue simulation until we reach a terminal state or max depth while current_depth < depth: if self._is_terminal_state(state): break possible_actions = self.action_generator(state) if not possible_actions: break action = random.choice(possible_actions) state = self.simulator(state, action) current_depth += 1 return self.state_evaluator(state) def _is_terminal_state(self, state: Any) -> bool: """Determine if the state is terminal.""" # This should be customized based on your environment return False def _backpropagate(self, node: MCTSNode, value: float) -> None: """ Backpropagate the simulation result up the tree. Args: node: The leaf node where simulation started value: The value from the simulation """ while node is not None: node.visits += 1 node.value += value # Update RAVE values if enabled if self.use_rave and node.parent is not None: state_hash = self._hash_state(node.parent.state) action = node.action_taken if (state_hash, action) not in self.rave_values: self.rave_values[(state_hash, action)] = [0, 0] # [value, visits] rave_value, rave_visits = self.rave_values[(state_hash, action)] self.rave_values[(state_hash, action)] = [ rave_value + value, rave_visits + 1 ] node = node.parent def _hash_state(self, state: Any) -> int: """Create a hash of the state for RAVE table lookups.""" # This should be customized based on your state representation if hasattr(state, "__hash__"): return hash(state) return hash(str(state)) class MCTSToolSelector: """ Specialized MCTS implementation for selecting optimal tools in Claude Code. This class adapts the AdvancedMCTS for the specific context of tool selection. """ def __init__( self, tool_registry: Any, # Should be a reference to the tool registry context_evaluator: Callable, # Function to evaluate quality of response given context max_iterations: int = 200, exploration_weight: float = 1.0, use_learning: bool = True, tool_history_weight: float = 0.7, enable_plan_generation: bool = True, use_semantic_similarity: bool = True, adaptation_rate: float = 0.05 ): """ Initialize the MCTS tool selector with enhanced intelligence. Args: tool_registry: Registry containing available tools context_evaluator: Function to evaluate response quality max_iterations: Maximum search iterations exploration_weight: Controls exploration vs exploitation use_learning: Whether to use learning from past tool selections tool_history_weight: Weight given to historical tool performance enable_plan_generation: Generate complete tool sequences as plans use_semantic_similarity: Use semantic similarity for tool relevance adaptation_rate: Rate at which the system adapts to new patterns """ self.tool_registry = tool_registry self.context_evaluator = context_evaluator self.use_learning = use_learning self.tool_history_weight = tool_history_weight self.enable_plan_generation = enable_plan_generation self.use_semantic_similarity = use_semantic_similarity self.adaptation_rate = adaptation_rate # Tool performance history by query type self.tool_history = {} # Tool sequence effectiveness records self.sequence_effectiveness = {} # Semantic fingerprints for tools and queries self.tool_fingerprints = {} self.query_clusters = {} # Cached simulation results for similar queries self.simulation_cache = {} # Initialize the MCTS algorithm self.mcts = AdvancedMCTS( state_evaluator=self._evaluate_state, action_generator=self._generate_actions, simulator=self._simulate_action, max_iterations=max_iterations, exploration_weight=exploration_weight, progressive_widening=True ) # Initialize tool fingerprints self._initialize_tool_fingerprints() def _initialize_tool_fingerprints(self): """Initialize semantic fingerprints for all available tools.""" if not self.use_semantic_similarity: return for tool_name in self.tool_registry.get_all_tool_names(): tool = self.tool_registry.get_tool(tool_name) if tool and hasattr(tool, 'description'): # In a real implementation, this would compute an embedding # For now, we'll use a simple keyword extraction as a placeholder keywords = set(word.lower() for word in tool.description.split() if len(word) > 3) self.tool_fingerprints[tool_name] = { 'keywords': keywords, 'description': tool.description, 'usage_contexts': set() } def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> Union[str, List[str]]: """ Select the best tool to use for a given user query and context. Args: user_query: The user's query context: The current conversation context visualizer: Optional visualizer to show the selection process Returns: Either a single tool name or a sequence of tool names (if plan generation is enabled) """ # Analyze query to determine its type/characteristics query_type = self._analyze_query(user_query) # Update semantic fingerprints with this query if self.use_semantic_similarity: self._update_query_clusters(user_query, query_type) initial_state = { 'query': user_query, 'query_type': query_type, 'context': context, 'actions_taken': [], 'response_quality': 0.0, 'steps_remaining': 3 if self.enable_plan_generation else 1, 'step_results': {} } # First check if we have a high-confidence cached result for similar queries cached_result = self._check_cache(user_query, query_type) if cached_result and random.random() > 0.1: # 10% random exploration if visualizer: visualizer.add_execution( execution_id="mcts_cache_hit", tool_name="MCTS Tool Selection (cached)", parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query} ) visualizer.complete_execution( execution_id="mcts_cache_hit", result={"selected_tool": cached_result, "source": "cache"}, status="success" ) return cached_result # Run MCTS search best_action = self.mcts.search(initial_state, visualizer) # If plan generation is enabled, we might want to return a sequence if self.enable_plan_generation: # Extract the most promising action sequence from search plan = self._extract_plan_from_search() if plan and len(plan) > 1: # Store this plan in our cache self._cache_result(user_query, query_type, plan) return plan # Store single action in cache self._cache_result(user_query, query_type, best_action) return best_action def _analyze_query(self, query: str) -> str: """ Analyze a query to determine its type and characteristics. Args: query: The user query Returns: A string identifying the query type """ query_lower = query.lower() # Check for search-related queries if any(term in query_lower for term in ['find', 'search', 'where', 'look for']): return 'search' # Check for explanation queries if any(term in query_lower for term in ['explain', 'how', 'why', 'what is']): return 'explanation' # Check for file operation queries if any(term in query_lower for term in ['file', 'read', 'write', 'edit', 'create']): return 'file_operation' # Check for execution queries if any(term in query_lower for term in ['run', 'execute', 'start']): return 'execution' # Check for debugging queries if any(term in query_lower for term in ['debug', 'fix', 'error', 'problem']): return 'debugging' # Default to general return 'general' def _update_query_clusters(self, query: str, query_type: str): """ Update query clusters with new query information. Args: query: The user query query_type: The type of query """ # Extract query keywords keywords = set(word.lower() for word in query.split() if len(word) > 3) # Update query clusters if query_type not in self.query_clusters: self.query_clusters[query_type] = { 'keywords': set(), 'queries': [] } # Add keywords to cluster self.query_clusters[query_type]['keywords'].update(keywords) # Add query to cluster (limit to last 50) self.query_clusters[query_type]['queries'].append(query) if len(self.query_clusters[query_type]['queries']) > 50: self.query_clusters[query_type]['queries'].pop(0) # Update tool fingerprints with these keywords for tool_name, fingerprint in self.tool_fingerprints.items(): # If tool has been used successfully for this query type before if tool_name in self.tool_history.get(query_type, {}) and \ self.tool_history[query_type][tool_name]['success_rate'] > 0.6: fingerprint['usage_contexts'].add(query_type) def _check_cache(self, query: str, query_type: str) -> Union[str, List[str], None]: """ Check if we have a cached result for a similar query. Args: query: The user query query_type: The type of query Returns: A cached tool selection or None """ if not self.use_learning or query_type not in self.tool_history: return None # Find the most successful tool for this query type type_history = self.tool_history[query_type] best_tools = sorted( [(tool, data['success_rate']) for tool, data in type_history.items()], key=lambda x: x[1], reverse=True ) # Only use cache if we have a high confidence result if best_tools and best_tools[0][1] > 0.75: return best_tools[0][0] return None def _cache_result(self, query: str, query_type: str, action: Union[str, List[str]]): """ Cache a result for future similar queries. Args: query: The user query query_type: The type of query action: The selected action or plan """ # Store in simulation cache query_key = self._get_query_cache_key(query) self.simulation_cache[query_key] = { 'action': action, 'timestamp': self._get_timestamp(), 'query_type': query_type } # Limit cache size if len(self.simulation_cache) > 1000: # Remove oldest entries oldest_key = min(self.simulation_cache.keys(), key=lambda k: self.simulation_cache[k]['timestamp']) del self.simulation_cache[oldest_key] def _get_query_cache_key(self, query: str) -> str: """Generate a cache key for a query.""" # In a real implementation, this might use a hash of query embeddings # For now, use a simple keyword approach keywords = ' '.join(sorted(set(word.lower() for word in query.split() if len(word) > 3))) return keywords[:100] # Limit key length def _get_timestamp(self): """Get current timestamp.""" import time return time.time() def _evaluate_state(self, state: Dict[str, Any]) -> float: """ Evaluate the quality of a state based on response quality and steps. Args: state: The current state Returns: A quality score """ # Base score is the response quality score = state['response_quality'] # If plan generation is enabled, we want to encourage complete plans if self.enable_plan_generation: steps_completed = len(state['actions_taken']) total_steps = steps_completed + state['steps_remaining'] # Add bonus for completing more steps if total_steps > 0: step_completion_bonus = steps_completed / total_steps score += step_completion_bonus * 0.2 # 20% bonus for step completion return score def _generate_actions(self, state: Dict[str, Any]) -> List[str]: """ Generate possible tool actions from the current state with intelligent filtering. Args: state: The current state Returns: List of possible actions """ # Get query type query_type = state['query_type'] query = state['query'] # Get all available tools all_tools = set(self.tool_registry.get_all_tool_names()) # Tools already used in this sequence used_tools = set(state['actions_taken']) # Remaining tools remaining_tools = all_tools - used_tools # If we're using learning, prioritize tools based on history if self.use_learning and query_type in self.tool_history: prioritized_tools = [] # First, add tools that have been successful for this query type type_history = self.tool_history[query_type] # Check for successful tools for tool in remaining_tools: if tool in type_history and type_history[tool]['success_rate'] > 0.5: prioritized_tools.append(tool) # If we have at least some tools, return them if prioritized_tools and random.random() < self.tool_history_weight: return prioritized_tools # If using semantic similarity, filter by relevant tools if self.use_semantic_similarity: query_keywords = set(word.lower() for word in query.split() if len(word) > 3) # Score tools by semantic similarity to query scored_tools = [] for tool in remaining_tools: if tool in self.tool_fingerprints: fingerprint = self.tool_fingerprints[tool] # Calculate keyword overlap keyword_overlap = len(query_keywords.intersection(fingerprint['keywords'])) # Check if tool has been used for this query type context_match = 1.0 if query_type in fingerprint['usage_contexts'] else 0.0 # Combined score score = keyword_overlap * 0.7 + context_match * 0.3 scored_tools.append((tool, score)) # Sort and filter tools scored_tools.sort(key=lambda x: x[1], reverse=True) # Take top half of tools if we have enough if len(scored_tools) > 2: return [t[0] for t in scored_tools[:max(2, len(scored_tools) // 2)]] # If we reach here, use all remaining tools return list(remaining_tools) def _simulate_action(self, state: Dict[str, Any], action: str) -> Dict[str, Any]: """ Simulate taking an action (using a tool) in the given state with enhanced modeling. Args: state: The current state action: The tool action to simulate Returns: The new state after taking the action """ # Create a new state with the action added new_state = state.copy() new_actions = state['actions_taken'].copy() new_actions.append(action) new_state['actions_taken'] = new_actions # Decrement steps remaining if using plan generation if self.enable_plan_generation and new_state['steps_remaining'] > 0: new_state['steps_remaining'] -= 1 # Get query type and query query_type = state['query_type'] query = state['query'] # Simulate step result step_results = state['step_results'].copy() step_results[action] = self._simulate_tool_result(action, query) new_state['step_results'] = step_results # Estimate tool relevance based on learning or semantic similarity tool_relevance = self._estimate_tool_relevance(action, query, query_type) # Check for sequence effects (tools that work well together) sequence_bonus = 0.0 if len(new_actions) > 1: prev_tool = new_actions[-2] sequence_key = f"{prev_tool}->{action}" if sequence_key in self.sequence_effectiveness: sequence_bonus = self.sequence_effectiveness[sequence_key] * 0.3 # 30% weight for sequence effects # Update quality based on relevance and sequence effects current_quality = state['response_quality'] quality_improvement = tool_relevance + sequence_bonus # Add diminishing returns effect for additional tools if len(new_actions) > 1: diminishing_factor = 1.0 / len(new_actions) quality_improvement *= diminishing_factor new_quality = min(1.0, current_quality + quality_improvement) new_state['response_quality'] = new_quality return new_state def _simulate_tool_result(self, tool_name: str, query: str) -> Dict[str, Any]: """ Simulate the result of using a tool for a query. Args: tool_name: The name of the tool query: The user query Returns: A simulated result """ # In a real implementation, this would be a more sophisticated simulation return { "tool": tool_name, "success_probability": self._estimate_tool_relevance(tool_name, query), "simulated": True } def _estimate_tool_relevance(self, tool_name: str, query: str, query_type: str = None) -> float: """ Estimate how relevant a tool is for a given query using history and semantics. Args: tool_name: The name of the tool query: The user query query_type: Optional query type Returns: A relevance score between 0.0 and 1.0 """ relevance_score = 0.0 # If we have historical data for this query type if self.use_learning and query_type and query_type in self.tool_history and \ tool_name in self.tool_history[query_type]: # Get historical success rate history_score = self.tool_history[query_type][tool_name]['success_rate'] relevance_score += history_score * self.tool_history_weight # If we're using semantic similarity if self.use_semantic_similarity and tool_name in self.tool_fingerprints: fingerprint = self.tool_fingerprints[tool_name] # Calculate keyword overlap query_keywords = set(word.lower() for word in query.split() if len(word) > 3) keyword_overlap = len(query_keywords.intersection(fingerprint['keywords'])) # Normalize by query keywords if query_keywords: semantic_score = keyword_overlap / len(query_keywords) relevance_score += semantic_score * (1.0 - self.tool_history_weight) # Ensure we have a minimum score for exploration if relevance_score < 0.1: relevance_score = 0.1 + (random.random() * 0.1) # Random boost between 0.1-0.2 return relevance_score def _extract_plan_from_search(self) -> List[str]: """ Extract a complete plan (tool sequence) from the search results. Returns: A list of tool names representing the plan """ # In a real implementation, this would extract the highest value path # from the search tree. For now, return None to indicate no plan extraction. return None def update_tool_history(self, tool_name: str, query: str, success: bool, execution_time: float, result: Any = None): """ Update the tool history with the results of using a tool. Args: tool_name: The name of the tool used query: The query the tool was used for success: Whether the tool was successful execution_time: The execution time in seconds result: Optional result of the tool execution """ if not self.use_learning: return # Get query type query_type = self._analyze_query(query) # Initialize history entry if needed if query_type not in self.tool_history: self.tool_history[query_type] = {} if tool_name not in self.tool_history[query_type]: self.tool_history[query_type][tool_name] = { 'success_count': 0, 'failure_count': 0, 'total_time': 0.0, 'success_rate': 0.0, 'avg_time': 0.0, 'examples': [] } # Update history history = self.tool_history[query_type][tool_name] # Update counts if success: history['success_count'] += 1 else: history['failure_count'] += 1 # Update time history['total_time'] += execution_time # Update success rate total = history['success_count'] + history['failure_count'] history['success_rate'] = history['success_count'] / total if total > 0 else 0.0 # Update average time history['avg_time'] = history['total_time'] / total if total > 0 else 0.0 # Add example (limit to last 5) history['examples'].append({ 'query': query, 'success': success, 'timestamp': self._get_timestamp() }) if len(history['examples']) > 5: history['examples'].pop(0) # Update tool fingerprint if self.use_semantic_similarity and tool_name in self.tool_fingerprints: if success: # Add query type to usage contexts self.tool_fingerprints[tool_name]['usage_contexts'].add(query_type) # Add query keywords to tool fingerprint (with decay) query_keywords = set(word.lower() for word in query.split() if len(word) > 3) current_keywords = self.tool_fingerprints[tool_name]['keywords'] # Add new keywords with adaptation rate for keyword in query_keywords: if keyword not in current_keywords: if random.random() < self.adaptation_rate: current_keywords.add(keyword) def update_sequence_effectiveness(self, tool_sequence: List[str], success: bool, quality_score: float): """ Update the effectiveness record for a sequence of tools. Args: tool_sequence: The sequence of tools used success: Whether the sequence was successful quality_score: A quality score for the sequence """ if not self.use_learning or len(tool_sequence) < 2: return # Update pairwise effectiveness for i in range(len(tool_sequence) - 1): first_tool = tool_sequence[i] second_tool = tool_sequence[i + 1] sequence_key = f"{first_tool}->{second_tool}" if sequence_key not in self.sequence_effectiveness: self.sequence_effectiveness[sequence_key] = 0.5 # Initial neutral score # Update score with decay current_score = self.sequence_effectiveness[sequence_key] if success: # Increase score with quality bonus new_score = current_score + self.adaptation_rate * quality_score else: # Decrease score new_score = current_score - self.adaptation_rate # Clamp between 0 and 1 self.sequence_effectiveness[sequence_key] = max(0.0, min(1.0, new_score))