"""
ReasoningSearcher - A multi-stage semantic search with Chain-of-Thought reasoning.
Key improvements over existing searchers:
1. REASONING STAGE: LLM analyzes the problem to understand root causes
2. CONCEPT EXTRACTION: Derives technical concepts from symptoms
3. PATTERN EXPANSION: Generates multiple pattern variations
4. CONTEXT EXPANSION: Returns full functions/classes, not fragments
5. MULTI-HOP SEARCH: Iteratively refines based on findings
This approach addresses the "semantic gap" problem where users describe
symptoms but code contains solutions in different terminology.
"""
from __future__ import annotations
from dotenv import load_dotenv
load_dotenv()
import json
import logging
import os
import re
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from .base import BaseSearcher, SearchItem, SearchResult
logger = logging.getLogger(__name__)
# =============================================================================
# Stage 1: Reasoning & Concept Extraction
# =============================================================================
class TechnicalConcept(BaseModel):
"""A technical concept derived from the problem description."""
concept: str = Field(description="The technical concept or pattern name")
explanation: str = Field(description="Why this concept is relevant to the problem")
search_terms: List[str] = Field(
description="Specific code patterns/names to search for this concept"
)
class ProblemAnalysis(BaseModel):
"""Result of analyzing the user's problem."""
problem_summary: str = Field(
description="One sentence summary of what the user is experiencing"
)
likely_root_causes: List[str] = Field(
description="Technical root causes that could explain this behavior"
)
technical_concepts: List[TechnicalConcept] = Field(
description="Technical concepts and patterns to search for"
)
file_name_hints: List[str] = Field(
description="Likely file names or directory names (e.g., 'handler', 'context', 'pool')"
)
code_structure_hints: List[str] = Field(
description="Expected code structures (e.g., 'class with reset method', 'function that returns copy')"
)
REASONING_PROMPT = """You are an expert software engineer analyzing a developer's problem.
TASK: Analyze this problem and identify what technical concepts/patterns in the codebase
would explain the behavior described.
PROBLEM:
"{query}"
Think step by step:
1. What SYMPTOMS is the user describing?
2. What TECHNICAL ROOT CAUSES could explain these symptoms?
3. What CODE PATTERNS/CONCEPTS would implement those causes?
4. What SPECIFIC CODE TERMS would appear in the source code?
IMPORTANT:
- Think about the IMPLEMENTATION, not just the surface-level keywords
- Consider common patterns: pooling, caching, middleware chains, lazy evaluation
- The user's words (like "corrupted", "missing") won't appear in code
- Code uses technical terms: sync.Pool, reset(), Copy(), _middleware_chain
For each concept, provide VERY SPECIFIC search terms like:
- "func (c *Context) Copy" (exact method signature)
- "pool.Get()" (exact method call)
- "func.*ServeHTTP" (function definition pattern)
- "type Engine struct" (type definition)
Return a structured analysis with:
- Technical concepts that could cause this behavior
- SPECIFIC code patterns to search for (method names, function signatures)
- File/directory name hints"""
# =============================================================================
# Stage 2: Pattern Generation from Concepts
# =============================================================================
class SearchPattern(BaseModel):
"""A search pattern to execute."""
pattern: str = Field(description="Plain text pattern for ripgrep -F -i")
priority: int = Field(description="1-10, higher = more likely to find answer")
concept: str = Field(description="Which concept this pattern searches for")
class PatternSet(BaseModel):
"""Set of patterns generated from concepts."""
patterns: List[SearchPattern] = Field(description="Patterns ordered by priority")
PATTERN_GENERATION_PROMPT = """Based on this problem analysis, generate search patterns.
ANALYSIS:
{analysis}
REPOSITORY CONTEXT:
{repo_context}
Generate 20-30 search patterns that will find code related to these concepts.
RULES:
1. Use PLAIN TEXT only (no regex). Patterns used with `rg -F -i`
2. Include SPECIFIC patterns for:
- Method signatures: "func (c *Context) Copy"
- Method calls: "pool.Get()", "pool.Put("
- Type definitions: "type Context struct"
- Important function names: "ServeHTTP", "reset()"
3. Include BOTH:
- Implementation patterns (where the behavior IS)
- Connection patterns (what CALLS the implementation)
4. For object pooling issues, ALWAYS include:
- "pool.Get"
- "pool.Put"
- "sync.Pool"
- "func.*reset"
- "func.*Copy"
5. For middleware/request lifecycle:
- "ServeHTTP"
- "HandleRequest"
- "middleware"
PRIORITY ORDER:
1. Exact method/function signatures (priority 10)
2. Method calls with parentheses (priority 9)
3. Type/class definitions (priority 8)
4. General concept keywords (priority 5-7)"""
# =============================================================================
# Stage 3: Context Expansion
# =============================================================================
def find_function_boundaries(
content: str,
match_line: int,
language: str = "auto",
) -> Tuple[int, int]:
"""
Find the start and end of the function/method containing a match.
Returns (start_line, end_line) 1-indexed.
"""
lines = content.split('\n')
# Detect language from content if auto
if language == "auto":
if "func " in content or "package " in content:
language = "go"
elif "def " in content or "class " in content:
language = "python"
elif "function " in content or "const " in content or "=>" in content:
language = "js"
else:
language = "generic"
# Find function start (go backwards)
start_line = match_line
brace_count = 0
found_start = False
for i in range(match_line - 1, -1, -1):
line = lines[i] if i < len(lines) else ""
# Count braces going backwards
brace_count += line.count('}') - line.count('{')
# Check for function definition
if language == "go":
if re.match(r'^func\s', line) or re.match(r'^func\s*\(', line):
start_line = i + 1
found_start = True
break
elif language == "python":
if re.match(r'^(async\s+)?def\s+\w+|^class\s+\w+', line):
start_line = i + 1
found_start = True
break
elif language in ("js", "ts"):
if re.match(r'^(async\s+)?function\s+\w+|^\w+\s*[=:]\s*(async\s+)?\(', line):
start_line = i + 1
found_start = True
break
else:
# Generic: look for function-like patterns
if re.match(r'^(func|def|function|public|private|protected)\s', line):
start_line = i + 1
found_start = True
break
if not found_start:
# Fallback: expand 30 lines up
start_line = max(1, match_line - 30)
# Find function end (go forwards)
end_line = match_line
brace_count = 0
in_function = False
for i in range(start_line - 1, len(lines)):
line = lines[i]
if '{' in line:
in_function = True
brace_count += line.count('{') - line.count('}')
if in_function and brace_count <= 0:
end_line = i + 1
break
else:
# Fallback: expand 50 lines down
end_line = min(len(lines), match_line + 50)
# For Python (indentation-based)
if language == "python" and not found_start:
# Find by indentation
match_indent = len(lines[match_line - 1]) - len(lines[match_line - 1].lstrip())
for i in range(match_line - 1, -1, -1):
line = lines[i]
if line.strip() and not line.strip().startswith('#'):
indent = len(line) - len(line.lstrip())
if indent < match_indent and re.match(r'\s*(def|class|async def)\s', line):
start_line = i + 1
break
for i in range(match_line, len(lines)):
line = lines[i]
if line.strip() and not line.strip().startswith('#'):
indent = len(line) - len(line.lstrip())
if indent <= (len(lines[start_line - 1]) - len(lines[start_line - 1].lstrip())):
if i > match_line:
end_line = i
break
return start_line, end_line
# =============================================================================
# Main Searcher Implementation
# =============================================================================
@dataclass
class RawSnippet:
"""A raw code snippet from ripgrep."""
file_path: str
lines: List[Tuple[int, str]]
matched_patterns: List[str]
matched_concepts: Set[str] = field(default_factory=set)
@property
def line_start(self) -> int:
return self.lines[0][0] if self.lines else 0
@property
def line_end(self) -> int:
return self.lines[-1][0] if self.lines else 0
@property
def content(self) -> str:
return "\n".join(line for _, line in self.lines)
class ReasoningSearcher(BaseSearcher):
"""
Multi-stage semantic searcher with Chain-of-Thought reasoning.
Flow:
1. REASON: Analyze problem → derive technical concepts
2. GENERATE: Create search patterns from concepts
3. SEARCH: Execute patterns, expand context to full functions
4. REFINE: (Optional) Second pass based on findings
"""
def __init__(
self,
model: str = "claude-sonnet-4-20250514",
max_patterns: int = 25,
context_lines: int = 10, # Initial context, will be expanded
max_results: int = 10,
enable_refinement: bool = True,
expand_to_functions: bool = True,
verbose: bool = False,
):
self.model = model
self.max_patterns = max_patterns
self.context_lines = context_lines
self.max_results = max_results
self.enable_refinement = enable_refinement
self.expand_to_functions = expand_to_functions
self.verbose = verbose
api_key = os.getenv("CLAUDE_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("CLAUDE_API_KEY or ANTHROPIC_API_KEY must be set")
self.llm = ChatAnthropic(
model=model,
api_key=api_key,
max_tokens=4096,
)
@property
def name(self) -> str:
return f"ReasoningSearcher ({self.model})"
def _log(self, msg: str) -> None:
if self.verbose:
print(f"[Reasoning] {msg}")
logger.debug(msg)
# =========================================================================
# Stage 1: Problem Analysis
# =========================================================================
def _analyze_problem(self, query: str) -> ProblemAnalysis:
"""Use Chain-of-Thought to analyze the problem."""
prompt = ChatPromptTemplate.from_messages([
("system", REASONING_PROMPT),
("human", "Analyze this problem and identify technical concepts to search for."),
])
chain = prompt | self.llm.with_structured_output(ProblemAnalysis)
result = chain.invoke({"query": query})
self._log(f"Problem summary: {result.problem_summary}")
self._log(f"Root causes: {result.likely_root_causes}")
self._log(f"Concepts: {[c.concept for c in result.technical_concepts]}")
return result
# =========================================================================
# Stage 2: Pattern Generation
# =========================================================================
def _get_repo_context(self, repo_path: str) -> str:
"""Get basic repo context for pattern generation."""
repo_path = Path(repo_path)
context_parts = [f"Repository: {repo_path.name}"]
# Detect languages
try:
result = subprocess.run(
["find", str(repo_path), "-type", "f", "-name", "*.py", "-o",
"-name", "*.go", "-o", "-name", "*.ts", "-o", "-name", "*.js"],
capture_output=True, text=True, timeout=10
)
files = result.stdout.strip().split('\n')[:100]
ext_counts: Dict[str, int] = {}
for f in files:
if f:
ext = Path(f).suffix
ext_counts[ext] = ext_counts.get(ext, 0) + 1
top_langs = sorted(ext_counts.items(), key=lambda x: -x[1])[:3]
if top_langs:
context_parts.append(f"Main languages: {', '.join(e for e, _ in top_langs)}")
except:
pass
# Key directories
try:
dirs = []
for item in repo_path.iterdir():
if item.is_dir() and not item.name.startswith('.'):
if item.name in ('src', 'lib', 'core', 'pkg', 'internal', 'app'):
dirs.append(item.name)
if dirs:
context_parts.append(f"Key directories: {', '.join(dirs)}")
except:
pass
return "\n".join(context_parts)
def _generate_patterns(
self,
analysis: ProblemAnalysis,
repo_context: str,
) -> List[SearchPattern]:
"""Generate search patterns from problem analysis."""
# First, extract patterns directly from concepts
direct_patterns = []
for concept in analysis.technical_concepts:
for term in concept.search_terms:
if len(term) >= 4:
direct_patterns.append(SearchPattern(
pattern=term,
priority=8,
concept=concept.concept,
))
# Add hardcoded high-priority patterns for known problem types
problem_lower = analysis.problem_summary.lower()
root_causes_text = " ".join(analysis.likely_root_causes).lower()
# Object pooling / context reuse patterns
if any(kw in problem_lower + root_causes_text for kw in
['pool', 'reuse', 'goroutine', 'concurrent', 'corrupt', 'different request']):
pooling_patterns = [
("pool.Get()", 10, "Object Pooling"),
("pool.Put(", 10, "Object Pooling"),
("sync.Pool", 10, "Object Pooling"),
("ServeHTTP", 9, "Request Lifecycle"),
("func (c *Context) Copy", 10, "Safe Copy"),
("func (c *Context) reset", 10, "Object Reset"),
("c.reset()", 9, "Object Reset"),
("Copy()", 8, "Safe Copy"),
]
for pattern, priority, concept in pooling_patterns:
direct_patterns.append(SearchPattern(
pattern=pattern, priority=priority, concept=concept
))
# Middleware / request lifecycle patterns
if any(kw in problem_lower + root_causes_text for kw in
['middleware', 'lifecycle', 'request', 'response', 'handler', 'header']):
middleware_patterns = [
("_middleware_chain", 9, "Middleware Chain"),
("process_request", 8, "Middleware"),
("process_response", 8, "Middleware"),
("get_response", 8, "Request Handling"),
("HandleHTTPRequest", 9, "Request Handling"),
# Django-specific
("class WSGIHandler", 10, "WSGI Handler"),
("class BaseHandler", 10, "Base Handler"),
("convert_exception_to_response", 9, "Exception Handling"),
("start_response", 8, "WSGI"),
("class MiddlewareMixin", 9, "Middleware Mixin"),
("load_middleware", 8, "Middleware Loading"),
("__call__(self, environ", 9, "WSGI Entry Point"),
]
for pattern, priority, concept in middleware_patterns:
direct_patterns.append(SearchPattern(
pattern=pattern, priority=priority, concept=concept
))
# Add file name hints as patterns
for hint in analysis.file_name_hints:
if len(hint) >= 3:
direct_patterns.append(SearchPattern(
pattern=hint,
priority=5,
concept="file_name",
))
# Use LLM to generate additional patterns
prompt = ChatPromptTemplate.from_messages([
("system", PATTERN_GENERATION_PROMPT),
("human", "Generate search patterns for these concepts."),
])
chain = prompt | self.llm.with_structured_output(PatternSet)
analysis_text = f"""
Problem: {analysis.problem_summary}
Root Causes: {', '.join(analysis.likely_root_causes)}
Concepts:
{chr(10).join(f'- {c.concept}: {c.explanation}' for c in analysis.technical_concepts)}
Code Structure Hints: {', '.join(analysis.code_structure_hints)}
"""
result = chain.invoke({
"analysis": analysis_text,
"repo_context": repo_context,
})
# Combine and deduplicate
all_patterns = direct_patterns + result.patterns
seen = set()
unique_patterns = []
for p in sorted(all_patterns, key=lambda x: -x.priority):
pattern_lower = p.pattern.lower()
if pattern_lower not in seen and len(p.pattern) >= 4:
# Skip regex-like patterns
if not any(c in p.pattern for c in '*+?[]()|\^$'):
seen.add(pattern_lower)
unique_patterns.append(p)
self._log(f"Generated {len(unique_patterns)} unique patterns")
return unique_patterns[:self.max_patterns]
# =========================================================================
# Stage 3: Search Execution with Context Expansion
# =========================================================================
def _run_ripgrep(
self,
pattern: SearchPattern,
repo_path: str,
) -> List[RawSnippet]:
"""Execute ripgrep and optionally expand to full functions."""
cmd = [
"rg", "-F", "-i", "-n",
"-C", str(self.context_lines),
"--json", "-m", "20",
pattern.pattern,
repo_path,
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode not in (0, 1):
return []
snippets = self._parse_output(result.stdout, repo_path, pattern)
# Expand to full functions if enabled
if self.expand_to_functions:
snippets = self._expand_snippets_to_functions(snippets, repo_path)
return snippets
except subprocess.TimeoutExpired:
return []
except FileNotFoundError:
raise RuntimeError("ripgrep not installed")
def _parse_output(
self,
output: str,
repo_path: str,
pattern: SearchPattern,
) -> List[RawSnippet]:
"""Parse ripgrep JSON output."""
file_lines: Dict[str, List[Tuple[int, str]]] = {}
for line in output.split('\n'):
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if data.get('type') not in ('match', 'context'):
continue
msg = data.get('data', {})
file_path = msg.get('path', {}).get('text', '')
line_text = msg.get('lines', {}).get('text', '').rstrip('\n')
line_num = msg.get('line_number', 0)
if not file_path or not line_num:
continue
try:
rel_path = str(Path(file_path).relative_to(repo_path))
except ValueError:
rel_path = file_path
if rel_path not in file_lines:
file_lines[rel_path] = []
file_lines[rel_path].append((line_num, line_text))
snippets = []
for file_path, lines in file_lines.items():
lines.sort(key=lambda x: x[0])
# Merge nearby lines
current = []
for line_num, content in lines:
if not current:
current.append((line_num, content))
elif line_num <= current[-1][0] + 5:
current.append((line_num, content))
else:
snippet = RawSnippet(
file_path=file_path,
lines=current,
matched_patterns=[pattern.pattern],
matched_concepts={pattern.concept},
)
snippets.append(snippet)
current = [(line_num, content)]
if current:
snippet = RawSnippet(
file_path=file_path,
lines=current,
matched_patterns=[pattern.pattern],
matched_concepts={pattern.concept},
)
snippets.append(snippet)
return snippets
def _expand_snippets_to_functions(
self,
snippets: List[RawSnippet],
repo_path: str,
) -> List[RawSnippet]:
"""Expand snippets to include full function/method definitions.
Key: preserve the ORIGINAL match location, just expand context around it.
"""
expanded = []
for snippet in snippets:
try:
file_path = Path(repo_path) / snippet.file_path
with open(file_path, 'r', errors='ignore') as f:
content = f.read()
lines = content.split('\n')
# Use the START of the snippet as anchor (not middle)
# This preserves where the actual match is
match_line = snippet.line_start
# Try to find function boundaries
start, end = find_function_boundaries(content, match_line)
# Validation: the expanded range MUST include original match PLUS buffer
# Add buffer to catch adjacent important code
if start > snippet.line_start or end < snippet.line_end:
# Function boundary detection failed, use simple expansion
start = max(1, snippet.line_start - 25)
end = min(len(lines), snippet.line_end + 40)
else:
# Function boundaries found - add small buffer just in case
start = max(1, start - 5)
end = min(len(lines), end + 10)
# Ensure reasonable size
if end - start > 120:
# Too big - be more conservative, center on original snippet
center = (snippet.line_start + snippet.line_end) // 2
start = max(1, center - 50)
end = min(len(lines), center + 50)
# Extract lines
expanded_lines = [
(i + 1, lines[i])
for i in range(start - 1, min(end, len(lines)))
]
expanded.append(RawSnippet(
file_path=snippet.file_path,
lines=expanded_lines,
matched_patterns=snippet.matched_patterns,
matched_concepts=snippet.matched_concepts,
))
except Exception as e:
self._log(f"Failed to expand {snippet.file_path}: {e}")
expanded.append(snippet)
return expanded
# =========================================================================
# Stage 4: Merge and Rank
# =========================================================================
def _merge_snippets(self, snippets: List[RawSnippet]) -> List[RawSnippet]:
"""Merge overlapping snippets from same file.
Key insight: Only merge snippets that are TRULY overlapping (within ~20 lines).
Don't merge distant parts of the same file - they might be different functions.
"""
if not snippets:
return []
by_file: Dict[str, List[RawSnippet]] = {}
for s in snippets:
by_file.setdefault(s.file_path, []).append(s)
merged = []
for file_path, file_snippets in by_file.items():
file_snippets.sort(key=lambda s: s.line_start)
current = None
for snippet in file_snippets:
if current is None:
current = snippet
elif snippet.line_start <= current.line_end + 15: # Reduced from 30
# Merge ONLY if very close
all_lines = dict(current.lines)
all_lines.update(dict(snippet.lines))
current = RawSnippet(
file_path=file_path,
lines=sorted(all_lines.items()),
matched_patterns=list(set(
current.matched_patterns + snippet.matched_patterns
)),
matched_concepts=current.matched_concepts | snippet.matched_concepts,
)
else:
# Keep as separate snippet
merged.append(current)
current = snippet
if current:
merged.append(current)
return merged
def _rank_snippets(
self,
snippets: List[RawSnippet],
analysis: ProblemAnalysis,
) -> List[RawSnippet]:
"""Rank snippets by relevance to the problem."""
def score(s: RawSnippet) -> float:
# More concepts matched = higher score
concept_score = len(s.matched_concepts) * 5
# More patterns matched = higher score
pattern_score = len(s.matched_patterns) * 2
# Prefer medium-length snippets
length = len(s.lines)
if 20 <= length <= 80:
length_score = 3.0
elif length < 20:
length_score = length / 10
else:
length_score = max(0.5, 3.0 - (length - 80) / 50)
# File type priority
ext = Path(s.file_path).suffix.lower()
if ext in ('.py', '.go', '.ts', '.js', '.cpp', '.java'):
type_score = 2.0
elif ext in ('.md', '.txt', '.rst'):
type_score = 0.5 # Docs can be useful!
else:
type_score = 1.0
# Penalize tests
path_lower = s.file_path.lower()
if '/test' in path_lower or '_test.' in path_lower:
test_penalty = -5.0
else:
test_penalty = 0.0
# Bonus for core directories
if '/core/' in path_lower or '/src/' in path_lower or '/lib/' in path_lower:
location_bonus = 3.0
elif '/internal/' in path_lower or '/pkg/' in path_lower:
location_bonus = 2.0
else:
location_bonus = 0.0
# Bonus for matching file name hints
file_name_bonus = 0.0
for hint in analysis.file_name_hints:
if hint.lower() in path_lower:
file_name_bonus += 2.0
return (
concept_score + pattern_score + length_score +
type_score + test_penalty + location_bonus + file_name_bonus
)
return sorted(snippets, key=score, reverse=True)
# =========================================================================
# Main Search Method
# =========================================================================
def search(
self,
query: str,
repo_path: str,
path: Optional[str] = None,
) -> SearchResult:
"""Perform multi-stage semantic search with reasoning."""
start_time = time.time()
tool_time = 0.0
try:
repo_path = os.path.abspath(repo_path)
if path:
repo_path = os.path.join(repo_path, path)
self._log(f"Searching in: {repo_path}")
self._log(f"Query: {query[:100]}...")
# Stage 1: Analyze problem
self._log("=== Stage 1: Analyzing problem ===")
analysis = self._analyze_problem(query)
# Stage 2: Generate patterns
self._log("=== Stage 2: Generating patterns ===")
repo_context = self._get_repo_context(repo_path)
patterns = self._generate_patterns(analysis, repo_context)
for p in patterns[:10]:
self._log(f" Pattern: '{p.pattern}' (priority={p.priority}, concept={p.concept})")
# Stage 3: Execute search
self._log("=== Stage 3: Searching ===")
all_snippets = []
for pattern in patterns:
rg_start = time.time()
snippets = self._run_ripgrep(pattern, repo_path)
tool_time += (time.time() - rg_start) * 1000
if snippets:
self._log(f" '{pattern.pattern}': {len(snippets)} snippets")
all_snippets.extend(snippets)
self._log(f"Total raw snippets: {len(all_snippets)}")
# Stage 4: Merge and rank
self._log("=== Stage 4: Merging and ranking ===")
merged = self._merge_snippets(all_snippets)
self._log(f"After merge: {len(merged)}")
ranked = self._rank_snippets(merged, analysis)
final = ranked[:self.max_results]
self._log(f"Final results: {len(final)}")
for s in final[:5]:
self._log(f" {s.file_path}:{s.line_start}-{s.line_end} "
f"(concepts: {s.matched_concepts})")
# Convert to SearchResult
items = [
SearchItem(
file_path=s.file_path,
content=s.content,
line_start=s.line_start,
line_end=s.line_end,
match_context=f"Concepts: {', '.join(s.matched_concepts)}",
)
for s in final
]
total_time = (time.time() - start_time) * 1000
return SearchResult(
items=items,
patterns_used=[p.pattern for p in patterns[:10]],
execution_time_ms=total_time - tool_time,
total_time_ms=total_time,
tool_time_ms=tool_time,
)
except Exception as e:
total_time = (time.time() - start_time) * 1000
self._log(f"Error: {e}")
import traceback
self._log(traceback.format_exc())
return SearchResult(
items=[],
execution_time_ms=total_time,
total_time_ms=total_time,
error=str(e),
)
class ReasoningSearcherVerbose(ReasoningSearcher):
"""ReasoningSearcher with verbose logging enabled."""
def __init__(self, **kwargs):
kwargs["verbose"] = True
super().__init__(**kwargs)