test_cleaner.py•12.6 kB
import pytest
from unittest.mock import patch, AsyncMock, mock_open
from tools.cleaner import clean_prompt
from llm.client import LLMTimeoutError, LLMHttpError, LLMNetworkError
from schemas import CleanPromptOutput
class TestCleaner:
"""Test the clean_prompt function"""
@pytest.fixture
def mock_system_prompt(self):
"""Mock system prompt content"""
return """[ROLE] "Prompt Cleaner"; Expert prompt engineer
[TASK] Refine RAW_PROMPT; preserve intent, placeholders, files, code callouts
[INPUT] RAW_PROMPT: string
[OUTPUT] STRICT; stdout = Valid JSON (UTF-8), double quotes, no trailing commas, no extra text; FIELDS: {
"cleaned": "string",
"notes": ["string"],
"open_questions": ["string"],
"risks": ["string"],
"unchanged": boolean,
"quality": {
"score": integer (1-5),
"reasons": ["string"]
}
}
[QUALITY_GATE] Score RAW_PROMPT 0–5 (1pt each): intent clear; io stated/N/A; constraints/acceptance present/N/A; no contradictions; If score ≥4 AND no redactions: unchanged=true and cleaned=RAW_PROMPT (byte-exact). Else unchanged=false and refine
[CLEANING_RULES] Concise, actionable, unambiguous; Use "\n- " for lists; specify inputs/outputs when present or clearly implied; Developer tone if code/spec; include types and edge/error cases; Don't invent requirements or change scope; preserve {{var}}, <VAR>, $VAR, backticks; Keep original language
[TROUBLESHOOT] If RAW_PROMPT is a direct question or short-answer: produce normal json preserving original prompt
[MODE_CONTEXT] Consider MODE (code/general) and CONTEXT for domain-specific improvements"""
@pytest.fixture
def mock_llm_response(self):
"""Mock LLM response"""
return '{"cleaned": "Enhanced prompt", "notes": ["Added specificity"], "open_questions": [], "risks": [], "unchanged": false, "quality": {"score": 4, "reasons": ["Clear and actionable"]}}'
@pytest.mark.asyncio
async def test_successful_cleaning(self, mock_system_prompt, mock_llm_response):
"""Test successful prompt cleaning"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.return_value = mock_llm_response
mock_client_class.return_value = mock_client
result = await clean_prompt(
raw_prompt="help me write code",
context="web development",
mode="general",
temperature=0.2,
)
assert isinstance(result, CleanPromptOutput)
assert result.cleaned == "Enhanced prompt"
assert result.notes == ["Added specificity"]
assert result.unchanged is False
assert result.quality.score == 4
@pytest.mark.asyncio
async def test_code_mode(self, mock_system_prompt, mock_llm_response):
"""Test cleaning with code mode"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.return_value = mock_llm_response
mock_client_class.return_value = mock_client
await clean_prompt(
raw_prompt="write a function",
context="Python development",
mode="code",
temperature=0.1,
)
# Verify the client was called with correct parameters
mock_client.chat_completions.assert_called_once()
call_args = mock_client.chat_completions.call_args
messages = call_args[1]["messages"]
assert "MODE: code" in messages[1]["content"]
assert "CONTEXT: Python development" in messages[1]["content"]
assert "RAW_PROMPT:\nwrite a function" in messages[1]["content"]
assert call_args[1]["temperature"] == 0.1
@pytest.mark.asyncio
async def test_invalid_input_validation(self):
"""Test input validation"""
with pytest.raises(ValueError, match="raw_prompt must be a non-empty string"):
await clean_prompt("")
with pytest.raises(ValueError, match="mode must be 'code' or 'general'"):
await clean_prompt("test", mode="invalid")
with pytest.raises(ValueError, match="temperature must be between 0.0 and 1.0"):
await clean_prompt("test", temperature=1.5)
@pytest.mark.asyncio
async def test_llm_timeout_error(self, mock_system_prompt):
"""Test handling of LLM timeout error"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.side_effect = LLMTimeoutError(
"Request timed out"
)
mock_client_class.return_value = mock_client
with pytest.raises(LLMTimeoutError):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_llm_http_error(self, mock_system_prompt):
"""Test handling of LLM HTTP error"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.side_effect = LLMHttpError(
"HTTP 500: Server Error"
)
mock_client_class.return_value = mock_client
with pytest.raises(LLMHttpError):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_llm_network_error(self, mock_system_prompt):
"""Test handling of LLM network error"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.side_effect = LLMNetworkError(
"Network error"
)
mock_client_class.return_value = mock_client
with pytest.raises(LLMNetworkError):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_json_extraction_failure(self, mock_system_prompt):
"""Test handling of JSON extraction failure"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.return_value = "This is not valid JSON"
mock_client_class.return_value = mock_client
with pytest.raises(ValueError, match="Failed to extract valid JSON"):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_schema_validation_failure(self, mock_system_prompt):
"""Test handling of schema validation failure"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
# Return JSON that doesn't match schema
mock_client.chat_completions.return_value = '{"invalid": "response"}'
mock_client_class.return_value = mock_client
with pytest.raises(ValueError, match="Response validation failed"):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_retry_mechanism(self, mock_system_prompt):
"""Test retry mechanism on failures"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
# First call fails, second succeeds
mock_client.chat_completions.side_effect = [
LLMHttpError("HTTP 500: Server Error"),
'{"cleaned": "Enhanced prompt", "notes": [], "open_questions": [], "risks": [], "unchanged": false, "quality": {"score": 4, "reasons": ["Clear"]}}',
]
mock_client_class.return_value = mock_client
result = await clean_prompt("test prompt")
# Should have been called twice (1 retry)
assert mock_client.chat_completions.call_count == 2
assert isinstance(result, CleanPromptOutput)
@pytest.mark.asyncio
async def test_strict_mode_on_retry(self, mock_system_prompt):
"""Test that strict mode is applied on retries"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
# First call fails with JSON extraction, second succeeds
mock_client.chat_completions.side_effect = [
"Invalid JSON response",
'{"cleaned": "Enhanced prompt", "notes": [], "open_questions": [], "risks": [], "unchanged": false, "quality": {"score": 4, "reasons": ["Clear"]}}',
]
mock_client_class.return_value = mock_client
await clean_prompt("test prompt")
# Check that second call used temperature 0.0 (strict mode)
second_call = mock_client.chat_completions.call_args_list[1]
assert second_call[1]["temperature"] == 0.0
# Check that strict instructions were added to system prompt
messages = second_call[1]["messages"]
assert "STRICT OUTPUT MODE" in messages[0]["content"]
@pytest.mark.asyncio
async def test_file_not_found_error(self):
"""Test handling of missing system prompt file"""
# Clear the cache by patching the class attribute
with patch("tools.cleaner._SystemPromptCache._cache", None):
with patch(
"builtins.open", side_effect=FileNotFoundError("File not found")
):
with pytest.raises(FileNotFoundError):
await clean_prompt("test prompt")
@pytest.mark.asyncio
async def test_context_parameter(self, mock_system_prompt, mock_llm_response):
"""Test that context parameter is properly passed"""
with patch("builtins.open", mock_open(read_data=mock_system_prompt)):
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.return_value = mock_llm_response
mock_client_class.return_value = mock_client
await clean_prompt(
raw_prompt="help me",
context="I'm building a web app",
mode="general",
)
# Verify context is included in the message
call_args = mock_client.chat_completions.call_args
messages = call_args[1]["messages"]
assert "CONTEXT: I'm building a web app" in messages[1]["content"]
@pytest.mark.asyncio
async def test_system_prompt_caching(self, mock_system_prompt, mock_llm_response):
"""Test that system prompt is cached and reused"""
# Clear the cache first to ensure fresh state
with patch("tools.cleaner._SystemPromptCache._cache", None):
with patch(
"builtins.open", mock_open(read_data=mock_system_prompt)
) as mock_file:
with patch("tools.cleaner.LLMClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.chat_completions.return_value = mock_llm_response
mock_client_class.return_value = mock_client
# First call should read the file
await clean_prompt("test prompt 1")
assert mock_file.call_count == 1
# Second call should use cache (no additional file reads)
await clean_prompt("test prompt 2")
assert mock_file.call_count == 1 # Still 1, not 2