mcp-eunomia
- src
- orchestra_server
import asyncio
import logging
from contextlib import AsyncExitStack
import mcp
import mcp.types as types
from eunomia import *
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.server import NotificationOptions, Server
from mcp.server.models import InitializationOptions
from .config import Settings
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
settings = Settings()
server = Server(settings.APP_NAME)
eunomia_orchestra = settings.ORCHESTRA
SERVER_TOOLS_SEP = "___"
SERVER_PROMPTS_SEP = "___"
SERVER_RESOURCES_SEP = "___"
servers_sessions = {}
#
# -----------------------
# TOOLS IMPLEMENTATION
# -----------------------
#
@server.list_tools()
async def list_tools() -> list[types.Tool]:
server_tools = []
for server_name, session in servers_sessions.items():
response = await session.list_tools()
# Rename and flatten
for tool in response.tools:
renamed_tool = types.Tool(
name=f"{server_name}{SERVER_TOOLS_SEP}{tool.name}",
description=tool.description,
inputSchema=tool.inputSchema,
)
server_tools.append(renamed_tool)
return server_tools
@server.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
server_name, tool_name = name.split(SERVER_TOOLS_SEP, 1)
try:
if server_name not in servers_sessions:
raise ValueError(f"No session found for sub-server: {server_name}")
session = servers_sessions[server_name]
logging.debug(
f"Calling tool: {tool_name} on server: {server_name} with arguments: {arguments}"
)
result = await asyncio.wait_for(
session.call_tool(tool_name, arguments), timeout=10
)
except asyncio.TimeoutError:
logging.error(f"Timeout while calling tool '{tool_name}'")
raise
except Exception as e:
logging.exception(f"Error during call_tool for '{tool_name}'")
raise
# Optionally run through Eunomia
logging.debug("Running eunomia orchestra...")
full_content_post_eunomia = []
try:
for content in result.content:
if content.type == "text":
text_post_eunomia = eunomia_orchestra.run(content.text)
full_content_post_eunomia.append(text_post_eunomia)
else:
# For non-text content, just pass it through
full_content_post_eunomia.append(content)
except Exception as e:
logging.exception(f"Error running eunomia_orchestra: {e}")
# Return them as text content for simplicity
# or adapt as needed to preserve images/embeds.
return [types.TextContent(type="text", text=str(full_content_post_eunomia))]
#
# -----------------------
# PROMPTS IMPLEMENTATION
# -----------------------
#
@server.list_prompts()
async def list_prompts() -> list[types.Prompt]:
"""
Aggregate and rename prompts from each sub-server.
"""
aggregated_prompts = []
for server_name, session in servers_sessions.items():
try:
response = await session.list_prompts()
# response.prompts is typically the list of prompts
for prompt in response.prompts:
# Optionally rename prompt to avoid collisions
renamed_prompt = types.Prompt(
name=f"{server_name}{SERVER_PROMPTS_SEP}{prompt.name}",
description=prompt.description,
arguments=prompt.arguments,
)
aggregated_prompts.append(renamed_prompt)
except Exception as e:
logger.exception(
f"Failed to list_prompts from sub-server {server_name}: {e}"
)
return aggregated_prompts
@server.get_prompt()
async def get_prompt(
name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
"""
Route the get_prompt call to the correct sub-server,
and return the prompt's messages.
"""
# Extract sub-server name and actual prompt name
server_name, prompt_name = name.split(SERVER_PROMPTS_SEP, 1)
if server_name not in servers_sessions:
raise ValueError(f"No session found for sub-server: {server_name}")
session = servers_sessions[server_name]
try:
result = await session.get_prompt(prompt_name, arguments)
return result
except Exception as e:
logger.exception(
f"Failed to get_prompt '{prompt_name}' from sub-server {server_name}: {e}"
)
raise
#
# -----------------------
# RESOURCES IMPLEMENTATION
# -----------------------
#
@server.list_resources()
async def list_resources() -> list[types.Resource]:
"""
Aggregate and rename resources from each sub-server.
"""
aggregated_resources = []
for server_name, session in servers_sessions.items():
try:
response = await session.list_resources()
for resource in response.resources:
new_uri = f"{server_name}{SERVER_RESOURCES_SEP}{resource.uri}"
renamed_resource = types.Resource(
uri=new_uri,
name=f"{server_name}{SERVER_RESOURCES_SEP}{resource.name}",
mimeType=resource.mimeType,
)
aggregated_resources.append(renamed_resource)
except Exception as e:
logger.exception(
f"Failed to list_resources from sub-server {server_name}: {e}"
)
return aggregated_resources
@server.read_resource()
async def read_resource(uri: types.AnyUrl) -> str:
"""
Extract the sub-server name from the prefixed URI,
then call the read_resource on that sub-server.
"""
uri_str = str(uri)
if SERVER_RESOURCES_SEP not in uri_str:
raise ValueError("Invalid resource URI format (missing sub-server prefix).")
server_name, real_uri = uri_str.split(SERVER_RESOURCES_SEP, 1)
if server_name not in servers_sessions:
raise ValueError(f"No session found for sub-server: {server_name}")
session = servers_sessions[server_name]
try:
result = await session.read_resource(real_uri)
return result.content
except Exception as e:
logger.exception(
f"Failed to read_resource '{real_uri}' from sub-server {server_name}: {e}"
)
raise
#
# -----------------------
# SERVER MAIN
# -----------------------
#
async def main():
"""
Main entry point for the aggregator MCP server.
Sets up and runs the server using stdin/stdout streams.
"""
async with AsyncExitStack() as stack:
await initialize_sub_servers(stack)
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name=settings.APP_NAME,
server_version=settings.APP_VERSION,
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
async def initialize_sub_servers(stack: AsyncExitStack):
"""
Create and initialize sessions for each sub-server,
storing them in the global servers_sessions dictionary.
"""
for server_name, params in settings.MCP_SERVERS.items():
command = params.get("command")
args = params.get("args", [])
env = params.get("env")
server_params = StdioServerParameters(command=command, args=args, env=env)
stdio_transport = await stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await stack.enter_async_context(ClientSession(stdio, write))
await session.initialize()
servers_sessions[server_name] = session