import logging
import jsonschema
from typing import Dict, Any, Callable, List
from functools import wraps
from app.protocol.models import ToolResult
from app.core.errors import ToolExecutionError, InvalidParamsError
logger = logging.getLogger(__name__)
class ToolRegistry:
"""Registry for MCP tools with JSON schema validation"""
def __init__(self):
self.tools: Dict[str, Dict[str, Any]] = {}
def register_tool(self, name: str, schema: Dict[str, Any], func: Callable):
"""Register a tool with its schema and handler function"""
self.tools[name] = {
"name": name,
"schema": schema,
"handler": func
}
logger.info(f"Registered tool: {name}")
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""Get all tool schemas for capability advertisement"""
return [
{
"name": tool_data["name"],
"description": tool_data["schema"].get("description", ""),
"inputSchema": tool_data["schema"]
}
for tool_data in self.tools.values()
]
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> ToolResult:
"""Call a tool with validation"""
if name not in self.tools:
raise ToolExecutionError(name, f"Tool '{name}' not found")
tool_data = self.tools[name]
try:
# Validate arguments against schema
jsonschema.validate(arguments, tool_data["schema"])
# Call the tool function
result = await tool_data["handler"](**arguments)
# Ensure result is in the correct format
if isinstance(result, dict) and "content" in result:
return ToolResult(**result)
else:
# Wrap simple results
return ToolResult(
content=[{"type": "text", "text": str(result)}],
isError=False
)
except jsonschema.ValidationError as e:
raise InvalidParamsError(f"Invalid arguments for tool '{name}': {e.message}")
except Exception as e:
logger.error(f"Tool execution error for '{name}': {str(e)}")
raise ToolExecutionError(name, str(e))
# Global registry instance
tool_registry = ToolRegistry()
def mcp_tool(name: str, schema: Dict[str, Any]):
"""Decorator to register MCP tools"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
# Register the tool
tool_registry.register_tool(name, schema, wrapper)
return wrapper
return decorator
# Sample tools implementation
@mcp_tool("sum_values", {
"type": "object",
"description": "Add two numbers together",
"properties": {
"a": {
"type": "number",
"description": "First number to add"
},
"b": {
"type": "number",
"description": "Second number to add"
}
},
"required": ["a", "b"]
})
async def sum_values(a: float, b: float) -> Dict[str, Any]:
"""Sample tool that adds two numbers"""
result = a + b
return {
"content": [
{
"type": "text",
"text": f"The sum of {a} and {b} is {result}"
}
],
"isError": False
}
@mcp_tool("calculate", {
"type": "object",
"description": "Perform basic arithmetic operations",
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"a": {
"type": "number",
"description": "First operand"
},
"b": {
"type": "number",
"description": "Second operand"
}
},
"required": ["operation", "a", "b"]
})
async def calculate(operation: str, a: float, b: float) -> Dict[str, Any]:
"""Basic calculator tool"""
try:
if operation == "add":
result = a + b
elif operation == "subtract":
result = a - b
elif operation == "multiply":
result = a * b
elif operation == "divide":
if b == 0:
raise ValueError("Division by zero")
result = a / b
else:
raise ValueError(f"Unknown operation: {operation}")
return {
"content": [
{
"type": "text",
"text": f"{a} {operation} {b} = {result}"
}
],
"isError": False
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Calculation error: {str(e)}"
}
],
"isError": True
}