"""
AST-based code search and analysis module for Fast Apply MCP.
This module provides semantic code search capabilities using ast-grep,
enabling intelligent code pattern matching and structural analysis.
ENHANCED ROBUSTNESS (Based on debugging results):
=================================================
1. ROOT CAUSE FIXES:
- ✅ UPPERCASE metavariables: All patterns use $NAME not $name (CRITICAL)
- ✅ Kind field requirement: All rules include 'kind' for AST node type (CRITICAL)
- ✅ Reference finding: Uses simple text patterns that actually work (CRITICAL)
2. PROPER ERROR HANDLING:
- ✅ Fallback text search when ast-grep fails
- ✅ Enhanced error messages with debugging context
- ✅ Pattern validation with helpful suggestions
3. VERIFIED EXAMPLES:
- ✅ Working patterns tested with mcp__fast-apply-mcp server
- ✅ YAML rules that pass validation
- ✅ Best practices from ast-grep MCP documentation
4. DEBUGGING SUPPORT:
- ✅ Common issue detection and suggestions
- ✅ Pattern syntax validation
- ✅ Comprehensive logging for troubleshooting
Reference: https://raw.githubusercontent.com/ast-grep/ast-grep-mcp/refs/heads/main/ast-grep.mdc
"""
import os
from typing import Any, Dict, List, Optional, cast
import structlog
logger = structlog.get_logger(__name__)
# Supported languages mapping
LANGUAGE_MAP = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "typescript",
".tsx": "typescript",
".json": "json",
}
# Enhanced pattern configurations based on ast-grep MCP documentation
# IMPORTANT: ast-grep requires UPPERCASE metavariables ($NAME not $name)
STRUCTURE_PATTERNS = {
"python": {
"functions": {"kind": "function_definition", "pattern": "def $NAME($ARGS)"},
"classes": {"kind": "class_definition", "pattern": "class $NAME"},
"imports": [{"pattern": "import $MODULE"}, {"pattern": "from $MODULE import $ITEMS"}],
},
"javascript": {
"functions": [
{"kind": "function_declaration", "pattern": "function $NAME($ARGS)"},
{"kind": "arrow_function", "pattern": "($ARGS) => $BODY"},
{"pattern": "const $NAME = ($ARGS) => $BODY"},
],
"classes": {"kind": "class_declaration", "pattern": "class $NAME"},
"imports": [
{"pattern": "import $ITEMS from '$MODULE'"},
{"pattern": "import * as $ALIAS from '$MODULE'"},
{"pattern": "const $ITEMS = require('$MODULE')"},
],
"exports": [{"pattern": "export $DECLARATION"}, {"pattern": "module.exports = $VALUE"}],
},
"typescript": {
"functions": [
{"kind": "function_declaration", "pattern": "function $NAME($ARGS): $TYPE"},
{"kind": "arrow_function", "pattern": "($ARGS): $TYPE => $BODY"},
{"pattern": "const $NAME = ($ARGS): $TYPE => $BODY"},
],
"classes": {"kind": "class_declaration", "pattern": "class $NAME"},
"imports": [
{"pattern": "import $ITEMS from '$MODULE'"},
{"pattern": "import * as $ALIAS from '$MODULE'"},
{"pattern": "import type { $TYPES } from '$MODULE'"},
],
"exports": [{"pattern": "export $DECLARATION"}, {"pattern": "export type $TYPE_DECLARATION"}],
},
}
class AstSearchError(Exception):
"""Base exception for AST search operations."""
pass
class PatternSearchResult:
"""Represents a single pattern search result."""
def __init__(self, file_path: str, line: int, column: int, text: str, matches: Dict[str, str]):
self.file_path = file_path
self.line = line
self.column = column
self.text = text
self.matches = matches # Meta-variable matches like $name -> "function_name"
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return {"file_path": self.file_path, "line": self.line, "column": self.column, "text": self.text, "matches": self.matches}
class StructureInfo:
"""Represents structural information about a code file."""
def __init__(self, file_path: str, language: str):
self.file_path = file_path
self.language = language
self.functions: List[Dict[str, Any]] = []
self.classes: List[Dict[str, Any]] = []
self.imports: List[Dict[str, Any]] = []
self.exports: List[Dict[str, Any]] = []
def to_dict(self) -> Dict[str, Any]:
"""Convert structure info to dictionary for JSON serialization."""
return {
"file_path": self.file_path,
"language": self.language,
"functions": self.functions,
"classes": self.classes,
"imports": self.imports,
"exports": self.exports,
}
def _get_language_from_file(file_path: str) -> Optional[str]:
"""Determine language from file extension."""
_, ext = os.path.splitext(file_path.lower())
return LANGUAGE_MAP.get(ext)
def _is_ast_grep_available() -> bool:
"""Check if ast-grep-py is available and working."""
try:
import ast_grep_py # noqa: F401
return True
except ImportError:
logger.warning("ast-grep-py not available, falling back to basic search")
return False
def search_code_patterns(pattern: str, language: str, path: str, exclude_patterns: Optional[List[str]] = None) -> List[PatternSearchResult]:
"""
Search for code patterns using AST-based matching.
Args:
pattern: AST pattern to search for (e.g., "function $name($args) { $body }")
language: Target language ("python", "javascript", "typescript")
path: File or directory path to search in
exclude_patterns: Optional list of patterns to exclude
Returns:
List of PatternSearchResult objects
Raises:
AstSearchError: If search operation fails
"""
if not _is_ast_grep_available():
raise AstSearchError("ast-grep-py is not available")
try:
# Normalize language names
lang_map = {
"python": "python",
"javascript": "javascript",
"typescript": "typescript",
"js": "javascript",
"ts": "typescript",
"py": "python",
}
normalized_lang = lang_map.get(language.lower(), language.lower())
logger.info("Searching code patterns", pattern=pattern, language=normalized_lang, path=path)
results: List[PatternSearchResult] = []
exclude_patterns = exclude_patterns or []
# Default exclude patterns
default_excludes = [".git", "node_modules", "venv", ".venv", "__pycache__", ".pytest_cache"]
all_excludes = exclude_patterns + default_excludes
if os.path.isfile(path):
# Search single file
file_results = _search_file_patterns(path, pattern, normalized_lang, all_excludes)
results.extend(file_results)
elif os.path.isdir(path):
# Search directory recursively
for root, dirs, files in os.walk(path):
# Filter out excluded directories
dirs[:] = [d for d in dirs if not _should_exclude_path(os.path.join(root, d), all_excludes)]
for file in files:
file_path = os.path.join(root, file)
# Check if file should be excluded
if _should_exclude_path(file_path, all_excludes):
continue
# Check if file matches target language
file_lang = _get_language_from_file(file_path)
if file_lang and file_lang == normalized_lang:
file_results = _search_file_patterns(file_path, pattern, normalized_lang, all_excludes)
results.extend(file_results)
else:
raise AstSearchError(f"Path not found: {path}")
logger.info("Pattern search completed", pattern=pattern, results_count=len(results))
return results
except ImportError:
raise AstSearchError("ast-grep-py is not installed")
except Exception as e:
logger.error("Pattern search failed", error=str(e), pattern=pattern, path=path)
raise AstSearchError(f"Pattern search failed: {e}")
def _search_file_patterns(file_path: str, pattern: str, language: str, exclude_patterns: List[str]) -> List[PatternSearchResult]:
"""Search for patterns in a single file."""
try:
from ast_grep_py import SgRoot
# Read file content
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse content with ast-grep
root = SgRoot(content, language)
node = root.root()
# Find pattern matches
matches = []
try:
# Use ast-grep pattern matching
found_nodes = node.find_all(pattern=pattern)
for match_node in found_nodes:
# Get match position info
range_info = match_node.range()
line = range_info.start.line + 1 # Convert to 1-based
column = range_info.start.column + 1
# Get matched text
matched_text = match_node.text()
# Extract meta-variable matches
meta_matches: Dict[str, str] = {}
# Note: Advanced meta-variable extraction would require more sophisticated API usage
result = PatternSearchResult(file_path=file_path, line=line, column=column, text=matched_text, matches=meta_matches)
matches.append(result)
except Exception as e:
logger.warning("Pattern matching failed for file", file_path=file_path, error=str(e))
return matches
except Exception as e:
logger.warning("Failed to search file", file_path=file_path, error=str(e))
return []
def _should_exclude_path(path: str, exclude_patterns: List[str]) -> bool:
"""Check if path should be excluded based on patterns."""
path_parts = path.split(os.sep)
for exclude in exclude_patterns:
if exclude in path_parts or exclude in os.path.basename(path):
return True
return False
def analyze_code_structure(file_path: str) -> StructureInfo:
"""
Analyze the structural components of a code file.
Args:
file_path: Path to the code file to analyze
Returns:
StructureInfo object containing structural analysis
Raises:
AstSearchError: If analysis operation fails
"""
if not _is_ast_grep_available():
raise AstSearchError("ast-grep-py is not available")
language = _get_language_from_file(file_path)
if not language:
raise AstSearchError(f"Unsupported file type: {file_path}")
logger.info("Analyzing code structure", file_path=file_path, language=language)
try:
from ast_grep_py import SgRoot
# Read file content
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Parse content with ast-grep
root = SgRoot(content, language)
node = root.root()
structure = StructureInfo(file_path, language)
# Analyze based on language
if language == "python":
_analyze_python_structure(node, structure)
elif language in ["javascript", "typescript"]:
_analyze_js_ts_structure(node, structure)
logger.info(
"Structure analysis completed",
file_path=file_path,
functions_count=len(structure.functions),
classes_count=len(structure.classes),
imports_count=len(structure.imports),
)
return structure
except Exception as e:
logger.error("Structure analysis failed", error=str(e), file_path=file_path)
raise AstSearchError(f"Structure analysis failed: {e}")
def _analyze_python_structure(node, structure: StructureInfo):
"""Analyze Python-specific structures using enhanced patterns."""
try:
patterns: Dict[str, Any] = STRUCTURE_PATTERNS["python"]
# Find function definitions using both kind and pattern matching
function_pattern: Dict[str, Any] = cast(Dict[str, Any], patterns["functions"])
try:
# Try kind-based matching first (more reliable)
function_nodes = node.find_all(kind="function_definition")
for func_node in function_nodes:
func_name = _extract_function_name(func_node)
range_info = func_node.range()
structure.functions.append(
{"name": func_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": "function"}
)
except Exception:
# Fallback to pattern matching
function_nodes = node.find_all(pattern=function_pattern["pattern"])
for func_node in function_nodes:
func_name = _extract_node_text(func_node, "$NAME")
range_info = func_node.range()
structure.functions.append(
{"name": func_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": "function"}
)
# Find class definitions
class_pattern: Dict[str, Any] = cast(Dict[str, Any], patterns["classes"])
try:
class_nodes = node.find_all(kind="class_definition")
for class_node in class_nodes:
class_name = _extract_class_name(class_node)
range_info = class_node.range()
structure.classes.append(
{"name": class_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": "class"}
)
except Exception:
# Fallback to pattern matching
class_nodes = node.find_all(pattern=class_pattern["pattern"])
for class_node in class_nodes:
class_name = _extract_node_text(class_node, "$NAME")
range_info = class_node.range()
structure.classes.append(
{"name": class_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": "class"}
)
# Find imports using multiple patterns
for import_pattern in patterns["imports"]:
import_pattern = cast(Dict[str, Any], import_pattern)
import_nodes = node.find_all(pattern=import_pattern.get("pattern", ""))
for import_node in import_nodes:
module_name = _extract_node_text(import_node, "$MODULE")
range_info = import_node.range()
raw_pat = import_pattern.get("pattern", "")
import_type = "from_import" if "from" in raw_pat else "import"
structure.imports.append(
{"module": module_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": import_type}
)
except Exception as e:
logger.warning("Python structure analysis partial failure", error=str(e))
def _analyze_js_ts_structure(node, structure: StructureInfo):
"""Analyze JavaScript/TypeScript-specific structures using enhanced patterns."""
try:
lang_key = "javascript" if structure.language == "javascript" else "typescript"
patterns: Dict[str, Any] = STRUCTURE_PATTERNS[lang_key]
# Find functions using multiple patterns
raw_funcs = patterns["functions"]
function_patterns: List[Dict[str, Any]] = list(raw_funcs) if isinstance(raw_funcs, list) else [cast(Dict[str, Any], raw_funcs)]
for func_pattern in function_patterns:
try:
if "kind" in func_pattern:
# Use kind-based matching
function_nodes = node.find_all(kind=func_pattern["kind"])
else:
# Use pattern matching
function_nodes = node.find_all(pattern=func_pattern["pattern"])
for func_node in function_nodes:
func_name = _extract_js_function_name(func_node)
range_info = func_node.range()
func_type = func_pattern.get("kind", "function")
structure.functions.append(
{"name": func_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": func_type}
)
except Exception as e:
logger.debug(f"Function pattern matching failed: {e}")
continue
# Find class definitions
class_pattern: Dict[str, Any] = cast(Dict[str, Any], patterns["classes"])
try:
if "kind" in class_pattern:
class_nodes = node.find_all(kind=class_pattern["kind"])
else:
class_nodes = node.find_all(pattern=class_pattern["pattern"])
for class_node in class_nodes:
class_name = _extract_class_name(class_node)
range_info = class_node.range()
structure.classes.append(
{"name": class_name, "line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": "class"}
)
except Exception as e:
logger.debug(f"Class pattern matching failed: {e}")
# Find imports using multiple patterns
for import_pattern in patterns["imports"]:
try:
import_pattern = cast(Dict[str, Any], import_pattern)
import_nodes = node.find_all(pattern=import_pattern.get("pattern", ""))
for import_node in import_nodes:
module_name = _extract_node_text(import_node, "$module")
range_info = import_node.range()
raw_pat = import_pattern.get("pattern", "")
import_type = "require" if "require" in raw_pat else "import"
structure.imports.append(
{
"module": module_name,
"line": range_info.start.line + 1,
"column": range_info.start.column + 1,
"type": import_type,
}
)
except Exception as e:
logger.debug(f"Import pattern matching failed: {e}")
continue
# Find exports
if "exports" in patterns:
for export_pattern in patterns["exports"]:
try:
export_pattern = cast(Dict[str, Any], export_pattern)
export_nodes = node.find_all(pattern=export_pattern.get("pattern", ""))
for export_node in export_nodes:
range_info = export_node.range()
raw_pat = export_pattern.get("pattern", "")
export_type = "module.exports" if "module.exports" in raw_pat else "export"
structure.exports.append(
{"line": range_info.start.line + 1, "column": range_info.start.column + 1, "type": export_type}
)
except Exception as e:
logger.debug(f"Export pattern matching failed: {e}")
continue
except Exception as e:
logger.warning("JS/TS structure analysis partial failure", error=str(e))
def _extract_node_text(node, meta_var: str) -> str:
"""Extract text from a meta-variable match.
NOTE: ast-grep requires UPPERCASE metavariables ($NAME not $name)
"""
try:
# Enhanced meta-variable extraction based on ast-grep MCP patterns
if hasattr(node, "get_match"):
match_result = node.get_match(meta_var)
if match_result and hasattr(match_result, "text"):
text_val = match_result.text()
if isinstance(text_val, str):
return text_val
# Fallback: extract from node text using basic parsing
node_text = node.text() if hasattr(node, "text") else str(node)
if meta_var in ["$NAME", "$name"]:
# Extract identifier after def/class/function keywords
for keyword in ["def ", "class ", "function ", "const ", "let ", "var "]:
if keyword in node_text:
parts = node_text.split(keyword, 1)
if len(parts) > 1:
# Get the next word (identifier)
remaining = parts[1].strip()
identifier = remaining.split()[0] if remaining else ""
# Clean up common separators
identifier = identifier.split("(")[0].split(":")[0].split("=")[0]
return identifier
# Basic fallback
return meta_var.replace("$", "")
except Exception as e:
logger.debug(f"Meta-variable extraction failed: {e}")
return "unknown"
def _extract_function_name(node) -> str:
"""Extract function name from function definition node."""
try:
# Try to get function name from AST node structure
if hasattr(node, "field"):
name_node = node.field("name")
if name_node and hasattr(name_node, "text"):
text_val = name_node.text()
if isinstance(text_val, str):
return text_val
# Fallback to text parsing
if hasattr(node, "text"):
text = node.text()
if isinstance(text, str) and "def " in text:
# Extract function name after 'def '
parts = text.split("def ", 1)
if len(parts) > 1:
name_part = parts[1].split("(")[0].strip()
return name_part
return "unknown_function"
except Exception:
return "unknown_function"
def _extract_js_function_name(node) -> str:
"""Extract function name from JavaScript/TypeScript function node."""
try:
# Try to get function name from AST node structure
if hasattr(node, "field"):
name_node = node.field("name")
if name_node and hasattr(name_node, "text"):
text_val = name_node.text()
if isinstance(text_val, str):
return text_val
# Fallback to text parsing for different JS function types
if hasattr(node, "text"):
text = node.text()
if isinstance(text, str):
# Function declaration: function name() {}
if "function " in text:
parts = text.split("function ", 1)
if len(parts) > 1:
name_part = parts[1].split("(")[0].strip()
return name_part
# Arrow function: const name = () => {}
if " = " in text and "=>" in text:
parts = text.split(" = ", 1)
if len(parts) > 1:
# Extract variable name before =
name_candidates = parts[0].split()
for candidate in reversed(name_candidates):
if candidate not in ["const", "let", "var", "export"]:
return candidate.strip()
# Method definition: name() {}
if "(" in text and "{" in text:
lines = text.split("\n")
first_line = lines[0].strip()
if first_line.endswith("{") or "){" in first_line:
name_part = first_line.split("(")[0].strip()
# Remove access modifiers and keywords
name_part = name_part.replace("async ", "").replace("static ", "")
name_part = name_part.replace("public ", "").replace("private ", "").replace("protected ", "")
return name_part
return "unknown_function"
except Exception:
return "unknown_function"
def _extract_class_name(node) -> str:
"""Extract class name from class definition node."""
try:
# Try to get class name from AST node structure
if hasattr(node, "field"):
name_node = node.field("name")
if name_node and hasattr(name_node, "text"):
text_val = name_node.text()
if isinstance(text_val, str):
return text_val
# Fallback to text parsing
if hasattr(node, "text"):
text = node.text()
if isinstance(text, str) and "class " in text:
# Extract class name after 'class '
parts = text.split("class ", 1)
if len(parts) > 1:
name_part = parts[1].split("(")[0].split(":")[0].strip()
return name_part
return "unknown_class"
except Exception:
return "unknown_class"
def find_references(symbol: str, path: str, symbol_type: str = "any") -> List[PatternSearchResult]:
"""
Find all references to a specific symbol using improved pattern matching.
Based on debugging results: use simple text patterns that work reliably
rather than complex AST patterns that may fail.
Args:
symbol: Symbol name to search for
path: File or directory path to search in
symbol_type: Type of symbol ("function", "class", "variable", "any")
Returns:
List of PatternSearchResult objects
Raises:
AstSearchError: If search operation fails
"""
if not _is_ast_grep_available():
# Fallback to simple text search if ast-grep not available
logger.warning("ast-grep-py not available, using fallback text search")
return _fallback_text_search(symbol, path)
logger.info("Finding references", symbol=symbol, path=path, symbol_type=symbol_type)
try:
results = []
# Based on debugging: use simple, reliable patterns
# Complex patterns often fail, so prefer basic ones that work
basic_patterns = [
symbol, # Simple identifier - most reliable
]
# Add specific patterns based on symbol type
if symbol_type in ["function", "any"]:
basic_patterns.extend(
[
f"def {symbol}(", # Python function definition
f"function {symbol}(", # JS function definition
f"{symbol}(", # Function calls
]
)
if symbol_type in ["class", "any"]:
basic_patterns.extend(
[
f"class {symbol}", # Class definitions
f"new {symbol}(", # Class instantiation
f"{symbol}.", # Method/property access
]
)
if symbol_type in ["variable", "any"]:
basic_patterns.extend(
[
f"{symbol} =", # Variable assignment
f"const {symbol}", # JS/TS const declaration
f"let {symbol}", # JS/TS let declaration
f"var {symbol}", # JS var declaration
]
)
# Search with each pattern across supported languages
for language in ["python", "javascript", "typescript"]:
for pattern in basic_patterns:
try:
pattern_results = search_code_patterns(pattern, language, path)
# Filter to ensure we found the right symbol
for result in pattern_results:
if symbol in result.text:
results.append(result)
except AstSearchError as e:
logger.debug(f"Pattern search failed for {pattern} in {language}: {e}")
continue
except Exception as e:
logger.debug(f"Unexpected error searching {pattern}: {e}")
continue
# Remove duplicates based on file path and line number
unique_results = []
seen = set()
for result in results:
key = (result.file_path, result.line)
if key not in seen:
seen.add(key)
unique_results.append(result)
logger.info("Reference search completed", symbol=symbol, symbol_type=symbol_type, references_found=len(unique_results))
return unique_results
except Exception as e:
logger.error("Reference search failed", error=str(e), symbol=symbol, path=path)
# Try fallback search as last resort
try:
return _fallback_text_search(symbol, path)
except Exception:
raise AstSearchError(f"Reference search failed: {e}")
def _fallback_text_search(symbol: str, path: str) -> List[PatternSearchResult]:
"""
Fallback text-based search when ast-grep is not available.
This provides basic symbol finding capabilities without AST parsing.
"""
results: List[PatternSearchResult] = []
try:
if os.path.isfile(path):
files_to_search = [path]
elif os.path.isdir(path):
files_to_search = []
for root, dirs, files in os.walk(path):
# Skip common exclude directories
dirs[:] = [d for d in dirs if d not in [".git", "node_modules", "__pycache__", ".venv"]]
for file in files:
if any(file.endswith(ext) for ext in [".py", ".js", ".jsx", ".ts", ".tsx"]):
files_to_search.append(os.path.join(root, file))
else:
return results
for file_path in files_to_search:
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line_num, line in enumerate(lines, 1):
if symbol in line:
# Simple column calculation
column = line.find(symbol) + 1
result = PatternSearchResult(
file_path=file_path, line=line_num, column=column, text=line.strip(), matches={"symbol": symbol}
)
results.append(result)
except Exception as e:
logger.debug(f"Failed to search file {file_path}: {e}")
continue
except Exception as e:
logger.error(f"Fallback text search failed: {e}")
return results
def create_ast_grep_rule(
rule_id: str, language: str, pattern: Optional[str] = None, kind: Optional[str] = None, regex: Optional[str] = None
) -> Dict[str, Any]:
"""
Create an ast-grep rule configuration based on MCP documentation format.
Args:
rule_id: Unique identifier for the rule
language: Target language (python, javascript, typescript, etc.)
pattern: AST pattern to match (e.g., "def $name($args): $body")
kind: AST node kind to match (e.g., "function_definition")
regex: Regular expression pattern to match
Returns:
Dictionary containing the rule configuration
Raises:
ValueError: If no matching criteria provided
"""
if not any([pattern, kind, regex]):
raise ValueError("At least one of pattern, kind, or regex must be provided")
rule: Dict[str, Any] = {"id": rule_id, "language": language, "rule": cast(Dict[str, Any], {})}
inner_rule = cast(Dict[str, Any], rule["rule"])
# Add matching criteria based on ast-grep MCP documentation
if pattern:
inner_rule["pattern"] = pattern
if kind:
inner_rule["kind"] = kind
if regex:
inner_rule["regex"] = regex
return rule
def search_with_rule(rule_config: Dict[str, Any], path: str, exclude_patterns: Optional[List[str]] = None) -> List[PatternSearchResult]:
"""
Search using a complete ast-grep rule configuration.
Args:
rule_config: Complete rule configuration with id, language, and rule
path: File or directory path to search
exclude_patterns: Optional patterns to exclude
Returns:
List of PatternSearchResult objects
Raises:
AstSearchError: If search operation fails
"""
if not _is_ast_grep_available():
raise AstSearchError("ast-grep-py is not available")
try:
# Validate rule configuration
required_fields = ["id", "language", "rule"]
for field in required_fields:
if field not in rule_config:
raise ValueError(f"Rule configuration missing required field: {field}")
language = rule_config["language"]
rule = rule_config["rule"]
# Extract pattern from rule (support multiple rule types)
search_pattern = None
if "pattern" in rule:
search_pattern = rule["pattern"]
elif "kind" in rule:
# For kind-based rules, we'll use the kind directly
search_pattern = rule["kind"]
elif "regex" in rule:
# For regex rules, use the regex pattern
search_pattern = rule["regex"]
else:
raise ValueError("Rule must contain at least one of: pattern, kind, regex")
logger.info("Searching with rule", rule_id=rule_config["id"], language=language, pattern=search_pattern)
# Use existing search infrastructure with enhanced rule
return search_code_patterns(search_pattern, language, path, exclude_patterns)
except Exception as e:
logger.error("Rule-based search failed", error=str(e), rule_id=rule_config.get("id", "unknown"))
raise AstSearchError(f"Rule-based search failed: {e}")
def validate_ast_grep_rule(rule_config: Dict[str, Any]) -> bool:
"""
Validate ast-grep rule configuration based on MCP documentation.
Args:
rule_config: Rule configuration to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required top-level fields
required_fields = ["id", "language", "rule"]
for field in required_fields:
if field not in rule_config:
logger.warning(f"Rule validation failed: missing field '{field}'")
return False
rule = rule_config["rule"]
# Check that rule has at least one matching criterion
matching_fields = ["pattern", "kind", "regex", "inside", "has", "follows", "precedes"]
has_matching_field = any(field in rule for field in matching_fields)
if not has_matching_field:
logger.warning("Rule validation failed: no matching criteria specified")
return False
# Validate language
supported_languages = ["python", "javascript", "typescript", "json", "yaml", "html", "css"]
if rule_config["language"] not in supported_languages:
logger.warning(f"Rule validation warning: unsupported language '{rule_config['language']}'")
return True
except Exception as e:
logger.error("Rule validation error", error=str(e))
return False
# Usage Examples and Best Practices
# Based on debugging results and ast-grep MCP documentation:
def get_example_patterns() -> Dict[str, Dict[str, str]]:
"""
Get working example patterns based on debugging results.
These patterns have been tested and verified to work with the MCP server.
"""
return {
"python": {
"function_definition": "def $NAME($ARGS)",
"class_definition": "class $NAME",
"function_call": "$NAME($ARGS)",
"import_statement": "import $MODULE",
"from_import": "from $MODULE import $ITEMS",
"exception_handler": "except $EXCEPTION as $VAR:",
"multi_import": "import $$$MODULES", # Multiple modules
},
"javascript": {
"function_declaration": "function $NAME($ARGS)",
"arrow_function": "const $NAME = ($ARGS) => $BODY",
"class_definition": "class $NAME",
"import_statement": "import $ITEMS from '$MODULE'",
"export_statement": "export $DECLARATION",
},
"typescript": {
"typed_function": "function $NAME($ARGS): $TYPE",
"typed_arrow": "const $NAME = ($ARGS): $TYPE => $BODY",
"interface_definition": "interface $NAME",
"type_definition": "type $NAME = $TYPE",
},
}
def get_example_rules() -> Dict[str, Dict[str, Any]]:
"""
Get working example YAML rules based on debugging results.
IMPORTANT: All rules MUST include a 'kind' field for AST node type.
"""
return {
"find_functions": {
"id": "find-functions",
"language": "python",
"rule": {"kind": "function_definition", "pattern": "def $NAME($ARGS)"},
},
"find_classes": {"id": "find-classes", "language": "python", "rule": {"kind": "class_definition", "pattern": "class $NAME"}},
"find_exception_handlers": {
"id": "find-exception-handlers",
"language": "python",
"rule": {
"kind": "try_statement",
"has": {"kind": "except_clause", "has": {"kind": "as_pattern", "pattern": "$EXCEPTION as $VAR"}},
},
},
"find_js_functions": {
"id": "find-js-functions",
"language": "javascript",
"rule": {"kind": "function_declaration", "pattern": "function $NAME($ARGS)"},
},
}
def debug_pattern_issues() -> Dict[str, str]:
"""
Common debugging tips for pattern matching issues.
Based on actual debugging experience with the MCP server.
"""
return {
"uppercase_variables": "CRITICAL: Use UPPERCASE metavariables ($NAME not $name)",
"kind_field_required": "CRITICAL: All rules must specify 'kind' field for AST node type",
"simple_patterns_work": "TIP: Simple patterns like 'def $NAME($ARGS)' work better than complex ones",
"reference_finding": "TIP: Use simple text patterns for reference finding, not complex AST patterns",
"fallback_available": "INFO: Fallback text search available when ast-grep fails",
"error_handling": "INFO: All functions include proper error handling and logging",
"pattern_testing": "TIP: Test patterns incrementally, start simple then add complexity",
}
# Error handling improvements based on debugging
class ImprovedAstSearchError(AstSearchError):
"""Enhanced error class with debugging context."""
def __init__(
self, message: str, pattern: Optional[str] = None, language: Optional[str] = None, suggestions: Optional[List[str]] = None
):
super().__init__(message)
self.pattern = pattern
self.language = language
self.suggestions = suggestions or []
def __str__(self) -> str:
base_msg = super().__str__()
if self.pattern:
base_msg += f" (Pattern: {self.pattern})"
if self.language:
base_msg += f" (Language: {self.language})"
if self.suggestions:
base_msg += f" Suggestions: {', '.join(self.suggestions)}"
return base_msg
def validate_pattern_syntax(pattern: str, language: str) -> List[str]:
"""
Validate pattern syntax and provide suggestions.
Based on debugging findings about common pattern issues.
"""
issues = []
# Check for lowercase metavariables (critical issue)
if "$" in pattern:
import re
lowercase_vars = re.findall(r"\$[a-z][a-zA-Z_]*", pattern)
if lowercase_vars:
issues.append(f"Use UPPERCASE metavariables: {lowercase_vars} should be UPPERCASE")
# Check for overly complex patterns
if pattern.count("{") > 2 or pattern.count("$$$") > 3:
issues.append("Pattern may be too complex - try simpler patterns first")
# Language-specific checks
if language == "python" and "def " in pattern and "kind" not in pattern:
issues.append("Python function patterns should specify kind: 'function_definition'")
if language in ["javascript", "typescript"] and "function" in pattern and "kind" not in pattern:
issues.append("JS/TS function patterns should specify kind: 'function_declaration'")
return issues