"""AST-based code chunking for semantic extraction."""
import hashlib
from pathlib import Path
from typing import Iterator
from tree_sitter import Node
from local_deepwiki.config import ChunkingConfig, get_config
from local_deepwiki.core.parser import (
CodeParser,
get_node_text,
get_node_name,
get_docstring,
find_nodes_by_type,
)
from local_deepwiki.models import CodeChunk, ChunkType, Language
# Node types that represent extractable code units per language
FUNCTION_NODE_TYPES: dict[Language, set[str]] = {
Language.PYTHON: {"function_definition", "async_function_definition"},
Language.JAVASCRIPT: {"function_declaration", "arrow_function", "function_expression", "method_definition"},
Language.TYPESCRIPT: {"function_declaration", "arrow_function", "function_expression", "method_definition"},
Language.GO: {"function_declaration", "method_declaration"},
Language.RUST: {"function_item"},
Language.JAVA: {"method_declaration", "constructor_declaration"},
Language.C: {"function_definition"},
Language.CPP: {"function_definition"},
Language.SWIFT: {"function_declaration", "init_declaration"},
}
CLASS_NODE_TYPES: dict[Language, set[str]] = {
Language.PYTHON: {"class_definition"},
Language.JAVASCRIPT: {"class_declaration"},
Language.TYPESCRIPT: {"class_declaration", "interface_declaration", "type_alias_declaration"},
Language.GO: {"type_declaration"},
Language.RUST: {"struct_item", "impl_item", "trait_item", "enum_item"},
Language.JAVA: {"class_declaration", "interface_declaration", "enum_declaration"},
Language.C: {"struct_specifier"},
Language.CPP: {"class_specifier", "struct_specifier"},
Language.SWIFT: {"class_declaration", "struct_declaration", "protocol_declaration", "enum_declaration", "extension_declaration"},
}
IMPORT_NODE_TYPES: dict[Language, set[str]] = {
Language.PYTHON: {"import_statement", "import_from_statement"},
Language.JAVASCRIPT: {"import_statement", "import_declaration"},
Language.TYPESCRIPT: {"import_statement", "import_declaration"},
Language.GO: {"import_declaration"},
Language.RUST: {"use_declaration"},
Language.JAVA: {"import_declaration"},
Language.C: {"preproc_include"},
Language.CPP: {"preproc_include"},
Language.SWIFT: {"import_declaration"},
}
def get_parent_classes(class_node: Node, source: bytes, language: Language) -> list[str]:
"""Extract parent class names from a class definition.
Args:
class_node: The class AST node.
source: Source bytes.
language: Programming language.
Returns:
List of parent class names.
"""
parents = []
if language == Language.PYTHON:
# Python: class Child(Parent, Mixin): → argument_list > identifier
for child in class_node.children:
if child.type == "argument_list":
for arg in child.children:
if arg.type == "identifier":
parents.append(get_node_text(arg, source))
elif language in (Language.TYPESCRIPT, Language.JAVASCRIPT):
# TS/JS: class Child extends Parent implements Interface
for child in class_node.children:
if child.type == "class_heritage":
for clause in child.children:
if clause.type in ("extends_clause", "implements_clause"):
for item in clause.children:
if item.type in ("identifier", "type_identifier"):
parents.append(get_node_text(item, source))
elif language == Language.JAVA:
# Java: class Child extends Parent implements Interface
for child in class_node.children:
if child.type == "superclass":
for item in child.children:
if item.type == "type_identifier":
parents.append(get_node_text(item, source))
elif child.type == "super_interfaces":
for item in find_nodes_by_type(child, {"type_identifier"}):
parents.append(get_node_text(item, source))
elif language == Language.SWIFT:
# Swift: class Child: Parent, Protocol
for child in class_node.children:
if child.type == "type_inheritance_clause":
for item in child.children:
if item.type in ("user_type", "type_identifier"):
# Get the identifier from user_type
text = get_node_text(item, source)
if text and text not in (":", ","):
parents.append(text)
elif language == Language.CPP:
# C++: class Child : public Parent
for child in class_node.children:
if child.type == "base_class_clause":
for item in find_nodes_by_type(child, {"type_identifier"}):
parents.append(get_node_text(item, source))
return parents
class CodeChunker:
"""Extract semantic code chunks from source files using AST analysis."""
def __init__(self, config: ChunkingConfig | None = None):
"""Initialize the chunker.
Args:
config: Optional chunking configuration.
"""
self.config = config or get_config().chunking
self.parser = CodeParser()
def chunk_file(self, file_path: Path, repo_root: Path) -> Iterator[CodeChunk]:
"""Extract code chunks from a source file.
Args:
file_path: Path to the source file.
repo_root: Root directory of the repository.
Yields:
CodeChunk objects for each semantic unit found.
"""
result = self.parser.parse_file(file_path)
if result is None:
return
root, language, source = result
rel_path = str(file_path.relative_to(repo_root))
# Extract module-level chunk (file overview)
yield self._create_module_chunk(root, source, language, rel_path)
# Extract imports
import_types = IMPORT_NODE_TYPES.get(language, set())
import_nodes = find_nodes_by_type(root, import_types)
if import_nodes:
yield self._create_imports_chunk(import_nodes, source, language, rel_path)
# Extract classes and their methods
class_types = CLASS_NODE_TYPES.get(language, set())
for class_node in find_nodes_by_type(root, class_types):
yield from self._extract_class_chunks(class_node, source, language, rel_path)
# Extract top-level functions (not inside classes)
function_types = FUNCTION_NODE_TYPES.get(language, set())
for func_node in find_nodes_by_type(root, function_types):
# Skip if inside a class (already processed)
if not self._is_inside_class(func_node, class_types):
yield self._create_function_chunk(func_node, source, language, rel_path)
def _create_module_chunk(
self,
root: Node,
source: bytes,
language: Language,
file_path: str,
) -> CodeChunk:
"""Create a chunk for the module/file overview.
Args:
root: AST root node.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
Returns:
A CodeChunk for the module.
"""
# Get module docstring if present
docstring = None
if language == Language.PYTHON:
# Python module docstring is first expression
if root.children and root.children[0].type == "expression_statement":
expr = root.children[0]
if expr.children and expr.children[0].type == "string":
docstring = get_node_text(expr.children[0], source)
if docstring.startswith('"""') or docstring.startswith("'''"):
docstring = docstring[3:-3].strip()
# Create a summary of the file structure
content = self._create_file_summary(root, source, language)
chunk_id = self._generate_id(file_path, "module", 0)
return CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.MODULE,
name=Path(file_path).stem,
content=content,
start_line=1,
end_line=source.count(b"\n") + 1,
docstring=docstring,
metadata={"is_overview": True},
)
def _create_file_summary(self, root: Node, source: bytes, language: Language) -> str:
"""Create a summary of file structure for the module chunk.
Args:
root: AST root node.
source: Source bytes.
language: Programming language.
Returns:
A summary string of file contents.
"""
parts = []
# List imports
import_types = IMPORT_NODE_TYPES.get(language, set())
imports = find_nodes_by_type(root, import_types)
if imports:
import_text = "\n".join(get_node_text(n, source) for n in imports[:10])
if len(imports) > 10:
import_text += f"\n# ... and {len(imports) - 10} more imports"
parts.append(f"# Imports:\n{import_text}")
# List classes
class_types = CLASS_NODE_TYPES.get(language, set())
classes = find_nodes_by_type(root, class_types)
if classes:
class_names = [get_node_name(c, source, language) or "anonymous" for c in classes]
parts.append(f"# Classes: {', '.join(class_names)}")
# List functions
function_types = FUNCTION_NODE_TYPES.get(language, set())
functions = [
f for f in find_nodes_by_type(root, function_types)
if not self._is_inside_class(f, class_types)
]
if functions:
func_names = [get_node_name(f, source, language) or "anonymous" for f in functions]
parts.append(f"# Functions: {', '.join(func_names)}")
return "\n\n".join(parts) if parts else "# Empty file"
def _create_imports_chunk(
self,
import_nodes: list[Node],
source: bytes,
language: Language,
file_path: str,
) -> CodeChunk:
"""Create a chunk for import statements.
Args:
import_nodes: List of import nodes.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
Returns:
A CodeChunk for imports.
"""
content = "\n".join(get_node_text(n, source) for n in import_nodes)
start_line = min(n.start_point[0] + 1 for n in import_nodes)
end_line = max(n.end_point[0] + 1 for n in import_nodes)
chunk_id = self._generate_id(file_path, "imports", start_line)
return CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.IMPORT,
name="imports",
content=content,
start_line=start_line,
end_line=end_line,
metadata={"import_count": len(import_nodes)},
)
def _extract_class_chunks(
self,
class_node: Node,
source: bytes,
language: Language,
file_path: str,
) -> Iterator[CodeChunk]:
"""Extract chunks from a class definition.
Args:
class_node: The class AST node.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
Yields:
CodeChunks for the class and its methods.
"""
class_name = get_node_name(class_node, source, language) or "anonymous"
docstring = get_docstring(class_node, source, language)
content = get_node_text(class_node, source)
# Extract parent classes for inheritance
parent_classes = get_parent_classes(class_node, source, language)
# Check if class is too large and needs to be split
lines = content.count("\n") + 1
if lines > 100:
# For large classes, create a summary chunk and method chunks
yield self._create_class_summary_chunk(
class_node, source, language, file_path, class_name, docstring, parent_classes
)
# Extract methods separately
function_types = FUNCTION_NODE_TYPES.get(language, set())
for method_node in find_nodes_by_type(class_node, function_types):
yield self._create_method_chunk(
method_node, source, language, file_path, class_name
)
else:
# Small class - include everything in one chunk
chunk_id = self._generate_id(file_path, f"class_{class_name}", class_node.start_point[0])
metadata = {"line_count": lines}
if parent_classes:
metadata["parent_classes"] = parent_classes
yield CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.CLASS,
name=class_name,
content=content,
start_line=class_node.start_point[0] + 1,
end_line=class_node.end_point[0] + 1,
docstring=docstring,
metadata=metadata,
)
def _create_class_summary_chunk(
self,
class_node: Node,
source: bytes,
language: Language,
file_path: str,
class_name: str,
docstring: str | None,
parent_classes: list[str] | None = None,
) -> CodeChunk:
"""Create a summary chunk for a large class.
Args:
class_node: The class AST node.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
class_name: Name of the class.
docstring: Class docstring if any.
parent_classes: List of parent class names.
Returns:
A summary CodeChunk for the class.
"""
# Get class signature and method list
function_types = FUNCTION_NODE_TYPES.get(language, set())
methods = find_nodes_by_type(class_node, function_types)
method_names = [get_node_name(m, source, language) or "anonymous" for m in methods]
# Build summary content
signature_end = class_node.start_byte
for child in class_node.children:
if child.type in ("block", "class_body", "declaration_list"):
signature_end = child.start_byte
break
signature = source[class_node.start_byte:signature_end].decode("utf-8", errors="replace").strip()
content = f"{signature}\n # Methods: {', '.join(method_names)}"
chunk_id = self._generate_id(file_path, f"class_{class_name}", class_node.start_point[0])
metadata = {"is_summary": True, "method_count": len(methods)}
if parent_classes:
metadata["parent_classes"] = parent_classes
return CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.CLASS,
name=class_name,
content=content,
start_line=class_node.start_point[0] + 1,
end_line=class_node.end_point[0] + 1,
docstring=docstring,
metadata=metadata,
)
def _create_method_chunk(
self,
method_node: Node,
source: bytes,
language: Language,
file_path: str,
class_name: str,
) -> CodeChunk:
"""Create a chunk for a class method.
Args:
method_node: The method AST node.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
class_name: Name of the parent class.
Returns:
A CodeChunk for the method.
"""
method_name = get_node_name(method_node, source, language) or "anonymous"
content = get_node_text(method_node, source)
docstring = get_docstring(method_node, source, language)
chunk_id = self._generate_id(file_path, f"{class_name}.{method_name}", method_node.start_point[0])
return CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.METHOD,
name=method_name,
content=content,
start_line=method_node.start_point[0] + 1,
end_line=method_node.end_point[0] + 1,
docstring=docstring,
parent_name=class_name,
)
def _create_function_chunk(
self,
func_node: Node,
source: bytes,
language: Language,
file_path: str,
) -> CodeChunk:
"""Create a chunk for a top-level function.
Args:
func_node: The function AST node.
source: Source bytes.
language: Programming language.
file_path: Relative file path.
Returns:
A CodeChunk for the function.
"""
func_name = get_node_name(func_node, source, language) or "anonymous"
content = get_node_text(func_node, source)
docstring = get_docstring(func_node, source, language)
chunk_id = self._generate_id(file_path, f"func_{func_name}", func_node.start_point[0])
return CodeChunk(
id=chunk_id,
file_path=file_path,
language=language,
chunk_type=ChunkType.FUNCTION,
name=func_name,
content=content,
start_line=func_node.start_point[0] + 1,
end_line=func_node.end_point[0] + 1,
docstring=docstring,
)
def _is_inside_class(self, node: Node, class_types: set[str]) -> bool:
"""Check if a node is inside a class definition.
Args:
node: The node to check.
class_types: Set of class node type names.
Returns:
True if the node is inside a class.
"""
parent = node.parent
while parent:
if parent.type in class_types:
return True
parent = parent.parent
return False
def _generate_id(self, file_path: str, name: str, line: int) -> str:
"""Generate a unique chunk ID.
Args:
file_path: File path.
name: Chunk name.
line: Line number.
Returns:
A unique ID string.
"""
key = f"{file_path}:{name}:{line}"
return hashlib.sha256(key.encode()).hexdigest()[:16]