"""
Context-aware pattern searcher.
Improvement over basic PatternSearcher:
1. First collects repository context (structure, file names, class names)
2. Passes context to LLM for better pattern generation
3. Uses file name patterns in addition to content patterns
This helps the LLM generate more targeted patterns that match
actual code in the repository.
"""
from __future__ import annotations
from dotenv import load_dotenv
load_dotenv()
import json
import logging
import os
import re
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from .base import BaseSearcher, SearchItem, SearchResult
logger = logging.getLogger(__name__)
# =============================================================================
# Repository Context Collection
# =============================================================================
@dataclass
class RepoContext:
"""Context information about a repository."""
repo_name: str = ""
languages: List[str] = field(default_factory=list) # e.g., [".py", ".go", ".ts"]
key_directories: List[str] = field(default_factory=list) # e.g., ["src", "lib", "core"]
relevant_files: List[str] = field(default_factory=list) # Files matching query keywords
class_names: List[str] = field(default_factory=list) # Extracted class/type names
function_names: List[str] = field(default_factory=list) # Extracted function names
def to_prompt_context(self) -> str:
"""Convert to string for LLM prompt."""
lines = []
if self.languages:
lines.append(f"Main languages: {', '.join(self.languages)}")
if self.key_directories:
lines.append(f"Key directories: {', '.join(self.key_directories)}")
if self.relevant_files:
lines.append(f"Relevant files found: {', '.join(self.relevant_files[:15])}")
if self.class_names:
lines.append(f"Class/type names found: {', '.join(self.class_names[:20])}")
if self.function_names:
lines.append(f"Function names found: {', '.join(self.function_names[:15])}")
return "\n".join(lines)
def collect_repo_context(
repo_path: str,
query: str,
max_files: int = 100,
) -> RepoContext:
"""
Collect context about a repository to help with pattern generation.
Args:
repo_path: Path to repository
query: Search query (used to find relevant files)
max_files: Maximum files to scan
Returns:
RepoContext with collected information
"""
repo_path = Path(repo_path).resolve()
context = RepoContext(repo_name=repo_path.name)
# Extract keywords from query for file matching
query_words = set(re.findall(r'\b\w{4,}\b', query.lower()))
# Remove common words
stop_words = {'what', 'when', 'where', 'which', 'that', 'this', 'have',
'from', 'with', 'will', 'would', 'could', 'should', 'does',
'want', 'need', 'understand', 'looking', 'trying', 'sometimes',
'different', 'multiple', 'after', 'before', 'during'}
query_words -= stop_words
# 1. Find key directories
try:
for item in repo_path.iterdir():
if item.is_dir() and not item.name.startswith('.'):
if item.name.lower() in ('src', 'lib', 'core', 'pkg', 'internal',
'app', 'server', 'client', 'api'):
context.key_directories.append(item.name)
except PermissionError:
pass
# 2. Find all code files and detect languages
try:
result = subprocess.run(
["find", str(repo_path), "-type", "f",
"!", "-path", "*/test*", "!", "-path", "*/.git/*",
"!", "-path", "*/node_modules/*", "!", "-path", "*/__pycache__/*",
"-name", "*.py", "-o", "-name", "*.go", "-o", "-name", "*.ts",
"-o", "-name", "*.js", "-o", "-name", "*.cpp", "-o", "-name", "*.h",
"-o", "-name", "*.qll", "-o", "-name", "*.java", "-o", "-name", "*.rs"],
capture_output=True, text=True, timeout=15
)
files = [f for f in result.stdout.strip().split('\n') if f][:max_files * 3]
except:
files = []
# Count extensions
ext_counts: Dict[str, int] = {}
for f in files:
ext = Path(f).suffix.lower()
if ext:
ext_counts[ext] = ext_counts.get(ext, 0) + 1
# Top 3 languages
top_exts = sorted(ext_counts.items(), key=lambda x: -x[1])[:3]
context.languages = [ext for ext, _ in top_exts]
# 3. Find files with names matching query keywords
relevant_files: List[Tuple[str, int]] = [] # (path, score)
for f in files:
try:
rel_path = str(Path(f).relative_to(repo_path))
except ValueError:
rel_path = f
# Skip test files for now
if '/test' in rel_path.lower() or '_test.' in rel_path.lower():
continue
fname = Path(f).stem.lower()
# Score by how many query words match
score = sum(1 for w in query_words if w in fname or w in rel_path.lower())
if score > 0:
relevant_files.append((rel_path, score))
# Sort by score and take top matches
relevant_files.sort(key=lambda x: -x[1])
context.relevant_files = [f for f, _ in relevant_files[:20]]
# 4. Extract class/function names from ALL code files (not just query-matching)
# This is crucial because key implementations may be in files like "tree.go"
# that don't match query keywords
class_names: Set[str] = set()
func_names: Set[str] = set()
# Scan both relevant files AND all files in key directories
files_to_scan = list(context.relevant_files[:10])
# Also scan all files in the root and key dirs (skip tests)
for f in files[:50]:
try:
rel_path = str(Path(f).relative_to(repo_path))
except ValueError:
continue
# Skip tests and deep nested files
if '/test' in rel_path.lower() or '_test.' in rel_path.lower():
continue
if rel_path.count('/') > 2: # Not too deep
continue
if rel_path not in files_to_scan:
files_to_scan.append(rel_path)
for rel_path in files_to_scan:
full_path = repo_path / rel_path if not Path(rel_path).is_absolute() else Path(rel_path)
try:
with open(full_path, 'r', errors='ignore') as fp:
content = fp.read(10000) # First 10KB
# Extract based on file type
ext = full_path.suffix.lower()
if ext == '.py':
# Python: class Name, def name
class_names.update(re.findall(r'^class\s+(\w+)', content, re.MULTILINE))
func_names.update(re.findall(r'^def\s+(\w+)', content, re.MULTILINE))
elif ext == '.go':
# Go: type Name struct, func Name(
class_names.update(re.findall(r'type\s+(\w+)\s+struct', content))
func_names.update(re.findall(r'^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(', content, re.MULTILINE))
elif ext in ('.ts', '.js'):
# TypeScript/JS: class Name, function name, export class
class_names.update(re.findall(r'class\s+(\w+)', content))
func_names.update(re.findall(r'function\s+(\w+)', content))
func_names.update(re.findall(r'(?:export\s+)?(?:async\s+)?(\w+)\s*[=:]\s*(?:async\s+)?\(', content))
elif ext in ('.cpp', '.h'):
# C++: class Name, void Name(
class_names.update(re.findall(r'class\s+(\w+)', content))
func_names.update(re.findall(r'(?:void|int|bool|QString?|Qgs\w*)\s+(\w+)\s*\(', content))
elif ext == '.qll':
# CodeQL: class Name extends, predicate name
class_names.update(re.findall(r'class\s+(\w+)', content))
func_names.update(re.findall(r'predicate\s+(\w+)', content))
except:
pass
# Filter to meaningful names (5+ chars, not common words)
common = {'class', 'self', 'this', 'super', 'return', 'string', 'number', 'boolean',
'object', 'array', 'error', 'result', 'value', 'index', 'count'}
# Filter and prioritize names that might be more relevant
# Sort by length (longer names are often more specific/important)
filtered_classes = [
n for n in class_names
if len(n) >= 4 and n.lower() not in common
]
# Sort by length descending, then alphabetically
context.class_names = sorted(filtered_classes, key=lambda n: (-len(n), n))[:50]
filtered_funcs = [
n for n in func_names
if len(n) >= 4 and n.lower() not in common and not n.startswith('_')
]
# Include important-sounding functions first
important_keywords = {'route', 'handle', 'process', 'execute', 'render', 'query',
'match', 'find', 'search', 'create', 'build', 'parse'}
def func_priority(n: str) -> tuple:
# Higher priority for functions with important keywords
has_important = any(kw in n.lower() for kw in important_keywords)
return (0 if has_important else 1, -len(n), n)
context.function_names = sorted(filtered_funcs, key=func_priority)[:40]
return context
# =============================================================================
# Pattern Generation with Context
# =============================================================================
class SearchPattern(BaseModel):
"""A single search pattern for ripgrep."""
pattern: str = Field(description="Plain text pattern to search for")
description: str = Field(description="What this pattern finds")
search_files: List[str] = Field(
default_factory=list,
description="Specific files to search in (optional)"
)
class PatternGenerationResult(BaseModel):
"""Result of pattern generation."""
terms: List[SearchPattern] = Field(description="List of search patterns")
CONTEXT_PATTERN_PROMPT = """You are CodeSearchPatternGenerator - an expert at finding code in large repositories.
REPOSITORY CONTEXT:
{repo_context}
TASK:
Generate 8-12 search patterns to find code that answers this query:
"{query}"
OUTPUT FORMAT:
Return JSON with this schema:
{{{{
"terms": [
{{{{
"pattern": "exact text to search for",
"description": "what this finds",
"search_files": ["optional/specific/file.py"]
}}}}
]
}}}}
CRITICAL RULES:
1. Use PLAIN TEXT patterns only (no regex). Will use `rg -F -i`.
2. Look at the class/function names I found - USE THEM if relevant!
3. Look at the relevant files I found - search IN THEM if they match the query.
4. Generate patterns for:
- Class names (exact matches from context if available)
- Function/method names (exact matches from context if available)
- Key domain terms from the query
5. If I found files like "tree.go" or "context.go" and query is about trees/contexts, include those file names.
STRATEGY:
- If context shows class "NodeWrangler" and query asks about nodes → use "NodeWrangler" as pattern
- If context shows file "render/render.go" and query asks about rendering → search in that file
- Combine context-derived patterns with query-derived patterns
AVOID:
- Generic words: get, set, data, value, error, result
- Patterns shorter than 5 characters
- Regex syntax: * + ? [ ] ( ) | \\"""
# =============================================================================
# Context-Aware Pattern Searcher
# =============================================================================
@dataclass
class RawSnippet:
"""A raw code snippet from ripgrep output."""
file_path: str
lines: List[Tuple[int, str]]
matched_patterns: List[str]
@property
def line_start(self) -> int:
return self.lines[0][0] if self.lines else 0
@property
def line_end(self) -> int:
return self.lines[-1][0] if self.lines else 0
@property
def content(self) -> str:
return "\n".join(line for _, line in self.lines)
class ContextPatternSearcher(BaseSearcher):
"""
Context-aware pattern searcher.
Collects repository context before generating patterns,
allowing the LLM to generate more targeted patterns.
"""
def __init__(
self,
model: str = "claude-sonnet-4-20250514",
max_patterns: int = 12,
context_lines: int = 15,
max_results: int = 10,
verbose: bool = False,
):
self.model = model
self.max_patterns = max_patterns
self.context_lines = context_lines
self.max_results = max_results
self.verbose = verbose
api_key = os.getenv("CLAUDE_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("CLAUDE_API_KEY or ANTHROPIC_API_KEY must be set")
self.llm = ChatAnthropic(
model=model,
api_key=api_key,
max_tokens=2048,
)
@property
def name(self) -> str:
return f"ContextPatternSearcher ({self.model})"
def _log(self, msg: str) -> None:
if self.verbose:
print(f"[ContextPattern] {msg}")
logger.debug(msg)
def _generate_patterns(
self,
query: str,
context: RepoContext,
) -> List[SearchPattern]:
"""Generate patterns using repo context."""
prompt = ChatPromptTemplate.from_messages([
("system", CONTEXT_PATTERN_PROMPT),
("human", "Generate search patterns for the query above."),
])
chain = prompt | self.llm.with_structured_output(PatternGenerationResult)
result = chain.invoke({
"repo_context": context.to_prompt_context(),
"query": query,
})
self._log(f"Generated {len(result.terms)} patterns")
# Filter invalid patterns
valid = []
regex_chars = {'*', '+', '?', '[', ']', '(', ')', '|', '^', '$', '\\'}
for p in result.terms:
if any(c in p.pattern for c in regex_chars):
self._log(f"Skipping regex pattern: {p.pattern}")
continue
if len(p.pattern) < 4:
self._log(f"Skipping short pattern: {p.pattern}")
continue
valid.append(p)
return valid[:self.max_patterns]
def _run_ripgrep(
self,
pattern: SearchPattern,
repo_path: str,
) -> List[RawSnippet]:
"""Execute ripgrep for a pattern."""
repo_path = os.path.abspath(repo_path)
snippets = []
# Determine search paths
if pattern.search_files:
paths = [str(Path(repo_path) / f) for f in pattern.search_files]
paths = [p for p in paths if Path(p).exists()]
if not paths:
paths = [repo_path]
else:
paths = [repo_path]
for search_path in paths:
cmd = [
"rg", "-F", "-i", "-n",
"-C", str(self.context_lines),
"--json", "-m", "30",
pattern.pattern,
search_path,
]
self._log(f"Running: rg -F -i '{pattern.pattern}' {search_path}")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=30
)
if result.returncode == 0:
snippets.extend(self._parse_output(result.stdout, repo_path, pattern.pattern))
elif result.returncode != 1: # 1 = no matches
self._log(f"rg error: {result.stderr}")
except subprocess.TimeoutExpired:
self._log(f"Timeout for pattern: {pattern.pattern}")
except FileNotFoundError:
raise RuntimeError("ripgrep not installed")
return snippets
def _parse_output(
self,
output: str,
repo_path: str,
pattern: str,
) -> List[RawSnippet]:
"""Parse ripgrep JSON output."""
file_lines: Dict[str, List[Tuple[int, str]]] = {}
for line in output.split("\n"):
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if data.get("type") not in ("match", "context"):
continue
msg = data.get("data", {})
file_path = msg.get("path", {}).get("text", "")
line_text = msg.get("lines", {}).get("text", "").rstrip("\n")
line_num = msg.get("line_number", 0)
if not file_path or not line_num:
continue
try:
rel_path = str(Path(file_path).relative_to(repo_path))
except ValueError:
rel_path = file_path
if rel_path not in file_lines:
file_lines[rel_path] = []
file_lines[rel_path].append((line_num, line_text))
# Merge into snippets
snippets = []
for file_path, lines in file_lines.items():
lines.sort(key=lambda x: x[0])
current = []
for line_num, content in lines:
if not current:
current.append((line_num, content))
elif line_num <= current[-1][0] + 10:
current.append((line_num, content))
else:
snippets.append(RawSnippet(file_path, current, [pattern]))
current = [(line_num, content)]
if current:
snippets.append(RawSnippet(file_path, current, [pattern]))
return snippets
def _merge_snippets(self, snippets: List[RawSnippet]) -> List[RawSnippet]:
"""Merge overlapping snippets."""
if not snippets:
return []
by_file: Dict[str, List[RawSnippet]] = {}
for s in snippets:
by_file.setdefault(s.file_path, []).append(s)
merged = []
for file_path, file_snippets in by_file.items():
file_snippets.sort(key=lambda s: s.line_start)
current = None
for snippet in file_snippets:
if current is None:
current = snippet
elif snippet.line_start <= current.line_end + 20:
# Merge
all_lines = dict(current.lines)
all_lines.update(dict(snippet.lines))
current = RawSnippet(
file_path,
sorted(all_lines.items()),
list(set(current.matched_patterns + snippet.matched_patterns)),
)
else:
merged.append(current)
current = snippet
if current:
merged.append(current)
return merged
def _rank_snippets(self, snippets: List[RawSnippet]) -> List[RawSnippet]:
"""Rank snippets by relevance."""
def score(s: RawSnippet) -> float:
# Pattern matches
pattern_score = len(s.matched_patterns) * 3
# Length preference (20-60 lines is ideal)
length = len(s.lines)
if 20 <= length <= 60:
length_score = 2.0
elif length < 20:
length_score = length / 10
else:
length_score = max(0.5, 2.0 - (length - 60) / 40)
# File type
ext = Path(s.file_path).suffix.lower()
if ext in ('.py', '.go', '.ts', '.js', '.cpp', '.qll', '.java'):
type_score = 2.0
else:
type_score = 0.5
# Penalize tests
path_lower = s.file_path.lower()
if '/test' in path_lower or '_test.' in path_lower:
test_penalty = -5.0
else:
test_penalty = 0.0
# Bonus for core directories
if '/core/' in path_lower or '/src/' in path_lower or '/lib/' in path_lower:
location_bonus = 2.0
else:
location_bonus = 0.0
return pattern_score + length_score + type_score + test_penalty + location_bonus
return sorted(snippets, key=score, reverse=True)
def search(
self,
query: str,
repo_path: str,
path: Optional[str] = None,
) -> SearchResult:
"""Perform context-aware semantic search."""
start_time = time.time()
tool_time = 0.0
try:
repo_path = os.path.abspath(repo_path)
if path:
repo_path = os.path.join(repo_path, path)
self._log(f"Searching in: {repo_path}")
self._log(f"Query: {query}")
# Step 1: Collect repo context
context = collect_repo_context(repo_path, query)
self._log(f"Context: {context.to_prompt_context()}")
# Step 2: Generate patterns with context
patterns = self._generate_patterns(query, context)
patterns_used = [p.pattern for p in patterns]
self._log(f"Patterns: {patterns_used}")
# Step 3: Run ripgrep
all_snippets = []
for pattern in patterns:
rg_start = time.time()
snippets = self._run_ripgrep(pattern, repo_path)
tool_time += (time.time() - rg_start) * 1000
all_snippets.extend(snippets)
self._log(f"Pattern '{pattern.pattern}': {len(snippets)} snippets")
self._log(f"Total raw snippets: {len(all_snippets)}")
# Step 4: Merge and rank
merged = self._merge_snippets(all_snippets)
self._log(f"After merge: {len(merged)}")
ranked = self._rank_snippets(merged)
final = ranked[:self.max_results]
# Convert to SearchResult
items = [
SearchItem(
file_path=s.file_path,
content=s.content,
line_start=s.line_start,
line_end=s.line_end,
match_context=", ".join(s.matched_patterns),
)
for s in final
]
total_time = (time.time() - start_time) * 1000
return SearchResult(
items=items,
patterns_used=patterns_used,
execution_time_ms=total_time - tool_time,
total_time_ms=total_time,
tool_time_ms=tool_time,
)
except Exception as e:
total_time = (time.time() - start_time) * 1000
self._log(f"Error: {e}")
import traceback
self._log(traceback.format_exc())
return SearchResult(
items=[],
execution_time_ms=total_time - tool_time,
total_time_ms=total_time,
tool_time_ms=tool_time,
error=str(e),
)
class ContextPatternSearcherVerbose(ContextPatternSearcher):
"""Context pattern searcher with verbose logging."""
def __init__(self, **kwargs):
kwargs["verbose"] = True
super().__init__(**kwargs)