"""Prompt registry for managing MCP prompts."""
import inspect
from typing import Any, Callable
import mcp.types as types
from .base import BasePrompt
class PromptRegistry:
"""Registry for managing MCP prompts."""
def __init__(self) -> None:
"""Initialize the prompt registry."""
self._prompts: dict[str, BasePrompt | Callable[..., Any]] = {}
self._descriptions: dict[str, str] = {}
self._arguments: dict[str, list[dict[str, Any]]] = {}
def register(
self, name: str, prompt_func: Any, description: str = "", arguments: list[dict[str, Any]] | None = None
) -> None:
"""Register a prompt.
Args:
name: Prompt name
prompt_func: Prompt function or BasePrompt instance
description: Prompt description
arguments: Prompt arguments schema
"""
if isinstance(prompt_func, BasePrompt):
self._prompts[name] = prompt_func
else:
self._prompts[name] = prompt_func
self._descriptions[name] = description or self._extract_description(prompt_func)
self._arguments[name] = arguments or self._generate_arguments(prompt_func)
def _extract_description(self, func: Callable[..., Any]) -> str:
"""Extract description from function docstring.
Args:
func: Function to extract description from
Returns:
Function description
"""
return (func.__doc__ or "").strip()
def _generate_arguments(self, func: Callable[..., Any]) -> list[dict[str, Any]]:
"""Generate arguments schema from function signature.
Args:
func: Function to generate arguments for
Returns:
Arguments schema
"""
sig = inspect.signature(func)
arguments = []
for param_name, param in sig.parameters.items():
arguments.append(
{
"name": param_name,
"description": f"Parameter {param_name}",
"required": param.default == inspect.Parameter.empty,
}
)
return arguments
async def list_prompts(self) -> list[types.Prompt]:
"""List all registered prompts.
Returns:
List of MCP prompts
"""
prompts = []
for name, prompt in self._prompts.items():
if isinstance(prompt, BasePrompt):
prompts.append(prompt.to_mcp_prompt())
else:
prompt_args = []
for arg in self._arguments.get(name, []):
prompt_args.append(
types.PromptArgument(
name=arg["name"],
description=arg.get("description", ""),
required=arg.get("required", False),
)
)
prompts.append(
types.Prompt(
name=name,
description=self._descriptions.get(name, ""),
arguments=prompt_args,
)
)
return prompts
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Get a prompt by name.
Args:
name: Prompt name
arguments: Prompt arguments
Returns:
Generated prompt result
Raises:
ValueError: If prompt is not found
"""
if name not in self._prompts:
raise ValueError(f"Unknown prompt: {name}")
prompt = self._prompts[name]
if isinstance(prompt, BasePrompt):
return await prompt.generate(arguments)
else:
# Call regular function
try:
if arguments is None:
arguments = {}
if inspect.iscoroutinefunction(prompt):
result = await prompt(**arguments)
else:
result = prompt(**arguments)
# Convert result to GetPromptResult
if isinstance(result, types.GetPromptResult):
return result
elif isinstance(result, str):
return types.GetPromptResult(
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=result),
)
],
description=self._descriptions.get(name, ""),
)
else:
return types.GetPromptResult(
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=str(result)),
)
],
description=self._descriptions.get(name, ""),
)
except Exception as e:
return types.GetPromptResult(
messages=[
types.PromptMessage(
role="user",
content=types.TextContent(type="text", text=f"Error generating prompt {name}: {str(e)}"),
)
],
description=f"Error in prompt {name}",
)