cleaner.py•7.17 kB
"""
Prompt cleaning and enhancement module.
This module provides AI-powered prompt cleaning functionality with a two-level retry strategy:
1. HTTP-level retries: Handled by LLMClient for network/transport issues
2. Content-level retries: Handled here for AI output quality issues
The module uses a cached system prompt and implements exponential backoff for retries.
On content retries, it switches to "strict mode" with temperature=0.0 for deterministic output.
"""
import asyncio
import logging
from pathlib import Path
from typing import Optional
from schemas import CleanPromptOutput
from llm.client import LLMClient, LLMTimeoutError, LLMHttpError, LLMNetworkError
from utils.json_extractor import extract_first_json_object
from config import settings
logger = logging.getLogger(__name__)
class _SystemPromptCache:
"""Simple cache for system prompt to avoid global variables"""
_cache: Optional[str] = None
@classmethod
def get(cls) -> str:
"""Load and cache the system prompt from prompts/cleaner.md"""
if cls._cache is None:
prompt_path = Path(__file__).parent.parent / "prompts" / "cleaner.md"
try:
with open(prompt_path, "r", encoding="utf-8") as f:
cls._cache = f.read().strip()
except FileNotFoundError as exc:
logger.error("System prompt file not found: %s", prompt_path)
raise FileNotFoundError(
f"System prompt file not found: {prompt_path}"
) from exc
except Exception as e:
logger.error("Error loading system prompt: %s", e)
raise
return cls._cache
def _load_system_prompt() -> str:
"""Load and cache the system prompt from prompts/cleaner.md"""
return _SystemPromptCache.get()
async def clean_prompt(
raw_prompt: str, context: str = "", mode: str = "general", temperature: float = 0.2
) -> CleanPromptOutput:
"""
Clean and enhance a raw prompt using AI.
Args:
raw_prompt: The user's raw, unpolished prompt
context: Additional context about the task
mode: Processing mode ("code" or "general")
temperature: AI sampling temperature (0.0-1.0)
Returns:
CleanPromptOutput with enhanced prompt and metadata
Raises:
ValueError: If input validation fails
FileNotFoundError: If system prompt file is missing
LLMTimeoutError: If AI request times out
LLMHttpError: If AI API returns error
LLMNetworkError: If network error occurs
"""
if not raw_prompt or not isinstance(raw_prompt, str):
raise ValueError("raw_prompt must be a non-empty string")
if mode not in ["code", "general"]:
raise ValueError("mode must be 'code' or 'general'")
if not 0.0 <= temperature <= 1.0:
raise ValueError("temperature must be between 0.0 and 1.0")
# Load system prompt
system_prompt = _load_system_prompt()
# Format user message
user_message = f"MODE: {mode}\nCONTEXT: {context}\nRAW_PROMPT:\n{raw_prompt}"
# Create LLM client
client = LLMClient(
endpoint=settings.llm_api_endpoint,
api_key=settings.llm_api_key,
model=settings.llm_model,
timeout=settings.llm_timeout,
max_retries=settings.content_max_retries,
)
# Retry loop
last_exception = None
for attempt in range(settings.content_max_retries + 1):
try:
# On retry, add strict JSON instructions
current_system_prompt = system_prompt
current_temperature = temperature
if attempt > 0:
# On retry, add strict instructions to force clean JSON output
# This helps when the AI returns prose or malformed JSON
strict_instructions = (
"\n\nSTRICT OUTPUT MODE: Respond with EXACTLY ONE JSON object and nothing else. "
"No prose. No code fences. No prefix/suffix. Just the JSON."
)
current_system_prompt += strict_instructions
current_temperature = 0.0 # Force deterministic output on retries
# Prepare messages
messages = [
{"role": "system", "content": current_system_prompt},
{"role": "user", "content": user_message},
]
# Make API call
response = await client.chat_completions(
messages=messages,
temperature=current_temperature,
max_tokens=settings.llm_max_tokens,
request_id=f"clean_prompt_{attempt}",
)
# Extract JSON from response
try:
json_data = extract_first_json_object(response)
except ValueError as e:
logger.warning(
"JSON extraction failed (attempt %d): %s", attempt + 1, e
)
if attempt < settings.content_max_retries:
# Exponential backoff with jitter: 2^attempt + 0.1*attempt seconds
# This prevents thundering herd and adds randomness
delay = 2**attempt + (0.1 * attempt)
logger.info("Retrying in %.1fs...", delay)
await asyncio.sleep(delay)
continue
raise ValueError(
f"Failed to extract valid JSON from response: {e}"
) from e
# Validate with Pydantic
try:
result = CleanPromptOutput.model_validate(json_data)
logger.info("Successfully cleaned prompt (attempt %d)", attempt + 1)
return result
except (ValueError, TypeError, KeyError) as e:
logger.warning(
"Schema validation failed (attempt %d): %s", attempt + 1, e
)
if attempt < settings.content_max_retries:
# Exponential backoff with jitter: 2^attempt + 0.1*attempt seconds
# This prevents thundering herd and adds randomness
delay = 2**attempt + (0.1 * attempt)
logger.info("Retrying in %.1fs...", delay)
await asyncio.sleep(delay)
continue
raise ValueError(f"Response validation failed: {e}") from e
except (LLMTimeoutError, LLMHttpError, LLMNetworkError) as e:
last_exception = e
logger.warning("LLM error (attempt %d): %s", attempt + 1, e)
if attempt < settings.content_max_retries:
# Exponential backoff with jitter: 2^attempt + 0.1*attempt seconds
# This prevents thundering herd and adds randomness
delay = 2**attempt + (0.1 * attempt)
logger.info("Retrying in %.1fs...", delay)
await asyncio.sleep(delay)
continue
raise
# If we get here, all retries failed
if last_exception:
raise last_exception
else:
raise RuntimeError("All retry attempts failed")