"""
TreeSearcher - Hypothesis Tree Search for semantic code discovery.
KEY IDEA: Generate MULTIPLE hypotheses at each step, creating a search tree.
Then explore the most promising branches and prune the rest.
Example for "context data corrupted in goroutine":
Query
│
┌──────────────┼──────────────┬──────────────┐
▼ ▼ ▼ ▼
Hypothesis 1 Hypothesis 2 Hypothesis 3 Hypothesis 4
"Object Pool" "Race Cond" "Ctx Lifetime" "Shallow Copy"
│ │ │ │
┌───┼───┐ ┌───┼───┐ ┌───┼───┐ ┌───┼───┐
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
patterns... patterns... patterns... patterns...
│ │ │ │
└──────────────┴──────────────┴──────────────┘
│
RANK & MERGE
│
Top K results
This is like Beam Search but for code exploration.
"""
from __future__ import annotations
from dotenv import load_dotenv
load_dotenv()
import json
import logging
import os
import re
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from .base import BaseSearcher, SearchItem, SearchResult
logger = logging.getLogger(__name__)
# =============================================================================
# Hypothesis Generation
# =============================================================================
class Hypothesis(BaseModel):
"""A single hypothesis about what might cause the problem."""
name: str = Field(description="Short name for this hypothesis")
explanation: str = Field(description="Why this could be the cause")
confidence: float = Field(default=0.5, description="0.0-1.0, how likely this hypothesis is")
search_patterns: List[str] = Field(default_factory=list, description="Specific patterns to search for")
expected_file_hints: List[str] = Field(default_factory=list, description="File/dir names that would contain this")
class HypothesisSet(BaseModel):
"""Multiple hypotheses generated from a problem."""
hypotheses: List[Hypothesis] = Field(description="3-5 different hypotheses")
HYPOTHESIS_GENERATION_PROMPT = """You are analyzing a developer's problem to generate MULTIPLE HYPOTHESES.
PROBLEM:
"{query}"
Generate 3-5 DIFFERENT hypotheses about what could cause this problem.
Each hypothesis should explore a DIFFERENT root cause.
For each hypothesis:
1. Give it a short descriptive name
2. Explain WHY this could be the cause
3. Rate your confidence (0.0-1.0)
4. List 5-10 specific search patterns to find code related to this hypothesis
5. List file/directory name hints where this code might be
IMPORTANT - Generate DIVERSE hypotheses:
- Don't just vary the wording, vary the UNDERLYING CAUSE
- Consider: architecture issues, lifecycle issues, concurrency, configuration, etc.
- Think about different LAYERS: framework, application, infrastructure
Example hypotheses for "data corruption in concurrent access":
1. Object Pooling - objects are reused without proper reset
2. Race Condition - no synchronization on shared data
3. Reference Sharing - shallow copy instead of deep copy
4. Lifecycle Mismatch - object outlives its valid scope
5. Cache Invalidation - stale data served from cache
For SEARCH PATTERNS - IMPORTANT RULES:
1. Use PLAIN TEXT only - NO regex characters like *, \(, \), etc.
2. Patterns are searched with literal string matching
3. Include SHORT patterns (5-15 chars) that appear in code
4. Include both specific terms and generic ones
GOOD patterns: "sync.Pool", "pool.Get", "context.Copy", "ServeHTTP", "reset()"
BAD patterns: "func.*context", "http\.Handler", "go func\("
Mix of pattern types:
- Specific: "sync.Pool", "pool.Get()", "context.Value"
- Generic: "pool", "reset", "copy", "handler"
- Framework: "Context", "Handler", "Middleware" """
# =============================================================================
# Tree Node
# =============================================================================
@dataclass
class TreeNode:
"""A node in the hypothesis search tree."""
hypothesis: Optional[Hypothesis]
parent: Optional['TreeNode']
children: List['TreeNode'] = field(default_factory=list)
snippets: List[Dict] = field(default_factory=list)
score: float = 0.0
depth: int = 0
explored: bool = False
@property
def path(self) -> str:
"""Get the path from root to this node."""
if self.parent is None:
return "root"
parent_path = self.parent.path
name = self.hypothesis.name if self.hypothesis else "?"
return f"{parent_path} → {name}"
# =============================================================================
# Main Tree Searcher
# =============================================================================
class TreeSearcher(BaseSearcher):
"""
Hypothesis Tree Search for semantic code discovery.
Algorithm:
1. Generate N hypotheses from the problem
2. For each hypothesis, search for patterns (parallel)
3. Score each hypothesis branch by results quality
4. Optionally: expand best branches with sub-hypotheses
5. Merge and rank all results
"""
def __init__(
self,
model: str = "claude-sonnet-4-20250514",
num_hypotheses: int = 4,
max_depth: int = 2,
beam_width: int = 3, # Keep top N branches at each level
patterns_per_hypothesis: int = 8,
max_results: int = 10,
context_lines: int = 40,
parallel_searches: int = 4,
verbose: bool = False,
):
self.model = model
self.num_hypotheses = num_hypotheses
self.max_depth = max_depth
self.beam_width = beam_width
self.patterns_per_hypothesis = patterns_per_hypothesis
self.max_results = max_results
self.context_lines = context_lines
self.parallel_searches = parallel_searches
self.verbose = verbose
api_key = os.getenv("CLAUDE_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("CLAUDE_API_KEY or ANTHROPIC_API_KEY must be set")
self.llm = ChatAnthropic(
model=model,
api_key=api_key,
max_tokens=4096,
)
self._searched_patterns: Set[str] = set()
@property
def name(self) -> str:
return f"TreeSearcher ({self.model})"
def _log(self, msg: str, indent: int = 0) -> None:
if self.verbose:
prefix = " " * indent
print(f"[Tree] {prefix}{msg}")
logger.debug(msg)
# =========================================================================
# Hypothesis Generation
# =========================================================================
def _generate_hypotheses(
self,
query: str,
context: Optional[str] = None,
) -> List[Hypothesis]:
"""Generate multiple hypotheses for the problem."""
prompt_text = HYPOTHESIS_GENERATION_PROMPT
if context:
prompt_text += f"\n\nCONTEXT FROM PREVIOUS SEARCH:\n{context}"
prompt = ChatPromptTemplate.from_messages([
("system", prompt_text),
("human", "Generate diverse hypotheses for this problem."),
])
chain = prompt | self.llm.with_structured_output(HypothesisSet)
result = chain.invoke({"query": query})
# Sort by confidence
hypotheses = sorted(result.hypotheses, key=lambda h: -h.confidence)
self._log(f"Generated {len(hypotheses)} hypotheses:")
for h in hypotheses:
self._log(f" [{h.confidence:.1f}] {h.name}: {h.explanation[:50]}...", 1)
return hypotheses[:self.num_hypotheses]
# =========================================================================
# Search Execution
# =========================================================================
def _search_pattern(
self,
pattern: str,
repo_path: str,
) -> List[Dict]:
"""Search for a single pattern."""
if pattern in self._searched_patterns:
return []
if len(pattern) < 4:
return []
self._searched_patterns.add(pattern)
cmd = [
"rg", "-F", "-i", "-n",
"-C", str(self.context_lines),
"-m", "5",
"--max-count", "20",
"--json",
pattern,
repo_path,
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode not in (0, 1):
return []
return self._parse_ripgrep(result.stdout, repo_path, pattern)
except (subprocess.TimeoutExpired, FileNotFoundError):
return []
def _parse_ripgrep(
self,
output: str,
repo_path: str,
pattern: str,
) -> List[Dict]:
"""Parse ripgrep JSON output."""
file_lines: Dict[str, List[Tuple[int, str]]] = {}
for line in output.split('\n'):
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if data.get('type') not in ('match', 'context'):
continue
msg = data.get('data', {})
file_path = msg.get('path', {}).get('text', '')
line_text = msg.get('lines', {}).get('text', '').rstrip('\n')
line_num = msg.get('line_number', 0)
if not file_path or not line_num:
continue
# Skip tests and generated files
if '_test.' in file_path or '/test/' in file_path.lower():
continue
if '/vendor/' in file_path or '/node_modules/' in file_path:
continue
try:
rel_path = str(Path(file_path).relative_to(repo_path))
except ValueError:
rel_path = file_path
if rel_path not in file_lines:
file_lines[rel_path] = []
file_lines[rel_path].append((line_num, line_text))
# Convert to snippets
snippets = []
for file_path, lines in file_lines.items():
lines.sort(key=lambda x: x[0])
# Group consecutive lines
groups = []
current = []
for line_num, content in lines:
if not current:
current = [(line_num, content)]
elif line_num <= current[-1][0] + 3:
current.append((line_num, content))
else:
groups.append(current)
current = [(line_num, content)]
if current:
groups.append(current)
for group in groups[:3]: # Max 3 snippets per file per pattern
content = '\n'.join(l for _, l in group)
snippets.append({
"file_path": file_path,
"content": content,
"line_start": group[0][0],
"line_end": group[-1][0],
"pattern": pattern,
})
return snippets
def _explore_hypothesis(
self,
hypothesis: Hypothesis,
repo_path: str,
) -> Tuple[List[Dict], float]:
"""
Explore a single hypothesis by searching its patterns.
Returns (snippets, score).
"""
all_snippets = []
patterns = hypothesis.search_patterns[:self.patterns_per_hypothesis]
self._log(f"Patterns for '{hypothesis.name}':", 2)
for p in patterns[:5]:
self._log(f" - {p}", 3)
# Search patterns in parallel
with ThreadPoolExecutor(max_workers=self.parallel_searches) as executor:
futures = {
executor.submit(self._search_pattern, p, repo_path): p
for p in patterns
}
for future in as_completed(futures):
pattern = futures[future]
try:
snippets = future.result()
if snippets:
self._log(f"'{pattern}': {len(snippets)} snippets", 2)
all_snippets.extend(snippets)
except Exception as e:
self._log(f"Error searching '{pattern}': {e}", 2)
# Calculate score for this hypothesis branch
score = self._score_hypothesis(hypothesis, all_snippets)
return all_snippets, score
def _score_hypothesis(
self,
hypothesis: Hypothesis,
snippets: List[Dict],
) -> float:
"""
Score how well the search results support this hypothesis.
Higher score = more relevant results found.
"""
if not snippets:
return 0.0
score = 0.0
# Base score from number of unique files
unique_files = set(s["file_path"] for s in snippets)
score += min(len(unique_files) * 2, 10)
# Bonus for matching file hints
for s in snippets:
for hint in hypothesis.expected_file_hints:
if hint.lower() in s["file_path"].lower():
score += 1
# Bonus for finding multiple patterns
unique_patterns = set(s.get("pattern", "") for s in snippets)
score += len(unique_patterns) * 0.5
# Weighted by hypothesis confidence
score *= hypothesis.confidence
# Penalty for too many results (might be too generic)
if len(snippets) > 50:
score *= 0.7
return score
# =========================================================================
# Tree Search
# =========================================================================
def _build_and_explore_tree(
self,
query: str,
repo_path: str,
) -> TreeNode:
"""
Build the hypothesis tree and explore branches.
"""
# Create root node
root = TreeNode(hypothesis=None, parent=None, depth=0)
# Generate initial hypotheses
self._log("=== Generating hypotheses ===")
hypotheses = self._generate_hypotheses(query)
# Create child nodes for each hypothesis
for h in hypotheses:
child = TreeNode(hypothesis=h, parent=root, depth=1)
root.children.append(child)
# Explore tree level by level (BFS with beam search)
current_level = root.children
for depth in range(1, self.max_depth + 1):
self._log(f"\n=== Exploring depth {depth} ({len(current_level)} nodes) ===")
# Explore all nodes at current level
for node in current_level:
if node.explored:
continue
self._log(f"Exploring: {node.hypothesis.name}", 1)
snippets, score = self._explore_hypothesis(node.hypothesis, repo_path)
node.snippets = snippets
node.score = score
node.explored = True
self._log(f"Score: {score:.1f}, Snippets: {len(snippets)}", 1)
# Rank nodes and keep top beam_width
current_level.sort(key=lambda n: -n.score)
best_nodes = current_level[:self.beam_width]
self._log(f"\nTop {self.beam_width} branches:")
for node in best_nodes:
self._log(f" [{node.score:.1f}] {node.hypothesis.name}", 1)
# Generate sub-hypotheses for next level (if not at max depth)
if depth < self.max_depth:
next_level = []
for node in best_nodes:
if node.score > 0 and node.snippets:
# Generate refined hypotheses based on what we found
context = self._summarize_findings(node)
sub_hypotheses = self._generate_hypotheses(
query,
context=context
)
for h in sub_hypotheses[:2]: # Max 2 sub-hypotheses
child = TreeNode(
hypothesis=h,
parent=node,
depth=depth + 1
)
node.children.append(child)
next_level.append(child)
current_level = next_level
if not current_level:
self._log("No more branches to explore")
break
return root
def _summarize_findings(self, node: TreeNode) -> str:
"""Summarize what was found in a node for generating sub-hypotheses."""
if not node.snippets:
return ""
files = list(set(s["file_path"] for s in node.snippets))[:5]
patterns = list(set(s.get("pattern", "") for s in node.snippets))[:5]
# Sample code snippets - escape braces for langchain template
sample = node.snippets[0]["content"][:500] if node.snippets else ""
sample = sample.replace("{", "{{").replace("}", "}}")
return f"""
Previous hypothesis: {node.hypothesis.name}
Files found: {', '.join(files)}
Patterns matched: {', '.join(patterns)}
Sample code:
{sample}
"""
# =========================================================================
# Result Collection
# =========================================================================
def _collect_all_snippets(self, root: TreeNode) -> List[Dict]:
"""Collect all snippets from the tree with their path scores."""
all_snippets = []
def traverse(node: TreeNode, path_score: float):
# Add snippets from this node
for s in node.snippets:
s["tree_path"] = node.path
s["path_score"] = path_score + node.score
all_snippets.append(s)
# Traverse children
for child in node.children:
traverse(child, path_score + node.score)
traverse(root, 0)
return all_snippets
def _merge_and_rank(
self,
snippets: List[Dict],
) -> List[Dict]:
"""Merge overlapping snippets and rank by relevance."""
if not snippets:
return []
# Group by file
by_file: Dict[str, List[Dict]] = {}
for s in snippets:
fp = s["file_path"]
if fp not in by_file:
by_file[fp] = []
by_file[fp].append(s)
merged = []
for file_path, file_snippets in by_file.items():
# Sort by line number
file_snippets.sort(key=lambda x: x["line_start"])
# Merge overlapping
current = file_snippets[0].copy()
current["patterns"] = [current.get("pattern", "")]
current["path_scores"] = [current.get("path_score", 0)]
for next_s in file_snippets[1:]:
if next_s["line_start"] <= current["line_end"] + 20:
# Merge
current["content"] += "\n...\n" + next_s["content"]
current["line_end"] = max(current["line_end"], next_s["line_end"])
current["patterns"].append(next_s.get("pattern", ""))
current["path_scores"].append(next_s.get("path_score", 0))
else:
merged.append(current)
current = next_s.copy()
current["patterns"] = [current.get("pattern", "")]
current["path_scores"] = [current.get("path_score", 0)]
merged.append(current)
# Score merged snippets
def score(s: Dict) -> float:
sc = 0.0
# Path score from tree exploration
sc += max(s.get("path_scores", [0])) * 2
# Number of patterns matched
sc += len(set(s.get("patterns", []))) * 3
# File type preference - STRONG preference for source code
fp = s["file_path"].lower()
if fp.endswith(('.go', '.py', '.ts', '.js', '.cpp', '.java', '.rs')):
sc += 10 # Strong bonus for source code
elif fp.endswith(('.qll', '.ql')): # CodeQL
sc += 10
# Prefer core directories
if any(d in fp for d in ['/core/', '/src/', '/lib/', '/internal/', '/pkg/']):
sc += 5
# Handler/middleware specific
if any(d in fp for d in ['/handler', '/middleware', '/util']):
sc += 3
# STRONG penalty for docs/releases
if '/doc' in fp or '/release' in fp:
sc -= 20
if fp.endswith(('.md', '.txt', '.rst')):
sc -= 15
# Penalize test files
if '_test.' in fp or '/test/' in fp:
sc -= 10
return sc
merged.sort(key=score, reverse=True)
return merged[:self.max_results]
# =========================================================================
# Main Search
# =========================================================================
def search(
self,
query: str,
repo_path: str,
path: Optional[str] = None,
) -> SearchResult:
"""Perform hypothesis tree search."""
start_time = time.time()
self._searched_patterns.clear()
try:
repo_path = os.path.abspath(repo_path)
if path:
repo_path = os.path.join(repo_path, path)
self._log(f"Starting tree search in: {repo_path}")
self._log(f"Query: {query[:100]}...")
# Build and explore hypothesis tree
root = self._build_and_explore_tree(query, repo_path)
# Collect all snippets from tree
self._log("\n=== Collecting results ===")
all_snippets = self._collect_all_snippets(root)
self._log(f"Total snippets from tree: {len(all_snippets)}")
# Merge and rank
final = self._merge_and_rank(all_snippets)
self._log(f"Final results: {len(final)}")
for s in final[:5]:
self._log(f" {s['file_path']}:{s['line_start']}-{s['line_end']}", 1)
# Convert to SearchResult
items = [
SearchItem(
file_path=s["file_path"],
content=s["content"],
line_start=s["line_start"],
line_end=s["line_end"],
match_context=f"Patterns: {', '.join(list(set(s.get('patterns', [])))[:3])}",
)
for s in final
]
total_time = (time.time() - start_time) * 1000
return SearchResult(
items=items,
patterns_used=list(self._searched_patterns)[:20],
execution_time_ms=total_time,
total_time_ms=total_time,
)
except Exception as e:
total_time = (time.time() - start_time) * 1000
self._log(f"Error: {e}")
import traceback
self._log(traceback.format_exc())
return SearchResult(
items=[],
execution_time_ms=total_time,
total_time_ms=total_time,
error=str(e),
)
class TreeSearcherVerbose(TreeSearcher):
"""TreeSearcher with verbose logging enabled."""
def __init__(self, **kwargs):
kwargs["verbose"] = True
super().__init__(**kwargs)