router_base.py•10.3 kB
from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, List, Optional, TypeVar, TYPE_CHECKING
from pydantic import BaseModel, Field, ConfigDict
from mcp.server.fastmcp.tools import Tool as FastTool
from mcp_agent.agents.agent import Agent
from mcp_agent.core.context_dependent import ContextDependent
from mcp_agent.logging.logger import get_logger
if TYPE_CHECKING:
from mcp_agent.core.context import Context
logger = get_logger(__name__)
ResultT = TypeVar("ResultT", bound=str | Agent | Callable)
class RouterResult(BaseModel, Generic[ResultT]):
"""A class that represents the result of a Router.route request"""
result: ResultT
"""The router returns an MCP server name, an Agent, or a function to route the input to."""
p_score: float | None = None
"""
The probability score (i.e. 0->1) of the routing decision.
This is optional and may only be provided if the router is probabilistic (e.g. a probabilistic binary classifier).
"""
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
class RouterCategory(BaseModel):
"""
A class that represents a category of routing.
Used to collect information the router needs to decide.
"""
name: str
"""The name of the category"""
description: str | None = None
"""A description of the category"""
category: str | Agent | Callable
"""The class to route to"""
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
class ServerRouterCategory(RouterCategory):
"""A class that represents a category of routing to an MCP server"""
tools: List[FastTool] = Field(default_factory=list)
class AgentRouterCategory(RouterCategory):
"""A class that represents a category of routing to an agent"""
servers: List[ServerRouterCategory] = Field(default_factory=list)
class Router(ABC, ContextDependent):
"""
Routing classifies an input and directs it to one or more specialized followup tasks.
This class helps to route an input to a specific MCP server,
an Agent (an aggregation of MCP servers), or a function (any Callable).
When to use this workflow:
- This workflow allows for separation of concerns, and building more specialized prompts.
- Routing works well for complex tasks where there are distinct categories that
are better handled separately, and where classification can be handled accurately,
either by an LLM or a more traditional classification model/algorithm.
Examples where routing is useful:
- Directing different types of customer service queries
(general questions, refund requests, technical support)
into different downstream processes, prompts, and tools.
- Routing easy/common questions to smaller models like Claude 3.5 Haiku
and hard/unusual questions to more capable models like Claude 3.5 Sonnet
to optimize cost and speed.
Args:
routing_instruction: A string that tells the router how to route the input.
mcp_servers_names: A list of server names to route the input to.
agents: A list of agents to route the input to.
functions: A list of functions to route the input to.
"""
def __init__(
self,
server_names: List[str] | None = None,
agents: List[Agent] | None = None,
functions: List[Callable] | None = None,
routing_instruction: str | None = None,
context: Optional["Context"] = None,
**kwargs,
):
super().__init__(context=context, **kwargs)
self.routing_instruction = routing_instruction
self.server_names = server_names or []
self.agents = agents or []
self.functions = functions or []
self.server_registry = self.context.server_registry
# A dict of categories to route to, keyed by category name.
# These are populated in the initialize method.
self.server_categories: Dict[str, ServerRouterCategory] = {}
self.agent_categories: Dict[str, AgentRouterCategory] = {}
self.function_categories: Dict[str, RouterCategory] = {}
self.categories: Dict[str, RouterCategory] = {}
self.initialized: bool = False
if not self.server_names and not self.agents and not self.functions:
raise ValueError(
"At least one of mcp_servers_names, agents, or functions must be provided."
)
if self.server_names and not self.server_registry:
raise ValueError(
"server_registry must be provided if mcp_servers_names are provided."
)
@abstractmethod
async def route(
self, request: str, top_k: int = 1
) -> List[RouterResult[str | Agent | Callable]]:
"""
Route the input request to one or more MCP servers, agents, or functions.
If no routing decision can be made, returns an empty list.
Args:
request: The input to route.
top_k: The maximum number of top routing results to return. May return fewer.
"""
@abstractmethod
async def route_to_server(
self, request: str, top_k: int = 1
) -> List[RouterResult[str]]:
"""Route the input to one or more MCP servers."""
@abstractmethod
async def route_to_agent(
self, request: str, top_k: int = 1
) -> List[RouterResult[Agent]]:
"""Route the input to one or more agents."""
@abstractmethod
async def route_to_function(
self, request: str, top_k: int = 1
) -> List[RouterResult[Callable]]:
"""
Route the input to one or more functions.
Args:
input: The input to route.
"""
async def initialize(self):
"""Initialize the router categories."""
if self.initialized:
return
server_categories = [
self.get_server_category(server_name) for server_name in self.server_names
]
self.server_categories = {
category.name: category for category in server_categories
}
agent_categories = [self.get_agent_category(agent) for agent in self.agents]
self.agent_categories = {
category.name: category for category in agent_categories
}
function_categories = [
self.get_function_category(function) for function in self.functions
]
self.function_categories = {
category.name: category for category in function_categories
}
all_categories = server_categories + agent_categories + function_categories
self.categories = {category.name: category for category in all_categories}
self.initialized = True
def get_server_category(self, server_name: str) -> ServerRouterCategory:
server_config = self.server_registry.get_server_config(server_name)
# TODO: saqadri - Currently we only populate the server name and description.
# To make even more high fidelity routing decisions, we can populate the
# tools, resources and prompts that the server has access to.
return ServerRouterCategory(
category=server_name,
name=server_config.name if server_config else server_name,
description=server_config.description,
)
def get_agent_category(self, agent: Agent) -> AgentRouterCategory:
agent_description = (
agent.instruction({}) if callable(agent.instruction) else agent.instruction
)
return AgentRouterCategory(
category=agent,
name=agent.name,
description=agent_description,
servers=[
self.get_server_category(server_name)
for server_name in agent.server_names
],
)
def get_function_category(self, function: Callable) -> RouterCategory:
tool = FastTool.from_function(function)
return RouterCategory(
category=function,
name=tool.name,
description=tool.description,
)
def format_category(
self, category: RouterCategory, index: int | None = None
) -> str:
"""Format a category into a readable string."""
index_str = f"{index}. " if index is not None else " "
category_str = ""
if isinstance(category, ServerRouterCategory):
category_str = self._format_server_category(category)
elif isinstance(category, AgentRouterCategory):
category_str = self._format_agent_category(category)
else:
category_str = self._format_function_category(category)
return f"{index_str}{category_str}"
def _format_tools(self, tools: List[FastTool]) -> str:
"""Format a list of tools into a readable string."""
if not tools:
return "No tool information provided."
tool_descriptions = []
for tool in tools:
desc = f"- {tool.name}: {tool.description}"
tool_descriptions.append(desc)
return "\n".join(tool_descriptions)
def _format_server_category(self, category: ServerRouterCategory) -> str:
"""Format a server category into a readable string."""
description = category.description or "No description provided"
tools = self._format_tools(category.tools)
return f"Server Category: {category.name}\nDescription: {description}\nTools in server:\n{tools}"
def _format_agent_category(self, category: AgentRouterCategory) -> str:
"""Format an agent category into a readable string."""
description = category.description or "No description provided"
servers = "\n".join(
[f"- {server.name} ({server.description})" for server in category.servers]
)
return f"Agent Category: {category.name}\nDescription: {description}\nServers in agent:\n{servers}"
def _format_function_category(self, category: RouterCategory) -> str:
"""Format a function category into a readable string."""
description = category.description or "No description provided"
return f"Function Category: {category.name}\nDescription: {description}"