"""AST representation models for MCP server.
This module provides functions for converting tree-sitter AST nodes to dictionaries,
finding nodes at specific positions, and other AST-related operations.
"""
from typing import Any, Dict, List, Optional, Tuple
from ..utils.tree_sitter_helpers import (
get_node_text,
walk_tree,
)
from ..utils.tree_sitter_types import ensure_node
# Import the cursor-based implementation
from .ast_cursor import node_to_dict_cursor
def node_to_dict(
node: Any,
source_bytes: Optional[bytes] = None,
include_children: bool = True,
include_text: bool = True,
max_depth: int = 5,
) -> Dict[str, Any]:
"""
Convert a tree-sitter node to a dictionary representation.
This function now uses a cursor-based traversal approach for efficiency and
reliability, especially with large ASTs that could cause stack overflow with
recursive processing.
Args:
node: Tree-sitter Node object
source_bytes: Source code bytes
include_children: Whether to include children nodes
include_text: Whether to include node text
max_depth: Maximum depth to traverse
Returns:
Dictionary representation of the node
"""
# Use the cursor-based implementation for improved reliability
return node_to_dict_cursor(node, source_bytes, include_children, include_text, max_depth)
def summarize_node(node: Any, source_bytes: Optional[bytes] = None) -> Dict[str, Any]:
"""
Create a compact summary of a node without details or children.
Args:
node: Tree-sitter Node object
source_bytes: Source code bytes
Returns:
Dictionary with basic node information
"""
safe_node = ensure_node(node)
result = {
"type": safe_node.type,
"start_point": {
"row": safe_node.start_point[0],
"column": safe_node.start_point[1],
},
"end_point": {"row": safe_node.end_point[0], "column": safe_node.end_point[1]},
}
# Add a short text snippet if source is available
if source_bytes:
try:
# Use helper function to get text safely - make sure to decode
text = get_node_text(safe_node, source_bytes, decode=True)
if isinstance(text, bytes):
text = text.decode("utf-8", errors="replace")
lines = text.splitlines()
if lines:
snippet = lines[0][:50]
if len(snippet) < len(lines[0]) or len(lines) > 1:
snippet += "..."
result["preview"] = snippet
except Exception:
pass
return result
def find_node_at_position(root_node: Any, row: int, column: int) -> Optional[Any]:
"""
Find the most specific node at a given position using cursor-based traversal.
Args:
root_node: Root node to search from
row: Row (line) number, 0-based
column: Column number, 0-based
Returns:
The most specific node at the position, or None if not found
"""
safe_node = ensure_node(root_node)
point = (row, column)
# Check if point is within root_node
if not (safe_node.start_point <= point <= safe_node.end_point):
return None
# Find the smallest node that contains the point
cursor = walk_tree(safe_node)
current_best = cursor.node
# Special handling for function definitions and identifiers
def check_for_specific_nodes(node: Any) -> Optional[Any]:
# For function definitions, check if position is over the function name
if node.type == "function_definition":
for child in node.children:
if child.type in ["identifier", "name"]:
if (
child.start_point[0] <= row <= child.end_point[0]
and child.start_point[1] <= column <= child.end_point[1]
):
return child
return None
# First check if we have a specific node like a function name
specific_node = check_for_specific_nodes(safe_node)
if specific_node:
return specific_node
while cursor.goto_first_child():
# If current node contains the point, it's better than the parent
if cursor.node is not None and cursor.node.start_point <= point <= cursor.node.end_point:
current_best = cursor.node
# Check for specific nodes like identifiers
specific_node = check_for_specific_nodes(cursor.node)
if specific_node:
return specific_node
continue # Continue to first child
# If first child doesn't contain point, try siblings
cursor.goto_parent()
current_best = cursor.node # Reset current best to parent
# Try siblings
found_in_sibling = False
while cursor.goto_next_sibling():
if cursor.node is not None and cursor.node.start_point <= point <= cursor.node.end_point:
current_best = cursor.node
# Check for specific nodes
specific_node = check_for_specific_nodes(cursor.node)
if specific_node:
return specific_node
found_in_sibling = True
break
# If a sibling contains the point, continue to its children
if found_in_sibling:
continue
else:
# No child or sibling contains the point, we're done
break
return current_best
def extract_node_path(
root_node: Any,
target_node: Any,
) -> List[Tuple[str, Optional[str]]]:
"""
Extract the path from root to a specific node using safe node handling.
Args:
root_node: Root node
target_node: Target node
Returns:
List of (node_type, field_name) tuples from root to target
"""
safe_root = ensure_node(root_node)
safe_target = ensure_node(target_node)
# If nodes are the same, return empty path
if safe_root == safe_target:
return []
path = []
current = safe_target
while current != safe_root and current.parent:
field_name = None
# Find field name if any
parent_field_names = getattr(current.parent, "children_by_field_name", {})
if hasattr(parent_field_names, "items"):
for name, nodes in parent_field_names.items():
if current in nodes:
field_name = name
break
path.append((current.type, field_name))
current = current.parent
# Add root node unless it's already the target
if current == safe_root and path:
path.append((safe_root.type, None))
# Reverse to get root->target order
return list(reversed(path))