"""Code structure extraction from tree-sitter AST."""
from typing import List, Optional
import tree_sitter
from ..config import CLASS_NODE_TYPES, FUNCTION_NODE_TYPES
from ..models import CodeElement, CodeStructure, Parameter, ParseError
from ..parsers.tree_sitter import TreeSitterManager
class StructureExtractor:
"""Extracts code structure from tree-sitter AST."""
def __init__(self, parser_manager: TreeSitterManager) -> None:
"""Initialize the extractor."""
self.parser_manager = parser_manager
def extract_structure(
self,
tree: tree_sitter.Tree,
source_code: bytes,
language: str,
file_path: str,
include_docstrings: bool = False,
) -> CodeStructure:
"""Extract code structure from parsed tree."""
structure = CodeStructure(file_path=file_path, language=language)
# First, collect all errors
self._collect_errors(tree.root_node, source_code, structure)
# Then extract classes and functions
self._extract_elements(
tree.root_node,
source_code,
language,
structure,
include_docstrings=include_docstrings,
nesting_level=0,
)
return structure
def _collect_errors(
self,
node: tree_sitter.Node,
source_code: bytes,
structure: CodeStructure,
) -> None:
"""Collect parsing errors from the tree."""
if node.type == "ERROR":
line = node.start_point.row + 1
context = self.parser_manager.get_node_text(node, source_code)
error = ParseError(
line=line,
message="Syntax error",
context=context[:100] + "..." if len(context) > 100 else context,
)
structure.add_error(error)
for child in node.children:
self._collect_errors(child, source_code, structure)
def _extract_elements(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
structure: CodeStructure,
include_docstrings: bool,
nesting_level: int,
parent: Optional[CodeElement] = None,
) -> None:
"""Extract classes and functions from AST nodes."""
node_type = node.type
# Check if this is a class
if node_type in CLASS_NODE_TYPES.get(language, []):
element = self._extract_class(
node, source_code, language, nesting_level, include_docstrings
)
if element:
if parent is None:
structure.add_class(element)
else:
parent.add_child(element)
# Recurse into class body (children)
for child in node.children:
self._extract_elements(
child,
source_code,
language,
structure,
include_docstrings,
nesting_level + 1,
element,
)
return
# Check if this is a function
if node_type in FUNCTION_NODE_TYPES.get(language, []):
element = self._extract_function(
node, source_code, language, nesting_level, include_docstrings
)
if element:
if parent is None:
structure.add_function(element)
else:
parent.add_child(element)
# Recurse into function body (children)
for child in node.children:
self._extract_elements(
child,
source_code,
language,
structure,
include_docstrings,
nesting_level + 1,
element,
)
return
# Recurse into children for non-class/function nodes
for child in node.children:
self._extract_elements(
child,
source_code,
language,
structure,
include_docstrings,
nesting_level,
parent,
)
def _extract_class(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
nesting_level: int,
include_docstrings: bool,
) -> Optional[CodeElement]:
"""Extract a class definition."""
name = self._extract_name(node, source_code, language, "class")
if not name:
return None
start_line, end_line = self.parser_manager.get_node_line_range(node)
docstring = (
self._extract_docstring(node, source_code, language)
if include_docstrings
else None
)
return CodeElement(
name=name,
element_type="class",
start_line=start_line,
end_line=end_line,
nesting_level=nesting_level,
docstring=docstring,
)
def _extract_function(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
nesting_level: int,
include_docstrings: bool,
) -> Optional[CodeElement]:
"""Extract a function definition."""
# Handle decorated definitions in Python
if node.type == "decorated_definition":
for child in node.children:
if child.type in FUNCTION_NODE_TYPES.get(language, []):
return self._extract_function(
child, source_code, language, nesting_level, include_docstrings
)
return None
name = self._extract_name(node, source_code, language, "function")
if not name:
return None
start_line, end_line = self.parser_manager.get_node_line_range(node)
parameters = self._extract_parameters(node, source_code, language)
return_type = self._extract_return_type(node, source_code, language)
docstring = (
self._extract_docstring(node, source_code, language)
if include_docstrings
else None
)
return CodeElement(
name=name,
element_type="function",
start_line=start_line,
end_line=end_line,
nesting_level=nesting_level,
parameters=parameters,
return_type=return_type,
docstring=docstring,
)
def _extract_name(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
element_type: str,
) -> Optional[str]:
"""Extract the name of a class or function."""
# Language-specific name extraction
if language == "python":
if element_type == "class":
name_node = node.child_by_field_name("name")
else: # function
name_node = node.child_by_field_name("name")
# Handle decorated definitions
if name_node is None:
for child in node.children:
if child.type in [
"function_definition",
"async_function_definition",
]:
name_node = child.child_by_field_name("name")
break
elif language in ["javascript", "typescript"]:
name_node = node.child_by_field_name("name")
# For anonymous functions, try to find a variable name
if name_node is None and element_type == "function":
# This could be an anonymous function assigned to a variable
# For simplicity, we'll skip anonymous functions
return None
elif language == "java":
name_node = node.child_by_field_name("name")
elif language == "csharp":
name_node = node.child_by_field_name("name")
elif language == "go":
name_node = node.child_by_field_name("name")
else:
name_node = None
if name_node:
return self.parser_manager.get_node_text(name_node, source_code)
return None
def _extract_parameters(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
) -> List[Parameter]:
"""Extract function parameters."""
parameters: List[Parameter] = []
# Get parameters node
params_node = node.child_by_field_name("parameters")
if not params_node:
return parameters
# Extract individual parameters
for child in params_node.children:
if child.type == "identifier" or child.type == "parameter":
param = self._extract_parameter(child, source_code, language)
if param:
parameters.append(param)
elif child.type == "typed_parameter":
param = self._extract_typed_parameter(child, source_code, language)
if param:
parameters.append(param)
elif (
child.type == "required_parameter" or child.type == "optional_parameter"
):
param = self._extract_parameter(child, source_code, language)
if param:
parameters.append(param)
return parameters
def _extract_parameter(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
) -> Optional[Parameter]:
"""Extract a single parameter."""
# Try to get name from field
name_node = node.child_by_field_name("name")
if name_node:
name = self.parser_manager.get_node_text(name_node, source_code)
else:
# Fallback: use node text
name = self.parser_manager.get_node_text(node, source_code)
# Try to get type
type_node = node.child_by_field_name("type")
type_annotation = None
if type_node:
type_annotation = self.parser_manager.get_node_text(type_node, source_code)
# Try to get default value
default_node = node.child_by_field_name("value")
default_value = None
if default_node:
default_value = self.parser_manager.get_node_text(default_node, source_code)
return Parameter(
name=name, type_annotation=type_annotation, default_value=default_value
)
def _extract_typed_parameter(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
) -> Optional[Parameter]:
"""Extract a typed parameter (Python)."""
# In Python's tree-sitter grammar, typed_parameter has:
# - First child: identifier (parameter name) - NOT a field
# - Field "type": type annotation
# - Field "default": default value (optional)
name = ""
type_annotation = None
default_value = None
# Get parameter name from first identifier child
for child in node.children:
if child.type == "identifier":
name = self.parser_manager.get_node_text(child, source_code)
break
# Get type annotation from field
type_node = node.child_by_field_name("type")
if type_node:
type_annotation = self.parser_manager.get_node_text(type_node, source_code)
# Get default value from field
default_node = node.child_by_field_name("default")
if default_node:
default_value = self.parser_manager.get_node_text(default_node, source_code)
return Parameter(
name=name, type_annotation=type_annotation, default_value=default_value
)
def _extract_return_type(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
) -> Optional[str]:
"""Extract return type annotation."""
type_node = node.child_by_field_name("return_type")
if type_node:
return self.parser_manager.get_node_text(type_node, source_code)
return None
def _extract_docstring(
self,
node: tree_sitter.Node,
source_code: bytes,
language: str,
) -> Optional[str]:
"""Extract docstring from a class or function."""
# Language-specific docstring extraction
if language == "python":
# In Python, docstring is the first statement in the body
body_node = node.child_by_field_name("body")
if body_node and body_node.children:
first_child = body_node.children[0]
if first_child.type == "expression_statement":
string_node = first_child.children[0]
if string_node.type in ["string", "string_content"]:
docstring = self.parser_manager.get_node_text(
string_node, source_code
)
# Strip quotes from docstring
if docstring and docstring.startswith(("'''", '"""', "'", '"')):
if docstring.startswith(("'''", '"""')):
docstring = docstring[3:-3]
else:
docstring = docstring[1:-1]
return docstring
elif language in ["javascript", "typescript", "java", "csharp"]:
# For these languages, docstrings are typically JSDoc/Javadoc comments
# which are not part of the AST. We'll skip for now.
pass
elif language == "go":
# Go uses comments before the declaration
pass
return None