"""
Integration helpers for SGR improvements.
This module provides integration code for:
1. PathValidator - validating LLM-generated paths
2. PatternValidator - validating grep patterns
3. Enforce likely_files reading
Usage in SGRSearcherGemini:
from .sgr.integration import (
apply_path_validation,
apply_pattern_validation,
enforce_likely_files,
)
"""
import re
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from .path_validator import PathValidator
from .pattern_validator import PatternValidator, PatternValidation
@dataclass
class LikelyFileResult:
"""Result from reading a likely file."""
file_path: str
success: bool
definitions_found: List[str]
error: Optional[str] = None
class SGRIntegration:
"""
Integration layer for SGR improvements.
Provides validation and filtering for the SGR search pipeline.
"""
def __init__(
self,
repo_path: str,
language: Optional[str] = None,
verbose: bool = False,
):
"""
Initialize integration.
Args:
repo_path: Path to repository
language: Primary language of the repo
verbose: Whether to log operations
"""
self.repo_path = Path(repo_path)
self.language = language
self.verbose = verbose
# Initialize validators
self.path_validator = PathValidator(str(repo_path))
def _log(self, message: str) -> None:
"""Log a message if verbose."""
if self.verbose:
print(f"[SGR-Integration] {message}")
# ========================================
# Path Validation
# ========================================
def validate_likely_files(
self,
paths: List[str],
) -> Tuple[List[str], Dict[str, str]]:
"""
Validate likely_files from TechnicalTerms.
Returns:
Tuple of:
- Valid/fixed paths to use
- Dict of {original: error} for invalid paths
"""
valid = []
errors = {}
for path in paths:
is_valid, fixed_path, error = self.path_validator.validate_path(path)
if is_valid and fixed_path:
valid.append(fixed_path)
if error: # Was auto-fixed
self._log(f"Path fixed: {path} -> {fixed_path}")
else:
errors[path] = error or "Unknown error"
self._log(f"Path invalid: {path} - {error}")
return valid, errors
def validate_file_path(self, path: str) -> Tuple[bool, Optional[str]]:
"""
Validate a single file path.
Returns:
Tuple of (is_valid, fixed_path_or_None)
"""
is_valid, fixed, error = self.path_validator.validate_path(path)
if is_valid:
return True, fixed
return False, None
# ========================================
# Pattern Validation
# ========================================
def validate_pattern(
self,
pattern: str,
context: Optional[Dict[str, Any]] = None,
) -> PatternValidation:
"""
Validate a grep pattern.
Args:
pattern: The pattern to validate
context: Optional context dict with 'concepts', etc.
Returns:
PatternValidation result
"""
return PatternValidator.validate_pattern(
pattern,
language=self.language,
context=context,
)
def get_safe_pattern(
self,
pattern: str,
context: Optional[Dict[str, Any]] = None,
) -> str:
"""
Get a safe pattern to use.
If pattern is too broad, returns a transformed/suggested pattern.
"""
return PatternValidator.get_safe_pattern(
pattern,
language=self.language,
context=context,
fallback_to_original=True,
)
# ========================================
# Enforce Likely Files Reading
# ========================================
def read_likely_files(
self,
paths: List[str],
read_file_func, # Callable that reads a file
max_files: int = 5,
) -> List[LikelyFileResult]:
"""
Read likely_files and extract key definitions.
Args:
paths: List of likely file paths
read_file_func: Function(path) -> Dict with 'content' key
max_files: Maximum files to read
Returns:
List of LikelyFileResult with extracted definitions
"""
results = []
# Validate paths first
valid_paths, _ = self.validate_likely_files(paths)
for path in valid_paths[:max_files]:
# Check if it's a directory
if self.path_validator.is_directory(path):
# Get files in directory
files = self.path_validator.get_files_in_directory(path, max_files=3)
for file_path in files:
result = self._read_and_extract(file_path, read_file_func)
results.append(result)
else:
result = self._read_and_extract(path, read_file_func)
results.append(result)
return results
def _read_and_extract(
self,
file_path: str,
read_file_func,
) -> LikelyFileResult:
"""Read a single file and extract definitions."""
try:
result = read_file_func(file_path)
if result.get("error"):
return LikelyFileResult(
file_path=file_path,
success=False,
definitions_found=[],
error=result.get("error"),
)
content = result.get("content", "")
definitions = self._extract_definitions(file_path, content)
self._log(f"Read {file_path}: {len(definitions)} definitions found")
return LikelyFileResult(
file_path=file_path,
success=True,
definitions_found=definitions,
)
except Exception as e:
return LikelyFileResult(
file_path=file_path,
success=False,
definitions_found=[],
error=str(e),
)
def _extract_definitions(
self,
file_path: str,
content: str,
) -> List[str]:
"""
Extract key definitions from file content.
Heuristic extraction - no LLM call.
"""
findings = []
extension = Path(file_path).suffix.lower()
# Language-specific patterns
patterns = {
".py": [
(r"^class\s+(\w+)", "class"),
(r"^def\s+(\w+)", "function"),
(r"^\s{4}def\s+(\w+)", "method"),
],
".go": [
(r"^type\s+(\w+)\s+struct", "struct"),
(r"^type\s+(\w+)\s+interface", "interface"),
(r"^func\s+(\w+)\(", "function"),
(r"^func\s+\([^)]+\)\s+(\w+)\(", "method"),
],
".ts": [
(r"^(?:export\s+)?class\s+(\w+)", "class"),
(r"^(?:export\s+)?interface\s+(\w+)", "interface"),
(r"^(?:export\s+)?(?:async\s+)?function\s+(\w+)", "function"),
(r"^(?:export\s+)?type\s+(\w+)", "type"),
],
".js": [
(r"^class\s+(\w+)", "class"),
(r"^(?:async\s+)?function\s+(\w+)", "function"),
(r"^const\s+(\w+)\s*=\s*(?:async\s+)?\(", "arrow_function"),
],
".java": [
(r"^(?:public|private|protected)?\s*(?:abstract\s+)?class\s+(\w+)", "class"),
(r"^(?:public|private|protected)?\s*interface\s+(\w+)", "interface"),
(r"^\s+(?:public|private|protected)?\s*(?:static\s+)?[\w<>,\s]+\s+(\w+)\s*\(", "method"),
],
".qll": [
(r"^(?:private\s+)?class\s+(\w+)", "class"),
(r"^(?:private\s+)?predicate\s+(\w+)", "predicate"),
(r"^module\s+(\w+)", "module"),
],
".h": [
(r"^class\s+(\w+)", "class"),
(r"^struct\s+(\w+)", "struct"),
],
".cpp": [
(r"^(?:\w+\s+)+(\w+)::\w+\s*\(", "method"),
],
}
file_patterns = patterns.get(extension, [])
if not file_patterns:
# Fallback: any extension
file_patterns = [
(r"^class\s+(\w+)", "class"),
(r"^(?:def|func|function)\s+(\w+)", "function"),
]
lines = content.split('\n')
for line_num, line in enumerate(lines[:500], 1): # Limit to first 500 lines
for pattern, kind in file_patterns:
match = re.match(pattern, line)
if match:
name = match.group(1)
findings.append(f"{kind} '{name}' in {file_path}:{line_num}")
return findings[:30] # Limit findings
# ========================================
# Convenience functions for direct use
# ========================================
def create_integration(
repo_path: str,
language: Optional[str] = None,
verbose: bool = False,
) -> SGRIntegration:
"""Create an SGR integration instance."""
return SGRIntegration(
repo_path=repo_path,
language=language,
verbose=verbose,
)
def apply_path_validation(
paths: List[str],
repo_path: str,
) -> Tuple[List[str], Dict[str, str]]:
"""
Validate a list of paths against repository.
Returns:
Tuple of (valid_paths, {invalid_path: error})
"""
validator = PathValidator(repo_path)
return validator.validate_likely_files(paths)
def apply_pattern_validation(
pattern: str,
language: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
) -> PatternValidation:
"""Validate a grep pattern."""
return PatternValidator.validate_pattern(pattern, language, context)
def enforce_likely_files(
paths: List[str],
repo_path: str,
read_file_func,
language: Optional[str] = None,
max_files: int = 5,
) -> Tuple[List[str], List[LikelyFileResult]]:
"""
Enforce reading of likely_files.
Returns:
Tuple of:
- List of findings strings
- List of LikelyFileResult
"""
integration = SGRIntegration(
repo_path=repo_path,
language=language,
)
results = integration.read_likely_files(paths, read_file_func, max_files)
# Flatten findings
all_findings = []
for result in results:
if result.success:
all_findings.extend(result.definitions_found)
return all_findings, results