tool.py•5.63 kB
"""
Represents MCP tools with execution capabilities.
"""
import asyncio
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from mcp import ClientSession
from src.mcp.mcp_connection_manager import MCPConnectionManager
from src.utils.context import get_context
class Tool:
    """
    Represents a tool with its metadata and execution capabilities.
    """
    def __init__(
        self,
        name: str,
        description: str,
        input_schema: Dict[str, Any],
        server_name: str,
    ):
        self.name = name
        self.description = description
        self.input_schema = input_schema
        self.server_name = server_name
        self.logger = logging.getLogger(f"tool:{name}")
    # _TODO: Might want to move the formatting to each of the LLM provider classes, to adjust for each provider.
    def format_for_llm(self) -> str:
        """Format tool information for LLM consumption."""
        args_desc = []
        if "properties" in self.input_schema:
            for param_name, param_info in self.input_schema["properties"].items():
                arg_desc = (
                    f"- {param_name}: {param_info.get('description', 'No description')}"
                )
                if param_name in self.input_schema.get("required", []):
                    arg_desc += " (required)"
                args_desc.append(arg_desc)
        return f"""
Tool: {self.name}
Description: {self.description}
Arguments:
{chr(10).join(args_desc)}
Server: {self.server_name}
"""
    async def execute(
        self,
        arguments: Dict[str, Any],
        connection_manager: MCPConnectionManager,
        retries: int = 2,
        delay: float = 1.0,
    ) -> Any:
        """Execute the tool with retry mechanism."""
        attempt = 0
        while attempt < retries:
            try:
                # Get session for server
                session = await connection_manager.get_session(self.server_name)
                if not session:
                    raise RuntimeError(f"Server {self.server_name} not initialized")
                self.logger.info(f"Executing {self.name} with arguments: {arguments}")
                # Execute tool
                result = await session.call_tool(self.name, arguments)
                self.logger.info(f"Tool {self.name} execution result: {result}")
                return result
            except Exception as e:
                attempt += 1
                self.logger.warning(
                    f"Error executing tool: {e}. Attempt {attempt} of {retries}"
                )
                if attempt < retries:
                    self.logger.info(f"Retrying in {delay} seconds...")
                    await asyncio.sleep(delay)
                else:
                    self.logger.error(f"Max retries reached. Failing tool {self.name}")
                    raise
class ToolRegistry:
    """
    Registry for available tools across all servers.
    """
    def __init__(self):
        self.tools: Dict[str, Tool] = {}
        self.logger = logging.getLogger("tool-registry")
    def register_tool(self, tool: Tool) -> None:
        """Register a tool."""
        self.tools[tool.name] = tool
        self.logger.info(f"Registered tool: {tool.name} from server {tool.server_name}")
    def get_tool(self, name: str) -> Optional[Tool]:
        """Get a tool by name."""
        return self.tools.get(name)
    def list_tools(self) -> List[Tool]:
        """List all available tools."""
        return list(self.tools.values())
    def clear(self) -> None:
        """Clear all registered tools."""
        self.tools.clear()
    async def load_from_config(
        self, config: Dict[str, Any], connection_manager: MCPConnectionManager
    ) -> None:
        """
        Load tools from MCP servers defined in configuration.
        Args:
            config: Configuration dictionary
            connection_manager: Connection manager for server access
        """
        # Discover tools from servers
        servers = config.get("servers", {})
        for server_name in servers:
            self.logger.info(f"Discovering tools from server: {server_name}")
            await self.discover_tools(connection_manager, server_name)
    async def discover_tools(
        self, connection_manager: MCPConnectionManager, server_name: str
    ) -> List[Tool]:
        """
        Discover tools from a server and register them.
        """
        discovered_tools = []
        try:
            session = await connection_manager.get_session(server_name)
            if not session:
                self.logger.error(f"Server {server_name} not initialized")
                return []
            tools_response = await session.list_tools()
            for item in tools_response:
                if isinstance(item, tuple) and item[0] == "tools":
                    for tool_info in item[1]:
                        tool = Tool(
                            name=tool_info.name,
                            description=tool_info.description,
                            input_schema=tool_info.inputSchema,
                            server_name=server_name,
                        )
                        self.register_tool(tool)
                        discovered_tools.append(tool)
            self.logger.info(
                f"Discovered {len(discovered_tools)} tools from server {server_name}"
            )
            return discovered_tools
        except Exception as e:
            self.logger.error(f"Error discovering tools from server {server_name}: {e}")
            return []