"""Test file discovery and tree-sitter AST helpers for test analysis."""
from __future__ import annotations
from pathlib import Path
from tree_sitter import Node
from local_deepwiki.logging import get_logger
logger = get_logger(__name__)
def find_test_files(source_file: Path, repo_root: Path) -> list[Path]:
"""Find all corresponding test files for a source file.
Tries multiple strategies:
1. Direct match: src/.../foo.py -> tests/test_foo.py
2. Coverage tests: src/.../foo.py -> tests/test_foo_coverage.py
3. Suffix variants: tests/test_foo_*.py
4. Alternative naming: tests/foo_test.py
Args:
source_file: Path to the source file.
repo_root: Root directory of the repository.
Returns:
List of test file paths found (may be empty).
"""
# Get base filename without extension
base_name = source_file.stem # e.g., "api_docs"
# Skip test files themselves
if base_name.startswith("test_"):
return []
test_files: list[Path] = []
# Common test directories to check
test_dirs = [
repo_root / "tests",
repo_root / "test",
]
for test_dir in test_dirs:
if not test_dir.exists():
continue
# Try direct match: test_<basename>.py
test_file = test_dir / f"test_{base_name}.py"
if test_file.exists():
test_files.append(test_file)
# Try coverage variant: test_<basename>_coverage.py
coverage_file = test_dir / f"test_{base_name}_coverage.py"
if coverage_file.exists():
test_files.append(coverage_file)
# Try glob for other variants: test_<basename>_*.py
for variant in test_dir.glob(f"test_{base_name}_*.py"):
if variant not in test_files:
test_files.append(variant)
# Try alternative naming: <basename>_test.py
alt_file = test_dir / f"{base_name}_test.py"
if alt_file.exists() and alt_file not in test_files:
test_files.append(alt_file)
if test_files:
logger.debug("Found %s test file(s) for %s", len(test_files), source_file.name)
return test_files
def find_test_file(source_file: Path, repo_root: Path) -> Path | None:
"""Find the corresponding test file for a source file.
Legacy function for backwards compatibility.
Returns the first test file found, or None.
Args:
source_file: Path to the source file.
repo_root: Root directory of the repository.
Returns:
Path to the test file if found, None otherwise.
"""
test_files = find_test_files(source_file, repo_root)
return test_files[0] if test_files else None
def _get_node_text(node: Node, source: bytes) -> str:
"""Get the text content of a tree-sitter node."""
return source[node.start_byte : node.end_byte].decode("utf-8")
def _find_test_functions(root: Node) -> list[tuple[Node, str | None]]:
"""Find all test function definitions in the AST.
Finds both standalone test functions and test methods in test classes.
Args:
root: Root node of the parsed test file.
Returns:
List of (function_definition_node, class_name) tuples.
class_name is None for standalone functions.
"""
test_functions: list[tuple[Node, str | None]] = []
def walk(node: Node, current_class: str | None = None) -> None:
if node.type == "class_definition":
# Get class name
name_node = node.child_by_field_name("name")
if name_node:
class_name = name_node.text.decode("utf-8") if name_node.text else ""
# Check if it's a test class
if class_name.startswith("Test"):
# Walk children with this class context
for child in node.children:
walk(child, class_name)
return
if node.type == "function_definition":
# Get the function name
name_node = node.child_by_field_name("name")
if name_node:
name = name_node.text.decode("utf-8") if name_node.text else ""
if name.startswith("test_"):
test_functions.append((node, current_class))
for child in node.children:
walk(child, current_class)
walk(root)
return test_functions
def _get_function_name(func_node: Node, source: bytes) -> str:
"""Get the name of a function from its AST node."""
name_node = func_node.child_by_field_name("name")
if name_node:
return _get_node_text(name_node, source)
return "unknown"
def _get_docstring(func_node: Node, source: bytes) -> str | None:
"""Extract docstring from a function node if present."""
body = func_node.child_by_field_name("body")
if not body or not body.children:
return None
# First statement in body might be a docstring
first_stmt = body.children[0]
if first_stmt.type == "expression_statement":
expr = first_stmt.children[0] if first_stmt.children else None
if expr and expr.type == "string":
docstring = _get_node_text(expr, source)
# Clean up the docstring
docstring = docstring.strip("\"'")
if docstring.startswith('""'):
docstring = (
docstring[2:-2] if docstring.endswith('""') else docstring[2:]
)
return docstring.strip()
return None
def _get_function_body(func_node: Node, source: bytes) -> str:
"""Get the body of a function as a string."""
body = func_node.child_by_field_name("body")
if body:
return _get_node_text(body, source)
return ""
def _is_mock_heavy(body: str) -> bool:
"""Check if a test body uses mocking extensively.
We want to exclude heavily mocked tests as they don't show
real usage patterns.
"""
mock_indicators = [
"MagicMock",
"AsyncMock",
"@patch",
"patch(",
"mock_",
"mocker.",
]
mock_count = sum(1 for indicator in mock_indicators if indicator in body)
return mock_count >= 2