MCPunk

  • mcpunk
import ast import logging from functools import lru_cache from typing import Annotated, assert_never import asttokens import more_itertools from asttokens.util import Token from asttokens.util import walk as asttokens_walk from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) @lru_cache(maxsize=1024) def _ast_cache(src: str) -> asttokens.ASTTokens: """Source code to asttokens AST with caching. Note that the cache is not particularly large, so to make effective use of it you probably want to things like `for file: for analysis: ast = _ast_cache(...)` rather than `for analysis: for file: ast = _ast_cache(...)` """ return asttokens.ASTTokens(src, parse=True) class Callable(BaseModel): """Represents a callable (function/class) definition in the source code.""" name: Annotated[str, Field(description="Name of the function or class")] # line/col/offset are for where `my_func` begins in `def my_func(...):` or # `MyClass` in `class MyClass(...):` line: Annotated[int, Field(description="First line is 1")] col: Annotated[int, Field(description="First column is 0")] offset: Annotated[int, Field(description="char offset corresponding to line & col.")] code_offset_start: Annotated[ int, Field(description="Character offset of the start of the callable"), ] code_offset_end: Annotated[ int, Field(description="Character offset of the end of the callable"), ] # This is the whole code block. For a function it will include decorators, # for example. Does NOT line up with line/col/offset. code: Annotated[str, Field(description="Complete source code of the callable")] model_config = ConfigDict(validate_assignment=True) @classmethod def from_source_code( cls, source_code: str, ) -> list["Callable"]: """Extract all callables from the given source code.""" try: atok = _ast_cache(source_code) except SyntaxError: logger.error( # noqa: TRY400 f"Skipping {source_code} because it's not valid Python syntax", ) return [] assert atok.tree is not None nodes: list[ast.FunctionDef | ast.ClassDef | ast.AsyncFunctionDef] = [ n for n in asttokens_walk( atok.tree, # Joined strings seem to relate to f-strings, and enabling # them seems to be problematic. Let's just skip them. include_joined_str=False, ) if isinstance(n, ast.FunctionDef | ast.ClassDef | ast.AsyncFunctionDef) ] callables = [] for node in nodes: # print(ast.dump(node, indent=4)) # print(atok.get_text(node)) callable_name = node.name prefixes: list[str] if isinstance(node, ast.FunctionDef): prefixes = ["def"] elif isinstance(node, ast.AsyncFunctionDef): prefixes = ["async", "def"] elif isinstance(node, ast.ClassDef): prefixes = ["class"] else: assert_never(node) all_tokens: list[Token] = list(atok.get_tokens(node)) # Make sure we sort by position in the file all_tokens = sorted(all_tokens, key=lambda x: x.startpos) callable_start_tok: Token | None = None window_size = len(prefixes) + 1 looking_for = [*prefixes, callable_name] for window_elements in more_itertools.sliding_window(all_tokens, window_size): # This is a bit gross assert len(window_elements) == window_size if looking_for == [x.string for x in window_elements]: callable_start_tok = window_elements[-1] break if callable_start_tok is None: raise AssertionError( f"Could not find callable {callable_name} in {node}\n" f"{atok.get_text(node)}\n" f"{ast.dump(node, indent=4)}\n", ) range_ = atok.get_text_range(node) callable_ = Callable( name=node.name, line=callable_start_tok.start[0], col=callable_start_tok.start[1], offset=callable_start_tok.startpos, code=atok.get_text(node), code_offset_start=range_[0], code_offset_end=range_[1], ) callables.append(callable_) return callables def extract_imports(source_code: str) -> list[str]: """Extract all module-level import statements from source code. Takes source code as input and returns a list of import statements. """ atok = _ast_cache(source_code) imports: list[str] = [] for node in atok.tree.body: # type: ignore[union-attr] if isinstance(node, ast.Import | ast.ImportFrom): import_code = atok.get_text(node) imports.append(import_code) return imports def extract_module_statements(source_code: str) -> list[str]: """Extract all module-level statements from source code. function/class definitions are inserted like `def func1...` or `class MyClass...` to provide context around where the module-level statements are defined. Takes source code as input and returns a list of statement strings. """ # TODO: include comments atok = _ast_cache(source_code) statements: list[str] = [] for node in atok.tree.body: # type: ignore[union-attr] if isinstance(node, ast.FunctionDef): statements.append(f"def {node.name}...") elif isinstance(node, ast.AsyncFunctionDef): statements.append(f"async def {node.name}...") elif isinstance(node, ast.ClassDef): statements.append(f"class {node.name}...") elif isinstance(node, ast.Import | ast.ImportFrom): pass else: statement_code = atok.get_text(node) statements.append(statement_code) return statements