# unit_test_generator.py
from __future__ import annotations
import ast
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from .file_scanner import discover_python_sources
AUTO_HEADER = "# AUTO-GENERATED BY MCP TESTING AGENT. DO NOT EDIT MANUALLY.\n"
@dataclass
class UnitTestResult:
generated: int
skipped: int
files: List[str]
def generate_unit_tests_for_paths(
project_root: Path,
paths: List[Path],
language: str = "python",
overwrite: bool = False,
) -> UnitTestResult:
project_root = project_root.resolve()
if language != "python":
return UnitTestResult(0, 0, [])
# Use the unified scanner instead
all_sources = discover_python_sources(project_root)
# Convert to a set for faster lookup
source_set = {p.resolve() for p in all_sources}
generated = 0
skipped = 0
out_files: List[str] = []
for src in sorted(source_set):
rel = src.relative_to(project_root)
# ignore test files
if "tests" in rel.parts or rel.name.startswith("test_"):
skipped += 1
continue
test_file = _target_test_file(project_root, rel)
if test_file.exists() and not overwrite:
skipped += 1
continue
functions, methods = _extract_public_api(src)
if not functions and not methods:
skipped += 1
continue
content = _build_pytest_module(rel, functions, methods)
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text(content, encoding="utf-8")
generated += 1
out_files.append(str(test_file))
return UnitTestResult(generated, skipped, out_files)
# -----------------------------------------
# Helper functions
# -----------------------------------------
def _target_test_file(project_root: Path, source_rel_path: Path) -> Path:
tests_unit = project_root / "tests" / "unit"
parts = list(source_rel_path.parts)
if parts and parts[0] == "src":
parts = parts[1:]
filename = "_".join(p.replace(".py", "") for p in parts)
if not filename:
filename = "module"
test_name = f"test_{filename}.py"
return tests_unit / test_name
def _extract_public_api(source_file: Path) -> Tuple[List[str], Dict[str, List[str]]]:
text = source_file.read_text(encoding="utf-8")
try:
tree = ast.parse(text)
except SyntaxError:
return [], {}
functions: List[str] = []
methods: Dict[str, List[str]] = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef) and not node.name.startswith("_"):
functions.append(node.name)
elif isinstance(node, ast.ClassDef):
m = [
sub.name for sub in node.body
if isinstance(sub, ast.FunctionDef) and not sub.name.startswith("_")
]
if m:
methods[node.name] = m
return functions, methods
def _build_pytest_module(source_rel_path: Path, functions: List[str], methods: Dict[str, List[str]]) -> str:
import_name = _module_import_path_from_rel(source_rel_path)
lines: List[str] = []
lines.append(AUTO_HEADER)
lines.append("import pytest\n")
lines.append(f"import {import_name}\n\n")
for func in functions:
lines.append(f"\n\ndef test_{func}_basic():\n")
lines.append(f" pytest.skip('Autogenerated test — parameters required.')\n")
for cls, mlist in methods.items():
lines.append(f"\n\ndef test_{cls.lower()}_basic():\n")
lines.append(f" pytest.skip('Autogenerated test — parameters required.')\n")
lines.append("\n")
return "".join(lines)
def _module_import_path_from_rel(source_rel_path: Path) -> str:
parts = list(source_rel_path.with_suffix("").parts)
if parts and parts[0] == "src":
parts = parts[1:]
return ".".join(parts)