"""Tests for ResourceLimits and input size validation (Phase 3).
Tests CWE-400 prevention: resource consumption limits to prevent DoS attacks.
"""
from pathlib import Path
import pytest
from local_deepwiki.validation import (
VALID_RESEARCH_PRESETS,
ResourceLimits,
validate_deep_research_parameters,
validate_index_parameters,
validate_query_parameters,
)
class TestResourceLimitsConstants:
"""Tests for ResourceLimits constant values."""
def test_max_query_length_exists(self):
"""Test MAX_QUERY_LENGTH constant exists and is reasonable."""
assert hasattr(ResourceLimits, "MAX_QUERY_LENGTH")
assert isinstance(ResourceLimits.MAX_QUERY_LENGTH, int)
assert ResourceLimits.MAX_QUERY_LENGTH > 0
def test_max_query_length_value(self):
"""Test MAX_QUERY_LENGTH has expected value."""
assert ResourceLimits.MAX_QUERY_LENGTH == 5000
def test_max_question_length_exists(self):
"""Test MAX_QUESTION_LENGTH constant exists."""
assert hasattr(ResourceLimits, "MAX_QUESTION_LENGTH")
assert isinstance(ResourceLimits.MAX_QUESTION_LENGTH, int)
assert ResourceLimits.MAX_QUESTION_LENGTH > 0
def test_max_question_length_value(self):
"""Test MAX_QUESTION_LENGTH has expected value."""
assert ResourceLimits.MAX_QUESTION_LENGTH == 2000
def test_max_repo_size_exists(self):
"""Test MAX_REPO_SIZE constant exists."""
assert hasattr(ResourceLimits, "MAX_REPO_SIZE")
assert isinstance(ResourceLimits.MAX_REPO_SIZE, int)
def test_max_repo_size_value(self):
"""Test MAX_REPO_SIZE is 1GB."""
assert ResourceLimits.MAX_REPO_SIZE == 1_000_000_000
def test_max_files_per_repo_exists(self):
"""Test MAX_FILES_PER_REPO constant exists."""
assert hasattr(ResourceLimits, "MAX_FILES_PER_REPO")
assert isinstance(ResourceLimits.MAX_FILES_PER_REPO, int)
def test_max_files_per_repo_value(self):
"""Test MAX_FILES_PER_REPO is 50000."""
assert ResourceLimits.MAX_FILES_PER_REPO == 50_000
def test_max_file_size_exists(self):
"""Test MAX_FILE_SIZE constant exists."""
assert hasattr(ResourceLimits, "MAX_FILE_SIZE")
assert isinstance(ResourceLimits.MAX_FILE_SIZE, int)
def test_max_file_size_value(self):
"""Test MAX_FILE_SIZE is 50MB."""
assert ResourceLimits.MAX_FILE_SIZE == 50_000_000
def test_max_context_chunks_exists(self):
"""Test MAX_CONTEXT_CHUNKS constant exists."""
assert hasattr(ResourceLimits, "MAX_CONTEXT_CHUNKS")
assert isinstance(ResourceLimits.MAX_CONTEXT_CHUNKS, int)
def test_max_context_chunks_value(self):
"""Test MAX_CONTEXT_CHUNKS is 500."""
assert ResourceLimits.MAX_CONTEXT_CHUNKS == 500
def test_max_sub_questions_exists(self):
"""Test MAX_SUB_QUESTIONS constant exists."""
assert hasattr(ResourceLimits, "MAX_SUB_QUESTIONS")
assert ResourceLimits.MAX_SUB_QUESTIONS == 20
def test_max_research_depth_exists(self):
"""Test MAX_RESEARCH_DEPTH constant exists."""
assert hasattr(ResourceLimits, "MAX_RESEARCH_DEPTH")
assert ResourceLimits.MAX_RESEARCH_DEPTH == 5
def test_max_pdf_pages_exists(self):
"""Test MAX_PDF_PAGES constant exists."""
assert hasattr(ResourceLimits, "MAX_PDF_PAGES")
assert ResourceLimits.MAX_PDF_PAGES == 10_000
def test_max_html_size_exists(self):
"""Test MAX_HTML_SIZE constant exists."""
assert hasattr(ResourceLimits, "MAX_HTML_SIZE")
assert ResourceLimits.MAX_HTML_SIZE == 100_000_000
class TestValidateQueryParameters:
"""Tests for validate_query_parameters function."""
def test_valid_inputs(self, tmp_path):
"""Test valid inputs pass validation without error."""
# Create a test directory
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
# This should not raise
validate_query_parameters(
query="What is this codebase about?",
repo_path=str(repo_dir),
max_results=10,
)
def test_empty_query_raises(self, tmp_path):
"""Test empty query raises ValueError."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
with pytest.raises(ValueError, match="Query cannot be empty"):
validate_query_parameters(
query="",
repo_path=str(repo_dir),
max_results=10,
)
def test_query_too_long_raises(self, tmp_path):
"""Test query exceeding max length raises ValueError."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
long_query = "a" * (ResourceLimits.MAX_QUERY_LENGTH + 1)
with pytest.raises(ValueError, match="exceeds maximum length"):
validate_query_parameters(
query=long_query,
repo_path=str(repo_dir),
max_results=10,
)
def test_query_at_max_length_passes(self, tmp_path):
"""Test query at exactly max length passes."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
max_query = "a" * ResourceLimits.MAX_QUERY_LENGTH
# Should not raise
validate_query_parameters(
query=max_query,
repo_path=str(repo_dir),
max_results=10,
)
def test_invalid_path_raises(self):
"""Test non-existent path raises ValueError."""
with pytest.raises(ValueError, match="does not exist"):
validate_query_parameters(
query="test query",
repo_path="/nonexistent/path/to/repo",
max_results=10,
)
def test_path_not_directory_raises(self, tmp_path):
"""Test path pointing to file raises ValueError."""
# Create a file, not directory
test_file = tmp_path / "not_a_directory.txt"
test_file.write_text("content")
with pytest.raises(ValueError, match="is not a directory"):
validate_query_parameters(
query="test query",
repo_path=str(test_file),
max_results=10,
)
def test_max_results_too_low_raises(self, tmp_path):
"""Test max_results below 1 raises ValueError."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
with pytest.raises(ValueError, match="max_results must be between"):
validate_query_parameters(
query="test query",
repo_path=str(repo_dir),
max_results=0,
)
def test_max_results_negative_raises(self, tmp_path):
"""Test negative max_results raises ValueError."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
with pytest.raises(ValueError, match="max_results must be between"):
validate_query_parameters(
query="test query",
repo_path=str(repo_dir),
max_results=-5,
)
def test_max_results_too_high_raises(self, tmp_path):
"""Test max_results above limit raises ValueError."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
with pytest.raises(ValueError, match="max_results must be between"):
validate_query_parameters(
query="test query",
repo_path=str(repo_dir),
max_results=ResourceLimits.MAX_CONTEXT_CHUNKS + 1,
)
def test_max_results_at_limit_passes(self, tmp_path):
"""Test max_results at exactly limit passes."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
# Should not raise
validate_query_parameters(
query="test query",
repo_path=str(repo_dir),
max_results=ResourceLimits.MAX_CONTEXT_CHUNKS,
)
class TestValidateIndexParameters:
"""Tests for validate_index_parameters function."""
def test_valid_small_repo(self, tmp_path):
"""Test valid small repo returns correct tuple."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
# Create some small files
(repo_dir / "file1.py").write_text("print('hello')")
(repo_dir / "file2.py").write_text("x = 1")
(repo_dir / "subdir").mkdir()
(repo_dir / "subdir" / "file3.py").write_text("y = 2")
total_size, file_count = validate_index_parameters(str(repo_dir))
assert file_count == 3
assert total_size > 0
def test_empty_repo_returns_zero(self, tmp_path):
"""Test empty repo returns zeros."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
total_size, file_count = validate_index_parameters(str(repo_dir))
assert file_count == 0
assert total_size == 0
def test_repo_too_large_raises(self, tmp_path):
"""Test repo exceeding size limit raises ValueError."""
from unittest.mock import MagicMock, patch
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
# Create 25 small files; mocked stat makes each appear 45 MB
# 25 x 45 MB = 1.125 GB > MAX_REPO_SIZE (1 GB)
for i in range(25):
(repo_dir / f"file{i}.py").write_text("x")
original_stat = Path.stat
def fake_stat(self_path):
if self_path.parent == repo_dir and self_path.name.endswith(".py"):
result = MagicMock()
result.st_size = 45_000_000
return result
return original_stat(self_path)
with patch.object(Path, "stat", fake_stat):
with pytest.raises(ValueError, match="exceeds maximum size"):
validate_index_parameters(str(repo_dir))
def test_too_many_files_raises(self, tmp_path):
"""Test repo with too many files raises ValueError."""
from unittest.mock import MagicMock, patch
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
fake_filenames = [f"f{i}.py" for i in range(50_001)]
fake_walk_result = [(str(repo_dir), [], fake_filenames)]
original_stat = Path.stat
def fake_stat(self_path):
if str(self_path).startswith(str(repo_dir)):
result = MagicMock()
result.st_size = 10
return result
return original_stat(self_path)
with (
patch("os.walk", return_value=iter(fake_walk_result)),
patch.object(Path, "stat", fake_stat),
):
with pytest.raises(ValueError, match="exceeds maximum file count"):
validate_index_parameters(str(repo_dir))
def test_file_too_large_raises(self, tmp_path):
"""Test single file exceeding limit raises ValueError."""
from unittest.mock import MagicMock, patch
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
(repo_dir / "huge.bin").write_text("x")
original_stat = Path.stat
def fake_stat(self_path):
if self_path.name == "huge.bin":
result = MagicMock()
result.st_size = ResourceLimits.MAX_FILE_SIZE + 1
return result
return original_stat(self_path)
with patch.object(Path, "stat", fake_stat):
with pytest.raises(ValueError, match="File too large"):
validate_index_parameters(str(repo_dir))
def test_returns_correct_tuple_type(self, tmp_path):
"""Test return type is tuple of (int, int)."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
(repo_dir / "test.py").write_text("pass")
result = validate_index_parameters(str(repo_dir))
assert isinstance(result, tuple)
assert len(result) == 2
assert isinstance(result[0], int) # total_size
assert isinstance(result[1], int) # file_count
def test_counts_nested_files(self, tmp_path):
"""Test that nested files are counted."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
(repo_dir / "level1").mkdir()
(repo_dir / "level1" / "level2").mkdir()
(repo_dir / "level1" / "level2" / "deep.py").write_text("deep")
(repo_dir / "root.py").write_text("root")
total_size, file_count = validate_index_parameters(str(repo_dir))
assert file_count == 2
def test_only_counts_files_not_directories(self, tmp_path):
"""Test that only files are counted, not directories."""
repo_dir = tmp_path / "test_repo"
repo_dir.mkdir()
(repo_dir / "dir1").mkdir()
(repo_dir / "dir2").mkdir()
(repo_dir / "file.py").write_text("x")
total_size, file_count = validate_index_parameters(str(repo_dir))
assert file_count == 1
class TestValidateDeepResearchParameters:
"""Tests for validate_deep_research_parameters function."""
def test_valid_inputs(self):
"""Test valid inputs pass validation without error."""
# Should not raise
validate_deep_research_parameters(
question="How does authentication work?",
preset="default",
max_chunks=30,
)
def test_valid_with_none_preset(self):
"""Test None preset is valid."""
validate_deep_research_parameters(
question="What is the architecture?",
preset=None,
max_chunks=50,
)
def test_question_empty_raises(self):
"""Test empty question raises ValueError."""
with pytest.raises(ValueError, match="Question cannot be empty"):
validate_deep_research_parameters(
question="",
preset="default",
max_chunks=30,
)
def test_question_too_long_raises(self):
"""Test question exceeding max length raises ValueError."""
long_question = "q" * (ResourceLimits.MAX_QUESTION_LENGTH + 1)
with pytest.raises(ValueError, match="exceeds maximum length"):
validate_deep_research_parameters(
question=long_question,
preset="default",
max_chunks=30,
)
def test_question_at_max_length_passes(self):
"""Test question at exactly max length passes."""
max_question = "q" * ResourceLimits.MAX_QUESTION_LENGTH
# Should not raise
validate_deep_research_parameters(
question=max_question,
preset="default",
max_chunks=30,
)
def test_invalid_preset_raises(self):
"""Test invalid preset raises ValueError."""
with pytest.raises(ValueError, match="Invalid preset"):
validate_deep_research_parameters(
question="test question",
preset="invalid_preset",
max_chunks=30,
)
def test_preset_quick_valid(self):
"""Test 'quick' preset is valid."""
validate_deep_research_parameters(
question="test",
preset="quick",
max_chunks=30,
)
def test_preset_default_valid(self):
"""Test 'default' preset is valid."""
validate_deep_research_parameters(
question="test",
preset="default",
max_chunks=30,
)
def test_preset_thorough_valid(self):
"""Test 'thorough' preset is valid."""
validate_deep_research_parameters(
question="test",
preset="thorough",
max_chunks=30,
)
def test_max_chunks_too_low_raises(self):
"""Test max_chunks below 1 raises ValueError."""
with pytest.raises(ValueError, match="max_chunks must be between"):
validate_deep_research_parameters(
question="test",
preset="default",
max_chunks=0,
)
def test_max_chunks_negative_raises(self):
"""Test negative max_chunks raises ValueError."""
with pytest.raises(ValueError, match="max_chunks must be between"):
validate_deep_research_parameters(
question="test",
preset="default",
max_chunks=-10,
)
def test_max_chunks_too_high_raises(self):
"""Test max_chunks above limit raises ValueError."""
with pytest.raises(ValueError, match="max_chunks must be between"):
validate_deep_research_parameters(
question="test",
preset="default",
max_chunks=ResourceLimits.MAX_CONTEXT_CHUNKS + 1,
)
def test_max_chunks_at_limit_passes(self):
"""Test max_chunks at exactly limit passes."""
validate_deep_research_parameters(
question="test",
preset="default",
max_chunks=ResourceLimits.MAX_CONTEXT_CHUNKS,
)
def test_error_shows_valid_presets(self):
"""Test error message includes valid preset options."""
with pytest.raises(ValueError) as exc_info:
validate_deep_research_parameters(
question="test",
preset="bad",
max_chunks=30,
)
error_msg = str(exc_info.value)
assert "quick" in error_msg or "default" in error_msg or "thorough" in error_msg
class TestValidResearchPresets:
"""Tests for VALID_RESEARCH_PRESETS constant."""
def test_presets_exist(self):
"""Test VALID_RESEARCH_PRESETS is defined."""
assert VALID_RESEARCH_PRESETS is not None
def test_presets_is_set(self):
"""Test VALID_RESEARCH_PRESETS is a frozenset."""
assert isinstance(VALID_RESEARCH_PRESETS, frozenset)
def test_presets_contains_quick(self):
"""Test 'quick' is a valid preset."""
assert "quick" in VALID_RESEARCH_PRESETS
def test_presets_contains_default(self):
"""Test 'default' is a valid preset."""
assert "default" in VALID_RESEARCH_PRESETS
def test_presets_contains_thorough(self):
"""Test 'thorough' is a valid preset."""
assert "thorough" in VALID_RESEARCH_PRESETS
def test_presets_count(self):
"""Test there are exactly 3 presets."""
assert len(VALID_RESEARCH_PRESETS) == 3