"""Code chunking strategies for embedding generation."""
from typing import Any
from src.config import settings
from src.logger import get_logger
logger = get_logger(__name__)
# Constants for chunking
MAX_MODULE_DOCSTRING_SEARCH_LINES = 50
class CodeChunker:
"""Chunk code into logical units for embedding."""
def __init__(self) -> None:
self.chunk_size = settings.parser.chunk_size
self.max_tokens = settings.embeddings.max_tokens
def chunk_by_entity(
self,
entities: dict[str, list[dict[str, Any]]],
file_content: str,
) -> list[dict[str, Any]]:
"""Chunk code by logical entities (functions, classes, etc.)."""
chunks = []
lines = file_content.split("\n")
# Process functions
for func in entities.get("functions", []):
chunk = self._create_entity_chunk(
"function",
func,
lines,
include_context=True,
)
chunks.append(chunk)
# Process classes
for cls in entities.get("classes", []):
# Create chunk for entire class
chunk = self._create_entity_chunk("class", cls, lines, include_context=True)
chunks.append(chunk)
# Also create chunks for individual methods if class is large
if cls.get("end_line", 0) - cls.get("start_line", 0) > self.chunk_size:
for method in cls.get("methods", []):
method_chunk = self._create_entity_chunk(
"method",
method,
lines,
include_context=True,
parent_class=cls["name"],
)
chunks.append(method_chunk)
# Process module-level code
module_chunk = self._create_module_chunk(entities, lines)
if module_chunk:
chunks.append(module_chunk)
return chunks
def chunk_by_lines(
self,
file_content: str,
overlap: int = 20,
) -> list[dict[str, Any]]:
"""Chunk code by line count with overlap."""
lines = file_content.split("\n")
chunks = []
for i in range(0, len(lines), self.chunk_size - overlap):
start_line = i + 1
end_line = min(i + self.chunk_size, len(lines))
chunk_lines = lines[i:end_line]
chunk_content = "\n".join(chunk_lines)
chunks.append(
{
"type": "lines",
"content": chunk_content,
"start_line": start_line,
"end_line": end_line,
"metadata": {"line_count": len(chunk_lines), "has_overlap": i > 0},
},
)
return chunks
def _create_entity_chunk(
self,
entity_type: str,
entity: dict[str, Any],
lines: list[str],
*,
include_context: bool = False,
parent_class: str | None = None,
) -> dict[str, Any]:
"""Create a chunk for a code entity."""
start_line = entity.get("start_line", 1) - 1
end_line = entity.get("end_line", len(lines))
# Include context lines if requested
if include_context:
context_before = 3
# Include 2 trailing lines to capture the next significant line after blank spacing
context_after = 2
start_line = max(0, start_line - context_before)
end_line = min(len(lines), end_line + context_after)
chunk_lines = lines[start_line:end_line]
chunk_content = "\n".join(chunk_lines)
# Build metadata
metadata = {
"entity_name": entity.get("name", "unknown"),
"entity_type": entity_type,
"has_docstring": bool(entity.get("docstring")),
"line_count": len(chunk_lines),
}
if parent_class:
metadata["parent_class"] = parent_class
if entity_type in {"function", "method"}:
metadata["parameters"] = entity.get("parameters", [])
metadata["return_type"] = entity.get("return_type")
metadata["is_async"] = entity.get("is_async", False)
metadata["is_generator"] = entity.get("is_generator", False)
elif entity_type == "class":
metadata["base_classes"] = entity.get("base_classes", [])
metadata["method_count"] = len(entity.get("methods", []))
metadata["is_abstract"] = entity.get("is_abstract", False)
return {
"type": entity_type,
"content": chunk_content,
"start_line": start_line + 1,
"end_line": end_line,
"metadata": metadata,
}
def _create_module_chunk(
self,
entities: dict[str, list[dict[str, Any]]],
lines: list[str],
) -> dict[str, Any] | None:
"""Create a chunk for module-level code.
Includes leading module docstring block (if present), comments, and imports,
up to the first non-prolog line or MAX_MODULE_DOCSTRING_SEARCH_LINES.
"""
module_end = 0
in_docstring = False
doc_delim: str | None = None
for i, line in enumerate(lines):
if i >= MAX_MODULE_DOCSTRING_SEARCH_LINES - 1:
# Stop scanning too deep; cap at limit
module_end = MAX_MODULE_DOCSTRING_SEARCH_LINES
break
stripped = line.strip()
if not in_docstring:
# Start of a module docstring block
if stripped.startswith(('"""', "'''")):
in_docstring = True
doc_delim = '"""' if stripped.startswith('"""') else "'''"
module_end = i + 1
# Handle single-line docstring like """text"""
if stripped.count(doc_delim) >= 2:
in_docstring = False
doc_delim = None
continue
# Prolog lines we keep: blank, comments, imports
if stripped == "" or stripped.startswith(("#", "import ", "from ")):
module_end = i + 1
continue
# First non-prolog line reached
break
# Inside docstring block; keep lines until closing delimiter
module_end = i + 1
if doc_delim and doc_delim in stripped:
in_docstring = False
doc_delim = None
if module_end == 0:
return None
# Ensure the module chunk has substantive content (not just a lone docstring delimiter)
content_lines = lines[:module_end]
content = "\n".join(content_lines).strip()
if content in {'"""', "'''", ""}:
return None
chunk_content = "\n".join(content_lines)
return {
"type": "module",
"content": chunk_content,
"start_line": 1,
"end_line": module_end,
"metadata": {
"import_count": len(entities.get("imports", [])),
"class_count": len(entities.get("classes", [])),
"function_count": len(entities.get("functions", [])),
},
}
def merge_small_chunks(
self,
chunks: list[dict[str, Any]],
min_size: int = 10,
) -> list[dict[str, Any]]:
"""Merge small chunks to improve efficiency."""
merged = []
buffer = None
for chunk in chunks:
chunk_size = chunk["end_line"] - chunk["start_line"] + 1
if chunk_size < min_size and chunk["type"] in ("function", "method"):
buffer = chunk if buffer is None else self._merge_chunks(buffer, chunk)
else:
if buffer:
merged.append(buffer)
buffer = None
merged.append(chunk)
if buffer:
merged.append(buffer)
return merged
def _merge_chunks(
self,
chunk1: dict[str, Any],
chunk2: dict[str, Any],
) -> dict[str, Any]:
"""Merge two chunks."""
return {
"type": "merged",
"content": chunk1["content"] + "\n\n" + chunk2["content"],
"start_line": chunk1["start_line"],
"end_line": chunk2["end_line"],
"metadata": {
"merged_types": [chunk1["type"], chunk2["type"]],
"merged_entities": [
chunk1["metadata"].get("entity_name"),
chunk2["metadata"].get("entity_name"),
],
},
}