# unit_test_generator.py
from __future__ import annotations
import ast
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
AUTO_HEADER = "# AUTO-GENERATED BY MCP TESTING AGENT. DO NOT EDIT MANUALLY.\n"
@dataclass
class UnitTestResult:
generated: int
skipped: int
files: List[str]
def generate_unit_tests_for_paths(
project_root: Path,
paths: List[Path],
language: str = "python",
overwrite: bool = False,
) -> UnitTestResult:
"""
Generate unit tests for the given paths, relative to project_root.
Currently only supports Python source files (*.py).
"""
project_root = project_root.resolve()
if language != "python":
# Phase 1: only Python supported
return UnitTestResult(generated=0, skipped=0, files=[])
# Resolve all python files to generate tests for
source_files: Set[Path] = set()
for p in paths:
abs_p = (project_root / p).resolve() if not p.is_absolute() else p
if abs_p.is_dir():
for f in abs_p.rglob("*.py"):
if f.is_file():
source_files.add(f.resolve())
elif abs_p.suffix == ".py" and abs_p.is_file():
source_files.add(abs_p.resolve())
generated = 0
skipped = 0
out_files: List[str] = []
for src in sorted(source_files):
rel = src.relative_to(project_root)
# Skip obvious test files
if "tests" in rel.parts or rel.name.startswith("test_"):
skipped += 1
continue
test_file = _target_test_file(project_root, rel)
if test_file.exists() and not overwrite:
skipped += 1
continue
functions, methods = _extract_public_api(src)
if not functions and not methods:
skipped += 1
continue
content = _build_pytest_module(rel, functions, methods)
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text(content, encoding="utf-8")
generated += 1
out_files.append(str(test_file))
return UnitTestResult(generated=generated, skipped=skipped, files=out_files)
# ---------------------------------------------------------------------------
# Internals
# ---------------------------------------------------------------------------
def _target_test_file(project_root: Path, source_rel_path: Path) -> Path:
"""
Map src/module/foo.py -> tests/unit/test_module_foo.py
Map foo.py -> tests/unit/test_foo.py
"""
tests_unit = project_root / "tests" / "unit"
parts = list(source_rel_path.parts)
# Remove "src" prefix if present for nicer names
if parts and parts[0] == "src":
parts = parts[1:]
filename = "_".join(p.replace(".py", "") for p in parts)
if not filename:
filename = "module"
test_name = f"test_{filename}.py"
return tests_unit / test_name
def _extract_public_api(source_file: Path) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Extract top-level function names and class->method mapping from a Python file.
Returns:
(functions, methods) where:
- functions: [func_name, ...]
- methods: {class_name: [method_name, ...], ...}
"""
text = source_file.read_text(encoding="utf-8")
try:
tree = ast.parse(text, filename=str(source_file))
except SyntaxError:
return [], {}
functions: List[str] = []
methods: Dict[str, List[str]] = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef):
if not node.name.startswith("_"): # public-ish
functions.append(node.name)
elif isinstance(node, ast.ClassDef):
class_methods: List[str] = []
for sub in node.body:
if isinstance(sub, ast.FunctionDef) and not sub.name.startswith("_"):
class_methods.append(sub.name)
if class_methods:
methods[node.name] = class_methods
return functions, methods
def _build_pytest_module(
source_rel_path: Path,
functions: List[str],
methods: Dict[str, List[str]],
) -> str:
"""
Build a pytest test module as a string.
"""
import_name = _module_import_path_from_rel(source_rel_path)
lines: List[str] = []
lines.append(AUTO_HEADER)
lines.append("import pytest\n")
lines.append(f"import {import_name}\n\n")
# Simple function tests
for func in functions:
lines.append(f"\n\ndef test_{func}_basic():\n")
lines.append(f" # TODO: provide meaningful input values\n")
lines.append(f" result = {import_name}.{func}()\n")
lines.append(f" # TODO: assert expected result\n")
lines.append(f" assert result is not None\n")
# Simple class/method tests
for cls, mlist in methods.items():
lines.append(f"\n\ndef test_{cls.lower()}_basic():\n")
lines.append(f" obj = {import_name}.{cls}()\n")
for m in mlist:
lines.append(f" # TODO: provide meaningful input values for {cls}.{m}\n")
lines.append(f" result = obj.{m}()\n")
lines.append(" assert result is not None\n")
lines.append("\n")
return "".join(lines)
def _module_import_path_from_rel(source_rel_path: Path) -> str:
"""
Convert a relative path into a Python import path.
Example:
src/app/services/user.py -> app.services.user
app/main.py -> app.main
main.py -> main
"""
parts = list(source_rel_path.with_suffix("").parts)
# Strip "src" prefix if present
if parts and parts[0] == "src":
parts = parts[1:]
return ".".join(parts)