main.py•4.15 kB
import asyncio
import json
import logging
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from tools.cleaner import clean_prompt
from schemas import CleanPromptInput, CleanPromptOutput
from llm.client import LLMTimeoutError, LLMHttpError, LLMNetworkError
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Server("mcp-prompt-cleaner")
async def clean_prompt_tool(
raw_prompt: str, context: str = "", mode: str = "general", temperature: float = 0.2
) -> CleanPromptOutput:
"""
Enhance and clean raw prompts using AI.
This function uses a two-level retry strategy:
1. HTTP-level retries for network/transport issues
2. Content-level retries for AI output quality issues
Args:
raw_prompt: The user's raw, unpolished prompt
context: Additional context about the task
mode: Processing mode - "general" or "code" (default: "general")
temperature: AI sampling temperature 0.0-1.0 (default: 0.2)
Returns:
CleanPromptOutput with enhanced prompt and metadata
Raises:
ValueError: If input validation fails (empty prompt, invalid mode/temperature)
FileNotFoundError: If system prompt file is missing
LLMTimeoutError: If AI request times out after all retries
LLMHttpError: If AI API returns HTTP error after all retries
LLMNetworkError: If network error occurs after all retries
Example:
>>> result = await clean_prompt_tool(
... raw_prompt="help me write code",
... context="web development",
... mode="code",
... temperature=0.1
... )
>>> print(result.cleaned)
"""
logger.info("clean_prompt tool called with mode: %s", mode)
try:
# Validate input
input_data = CleanPromptInput(
raw_prompt=raw_prompt, context=context, mode=mode, temperature=temperature
)
logger.debug("Input validated: %s", input_data)
# Call implementation
logger.debug("Calling clean_prompt function...")
result = await clean_prompt(
raw_prompt=input_data.raw_prompt,
context=input_data.context,
mode=input_data.mode,
temperature=input_data.temperature,
)
logger.debug("clean_prompt returned: %s (type: %s)", result, type(result))
# Return the result directly
return result
except (ValueError, TypeError, KeyError, FileNotFoundError) as e:
logger.error("Input validation error in clean_prompt tool: %s", e)
raise
except (LLMTimeoutError, LLMHttpError, LLMNetworkError) as e:
logger.error("LLM service error in clean_prompt tool: %s", e)
raise
except RuntimeError as e:
logger.error("Runtime error in clean_prompt tool: %s", e)
raise
@app.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle tool calls following MCP protocol"""
logger.info("Tool called: %s with args: %s", name, arguments)
if name == "clean_prompt":
result = await clean_prompt_tool(**arguments)
# Serialize the CleanPromptOutput to JSON and wrap in TextContent
if hasattr(result, "model_dump"):
result_json = json.dumps(result.model_dump())
else:
result_json = json.dumps(result)
return [TextContent(type="text", text=result_json)]
else:
raise ValueError(f"Unknown tool: {name}")
@app.list_tools()
async def list_tools() -> list[Tool]:
"""List available tools"""
return [
Tool(
name="clean_prompt",
description="Enhance and clean raw prompts using AI. Accepts raw_prompt, context, mode (code/general), and temperature parameters.",
inputSchema=CleanPromptInput.model_json_schema(),
)
]
async def main():
async with stdio_server() as streams:
await app.run(streams[0], streams[1], app.create_initialization_options())
if __name__ == "__main__":
asyncio.run(main())