"""
Schema-Guided Reasoning (SGR) Searcher for semantic code search.
This searcher uses SGR to guide LLM through structured reasoning:
1. Analyze the question (extract terms, concepts, locations)
2. Create a search plan (strategy + actions)
3. Execute searches and interpret results iteratively
4. Generate structured final answer
Key benefits of SGR:
- Consistent, predictable outputs via Structured Output
- Each reasoning step is explicit and auditable
- Easy to test/evaluate intermediate steps
- Better accuracy through enforced reasoning structure
Reference: https://abdullin.com/schema-guided-reasoning/
Supports multiple LLM providers:
- OpenAI (GPT-4o, GPT-4o-mini)
- Gemini (gemini-2.5-flash-lite, gemini-2.5-flash, gemini-3-pro-preview)
"""
from __future__ import annotations
import json
import os
import re
import subprocess
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from dotenv import load_dotenv
from pydantic import BaseModel
import langsmith as ls
from langsmith import traceable
from .base import BaseSearcher, SearchItem, SearchResult
from .sgr import (
RepoContext,
QuestionAnalysis,
TechnicalTerm,
TechnicalTerms,
# Typed Actions (NEW - PREFERRED)
GrepAction,
ReadFileAction,
ListDirAction,
GlobSearchAction,
TypedSearchPlan,
# Legacy Actions (backward compatibility)
SearchAction,
SearchPlan,
SearchResultInterpretation,
FinalAnswer,
CodeLocation,
SearchIteration,
NextActionDecision,
prompts,
)
from .sgr.path_validator import PathValidator
from .sgr.pattern_validator import PatternValidator, PatternValidation
# Load environment variables
load_dotenv()
T = TypeVar('T')
# ============================================
# Retry Decorator (simple implementation)
# ============================================
def retry_with_backoff(
max_retries: int = 3,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
exceptions: tuple = (Exception,),
) -> Callable:
"""
Retry decorator with exponential backoff.
Args:
max_retries: Maximum number of retry attempts
initial_delay: Initial delay between retries in seconds
backoff_factor: Multiplier for delay after each retry
exceptions: Tuple of exceptions to catch and retry
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(*args, **kwargs) -> T:
delay = initial_delay
last_exception = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt < max_retries:
time.sleep(delay)
delay *= backoff_factor
else:
raise
raise last_exception # type: ignore
return wrapper
return decorator
class LLMProvider(str, Enum):
"""Supported LLM providers for SGR."""
OPENAI = "openai"
GEMINI = "gemini"
# ============================================
# LRU Cache for File Reading
# ============================================
class LRUCache:
"""Simple LRU cache for file contents."""
def __init__(self, max_size: int = 50):
self.cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self.max_size = max_size
def get(self, key: str) -> Optional[Dict[str, Any]]:
if key in self.cache:
# Move to end (most recently used)
self.cache.move_to_end(key)
return self.cache[key]
return None
def put(self, key: str, value: Dict[str, Any]) -> None:
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
# Remove oldest item
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self) -> None:
self.cache.clear()
# ============================================
# Tool Implementations
# ============================================
class SGRTools:
"""Tools for SGR-based code search with safety and caching."""
# Default directories to skip during search
DEFAULT_SKIP_DIRS = {
'node_modules', '__pycache__', '.git', 'venv', '.venv',
'dist', 'build', 'coverage', '.cache', '.pytest_cache',
'target', 'vendor', '.idea', '.vscode', 'egg-info'
}
def __init__(
self,
repo_path: str,
verbose: bool = False,
skip_dirs: Optional[set] = None,
log_func: Optional[Callable[[str], None]] = None,
language: Optional[str] = None,
):
self.repo_path = str(Path(repo_path).resolve())
self.verbose = verbose
self._file_cache = LRUCache(max_size=50)
self.skip_dirs = skip_dirs if skip_dirs is not None else self.DEFAULT_SKIP_DIRS
self._log_func = log_func
# Path validation
self.path_validator = PathValidator(self.repo_path)
self.language = language
def _log(self, message: str) -> None:
if self._log_func:
self._log_func(f"[Tool] {message}")
elif self.verbose:
print(f"[SGR-Tool] {message}")
def _sanitize_path(self, path: str) -> Optional[str]:
"""
Sanitize path to prevent path traversal attacks.
Returns None if path is invalid/unsafe.
"""
if not path:
return self.repo_path
try:
# Resolve the full path
full_path = (Path(self.repo_path) / path).resolve()
repo_resolved = Path(self.repo_path).resolve()
# Check that resolved path is within repo
if not str(full_path).startswith(str(repo_resolved)):
self._log(f"Path traversal blocked: {path}")
return None
return str(full_path)
except Exception as e:
self._log(f"Path sanitization error: {e}")
return None
def _validate_regex(self, pattern: str) -> Optional[str]:
"""
Validate regex pattern.
Returns error message if invalid, None if valid.
"""
if not pattern:
return "Pattern cannot be empty"
try:
re.compile(pattern)
return None
except re.error as e:
return f"Invalid regex pattern: {e}"
def grep(
self,
pattern: str,
path: Optional[str] = None,
glob: Optional[str] = None,
case_insensitive: bool = False,
context_lines: int = 3,
max_results: int = 50,
files_only: bool = False,
) -> Dict[str, Any]:
"""
Search file contents using ripgrep.
Args:
pattern: Regex pattern to search for
path: Directory to search in (relative to repo)
glob: File glob pattern (e.g., "*.py")
case_insensitive: Case insensitive search
context_lines: Lines of context around matches
max_results: Maximum number of matches
files_only: Return only file paths, not content
Returns structured result with matches.
"""
# NEW: Validate pattern with PatternValidator
pattern_validation = PatternValidator.validate_pattern(pattern, self.language)
if not pattern_validation.is_valid:
self._log(f"Pattern validation failed: {pattern} - {pattern_validation.issue}")
# Try to use transformed/suggested pattern
if pattern_validation.transformed_pattern:
pattern = pattern_validation.transformed_pattern
self._log(f" Using transformed pattern: {pattern}")
elif pattern_validation.suggested_pattern:
pattern = pattern_validation.suggested_pattern
self._log(f" Using suggested pattern: {pattern}")
else:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": f"Pattern rejected: {pattern_validation.issue}",
}
# Validate regex pattern
regex_error = self._validate_regex(pattern)
if regex_error:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": regex_error,
}
# Sanitize path
search_path = self._sanitize_path(path) if path else self.repo_path
if search_path is None:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": f"Invalid or unsafe path: {path}",
}
# Build command
if files_only:
cmd = ["rg", "-l"] # Files only mode
else:
cmd = ["rg", "--json"]
if case_insensitive:
cmd.append("-i")
if not files_only and context_lines > 0:
cmd.extend(["-C", str(context_lines)])
if glob:
cmd.extend(["--glob", glob])
if not files_only:
cmd.extend(["-m", str(max_results)])
cmd.append(pattern)
cmd.append(search_path)
self._log(f"grep: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=30,
cwd=self.repo_path,
)
# Handle files_only mode
if files_only:
files = []
for line in result.stdout.strip().split('\n'):
if line:
try:
rel_path = str(Path(line).relative_to(self.repo_path))
files.append(rel_path)
except ValueError:
files.append(line)
return {
"matches": [],
"files": files[:max_results],
"files_count": len(files),
"total_matches": len(files),
"truncated": len(files) > max_results,
"error": None,
}
# Parse ripgrep JSON output
matches = []
files_seen = set()
for line in result.stdout.strip().split('\n'):
if not line:
continue
try:
data = json.loads(line)
if data.get("type") == "match":
match_data = data.get("data", {})
file_path = match_data.get("path", {}).get("text", "")
line_num = match_data.get("line_number", 0)
line_text = match_data.get("lines", {}).get("text", "").strip()
# Make path relative
try:
file_path = str(Path(file_path).relative_to(self.repo_path))
except ValueError:
pass
files_seen.add(file_path)
matches.append({
"file": file_path,
"line": line_num,
"content": line_text[:500],
})
except Exception:
continue
# Fallback to plain text parsing if JSON failed
if not matches and result.stdout:
for line in result.stdout.strip().split('\n')[:max_results]:
if ':' in line:
parts = line.split(':', 2)
if len(parts) >= 3:
file_path = parts[0]
try:
file_path = str(Path(file_path).relative_to(self.repo_path))
except ValueError:
pass
files_seen.add(file_path)
matches.append({
"file": file_path,
"line": int(parts[1]) if parts[1].isdigit() else 0,
"content": parts[2][:500] if len(parts) > 2 else "",
})
return {
"matches": matches[:max_results],
"files": list(files_seen),
"files_count": len(files_seen),
"total_matches": len(matches),
"truncated": len(matches) >= max_results,
"error": None,
}
except subprocess.TimeoutExpired:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": "Search timed out after 30 seconds",
}
except Exception as e:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": str(e),
}
def read_file(
self,
file_path: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: int = 200,
use_cache: bool = True,
) -> Dict[str, Any]:
"""
Read file contents with caching.
Returns structured result with content.
"""
# Sanitize path
full_path = self._sanitize_path(file_path)
if full_path is None:
return {
"content": "",
"total_lines": 0,
"read_lines": 0,
"error": f"Invalid or unsafe path: {file_path}",
}
self._log(f"read_file: {file_path}")
# Check cache (only for full file reads without line range)
cache_key = file_path
if use_cache and not start_line and not end_line:
cached = self._file_cache.get(cache_key)
if cached:
self._log(f" (cache hit)")
return cached
try:
path_obj = Path(full_path)
if not path_obj.exists():
return {
"content": "",
"total_lines": 0,
"read_lines": 0,
"error": f"File not found: {file_path}",
}
with open(path_obj, 'r', encoding='utf-8', errors='replace') as f:
all_lines = f.readlines()
total_lines = len(all_lines)
# Apply line range
if start_line and end_line:
lines = all_lines[max(0, start_line-1):min(end_line, total_lines)]
start_idx = max(0, start_line - 1)
elif start_line:
lines = all_lines[max(0, start_line-1):max(0, start_line-1) + max_lines]
start_idx = max(0, start_line - 1)
else:
lines = all_lines[:max_lines]
start_idx = 0
# Format with line numbers
numbered_lines = []
for i, line in enumerate(lines):
line_num = start_idx + i + 1
numbered_lines.append(f"{line_num:6}|{line.rstrip()}")
content = "\n".join(numbered_lines)
result = {
"content": content,
"total_lines": total_lines,
"read_lines": len(lines),
"truncated": len(lines) < total_lines,
"error": None,
}
# Cache result (only for full file reads)
if use_cache and not start_line and not end_line:
self._file_cache.put(cache_key, result)
return result
except Exception as e:
return {
"content": "",
"total_lines": 0,
"read_lines": 0,
"error": str(e),
}
def list_dir(
self,
path: str = ".",
max_depth: int = 2,
) -> Dict[str, Any]:
"""
List directory structure.
Returns structured result with directory tree.
"""
# Sanitize path
target_path = self._sanitize_path(path)
if target_path is None:
return {
"structure": "",
"files_count": 0,
"dirs_count": 0,
"error": f"Invalid or unsafe path: {path}",
}
self._log(f"list_dir: {path}")
try:
target_path_obj = Path(target_path)
if not target_path_obj.exists():
return {
"structure": "",
"files_count": 0,
"dirs_count": 0,
"error": f"Path not found: {path}",
}
lines = []
files_count = 0
dirs_count = 0
def walk_dir(dir_path: Path, depth: int, prefix: str = ""):
nonlocal files_count, dirs_count
if depth > max_depth:
return
try:
entries = sorted(dir_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
except PermissionError:
return
for entry in entries:
if entry.name.startswith('.'):
continue
if entry.is_dir():
# Check if should skip using configurable skip_dirs
should_skip = any(skip in entry.name for skip in self.skip_dirs)
if should_skip:
lines.append(f"{prefix}{entry.name}/ [skipped]")
continue
dirs_count += 1
lines.append(f"{prefix}{entry.name}/")
if depth < max_depth:
walk_dir(entry, depth + 1, prefix + " ")
else:
files_count += 1
lines.append(f"{prefix}{entry.name}")
walk_dir(target_path_obj, 1)
# Truncate if too many lines
if len(lines) > 100:
lines = lines[:100]
lines.append(f"... (truncated, {files_count} files, {dirs_count} dirs total)")
return {
"structure": "\n".join(lines),
"files_count": files_count,
"dirs_count": dirs_count,
"error": None,
}
except Exception as e:
return {
"structure": "",
"files_count": 0,
"dirs_count": 0,
"error": str(e),
}
def glob_search(
self,
pattern: str,
path: Optional[str] = None,
max_results: int = 50,
) -> Dict[str, Any]:
"""
Search for files by name pattern using pathlib.glob.
Returns list of matching file paths.
"""
# Sanitize path
search_path = self._sanitize_path(path) if path else self.repo_path
if search_path is None:
return {
"files": [],
"count": 0,
"truncated": False,
"error": f"Invalid or unsafe path: {path}",
}
self._log(f"glob_search: {pattern} in {path or '.'}")
try:
search_path_obj = Path(search_path)
# Use rglob for recursive search
files = []
for match in search_path_obj.rglob(pattern):
if match.is_file():
try:
rel_path = str(match.relative_to(self.repo_path))
files.append(rel_path)
except ValueError:
files.append(str(match))
if len(files) >= max_results * 2: # Get extra for filtering
break
# Filter out common non-code directories using configurable skip_dirs
files = [f for f in files if not any(skip in f for skip in self.skip_dirs)]
return {
"files": files[:max_results],
"count": len(files),
"truncated": len(files) > max_results,
"error": None,
}
except Exception as e:
return {
"files": [],
"count": 0,
"truncated": False,
"error": str(e),
}
def analyze_repo_smart(self) -> Dict[str, Any]:
"""
Smart repository analysis:
1. Count files by extension → determine languages
2. Find and parse manifests → determine frameworks/deps
3. Find infrastructure files → understand stack
Returns structured info about the repository.
"""
self._log("analyze_repo_smart: scanning repository...")
# File extension to language mapping
EXT_TO_LANG = {
".py": "Python", ".pyx": "Python", ".pyi": "Python",
".js": "JavaScript", ".jsx": "JavaScript", ".mjs": "JavaScript",
".ts": "TypeScript", ".tsx": "TypeScript",
".go": "Go",
".java": "Java", ".kt": "Kotlin", ".scala": "Scala",
".rs": "Rust",
".c": "C", ".h": "C", ".cpp": "C++", ".hpp": "C++", ".cc": "C++",
".cs": "C#",
".rb": "Ruby",
".php": "PHP",
".swift": "Swift",
".qll": "CodeQL", ".ql": "CodeQL", # Important for CodeQL repos!
}
# Manifest files to look for
MANIFEST_NAMES = {
"package.json", "package-lock.json", "yarn.lock",
"requirements.txt", "pyproject.toml", "Pipfile", "setup.py", "setup.cfg",
"pom.xml", "build.gradle", "build.gradle.kts",
"go.mod", "go.sum",
"Cargo.toml", "Cargo.lock",
"composer.json",
"Gemfile", "Gemfile.lock",
"qlpack.yml", # CodeQL package
}
# Infrastructure files
INFRA_NAMES = {
"Dockerfile", "docker-compose.yml", "docker-compose.yaml",
".gitlab-ci.yml", "Jenkinsfile", "azure-pipelines.yml",
}
lang_stats: Dict[str, int] = {}
manifests_found: List[str] = []
infra_found: List[str] = []
top_level_dirs: List[str] = []
# NEW: Track source files per directory for smart key_directories detection
dir_source_counts: Dict[str, int] = {}
repo_path = Path(self.repo_path)
# Source code extensions (for identifying key source directories)
SOURCE_EXTS = {
".py", ".go", ".java", ".kt", ".scala", # Backend
".js", ".jsx", ".ts", ".tsx", ".mjs", # JavaScript/TypeScript
".c", ".cpp", ".cc", ".h", ".hpp", # C/C++
".cs", # C#
".rs", # Rust
".rb", # Ruby
".php", # PHP
".swift", # Swift
".qll", ".ql", # CodeQL
}
try:
# Get top-level directories
for item in repo_path.iterdir():
if item.is_dir() and item.name not in self.skip_dirs and not item.name.startswith('.'):
top_level_dirs.append(item.name)
dir_source_counts[item.name] = 0
# === UNIFORM SAMPLING: scan each top-level directory equally ===
# This guarantees we see files from ALL directories, not just alphabetically first ones
# Problem: rglob uses DFS and with max_files limit we might never reach some dirs
FILES_PER_DIR = 2000 # Sample up to N files from each top-level dir
file_count = 0
current_dir_name = None # Track which top-level dir we're in
def process_file(file_path: Path, top_dir: Optional[str] = None) -> None:
"""Process a single file for language stats and manifests."""
nonlocal file_count
file_count += 1
rel_path = str(file_path.relative_to(repo_path))
# Count by extension
ext = file_path.suffix.lower()
if ext in EXT_TO_LANG:
lang = EXT_TO_LANG[ext]
lang_stats[lang] = lang_stats.get(lang, 0) + 1
# NEW: Count source files per top-level directory
if top_dir and ext in SOURCE_EXTS:
dir_source_counts[top_dir] = dir_source_counts.get(top_dir, 0) + 1
# Find manifests
if file_path.name in MANIFEST_NAMES:
manifests_found.append(rel_path)
# Find infra files
if file_path.name in INFRA_NAMES:
infra_found.append(rel_path)
# GitHub Actions
if ".github/workflows" in rel_path and file_path.suffix in (".yml", ".yaml"):
infra_found.append(rel_path)
# First: process root-level files
for item in repo_path.iterdir():
if item.is_file():
process_file(item)
# Then: sample from EACH top-level directory uniformly
for dir_name in sorted(top_level_dirs):
dir_path = repo_path / dir_name
dir_file_count = 0
try:
for file_path in dir_path.rglob("*"):
if dir_file_count >= FILES_PER_DIR:
break
# Skip ignored directories
if any(skip in file_path.parts for skip in self.skip_dirs):
continue
if file_path.is_file():
dir_file_count += 1
process_file(file_path, top_dir=dir_name)
except PermissionError:
continue
# Parse key manifests for dependencies
dependencies: Dict[str, List[str]] = {}
for manifest in manifests_found[:10]: # Limit parsing
manifest_path = repo_path / manifest
try:
content = manifest_path.read_text(encoding="utf-8", errors="ignore")[:10000]
if manifest.endswith("package.json"):
import json
try:
data = json.loads(content)
deps = list((data.get("dependencies") or {}).keys())
deps += list((data.get("devDependencies") or {}).keys())
if deps:
dependencies[manifest] = deps[:20]
except:
pass
elif manifest.endswith("requirements.txt"):
deps = []
for line in content.splitlines()[:50]:
line = line.strip()
if line and not line.startswith("#") and not line.startswith("-"):
pkg = line.split("==")[0].split(">=")[0].split("<=")[0].split("[")[0]
if pkg:
deps.append(pkg.strip())
if deps:
dependencies[manifest] = deps[:20]
elif manifest.endswith("go.mod"):
deps = []
for line in content.splitlines()[:50]:
if line.strip().startswith("require") or "/" in line:
parts = line.strip().split()
if parts and "/" in parts[0]:
deps.append(parts[0])
if deps:
dependencies[manifest] = deps[:20]
elif manifest.endswith("qlpack.yml"):
# CodeQL package - extract name
for line in content.splitlines():
if line.startswith("name:"):
dependencies[manifest] = [line.split(":", 1)[1].strip()]
break
except Exception:
pass
# Determine primary language
primary_lang = None
if lang_stats:
primary_lang = max(lang_stats.items(), key=lambda x: x[1])[0]
# Detect framework from dependencies
framework = None
all_deps = []
for deps in dependencies.values():
all_deps.extend(deps)
# Framework detection rules
if "django" in all_deps or "Django" in all_deps:
framework = "Django"
elif "flask" in all_deps or "Flask" in all_deps:
framework = "Flask"
elif "fastapi" in all_deps or "FastAPI" in all_deps:
framework = "FastAPI"
elif "express" in all_deps:
framework = "Express"
elif "react" in all_deps or "next" in all_deps:
framework = "React"
elif "vue" in all_deps or "nuxt" in all_deps:
framework = "Vue"
elif "gin-gonic/gin" in " ".join(all_deps):
framework = "Gin"
elif "spring" in " ".join(all_deps).lower():
framework = "Spring"
# NEW: Sort directories by source file count (most source files first)
# This ensures key source directories like src/, lib/, core/ are prioritized
source_directories = sorted(
[(d, c) for d, c in dir_source_counts.items() if c > 0],
key=lambda x: -x[1] # Descending by count
)
return {
"languages": lang_stats,
"primary_language": primary_lang,
"framework": framework,
"top_level_dirs": sorted(top_level_dirs),
"manifests": manifests_found,
"dependencies": dependencies,
"infrastructure": infra_found,
"total_files_scanned": file_count,
# NEW: Directories ranked by source file count
"source_directories": source_directories,
"dir_source_counts": dir_source_counts,
}
except Exception as e:
return {
"languages": {},
"primary_language": None,
"framework": None,
"top_level_dirs": [],
"manifests": [],
"dependencies": {},
"infrastructure": [],
"source_directories": [],
"dir_source_counts": {},
"error": str(e),
}
def execute_action(self, action: SearchAction) -> Dict[str, Any]:
"""Execute a search action and return result."""
tool = action.tool
params = action.params or {} # Handle None params
# Validate params and log warnings for unknown params
warnings = self._validate_action_params(action)
for warning in warnings:
self._log(f" Warning: {warning}")
if tool == "grep":
# Use `or ""` to handle both missing key AND None value
pattern = params.get("pattern") or ""
if not pattern:
return {
"matches": [],
"files": [],
"files_count": 0,
"total_matches": 0,
"truncated": False,
"error": "Pattern is required for grep search",
}
return self.grep(
pattern=pattern,
path=params.get("path"),
glob=params.get("glob"),
case_insensitive=bool(params.get("case_insensitive", False)),
context_lines=int(params.get("context_lines", 3) or 3),
files_only=bool(params.get("files_only", False)),
)
elif tool == "read_file":
file_path = params.get("file_path") or ""
if not file_path:
return {
"content": "",
"total_lines": 0,
"read_lines": 0,
"error": "file_path is required",
}
return self.read_file(
file_path=file_path,
start_line=params.get("start_line"),
end_line=params.get("end_line"),
)
elif tool == "list_dir":
return self.list_dir(
path=params.get("path") or ".",
max_depth=int(params.get("max_depth", 2) or 2),
)
elif tool == "glob_search":
pattern = params.get("pattern") or "*"
if not pattern:
return {
"files": [],
"count": 0,
"truncated": False,
"error": "Pattern is required for glob search",
}
return self.glob_search(
pattern=pattern,
path=params.get("path"),
)
else:
return {"error": f"Unknown tool: {tool}"}
def _validate_action_params(self, action: SearchAction) -> List[str]:
"""Validate action params and return list of warnings."""
warnings = []
valid_params = {
"grep": {"pattern", "path", "glob", "case_insensitive", "context_lines", "files_only", "max_results"},
"read_file": {"file_path", "start_line", "end_line", "max_lines"},
"list_dir": {"path", "max_depth"},
"glob_search": {"pattern", "path", "max_results"},
}
expected = valid_params.get(action.tool, set())
actual = set((action.params or {}).keys())
unknown = actual - expected
if unknown:
warnings.append(f"Unknown params for {action.tool}: {unknown}")
return warnings
def execute_actions_parallel(
self,
actions: List[SearchAction],
max_workers: int = 3,
) -> List[Dict[str, Any]]:
"""Execute multiple actions in parallel."""
results = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self.execute_action, action): i
for i, action in enumerate(actions)
}
# Collect results in order
results = [None] * len(actions)
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as e:
results[idx] = {"error": str(e)}
return results
# ============================================
# SGR Searcher Implementation
# ============================================
class SGRSearcher(BaseSearcher):
"""
Schema-Guided Reasoning searcher.
Uses structured output to enforce reasoning steps:
1. Analyze question → QuestionAnalysis
2. Create plan → SearchPlan
3. Execute & interpret → SearchIteration[]
4. Generate answer → FinalAnswer
Supports multiple LLM providers:
- OpenAI: GPT-4o, GPT-4o-mini
- Gemini: gemini-2.5-flash-lite, gemini-2.5-flash, gemini-3-pro-preview
"""
# Maximum size for tool result summaries
MAX_TOOL_RESULT_SIZE = 2000
# Maximum findings to keep (prevents memory issues)
MAX_FINDINGS = 100
def __init__(
self,
provider: LLMProvider = LLMProvider.GEMINI,
model: Optional[str] = None,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
parallel_initial_actions: bool = True,
max_retries: int = 3,
):
"""
Initialize SGR searcher.
Args:
provider: LLM provider (openai or gemini)
model: Model name (optional, uses provider defaults)
max_iterations: Maximum search iterations
total_timeout: Total timeout in seconds
verbose: Print debug information
parallel_initial_actions: Execute initial plan actions in parallel
max_retries: Maximum retries for LLM calls
"""
self.provider = provider
self.max_iterations = max_iterations
self.total_timeout = total_timeout
self.verbose = verbose
self.parallel_initial_actions = parallel_initial_actions
self.max_retries = max_retries
# Set default model based on provider (using current model names)
if model:
self.model = model
elif provider == LLMProvider.OPENAI:
self.model = "gpt-4o-2024-08-06"
else: # GEMINI
self.model = "gemini-2.5-flash-lite"
# Initialize LLM client
self._llm = None
self._openai_client = None
self._init_llm()
def _init_llm(self) -> None:
"""Initialize LLM based on provider."""
if self.provider == LLMProvider.OPENAI:
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY must be set for OpenAI SGR searcher")
self._openai_client = OpenAI(api_key=api_key)
elif self.provider == LLMProvider.GEMINI:
from langchain_google_genai import ChatGoogleGenerativeAI
# Try different API key environment variables
api_key = (
os.getenv("GOOGLE_API_KEY") or
os.getenv("GEMINI_API_KEY") or
os.getenv("AI_STUDIO") or
os.getenv("VERTEX_AI_API_KEY")
)
if not api_key:
raise ValueError(
"GOOGLE_API_KEY, GEMINI_API_KEY, AI_STUDIO or VERTEX_AI_API_KEY must be set. "
"Get a key from https://aistudio.google.com/apikey"
)
self._llm = ChatGoogleGenerativeAI(
model=self.model,
google_api_key=api_key,
max_output_tokens=4096,
convert_system_message_to_human=True,
)
else:
raise ValueError(f"Unknown provider: {self.provider}")
@property
def name(self) -> str:
return f"SGRSearcher ({self.provider.value}/{self.model})"
def _log(self, message: str) -> None:
if self.verbose:
print(f"[SGR] {message}")
def _call_structured(
self,
messages: List[Dict[str, str]],
response_format: Type[BaseModel],
) -> BaseModel:
"""Call LLM with structured output based on provider with retry."""
return self._call_structured_with_retry(messages, response_format)
@retry_with_backoff(max_retries=3, initial_delay=1.0, backoff_factor=2.0)
def _call_structured_with_retry(
self,
messages: List[Dict[str, str]],
response_format: Type[BaseModel],
) -> BaseModel:
"""Call LLM with retry logic."""
if self.provider == LLMProvider.OPENAI:
return self._call_structured_openai(messages, response_format)
else: # GEMINI
return self._call_structured_gemini(messages, response_format)
def _call_structured_openai(
self,
messages: List[Dict[str, str]],
response_format: Type[BaseModel],
) -> BaseModel:
"""Call OpenAI with structured output."""
start_time = time.time()
response = self._openai_client.beta.chat.completions.parse(
model=self.model,
messages=messages,
response_format=response_format,
)
elapsed_ms = (time.time() - start_time) * 1000
self._log(f" OpenAI call took {elapsed_ms:.0f}ms")
message = response.choices[0].message
# Handle refusal (problem 3)
if hasattr(message, 'refusal') and message.refusal:
raise ValueError(f"OpenAI refused to generate response: {message.refusal}")
if message.parsed is None:
raise ValueError(
f"OpenAI returned None parsed result for {response_format.__name__}. "
f"Content: {message.content[:200] if message.content else 'empty'}"
)
return message.parsed
def _call_structured_gemini(
self,
messages: List[Dict[str, str]],
response_format: Type[BaseModel],
) -> BaseModel:
"""Call Gemini with structured output using LangChain."""
from langchain_core.messages import HumanMessage, AIMessage
start_time = time.time()
# Convert messages to LangChain format
lc_messages = []
system_content = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_content = content
elif role == "user":
if system_content:
content = f"Instructions:\n{system_content}\n\nUser query:\n{content}"
system_content = None
lc_messages.append(HumanMessage(content=content))
elif role == "assistant":
lc_messages.append(AIMessage(content=content))
# Create structured LLM
structured_llm = self._llm.with_structured_output(response_format)
# Invoke and return
result = structured_llm.invoke(lc_messages)
elapsed_ms = (time.time() - start_time) * 1000
self._log(f" Gemini call took {elapsed_ms:.0f}ms")
if result is None:
raise ValueError(f"Gemini returned None for {response_format.__name__}")
return result
def _call_llm_text(
self,
messages: List[Dict[str, str]],
) -> str:
"""
Call LLM with plain text output (not structured).
Used for simple tasks like path correction.
"""
start_time = time.time()
if self.provider == LLMProvider.OPENAI:
response = self._openai_client.chat.completions.create(
model=self.model,
messages=messages,
)
result = response.choices[0].message.content or ""
else: # GEMINI
from langchain_core.messages import HumanMessage, AIMessage
# Convert messages to LangChain format
lc_messages = []
system_content = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_content = content
elif role == "user":
if system_content:
content = f"Instructions:\n{system_content}\n\nUser query:\n{content}"
system_content = None
lc_messages.append(HumanMessage(content=content))
elif role == "assistant":
lc_messages.append(AIMessage(content=content))
# Use regular invoke (not structured)
response = self._llm.invoke(lc_messages)
result = response.content if hasattr(response, 'content') else str(response)
elapsed_ms = (time.time() - start_time) * 1000
self._log(f" LLM text call took {elapsed_ms:.0f}ms")
return result
# ========================================
# STEP 0: Analyze Repository Context
# ========================================
def _analyze_repo_context(
self,
repo_path: str,
tools: 'SGRTools',
) -> RepoContext:
"""
Step 0: Analyze repository structure to understand context.
Uses SMART analysis:
- Count files by extension → determine languages
- Parse manifests → determine frameworks
- Find infrastructure → understand stack
"""
self._log("Step 0: Analyzing repository context (smart)...")
# === SMART ANALYSIS ===
smart_info = tools.analyze_repo_smart()
self._log(f" Languages found: {smart_info.get('languages', {})}")
self._log(f" Primary (detected): {smart_info.get('primary_language')}")
self._log(f" Framework (detected): {smart_info.get('framework')}")
self._log(f" Top dirs: {smart_info.get('top_level_dirs', [])[:10]}")
self._log(f" Manifests: {smart_info.get('manifests', [])[:5]}")
# Format info for LLM
lang_summary = ", ".join(
f"{lang}: {count} files"
for lang, count in sorted(
smart_info.get("languages", {}).items(),
key=lambda x: -x[1]
)[:5]
)
deps_summary = []
for manifest, deps in smart_info.get("dependencies", {}).items():
deps_summary.append(f"{manifest}: {', '.join(deps[:10])}")
# NEW: Get source directories ranked by file count
source_directories = smart_info.get("source_directories", [])
source_dirs_summary = "\n".join(
f" {d}/: {c} source files"
for d, c in source_directories[:10]
) if source_directories else " (none detected)"
# NEW: Show ALL top-level directories with their source file counts
dir_source_counts = smart_info.get("dir_source_counts", {})
all_dirs_summary = ", ".join(
f"{d}({dir_source_counts.get(d, 0)})"
for d in smart_info.get('top_level_dirs', [])[:20]
)
# Get directory structure for LLM (but don't rely on it for key_directories)
dir_result = tools.list_dir(".", max_depth=1) # Reduced depth - just show structure
directory_listing = dir_result.get("structure", "")
messages = [
{"role": "system", "content": prompts.REPO_CONTEXT_SYSTEM},
{"role": "user", "content": prompts.REPO_CONTEXT_USER.format(
repo_path=repo_path,
directory_listing=directory_listing[:1500],
sample_files=f"""
=== DETECTED LANGUAGES (by file count) ===
{lang_summary}
=== TOP-LEVEL DIRECTORIES (with source file counts) ===
{all_dirs_summary}
=== KEY SOURCE DIRECTORIES (ranked by source file count) ===
{source_dirs_summary}
=== MANIFEST FILES FOUND ===
{chr(10).join(smart_info.get('manifests', [])[:10])}
=== DEPENDENCIES FROM MANIFESTS ===
{chr(10).join(deps_summary[:5])}
=== INFRASTRUCTURE FILES ===
{chr(10).join(smart_info.get('infrastructure', [])[:5])}
""",
)},
]
result = self._call_structured(messages, RepoContext)
# Override with detected values if LLM got it wrong
detected_lang = smart_info.get("primary_language")
detected_framework = smart_info.get("framework")
# Use detected language directly (no mapping - preserve exact name like "CodeQL", "C++")
if detected_lang and result.primary_language.lower() != detected_lang.lower():
self._log(f" [Override] Language: {result.primary_language} → {detected_lang}")
result.primary_language = detected_lang
if detected_framework and not result.framework:
self._log(f" [Override] Framework: None → {detected_framework}")
result.framework = detected_framework
# ========================================
# NEW: Smart key_directories detection
# ========================================
# CRITICAL: Don't trust LLM for key_directories - use source file counts!
# This fixes the problem where LLM picks cmake/, debian/ instead of src/
smart_key_dirs = []
# Step 1: PRIORITY - Check for well-known source directory names FIRST
# These directories almost always contain source code
PRIORITY_DIR_NAMES = {'src', 'lib', 'core', 'pkg', 'app', 'internal', 'cmd', 'api', 'source'}
all_top_dirs = smart_info.get('top_level_dirs', [])
for dir_name in all_top_dirs:
if dir_name.lower() in PRIORITY_DIR_NAMES:
dir_path = f"{dir_name}/"
source_count = dir_source_counts.get(dir_name, 0)
# Include if has ANY source files (not just >= 10)
if source_count > 0 and dir_path not in smart_key_dirs:
smart_key_dirs.insert(0, dir_path)
self._log(f" [Smart] Priority dir: {dir_name}/ ({source_count} source files)")
# Also include even if 0 source files detected (sampling might have missed)
elif dir_path not in smart_key_dirs:
# Check if directory actually exists and has files
test_path = Path(repo_path) / dir_name
if test_path.is_dir():
smart_key_dirs.insert(0, dir_path)
self._log(f" [Smart] Priority dir (exists): {dir_name}/")
# Step 2: Add directories with most source files
for dir_name, source_count in source_directories[:8]: # Top 8 by source count
dir_path = f"{dir_name}/"
if source_count >= 5 and dir_path not in smart_key_dirs: # Lowered threshold from 10 to 5
smart_key_dirs.append(dir_path)
# Step 3: Merge with LLM's suggestions (if they're actually source directories)
llm_key_dirs = result.key_directories or []
for llm_dir in llm_key_dirs:
dir_name = llm_dir.rstrip('/')
# Only add if it has source files and not already included
if dir_source_counts.get(dir_name, 0) > 0 and llm_dir not in smart_key_dirs:
smart_key_dirs.append(llm_dir)
# Step 4: Fallback - if no source directories found, use top dirs with any source files
if not smart_key_dirs:
top_dirs = smart_info.get("top_level_dirs", [])
smart_key_dirs = [
f"{d}/" for d in top_dirs[:8]
if dir_source_counts.get(d, 0) > 0
]
# Step 5: Last resort - just use top directories that might be source
if not smart_key_dirs:
# Exclude obviously non-source directories
NON_SOURCE_DIRS = {
'cmake', 'cmake_templates', 'debian', 'doc', 'docs', 'documentation',
'images', 'img', 'assets', 'resources', 'i18n', 'locale', 'locales',
'scripts', 'tools', 'util', 'utils', 'config', 'configs',
'test', 'tests', 'spec', 'specs', 'examples', 'example',
'build', 'dist', 'out', 'target', 'bin', 'obj',
'node_modules', 'vendor', 'third_party', 'external', 'deps',
'linux', 'windows', 'ms-windows', 'macos', 'darwin',
}
smart_key_dirs = [
f"{d}/" for d in smart_info.get("top_level_dirs", [])[:15]
if d.lower() not in NON_SOURCE_DIRS
][:6]
result.key_directories = smart_key_dirs[:10] # Limit to 10
self._log(f" [Smart] Key dirs by source count: {smart_key_dirs[:5]}...")
# Store full language distribution (not just primary)
detected_languages = smart_info.get("languages", {})
if detected_languages:
result.languages = detected_languages
self._log(f" Final Language: {result.primary_language}")
self._log(f" Final Languages: {result.languages}")
self._log(f" Final Framework: {result.framework}")
self._log(f" Final Type: {result.project_type}")
self._log(f" Final Key dirs: {result.key_directories}")
return result
# ========================================
# STEP 1: Analyze Question (concepts only)
# ========================================
def _analyze_question(self, query: str, repo_hint: str) -> QuestionAnalysis:
"""
Step 1: Analyze the question to extract HIGH-LEVEL CONCEPTS.
This step does NOT generate grep patterns - that's step 2.
Focus on understanding WHAT the user wants to find.
"""
self._log("Step 1: Analyzing question concepts...")
messages = [
{"role": "system", "content": prompts.QUESTION_ANALYSIS_SYSTEM},
{"role": "user", "content": prompts.QUESTION_ANALYSIS_USER.format(
repo_hint=repo_hint, query=query
)},
]
result = self._call_structured(messages, QuestionAnalysis)
# Validate that we got useful concepts
if not result.concepts:
self._log(" Warning: LLM returned empty concepts, using fallback")
words = [w for w in query.split() if len(w) > 3 and w.isalpha()]
result.concepts = words[:5] if words else ["implementation"]
self._log(f" Concepts: {result.concepts}")
self._log(f" User problem: {result.user_problem}")
self._log(f" What to find: {result.what_to_find}")
self._log(f" Question type: {result.question_type}")
return result
# ========================================
# Helper: Extract findings from file (NEW)
# ========================================
def _read_and_extract_findings(
self,
tools: 'SGRTools',
file_path: str,
) -> List[str]:
"""
Read a file and extract key definitions (class, function, etc.).
Returns list of finding strings.
"""
findings = []
result = tools.read_file(file_path)
if result.get("error"):
self._log(f" [Skip] Could not read {file_path}: {result.get('error')}")
return findings
content = result.get("content", "")
if not content:
return findings
extension = Path(file_path).suffix.lower()
# Language-specific patterns for key definitions
patterns = {
".py": [
(r"^class\s+(\w+)", "class"),
(r"^def\s+(\w+)", "function"),
(r"^\s{4}def\s+(\w+)", "method"),
],
".go": [
(r"^type\s+(\w+)\s+struct", "struct"),
(r"^type\s+(\w+)\s+interface", "interface"),
(r"^func\s+(\w+)\(", "function"),
(r"^func\s+\([^)]+\)\s+(\w+)\(", "method"),
],
".ts": [
(r"^(?:export\s+)?class\s+(\w+)", "class"),
(r"^(?:export\s+)?interface\s+(\w+)", "interface"),
(r"^(?:export\s+)?(?:async\s+)?function\s+(\w+)", "function"),
(r"^(?:export\s+)?type\s+(\w+)", "type"),
],
".js": [
(r"^class\s+(\w+)", "class"),
(r"^(?:async\s+)?function\s+(\w+)", "function"),
],
".java": [
(r"^(?:public|private|protected)?\s*(?:abstract\s+)?class\s+(\w+)", "class"),
(r"^(?:public|private|protected)?\s*interface\s+(\w+)", "interface"),
],
".qll": [
(r"^(?:private\s+)?class\s+(\w+)", "class"),
(r"^(?:private\s+)?predicate\s+(\w+)", "predicate"),
(r"^module\s+(\w+)", "module"),
],
".h": [
(r"^class\s+(\w+)", "class"),
(r"^struct\s+(\w+)", "struct"),
],
".cpp": [
(r"^(?:\w+\s+)+(\w+)::\w+\s*\(", "method"),
],
}
file_patterns = patterns.get(extension, [
(r"^class\s+(\w+)", "class"),
(r"^(?:def|func|function)\s+(\w+)", "function"),
])
# Remove line numbers from content (format: " 123|code")
lines = []
for line in content.split('\n')[:500]:
if '|' in line:
parts = line.split('|', 1)
if len(parts) == 2 and parts[0].strip().isdigit():
lines.append(parts[1])
else:
lines.append(line)
else:
lines.append(line)
for line_num, line in enumerate(lines, 1):
for pattern, kind in file_patterns:
match = re.match(pattern, line)
if match:
name = match.group(1)
findings.append(f"Found {kind} '{name}' in {file_path}:{line_num}")
self._log(f" [Read] {file_path}: {len(findings)} definitions found")
return findings[:30] # Limit
# ========================================
# STEP 2: Extract Technical Terms
# ========================================
def _get_naming_guide(self, language: str) -> str:
"""Get language-specific naming conventions guide."""
# Case-insensitive lookup (language can be "Python", "python", "PYTHON", etc.)
lang_lower = language.lower() if language else ""
guides = {
"python": prompts.NAMING_GUIDE_PYTHON,
"go": prompts.NAMING_GUIDE_GO,
"java": prompts.NAMING_GUIDE_JAVA,
"codeql": prompts.NAMING_GUIDE_DEFAULT, # CodeQL uses similar conventions
}
return guides.get(lang_lower, prompts.NAMING_GUIDE_DEFAULT)
def _extract_technical_terms(
self,
query: str,
repo_context: RepoContext,
analysis: QuestionAnalysis,
tools: 'SGRTools',
) -> TechnicalTerms:
"""
Step 2: Map concepts to actual technical identifiers.
IMPORTANT: First explores actual directory structure, then generates terms
based on REAL file/folder names, not invented ones.
"""
self._log("Step 2: Extracting technical terms...")
# === EXPLORATION: Get ACTUAL directory structure ===
self._log(" Exploring actual directory structure...")
# Get top-level structure
top_level = tools.list_dir(".", max_depth=2)
directory_structure = top_level.get("structure", "")[:2000]
# Explore key directories mentioned in repo context
files_listing_parts = []
for key_dir in repo_context.key_directories[:4]: # Limit to 4 dirs
dir_result = tools.list_dir(key_dir, max_depth=2)
if dir_result.get("structure"):
files_listing_parts.append(f"=== {key_dir} ===\n{dir_result['structure'][:800]}")
files_listing = "\n\n".join(files_listing_parts)[:3000]
self._log(f" Explored {len(files_listing_parts)} key directories")
# === Generate terms based on ACTUAL structure ===
language = repo_context.primary_language
framework = repo_context.framework or "unknown"
naming_guide = self._get_naming_guide(language)
system_prompt = prompts.TECHNICAL_TERMS_SYSTEM.format(
naming_guide=naming_guide,
)
# Format languages distribution
languages_dist = repo_context.languages or {}
languages_str = ", ".join(
f"{lang}: {count} files"
for lang, count in sorted(languages_dist.items(), key=lambda x: -x[1])[:5]
) if languages_dist else language
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompts.TECHNICAL_TERMS_USER.format(
repo_hint=framework,
language=language,
languages_distribution=languages_str,
framework=framework,
concepts=", ".join(analysis.concepts),
what_to_find=analysis.what_to_find,
directory_structure=directory_structure,
files_listing=files_listing,
)},
]
result = self._call_structured(messages, TechnicalTerms)
# Log results
primary = [t.term for t in result.primary_terms]
secondary = [t.term for t in result.secondary_terms]
self._log(f" Primary terms: {primary}")
self._log(f" Secondary terms: {secondary}")
self._log(f" Likely files: {result.likely_files}")
return result
# ========================================
# STEP 3: Create Search Plan
# ========================================
def _create_plan(
self,
query: str,
analysis: QuestionAnalysis,
technical_terms: Optional[TechnicalTerms] = None,
) -> SearchPlan:
"""
Step 3: Create a search plan using technical terms.
Uses TypedSearchPlan (with enforced required params) first,
falls back to legacy SearchPlan with smart fallback if needed.
"""
self._log("Step 3: Creating search plan...")
# Use technical terms if available, otherwise fallback to concepts
if technical_terms:
primary_terms = [t.term for t in technical_terms.primary_terms]
secondary_terms = [t.term for t in technical_terms.secondary_terms]
likely_files = technical_terms.likely_files
search_patterns = technical_terms.search_patterns
else:
# Fallback - extract terms from concepts
primary_terms = analysis.concepts[:5]
secondary_terms = []
likely_files = ["src/"]
search_patterns = []
# Try TypedSearchPlan first (enforces required params)
try:
result = self._create_typed_plan(
query=query,
primary_terms=primary_terms,
secondary_terms=secondary_terms,
search_patterns=search_patterns,
likely_files=likely_files,
concepts=analysis.concepts,
question_type=analysis.question_type,
)
if result and result.get_total_actions() > 0:
# Convert to legacy SearchPlan for execution
legacy_plan = SearchPlan.from_typed(result)
self._log(f" Strategy: {legacy_plan.strategy}")
self._log(f" Actions: {len(legacy_plan.actions)} (from TypedSearchPlan)")
for i, action in enumerate(legacy_plan.actions):
pattern = action.params.get('pattern', action.params.get('file_path', 'N/A'))
self._log(f" {i+1}. {action.tool}: {pattern} - {action.purpose[:50]}")
return legacy_plan
except Exception as e:
self._log(f" [Warning] TypedSearchPlan failed: {e}, using legacy fallback")
# Fallback to legacy SearchPlan with smart param filling
return self._create_legacy_plan_with_fallback(
query=query,
primary_terms=primary_terms,
secondary_terms=secondary_terms,
search_patterns=search_patterns,
likely_files=likely_files,
concepts=analysis.concepts,
question_type=analysis.question_type,
)
def _create_typed_plan(
self,
query: str,
primary_terms: List[str],
secondary_terms: List[str],
search_patterns: List[str],
likely_files: List[str],
concepts: List[str],
question_type: str,
) -> TypedSearchPlan:
"""
Create a TypedSearchPlan using the new schema with enforced required params.
TypedSearchPlan has separate lists for each action type, and each type
enforces its required fields (e.g., GrepAction requires 'pattern').
"""
messages = [
{"role": "system", "content": prompts.SEARCH_PLAN_SYSTEM},
{"role": "user", "content": prompts.SEARCH_PLAN_USER.format(
query=query,
primary_terms=", ".join(primary_terms),
secondary_terms=", ".join(secondary_terms),
search_patterns=", ".join(search_patterns),
likely_files=", ".join(likely_files) if likely_files else "(explore with list_dir)",
concepts=", ".join(concepts),
question_type=question_type,
)},
]
result = self._call_structured(messages, TypedSearchPlan)
# Validate that we got some actions
total = result.get_total_actions()
if total == 0:
self._log(" [Warning] TypedSearchPlan returned 0 actions, will use fallback")
raise ValueError("TypedSearchPlan returned 0 actions")
self._log(f" TypedSearchPlan: {len(result.grep_actions)} grep, "
f"{len(result.read_file_actions)} read, "
f"{len(result.list_dir_actions)} list, "
f"{len(result.glob_search_actions)} glob")
return result
def _create_legacy_plan_with_fallback(
self,
query: str,
primary_terms: List[str],
secondary_terms: List[str],
search_patterns: List[str],
likely_files: List[str],
concepts: List[str],
question_type: str,
) -> SearchPlan:
"""
Create a legacy SearchPlan with smart fallback for missing params.
This is the fallback when TypedSearchPlan fails or returns invalid results.
"""
self._log(" Using legacy SearchPlan with fallback...")
messages = [
{"role": "system", "content": prompts.SEARCH_PLAN_SYSTEM_LEGACY},
{"role": "user", "content": prompts.SEARCH_PLAN_USER_LEGACY.format(
query=query,
primary_terms=", ".join(primary_terms),
secondary_terms=", ".join(secondary_terms),
search_patterns=", ".join(search_patterns),
likely_files=", ".join(likely_files),
concepts=", ".join(concepts),
question_type=question_type,
)},
]
result = self._call_structured(messages, SearchPlan)
# SMART FALLBACK: Fill missing params intelligently
all_terms = primary_terms + secondary_terms
term_index = 0
for action in result.actions:
if action.tool == "grep":
if not action.params:
action.params = {}
if not action.params.get("pattern"):
# Try to extract pattern from purpose
pattern = self._extract_pattern_from_purpose(action.purpose, all_terms)
if pattern:
action.params["pattern"] = pattern
self._log(f" [Fallback] Extracted pattern from purpose: {pattern}")
elif term_index < len(all_terms):
# Use next available term
action.params["pattern"] = all_terms[term_index]
self._log(f" [Fallback] Set grep pattern to: {all_terms[term_index]}")
term_index += 1
else:
# Combine remaining primary terms
action.params["pattern"] = "|".join(primary_terms[:3]) if primary_terms else ".*"
self._log(f" [Fallback] Set grep pattern to combined: {action.params['pattern']}")
elif action.tool == "read_file":
if not action.params:
action.params = {}
if not action.params.get("file_path"):
# Try to extract file path from purpose or likely_files
file_path = self._extract_file_from_purpose(action.purpose, likely_files)
if file_path:
action.params["file_path"] = file_path
self._log(f" [Fallback] Extracted file_path from purpose: {file_path}")
elif likely_files:
# Use first likely file that looks like a file (not dir)
for lf in likely_files:
if '.' in lf.split('/')[-1]: # Has extension
action.params["file_path"] = lf
self._log(f" [Fallback] Set file_path to: {lf}")
break
elif action.tool == "glob_search":
if not action.params:
action.params = {}
if not action.params.get("pattern"):
# Try to extract glob pattern from purpose
pattern = self._extract_glob_from_purpose(action.purpose)
if pattern:
action.params["pattern"] = pattern
self._log(f" [Fallback] Extracted glob pattern: {pattern}")
else:
# Default to searching for common file types
action.params["pattern"] = "*.py" if "python" in query.lower() else "*"
self._log(f" [Fallback] Set glob pattern to: {action.params['pattern']}")
self._log(f" Strategy: {result.strategy}")
self._log(f" Actions: {len(result.actions)}")
for i, action in enumerate(result.actions):
pattern = action.params.get('pattern', action.params.get('file_path', 'N/A'))
self._log(f" {i+1}. {action.tool}: {pattern} - {action.purpose[:50]}")
return result
def _extract_pattern_from_purpose(self, purpose: str, available_terms: List[str]) -> Optional[str]:
"""Extract grep pattern from action purpose using NLP heuristics."""
purpose_lower = purpose.lower()
# Look for quoted terms in purpose
quoted = re.findall(r"['\"]([^'\"]+)['\"]", purpose)
if quoted:
return quoted[0]
# Look for terms after common phrases
for phrase in ["search for", "find", "look for", "locate", "grep for"]:
if phrase in purpose_lower:
idx = purpose_lower.index(phrase) + len(phrase)
rest = purpose[idx:].strip()
# Take first word or term
match = re.match(r"(\w+)", rest)
if match:
term = match.group(1)
# Verify it's meaningful (not just "the", "a", etc.)
if len(term) > 2 and term.lower() not in ["the", "and", "for", "how"]:
return term
# Look for available terms mentioned in purpose
for term in available_terms:
if term.lower() in purpose_lower:
return term
# Look for PascalCase or snake_case patterns
pascal = re.search(r"\b([A-Z][a-z]+(?:[A-Z][a-z]+)+)\b", purpose)
if pascal:
return pascal.group(1)
snake = re.search(r"\b([a-z]+_[a-z_]+)\b", purpose)
if snake:
return snake.group(1)
return None
def _extract_file_from_purpose(self, purpose: str, likely_files: List[str]) -> Optional[str]:
"""Extract file path from action purpose."""
# Look for file paths in purpose
file_match = re.search(r"([a-zA-Z0-9_/]+\.[a-z]{1,5})", purpose)
if file_match:
return file_match.group(1)
# Look for likely files mentioned
for lf in likely_files:
if lf in purpose:
return lf
# Check if file name part is mentioned
name = lf.split('/')[-1]
if name in purpose:
return lf
return None
def _extract_glob_from_purpose(self, purpose: str) -> Optional[str]:
"""Extract glob pattern from action purpose."""
# Look for glob patterns
glob_match = re.search(r"(\*[a-zA-Z0-9_.*]+)", purpose)
if glob_match:
return glob_match.group(1)
# Look for file extensions
ext_match = re.search(r"\.([a-z]{1,5})\b", purpose)
if ext_match:
return f"*.{ext_match.group(1)}"
# Common patterns
purpose_lower = purpose.lower()
if "python" in purpose_lower or ".py" in purpose_lower:
return "*.py"
if "typescript" in purpose_lower or ".ts" in purpose_lower:
return "*.ts"
if "javascript" in purpose_lower or ".js" in purpose_lower:
return "*.js"
if "go" in purpose_lower or ".go" in purpose_lower:
return "*.go"
if "java" in purpose_lower:
return "*.java"
return None
def _validate_plan_file_paths(
self,
tools: 'SGRTools',
plan: SearchPlan,
) -> SearchPlan:
"""
Validate file paths in read_file actions.
Uses rich validation to provide hints and attempt LLM correction
instead of just removing invalid paths.
"""
valid_actions = []
invalid_results = []
invalid_action_indices = []
for i, action in enumerate(plan.actions):
if action.tool == "read_file":
file_path = action.params.get("file_path", "")
if file_path:
result = tools.path_validator.validate_path_rich(file_path)
if result.is_valid and result.fixed_path:
# Path is valid (possibly with case fix)
if result.fixed_path != file_path:
self._log(f" [Path fix] {file_path} → {result.fixed_path}")
action.params["file_path"] = result.fixed_path
valid_actions.append(action)
elif result.is_directory and result.files_in_dir:
# It's a directory - log hint and convert to list_dir or pick a file
self._log(f" [Path validation] {file_path} is a DIRECTORY with files: {result.files_in_dir[:3]}")
# Try to find relevant file in directory
suggested_file = self._pick_relevant_file_from_dir(
result.files_in_dir, action.purpose
)
if suggested_file:
self._log(f" [Path fix] Using file from directory: {suggested_file}")
action.params["file_path"] = suggested_file
valid_actions.append(action)
else:
invalid_results.append(result)
invalid_action_indices.append(i)
else:
# Path is invalid - collect for LLM correction
self._log(f" [Path validation] Invalid: {result.get_hint_for_llm()}")
invalid_results.append(result)
invalid_action_indices.append(i)
else:
# No file_path - skip
pass
else:
valid_actions.append(action)
# If we have invalid paths, try to get LLM to correct them
if invalid_results:
self._log(f" [Path validation] {len(invalid_results)} invalid paths, asking LLM for correction...")
corrected_actions = self._ask_llm_to_fix_paths(invalid_results, plan.actions, invalid_action_indices)
for corrected in corrected_actions:
# Validate the corrected path too
result = tools.path_validator.validate_path_rich(corrected.params.get("file_path", ""))
if result.is_valid and result.fixed_path:
corrected.params["file_path"] = result.fixed_path
valid_actions.append(corrected)
self._log(f" [Path corrected] LLM fixed path to: {result.fixed_path}")
plan.actions = valid_actions
return plan
def _pick_relevant_file_from_dir(self, files: List[str], purpose: str) -> Optional[str]:
"""
Pick the most relevant file from a directory based on action purpose.
"""
if not files:
return None
purpose_lower = purpose.lower()
# Priority patterns
priority_patterns = [
# Entry points / main files
("__init__", 0.9),
("index", 0.8),
("main", 0.8),
("mod.rs", 0.8), # Rust
# Config
("config", 0.7),
("settings", 0.7),
]
# Look for keywords from purpose in filenames
best_file = None
best_score = 0.0
for f in files:
filename = f.split('/')[-1].lower()
score = 0.0
# Check priority patterns
for pattern, pattern_score in priority_patterns:
if pattern in filename:
score = max(score, pattern_score)
# Check if filename words appear in purpose
name_parts = filename.replace('_', ' ').replace('.', ' ').split()
for part in name_parts:
if len(part) > 2 and part in purpose_lower:
score = max(score, 0.6)
if score > best_score:
best_score = score
best_file = f
return best_file if best_score > 0.5 else (files[0] if files else None)
def _ask_llm_to_fix_paths(
self,
invalid_results: List, # List[PathValidationResult]
original_actions: List[SearchAction],
invalid_indices: List[int],
) -> List[SearchAction]:
"""
Ask LLM to correct invalid paths using hints.
"""
if not invalid_results:
return []
# Build hints for LLM
hints = self._path_validator_hints(invalid_results)
# Build context about what we're looking for
context_parts = []
for idx in invalid_indices:
if idx < len(original_actions):
action = original_actions[idx]
context_parts.append(f"- Purpose: {action.purpose}")
context_parts.append(f" Original path: {action.params.get('file_path', 'N/A')}")
context = "\n".join(context_parts)
prompt = f"""You provided file paths that don't exist in the repository.
{hints}
Context for what you were trying to find:
{context}
Please provide corrected file paths. Return ONLY the corrected paths, one per line.
Format: original_path -> corrected_path
If you cannot determine the correct path, respond with: original_path -> SKIP
"""
try:
messages = [
{"role": "system", "content": "You are helping to fix file paths. Be precise and only return paths that match the hints provided."},
{"role": "user", "content": prompt},
]
response = self._call_llm_text(messages)
# Parse response
corrected_actions = []
for line in response.strip().split('\n'):
if '->' in line:
parts = line.split('->')
if len(parts) == 2:
original = parts[0].strip()
corrected = parts[1].strip()
if corrected and corrected != "SKIP":
# Find the original action
for idx in invalid_indices:
if idx < len(original_actions):
action = original_actions[idx]
if action.params.get("file_path", "").strip() == original:
new_action = SearchAction(
tool="read_file",
params={"file_path": corrected},
purpose=action.purpose,
expected_result=action.expected_result,
)
corrected_actions.append(new_action)
break
return corrected_actions
except Exception as e:
self._log(f" [Path correction] LLM call failed: {e}")
return []
def _validate_next_action_path(
self,
tools: 'SGRTools',
action: SearchAction,
) -> Tuple[Optional[SearchAction], Optional[str]]:
"""
Validate file path in a single read_file action.
Returns:
Tuple of (action_or_None, hint_for_llm_or_None)
- If valid: (action with fixed path, None)
- If invalid: (None, hint string for LLM)
"""
if action.tool != "read_file":
return (action, None)
file_path = action.params.get("file_path", "")
if not file_path:
return (None, "No file_path provided")
result = tools.path_validator.validate_path_rich(file_path)
if result.is_valid and result.fixed_path:
if result.fixed_path != file_path:
self._log(f" [Path fix] {file_path} → {result.fixed_path}")
action.params["file_path"] = result.fixed_path
return (action, None)
# Invalid path - return hint for LLM
hint = result.get_hint_for_llm()
self._log(f" [Path validation] {hint}")
# Special case: it's a directory with files
if result.is_directory and result.files_in_dir:
suggested = self._pick_relevant_file_from_dir(result.files_in_dir, action.purpose)
if suggested:
self._log(f" [Path fix] Auto-selected file from directory: {suggested}")
action.params["file_path"] = suggested
return (action, None)
return (None, hint)
@staticmethod
def _path_validator_hints(invalid_results) -> str:
"""Generate hints summary from PathValidationResult list."""
lines = ["The following paths need correction:"]
for result in invalid_results:
lines.append(f"\n- {result.get_hint_for_llm()}")
return "\n".join(lines)
# ========================================
# STEP 4: Interpret Results
# ========================================
def _restore_full_paths_in_findings(
self,
findings: List[str],
result: Dict[str, Any],
tool: str,
) -> List[str]:
"""
Restore full file paths in findings that LLM may have shortened.
Problem: LLM often extracts just 'file.py' instead of 'src/core/file.py'
Solution: Match against known full paths from the original result
"""
if tool != "grep" or not result.get("matches"):
return findings
# Build mapping: filename -> full_path
filename_to_full_path: Dict[str, str] = {}
for match in result.get("matches", []):
full_path = match.get("file", "")
if full_path:
filename = Path(full_path).name
# Prefer longer paths (more specific)
if filename not in filename_to_full_path or len(full_path) > len(filename_to_full_path[filename]):
filename_to_full_path[filename] = full_path
if not filename_to_full_path:
return findings
restored = []
for finding in findings:
restored_finding = finding
# Try to expand short filenames to full paths
for filename, full_path in filename_to_full_path.items():
# Check if finding contains just the filename (not already full path)
# Pattern: 'filename' or filename: at start, or 'filename':
if full_path not in finding:
# Replace patterns like 'session-manager.js': or session-manager.js:
# Pattern matches: 'filename': or filename:
pattern = rf"['\"]?{re.escape(filename)}['\"]?:"
if re.search(pattern, finding):
# Replace with full path
restored_finding = re.sub(
pattern,
f"{full_path}:",
restored_finding,
count=1
)
restored.append(restored_finding)
return restored
def _interpret_result(
self,
query: str,
action: SearchAction,
result: Dict[str, Any],
previous_findings: List[str],
) -> SearchResultInterpretation:
"""
Step 3: Interpret a search result.
Uses structured output to decide next steps.
"""
self._log("Step 3: Interpreting result...")
result_text = self._format_tool_result(action.tool, result)
messages = [
{"role": "system", "content": prompts.INTERPRET_RESULT_SYSTEM},
{"role": "user", "content": prompts.INTERPRET_RESULT_USER.format(
query=query,
tool=action.tool,
purpose=action.purpose,
params=action.params,
result=result_text[:self.MAX_TOOL_RESULT_SIZE],
previous_findings=previous_findings[-5:] if previous_findings else 'None yet',
)},
]
interpretation = self._call_structured(messages, SearchResultInterpretation)
# FIX: Restore full paths in key_findings that LLM may have shortened
interpretation.key_findings = self._restore_full_paths_in_findings(
interpretation.key_findings,
result,
action.tool,
)
self._log(f" Found relevant: {interpretation.found_relevant}")
self._log(f" Key findings: {interpretation.key_findings}")
self._log(f" Confidence: {interpretation.confidence}")
self._log(f" Need more: {interpretation.next_action_needed}")
return interpretation
def _format_tool_result(self, tool: str, result: Dict[str, Any]) -> str:
"""Format tool result for LLM consumption."""
if result.get("error"):
return f"Error: {result['error']}"
if tool == "grep":
lines = [f"Found {result.get('total_matches', 0)} matches in {result.get('files_count', 0)} files:"]
# Show files if files_only mode
if result.get("files") and not result.get("matches"):
for f in result.get("files", [])[:30]:
lines.append(f" {f}")
else:
for match in result.get("matches", [])[:40]:
lines.append(f" {match.get('file')}:{match.get('line')}: {match.get('content', '')[:150]}")
if result.get("truncated"):
lines.append(" ... (results truncated)")
return "\n".join(lines)
elif tool == "read_file":
content = result.get('content', '')
return f"File content ({result.get('read_lines', 0)} lines):\n{content}"
elif tool == "list_dir":
return f"Directory structure ({result.get('files_count', 0)} files, {result.get('dirs_count', 0)} dirs):\n{result.get('structure', '')}"
elif tool == "glob_search":
files = result.get("files", [])
return f"Found {len(files)} files:\n" + "\n".join(f" {f}" for f in files[:30])
return str(result)
# ========================================
# STEP 5: Generate Final Answer
# ========================================
def _expand_short_paths_in_answer(
self,
answer: FinalAnswer,
all_findings: List[str],
iterations: List[SearchIteration],
) -> FinalAnswer:
"""
Expand short file paths in FinalAnswer using known full paths from findings.
Problem: LLM may use 'file.py' instead of 'src/core/file.py'
Solution: Search all_findings and iteration results for full paths
"""
# Build set of all known full paths from iterations
known_full_paths: Dict[str, str] = {} # filename -> full_path
# Common file extensions to look for
FILE_EXTS = r'\.(py|go|js|ts|jsx|tsx|java|kt|rs|cpp|hpp|c|h|cs|rb|php|swift|qll|ql|md|rst|txt|yml|yaml|json)'
for iteration in iterations:
if iteration.tool_result_summary:
# Extract paths from grep-style results: path/to/file.py:123: content
path_matches = re.findall(
rf'([a-zA-Z0-9_/.+-]+{FILE_EXTS})(?::\d+)?:',
iteration.tool_result_summary, re.MULTILINE
)
for full_path in path_matches:
if '/' in full_path: # Only consider paths with directories
filename = Path(full_path).name
if filename not in known_full_paths or len(full_path) > len(known_full_paths[filename]):
known_full_paths[filename] = full_path
# Also look for paths mentioned without line numbers (e.g., in read_file results)
path_matches2 = re.findall(
rf'(?:^|\s|"|\'|`)([a-zA-Z0-9_/.-]+/[a-zA-Z0-9_.-]+{FILE_EXTS})(?:\s|"|\'|`|$|:)',
iteration.tool_result_summary, re.MULTILINE
)
for full_path in path_matches2:
if '/' in full_path:
filename = Path(full_path).name
if filename not in known_full_paths or len(full_path) > len(known_full_paths[filename]):
known_full_paths[filename] = full_path
# Also extract from all_findings
for finding in all_findings:
# Pattern 1: path/to/file.ext:linenum:
path_matches = re.findall(rf'([a-zA-Z0-9_/.+-]+{FILE_EXTS}):\d+:', finding)
for full_path in path_matches:
if '/' in full_path:
filename = Path(full_path).name
if filename not in known_full_paths or len(full_path) > len(known_full_paths[filename]):
known_full_paths[filename] = full_path
# Pattern 2: path/to/file.ext (without line number)
path_matches2 = re.findall(rf'([a-zA-Z0-9_/.+-]+/[a-zA-Z0-9_/.+-]+{FILE_EXTS})', finding)
for full_path in path_matches2:
if '/' in full_path:
filename = Path(full_path).name
if filename not in known_full_paths or len(full_path) > len(known_full_paths[filename]):
known_full_paths[filename] = full_path
if not known_full_paths:
return answer
self._log(f" [Path expansion] Known paths: {len(known_full_paths)}")
# Expand short paths in code_locations
expanded_count = 0
for loc in answer.code_locations:
file_path = loc.file_path
# Check if it's a short path (no directory separator or just one level)
if '/' not in file_path or file_path.count('/') <= 1:
filename = Path(file_path).name
if filename in known_full_paths:
full_path = known_full_paths[filename]
if full_path != file_path:
self._log(f" [Path expansion] {file_path} → {full_path}")
loc.file_path = full_path
expanded_count += 1
else:
# Try partial match (e.g., "manager.js" matching "session-manager.js")
for known_name, known_path in known_full_paths.items():
if filename in known_name or known_name.endswith(filename):
self._log(f" [Path expansion partial] {file_path} → {known_path}")
loc.file_path = known_path
expanded_count += 1
break
expanded_count += 1
if expanded_count > 0:
self._log(f" Expanded {expanded_count} short paths to full paths")
return answer
def _generate_answer(
self,
query: str,
analysis: QuestionAnalysis,
iterations: List[SearchIteration],
all_findings: List[str],
technical_terms: Optional[TechnicalTerms] = None,
repo_context: Optional[RepoContext] = None,
) -> FinalAnswer:
"""
Step 5: Generate structured final answer.
Uses structured output to ensure proper format.
"""
self._log("Step 5: Generating final answer...")
# Compile iteration summaries
iteration_summaries = []
for it in iterations:
summary = f"- {it.action.tool}: {it.action.purpose}"
if it.interpretation:
summary += f" → {it.interpretation.key_findings}"
iteration_summaries.append(summary)
# Get technical terms searched
if technical_terms:
terms_searched = technical_terms.get_all_terms()
else:
terms_searched = analysis.concepts
# Get framework info
framework = repo_context.framework if repo_context else "unknown"
repo_hint = repo_context.framework or repo_context.primary_language if repo_context else "repository"
messages = [
{"role": "system", "content": prompts.FINAL_ANSWER_SYSTEM},
{"role": "user", "content": prompts.FINAL_ANSWER_USER.format(
query=query,
repo_hint=repo_hint,
framework=framework,
technical_terms=", ".join(terms_searched[:10]),
iteration_summaries=chr(10).join(iteration_summaries),
findings=chr(10).join(f'- {f}' for f in all_findings[-25:]),
)},
]
answer = self._call_structured(messages, FinalAnswer)
# FIX: Expand short paths in code_locations using known full paths
answer = self._expand_short_paths_in_answer(answer, all_findings, iterations)
self._log(f" Summary: {answer.summary[:100]}...")
self._log(f" Code locations: {len(answer.code_locations)}")
self._log(f" Confidence: {answer.confidence}")
return answer
# ========================================
# Decide Next Action (with parallel calls and retry)
# ========================================
def _decide_next_action(
self,
query: str,
iterations: List[SearchIteration],
all_findings: List[str],
path_hints: Optional[List[str]] = None,
retry_feedback: Optional[str] = None,
) -> NextActionDecision:
"""
Decide whether to continue searching and what to do next.
Uses parallel LLM calls and aggregates results to get better decisions.
Args:
path_hints: Hints about invalid paths that LLM tried.
retry_feedback: Feedback about why previous attempt failed (e.g., empty params).
"""
messages = self._build_next_action_messages(
query, iterations, all_findings, path_hints, retry_feedback
)
# Make parallel LLM calls for more robust results
decisions = self._parallel_llm_calls(messages, NextActionDecision, n_calls=3)
# Aggregate decisions - pick the best one
return self._aggregate_decisions(decisions)
def _build_next_action_messages(
self,
query: str,
iterations: List[SearchIteration],
all_findings: List[str],
path_hints: Optional[List[str]] = None,
retry_feedback: Optional[str] = None,
) -> List[Dict[str, str]]:
"""Build messages for NextActionDecision LLM call."""
prev_summaries = []
unexplored_leads = []
for it in iterations[-5:]:
summary = f"{it.action.tool}({it.action.params})"
if it.interpretation:
summary += f" → found_relevant={it.interpretation.found_relevant}, confidence={it.interpretation.confidence}"
if it.interpretation.new_leads:
unexplored_leads.extend(it.interpretation.new_leads)
prev_summaries.append(summary)
unexplored_leads = list(dict.fromkeys(unexplored_leads))[-5:]
# Build extra sections
extra_sections = ""
if path_hints:
extra_sections += f"""
IMPORTANT - Path Correction Needed:
The following paths you tried were invalid. Use the hints to pick correct paths:
{chr(10).join(f'- {hint}' for hint in path_hints)}
"""
if retry_feedback:
extra_sections += f"""
CRITICAL - Your previous response was invalid:
{retry_feedback}
You MUST provide valid parameters. For grep, you MUST provide 'pattern'. For read_file, you MUST provide 'file_path'.
"""
return [
{"role": "system", "content": prompts.NEXT_ACTION_SYSTEM},
{"role": "user", "content": prompts.NEXT_ACTION_USER.format(
query=query,
iteration_count=len(iterations),
prev_summaries=chr(10).join(prev_summaries),
findings=chr(10).join(f'- {f}' for f in all_findings[-10:]),
unexplored_leads=chr(10).join(f'- {lead}' for lead in unexplored_leads) if unexplored_leads else '- None',
remaining=self.max_iterations - len(iterations),
) + extra_sections},
]
def _parallel_llm_calls(
self,
messages: List[Dict[str, str]],
response_format: Type[BaseModel],
n_calls: int = 3,
) -> List[BaseModel]:
"""
Make N parallel LLM calls and return all results.
This helps get more diverse/robust results that can be aggregated.
"""
import concurrent.futures
results = []
def make_call():
try:
return self._call_structured(messages, response_format)
except Exception as e:
self._log(f" [Parallel call failed] {e}")
return None
# Run calls in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=n_calls) as executor:
futures = [executor.submit(make_call) for _ in range(n_calls)]
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result is not None:
results.append(result)
return results
def _aggregate_decisions(self, decisions: List[NextActionDecision]) -> NextActionDecision:
"""
Aggregate multiple NextActionDecision results.
Strategy:
1. Filter out invalid decisions (no params when needed)
2. Prefer decisions with valid actions
3. If all decisions say stop - return stop
"""
if not decisions:
# All calls failed - return safe default
return NextActionDecision(
should_continue=False,
reason="All LLM calls failed",
ready_for_answer=True,
)
# Separate valid and invalid decisions
valid_with_action = []
valid_stop = []
invalid = []
for d in decisions:
action, used_fallback = d._get_next_action_with_fallback_info()
if action:
if used_fallback:
self._log(f" [Used deprecated params dict] tool={action.tool}, params={action.params}")
valid_with_action.append(d)
elif d.ready_for_answer or not d.should_continue:
valid_stop.append(d)
else:
# Log why it's invalid - show flat fields
self._log(f" [Invalid] tool={d.next_action_tool}, grep_pattern={d.grep_pattern}, file_path={d.file_path}, glob_pattern={d.glob_pattern}")
invalid.append(d)
self._log(f" [Aggregation] {len(valid_with_action)} valid actions, {len(valid_stop)} stop, {len(invalid)} invalid")
# Prefer valid actions
if valid_with_action:
# If multiple valid actions, prefer grep over read_file (more exploratory)
grep_actions = [d for d in valid_with_action if d.next_action_tool == "grep"]
if grep_actions:
return grep_actions[0]
return valid_with_action[0]
# If ANY decision wanted to continue but had invalid params - RETRY
# This is more important than majority-stop because LLM showed intent to search
if invalid:
self._log(f" [Aggregation] Found {len(invalid)} decisions wanting to continue - will retry")
return invalid[0] # Will trigger retry with feedback
# Only if ALL decisions explicitly say stop
if valid_stop:
return valid_stop[0]
# Default to first available
return decisions[0]
def _add_findings(self, all_findings: List[str], new_findings: List[str]) -> None:
"""Add findings with deduplication and size limit."""
for finding in new_findings:
if finding and finding not in all_findings:
all_findings.append(finding)
# Trim if too large (keep most recent)
if len(all_findings) > self.MAX_FINDINGS:
del all_findings[:len(all_findings) - self.MAX_FINDINGS]
def _try_grep_fallback(
self,
tools: SGRTools,
file_path: str,
key_terms: List[str],
) -> str:
"""Try to get file content via grep when direct read fails."""
for term in key_terms[:3]:
# Escape special regex chars in term
escaped_term = re.escape(term)
grep_result = tools.grep(
pattern=escaped_term,
path=file_path,
context_lines=10,
)
if grep_result.get("matches"):
return "\n".join(
f"{m.get('file', '')}:{m.get('line', '')}: {m.get('content', '')}"
for m in grep_result["matches"][:10]
)
return ""
def _execute_single_iteration(
self,
tools: 'SGRTools',
query: str,
action: SearchAction,
iterations: List[SearchIteration],
all_findings: List[str],
) -> tuple[SearchIteration, float]:
"""
Execute a single search iteration (problem 4: reduce duplication).
Returns:
Tuple of (SearchIteration, tool_time_ms)
"""
self._log(f"Iteration {len(iterations) + 1}: {action.tool}")
# Execute tool
tool_start = time.time()
result = tools.execute_action(action)
tool_time_ms = (time.time() - tool_start) * 1000
# Interpret result
interpretation = self._interpret_result(query, action, result, all_findings)
# Record iteration
iteration = SearchIteration(
iteration_number=len(iterations) + 1,
action=action,
tool_result_summary=self._format_tool_result(action.tool, result)[:self.MAX_TOOL_RESULT_SIZE],
interpretation=interpretation,
next_step="ready_to_answer" if not interpretation.next_action_needed else "continue_search",
)
# Collect findings with deduplication
self._add_findings(all_findings, interpretation.key_findings)
return iteration, tool_time_ms
def _get_next_action_or_stop(
self,
query: str,
iterations: List[SearchIteration],
all_findings: List[str],
pending_actions: List[SearchAction],
tools: Optional['SGRTools'] = None,
path_hints: Optional[List[str]] = None,
) -> Optional[SearchAction]:
"""
Get next action to execute, or None if should stop.
Args:
path_hints: Accumulated hints about invalid paths for LLM
Returns:
SearchAction if should continue, None if should stop
"""
if path_hints is None:
path_hints = []
if pending_actions:
action = pending_actions.pop(0)
# Validate file paths for read_file actions
if tools and action.tool == "read_file":
validated, hint = self._validate_next_action_path(tools, action)
if not validated:
if hint:
path_hints.append(hint)
self._log(f" [Skip] Pending action has invalid path, trying next")
return self._get_next_action_or_stop(query, iterations, all_findings, pending_actions, tools, path_hints)
return validated
return action
# Decide what to do next with retry logic
max_retries = 2
retry_feedback = None
for attempt in range(max_retries + 1):
decision = self._decide_next_action(
query, iterations, all_findings, path_hints, retry_feedback
)
if decision.ready_for_answer:
self._log("LLM decided: ready to generate answer")
return None
action = decision.get_next_action()
if action:
# We have a valid action - validate paths for read_file
if tools and action.tool == "read_file":
validated, hint = self._validate_next_action_path(tools, action)
if not validated:
if hint:
self._log(f" [Path hint for LLM] {hint}")
path_hints.append(hint)
# Retry with path hint
retry_feedback = f"Path '{action.params.get('file_path')}' does not exist. {hint}"
continue
# Path invalid and no hint - retry
retry_feedback = f"Path '{action.params.get('file_path')}' does not exist."
continue
return validated
return action
# Check if LLM had intent but invalid params (retry case)
# This happens when has_next_action=True but params were empty
has_intent = decision.has_next_action and decision.next_action_tool
if has_intent and attempt < max_retries:
# LLM wanted to do something but didn't provide params - RETRY
tool = decision.next_action_tool or "unknown"
purpose = decision.next_action_purpose or "no purpose specified"
self._log(f" [Retry {attempt + 1}/{max_retries}] LLM had intent but missing required field")
self._log(f" tool={tool}, grep_pattern={decision.grep_pattern}, file_path={decision.file_path}")
if tool == "grep":
retry_feedback = f"ERROR: tool='grep' requires grep_pattern field! You must set grep_pattern='your_search_term'. Example: grep_pattern='middleware'"
elif tool == "read_file":
retry_feedback = f"ERROR: tool='read_file' requires file_path field! You must set file_path='path/to/file.py'. Example: file_path='src/main.py'"
elif tool == "glob_search":
retry_feedback = f"ERROR: tool='glob_search' requires glob_pattern field! You must set glob_pattern='*.py'. Example: glob_pattern='*handler*.py'"
elif tool == "list_dir":
retry_feedback = f"ERROR: tool='list_dir' requires list_dir_path field! You must set list_dir_path='src/'. Example: list_dir_path='src/core/'"
else:
retry_feedback = f"ERROR: Unknown tool='{tool}'. Use grep, read_file, list_dir, or glob_search."
continue
# No action AND no intent (or max retries reached) - check should_continue
if not decision.should_continue:
self._log(f"LLM decided to stop: {decision.reason}")
return None
# should_continue=True but still no valid action after retries
if attempt >= max_retries:
self._log(f" [Max retries reached] LLM failed to provide valid params after {max_retries} retries")
return None
return None
def _should_stop_early(self, interp: SearchResultInterpretation, iteration_num: int) -> bool:
"""Check if we should stop searching based on interpretation."""
# Don't stop too early - need at least 3 iterations to gather enough info
if iteration_num < 3:
return False
return interp.confidence == "high" and interp.found_relevant
def _verify_file_exists(self, tools: 'SGRTools', file_path: str) -> bool:
"""
Verify that a file exists before including in results (problem 6).
Returns:
True if file exists, False otherwise
"""
# Try to read first line to check existence
result = tools.read_file(file_path, start_line=1, end_line=1)
if result.get("error"):
error_msg = result.get("error", "").lower()
if "not found" in error_msg or "no such file" in error_msg or "does not exist" in error_msg:
return False
return True
# ========================================
# Main Search Method
# ========================================
def search(
self,
query: str,
repo_path: str,
path: Optional[str] = None,
) -> SearchResult:
"""
Perform semantic search using SGR.
"""
project_name = os.getenv("LANGSMITH_PROJECT", "semantic-search-sgr")
with ls.tracing_context(enabled=True, project_name=project_name):
return self._search_impl(query, repo_path, path)
@traceable(name="SGRSearch", run_type="chain")
def _search_impl(
self,
query: str,
repo_path: str,
path: Optional[str],
) -> SearchResult:
"""Main search implementation."""
start_time = time.time()
tool_time_total = 0.0
# Initialize tools (without scope filter initially)
tools = SGRTools(repo_path, verbose=self.verbose, log_func=self._log)
# Determine repo hint for LLM
repo_hint = Path(repo_path).name
if path:
repo_hint = f"{repo_hint}/{path}"
try:
# ========================================
# STEP 0: Analyze Repository Context
# ========================================
repo_context = self._analyze_repo_context(repo_path, tools)
# ========================================
# STEP 1: Analyze Question (concepts)
# ========================================
analysis = self._analyze_question(query, repo_hint)
# Update tools with language info
tools = SGRTools(
repo_path,
verbose=self.verbose,
log_func=self._log,
language=repo_context.primary_language,
)
# ========================================
# STEP 2: Extract Technical Terms (with exploration)
# ========================================
technical_terms = self._extract_technical_terms(query, repo_context, analysis, tools)
# NEW: Validate likely_files paths
if technical_terms.likely_files:
valid_paths, fixed_paths, invalid_paths = tools.path_validator.validate_likely_files(
technical_terms.likely_files
)
if invalid_paths:
self._log(f" [Path validation] Removed {len(invalid_paths)} invalid paths: {invalid_paths[:3]}...")
technical_terms.likely_files = valid_paths + fixed_paths
# NEW: Fallback to key_directories if likely_files are empty after validation
if not technical_terms.likely_files and repo_context.key_directories:
fallback_dirs = repo_context.key_directories[:3]
self._log(f" [Fallback] likely_files empty, using key_directories: {fallback_dirs}")
technical_terms.likely_files = fallback_dirs
# ========================================
# STEP 2.5: Enforce reading likely_files (NEW)
# ========================================
all_findings: List[str] = []
if technical_terms.likely_files:
self._log(f"Step 2.5: Reading {len(technical_terms.likely_files)} likely files...")
for lf_path in technical_terms.likely_files[:5]: # Limit to 5 files
# Check if it's a directory
if tools.path_validator.is_directory(lf_path):
files_in_dir = tools.path_validator.get_files_in_directory(lf_path, max_files=3)
for file_path in files_in_dir:
findings = self._read_and_extract_findings(tools, file_path)
all_findings.extend(findings)
else:
findings = self._read_and_extract_findings(tools, lf_path)
all_findings.extend(findings)
self._log(f" Extracted {len(all_findings)} findings from likely files")
# ========================================
# STEP 3: Create Initial Plan
# ========================================
plan = self._create_plan(query, analysis, technical_terms)
# Validate file paths in read_file actions (prevent hallucinated paths)
plan = self._validate_plan_file_paths(tools, plan)
# ========================================
# STEP 4: Execute Search Iterations
# ========================================
iterations: List[SearchIteration] = []
pending_actions = list(plan.actions)
early_stop = False
if self.parallel_initial_actions and len(pending_actions) > 1:
# Execute first batch in parallel
self._log(f"Executing {len(pending_actions)} initial actions in parallel...")
tool_start = time.time()
results = tools.execute_actions_parallel(pending_actions, max_workers=3)
tool_time_total += (time.time() - tool_start) * 1000
# Process parallel results
for action, result in zip(pending_actions, results):
interpretation = self._interpret_result(query, action, result, all_findings)
iteration = SearchIteration(
iteration_number=len(iterations) + 1,
action=action,
tool_result_summary=self._format_tool_result(action.tool, result)[:self.MAX_TOOL_RESULT_SIZE],
interpretation=interpretation,
next_step="ready_to_answer" if not interpretation.next_action_needed else "continue_search",
)
iterations.append(iteration)
self._add_findings(all_findings, interpretation.key_findings)
if self._should_stop_early(interpretation, len(iterations)):
self._log("High confidence with relevant findings - preparing answer")
early_stop = True
break
pending_actions = [] # Clear since processed
# Sequential execution loop (unified - problem 4 fix)
while not early_stop and len(iterations) < self.max_iterations:
# Check timeout
elapsed = time.time() - start_time
if elapsed >= self.total_timeout:
self._log(f"Timeout reached ({self.total_timeout}s)")
break
# Check for high confidence in previous iteration
if iterations and iterations[-1].interpretation:
if self._should_stop_early(iterations[-1].interpretation, len(iterations)):
self._log("High confidence with relevant findings - stopping search")
break
# Get next action (unified method)
action = self._get_next_action_or_stop(query, iterations, all_findings, pending_actions, tools)
if not action:
break
# Execute iteration (unified method)
iteration, tool_time_ms = self._execute_single_iteration(
tools, query, action, iterations, all_findings
)
iterations.append(iteration)
tool_time_total += tool_time_ms
# ========================================
# STEP 5: Generate Final Answer
# ========================================
answer = self._generate_answer(
query, analysis, iterations, all_findings,
technical_terms=technical_terms,
repo_context=repo_context,
)
# Convert to SearchResult format
items = []
seen_files = set() # Avoid duplicate files
for loc in answer.code_locations:
# Skip duplicates and placeholder entries
if loc.file_path in seen_files or loc.file_path == "(not found)":
continue
seen_files.add(loc.file_path)
# Verify file exists before processing (problem 6)
if not self._verify_file_exists(tools, loc.file_path):
self._log(f" Skipping non-existent file: {loc.file_path}")
continue
# Try to read actual file content
file_result = tools.read_file(loc.file_path)
content = file_result.get("content", "")
if not content and not file_result.get("error"):
# Empty content but no error - file might be empty
content = "(empty file)"
elif file_result.get("error"):
self._log(f" Warning: Could not read {loc.file_path}: {file_result.get('error')}")
# File read error - try grep as fallback using technical_terms
fallback_terms = technical_terms.get_all_terms() if technical_terms else analysis.concepts
content = self._try_grep_fallback(tools, loc.file_path, fallback_terms)
if not content:
# Still no content - try using relevant_elements from location
for elem in loc.relevant_elements[:3]:
grep_result = tools.grep(
pattern=re.escape(elem),
path=loc.file_path,
context_lines=5,
)
if grep_result.get("matches"):
content = "\n".join(
f"{m.get('file', '')}:{m.get('line', '')}: {m.get('content', '')}"
for m in grep_result["matches"][:10]
)
break
if content:
items.append(SearchItem(
file_path=loc.file_path,
content=content,
match_context=loc.relevance,
))
else:
# Include the file even without content, with relevance info
items.append(SearchItem(
file_path=loc.file_path,
content=f"(file exists but could not read content)\nRelevant elements: {', '.join(loc.relevant_elements)}",
match_context=loc.relevance,
))
total_time = (time.time() - start_time) * 1000
llm_time = total_time - tool_time_total
# Log detailed metrics
self._log(f"Search complete:")
self._log(f" Items found: {len(items)}")
self._log(f" Total iterations: {len(iterations)}")
self._log(f" Total time: {total_time:.0f}ms")
self._log(f" LLM time: {llm_time:.0f}ms")
self._log(f" Tool time: {tool_time_total:.0f}ms")
# Collect tools used stats
tools_used = {}
for it in iterations:
tool = it.action.tool
tools_used[tool] = tools_used.get(tool, 0) + 1
self._log(f" Tools used: {tools_used}")
# Collect confidence progression
confidences = [it.interpretation.confidence for it in iterations if it.interpretation]
if confidences:
self._log(f" Confidence progression: {' → '.join(confidences)}")
# Find when high confidence was reached
for i, it in enumerate(iterations):
if it.interpretation and it.interpretation.confidence == "high":
self._log(f" High confidence reached at iteration: {i + 1}")
break
return SearchResult(
items=items,
patterns_used=[it.action.params.get("pattern", str(it.action.params)) for it in iterations if it.action.params],
execution_time_ms=llm_time,
total_time_ms=total_time,
tool_time_ms=tool_time_total,
)
except Exception as e:
self._log(f"Error: {e}")
import traceback
traceback.print_exc()
total_time = (time.time() - start_time) * 1000
return SearchResult(
items=[],
execution_time_ms=total_time,
total_time_ms=total_time,
error=str(e),
)
# ============================================
# Convenience Classes - OpenAI
# ============================================
class SGRSearcherGPT4o(SGRSearcher):
"""SGR searcher using GPT-4o (best quality)."""
def __init__(
self,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.OPENAI,
model="gpt-4o-2024-08-06",
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)
class SGRSearcherGPT4oMini(SGRSearcher):
"""SGR searcher using GPT-4o-mini (faster, cheaper)."""
def __init__(
self,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.OPENAI,
model="gpt-4o-mini",
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)
# ============================================
# Convenience Classes - Gemini
# ============================================
class SGRSearcherGemini(SGRSearcher):
"""SGR searcher using Gemini (default: gemini-2.5-flash-lite)."""
def __init__(
self,
model: str = "gemini-2.5-flash-lite",
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.GEMINI,
model=model,
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)
class SGRSearcherGeminiFlashLite(SGRSearcher):
"""SGR searcher using Gemini Flash Lite (fastest, cheapest)."""
def __init__(
self,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.GEMINI,
model="gemini-2.5-flash-lite",
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)
class SGRSearcherGeminiFlash(SGRSearcher):
"""SGR searcher using Gemini Flash (balanced speed/quality)."""
def __init__(
self,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.GEMINI,
model="gemini-2.5-flash",
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)
class SGRSearcherGeminiPro(SGRSearcher):
"""SGR searcher using Gemini 3 Pro Preview (highest quality)."""
def __init__(
self,
max_iterations: int = 8,
total_timeout: float = 120.0,
verbose: bool = False,
):
super().__init__(
provider=LLMProvider.GEMINI,
model="gemini-3-pro-preview",
max_iterations=max_iterations,
total_timeout=total_timeout,
verbose=verbose,
)