markdown.py•5.76 kB
from dataclasses import dataclass
from importlib import resources
from typing import Any, Optional, cast
from tree_sitter import Node
from llm_context.excerpters.base import Excerpt, Excerpter, Excerpts, Excluded
from llm_context.excerpters.language_mapping import to_language
from llm_context.excerpters.parser import ASTFactory, Source
@dataclass(frozen=True)
class ContentRange:
start_line: int
end_line: int
node_type: str
@dataclass(frozen=True)
class Markdown(Excerpter):
config: dict[str, Any]
def excerpt(self, sources: list[Source]) -> Excerpts:
if not sources:
return Excerpts([], {})
excerpts = [
self._excerpt_source(source)
for source in sources
if to_language(source.rel_path) == "markdown"
]
return Excerpts(excerpts, {"sample_definitions": []})
def excluded(self, sources: list[Source]) -> list[Excluded]:
if not sources:
return []
results = []
for source in sources:
if to_language(source.rel_path) != "markdown":
continue
excluded_content = self._collect_excluded(source)
if excluded_content:
results.append(
Excluded(
{"omitted_content": excluded_content},
{"file": source.rel_path},
)
)
return results
def _excerpt_source(self, source: Source) -> Excerpt:
ast = ASTFactory.create().create_from_code(source)
included_ranges = self._get_included_ranges(ast, source)
content = self._format_content(source.content, included_ranges)
return Excerpt(source.rel_path, content, self._metadata())
def _get_included_ranges(self, ast: Any, source: Source) -> set[int]:
query_str = self._get_query()
matches = ast.match(query_str)
included_lines: set[int] = set()
for _, captures in matches:
for capture_name, nodes in captures.items():
node_type = self._map_capture_to_type(capture_name)
if self._should_include(node_type):
for node in nodes:
for line_num in range(node.start_point[0], node.end_point[0] + 1):
included_lines.add(line_num)
return included_lines
def _format_content(self, content: str, included_line_set: set[int]) -> str:
lines = content.splitlines()
result: list[str] = []
last_included_index = -2
for i, line in enumerate(lines):
is_included = i in included_line_set
if is_included:
if i - last_included_index > 2 and result:
result.append("⋮...")
result.append(line)
last_included_index = i
if result and last_included_index < len(lines) - 2:
result.append("⋮...")
return "\n".join(result)
def _collect_excluded(self, source: Source) -> str:
ast = ASTFactory.create().create_from_code(source)
matches = ast.match(self._get_query())
included_ranges: set[tuple[int, int]] = set()
for _, captures in matches.items():
for capture_name, nodes in captures.items():
node_type = self._map_capture_to_type(capture_name)
if self._should_include(node_type):
for node in nodes:
included_ranges.add((node.start_point[0], node.end_point[0]))
excluded_paras = []
for _, captures in matches.items():
if "content.paragraph" in captures:
for node in captures["content.paragraph"]:
node_range = (node.start_point[0], node.end_point[0])
if node_range not in included_ranges:
text = self._node_text(node, source.content)
if text.strip():
excluded_paras.append(text)
return "\n\n".join(excluded_paras[:3]) if excluded_paras else ""
def _map_capture_to_type(self, capture_name: str) -> str:
mapping = {
"heading.atx": "heading",
"heading.setext": "heading",
"code.fenced": "code_block",
"code.indented": "code_block",
"list.item": "list_item",
"table.pipe": "table",
"quote.block": "blockquote",
"break.thematic": "thematic_break",
"content.paragraph": "paragraph",
}
return mapping.get(capture_name, "")
def _should_include(self, node_type: str) -> bool:
config_map = {
"heading": True,
"code_block": cast(bool, self.config.get("with-code-blocks", True)),
"list_item": cast(bool, self.config.get("with-lists", True)),
"table": cast(bool, self.config.get("with-tables", True)),
"blockquote": cast(bool, self.config.get("with-blockquotes", True)),
"thematic_break": cast(bool, self.config.get("with-thematic-breaks", True)),
"paragraph": False,
}
return config_map.get(node_type, False)
def _get_query(self) -> str:
return resources.files("llm_context.excerpters.ts-qry").joinpath("markdown.scm").read_text()
def _node_text(self, node: Node, content: str) -> str:
return content[node.start_byte : node.end_byte]
def _metadata(self) -> dict[str, Any]:
included = [
k.replace("with-", "") for k, v in self.config.items() if isinstance(v, bool) and v
]
return {
"processor_type": "markdown",
"included_elements": included,
}