mcp_handler.py•4.41 kB
import re
import xml.etree.ElementTree as ET
import json
from typing import Dict, List, Any, Optional, Callable
class Tool:
"""Base class for tools that can be used via MCP"""
def __init__(self, name: str, description: str, parameters: List[Dict[str, Any]]):
self.name = name
self.description = description
self.parameters = parameters
def get_definition(self) -> str:
"""Get the XML definition of this tool for MCP"""
xml = f"""<tool>
<name>{self.name}</name>
<description>{self.description}</description>
<parameters>"""
for param in self.parameters:
xml += f"""
<parameter>
<name>{param['name']}</name>
<type>{param['type']}</type>
<description>{param['description']}</description>"""
if param.get('required', False):
xml += """
<required>true</required>"""
if 'enum' in param:
xml += f"""
<enum>{json.dumps(param['enum'])}</enum>"""
xml += """
</parameter>"""
xml += """
</parameters>
</tool>"""
return xml
def execute(self, **kwargs) -> Any:
"""Execute the tool with the given parameters"""
raise NotImplementedError("Tool subclasses must implement execute()")
class MCPHandler:
"""Handles MCP tool definitions and invocations"""
def __init__(self):
self.tools: Dict[str, Tool] = {}
def register_tool(self, tool: Tool) -> None:
"""Register a tool with the handler"""
self.tools[tool.name] = tool
def get_tool_definitions(self) -> str:
"""Get XML definitions for all registered tools"""
xml = """<tools>"""
for tool in self.tools.values():
xml += "\n" + tool.get_definition()
xml += """
</tools>"""
return xml
def process_response(self, response: str) -> str:
"""Process a response from the LLM, executing any tool invocations"""
# Define regex pattern for tool invocations
# This pattern looks for function-like calls: tool_name(param1="value1", param2="value2")
pattern = r'(\w+)\(([^)]*)\)'
# Find all tool invocations in the response
tool_calls = re.findall(pattern, response)
if not tool_calls:
return response
# Process each tool invocation
for tool_name, args_str in tool_calls:
if tool_name not in self.tools:
continue
# Parse the arguments string
kwargs = {}
# Handle empty arguments case
if args_str.strip():
# Match key-value pairs like param="value" or param=123
arg_pattern = r'(\w+)=(?:"([^"]*?)"|\'([^\']*?)\'|(\d+(?:\.\d+)?))'
args_matches = re.findall(arg_pattern, args_str)
for arg_match in args_matches:
param_name = arg_match[0]
# Check which capture group has the value
param_value = arg_match[1] or arg_match[2] or arg_match[3]
# Convert numeric strings to numbers
if arg_match[3]: # This group captures numbers
if '.' in param_value:
param_value = float(param_value)
else:
param_value = int(param_value)
kwargs[param_name] = param_value
# Execute the tool
try:
tool_result = self.tools[tool_name].execute(**kwargs)
tool_result_str = f"\n\nResult from {tool_name}:\n{json.dumps(tool_result, indent=2) if isinstance(tool_result, (dict, list)) else tool_result}\n\n"
# Replace the tool invocation with the tool result
tool_invocation = f"{tool_name}({args_str})"
response = response.replace(tool_invocation, tool_invocation + tool_result_str)
except Exception as e:
error_msg = f"\n\nError executing {tool_name}: {str(e)}\n\n"
tool_invocation = f"{tool_name}({args_str})"
response = response.replace(tool_invocation, tool_invocation + error_msg)
return response