"""Base middleware classes."""
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List
import mcp.types as types
from ..utils.logging import get_logger
logger = get_logger(__name__)
class BaseMiddleware(ABC):
"""Base class for middleware components."""
def __init__(self, name: str) -> None:
"""Initialize middleware.
Args:
name: Middleware name
"""
self.name = name
self.enabled = True
@abstractmethod
async def process_tool_call(
self, name: str, arguments: Dict[str, Any], next_handler: Callable[[str, Dict[str, Any]], Any]
) -> List[types.ContentBlock]:
"""Process tool call through middleware.
Args:
name: Tool name
arguments: Tool arguments
next_handler: Next handler in chain
Returns:
Processing result
"""
pass
@abstractmethod
async def process_resource_read(self, uri: Any, next_handler: Callable[[Any], Any]) -> str:
"""Process resource read through middleware.
Args:
uri: Resource URI (can be string or AnyUrl)
next_handler: Next handler in chain
Returns:
Resource content
"""
pass
@abstractmethod
async def process_prompt_get(
self, name: str, arguments: Dict[str, str] | None, next_handler: Callable[[str, Dict[str, str] | None], Any]
) -> types.GetPromptResult:
"""Process prompt get through middleware.
Args:
name: Prompt name
arguments: Prompt arguments
next_handler: Next handler in chain
Returns:
Prompt result
"""
pass
class MiddlewareChain:
"""Middleware processing chain."""
def __init__(self) -> None:
"""Initialize middleware chain."""
self._middleware: List[BaseMiddleware] = []
self._logger = get_logger(__name__)
def add(self, middleware: BaseMiddleware) -> None:
"""Add middleware to chain.
Args:
middleware: Middleware to add
"""
self._middleware.append(middleware)
self._logger.debug(f"Added middleware: {middleware.name}")
def remove(self, middleware_name: str) -> bool:
"""Remove middleware from chain.
Args:
middleware_name: Name of middleware to remove
Returns:
True if removed, False if not found
"""
for i, middleware in enumerate(self._middleware):
if middleware.name == middleware_name:
del self._middleware[i]
self._logger.debug(f"Removed middleware: {middleware_name}")
return True
return False
def enable(self, middleware_name: str) -> bool:
"""Enable specific middleware.
Args:
middleware_name: Name of middleware to enable
Returns:
True if found and enabled, False otherwise
"""
for middleware in self._middleware:
if middleware.name == middleware_name:
middleware.enabled = True
self._logger.debug(f"Enabled middleware: {middleware_name}")
return True
return False
def disable(self, middleware_name: str) -> bool:
"""Disable specific middleware.
Args:
middleware_name: Name of middleware to disable
Returns:
True if found and disabled, False otherwise
"""
for middleware in self._middleware:
if middleware.name == middleware_name:
middleware.enabled = False
self._logger.debug(f"Disabled middleware: {middleware_name}")
return True
return False
async def process_tool_call(
self, name: str, arguments: Dict[str, Any], final_handler: Callable[[str, Dict[str, Any]], Any]
) -> List[types.ContentBlock]:
"""Process tool call through middleware chain.
Args:
name: Tool name
arguments: Tool arguments
final_handler: Final handler to call
Returns:
Processing result
"""
async def create_handler(index: int) -> Callable[[str, Dict[str, Any]], Any]:
if index >= len(self._middleware):
return final_handler
middleware = self._middleware[index]
if not middleware.enabled:
return await create_handler(index + 1)
async def handler(n: str, a: Dict[str, Any]) -> Any:
next_handler = await create_handler(index + 1)
return await middleware.process_tool_call(n, a, next_handler)
return handler
handler = await create_handler(0)
return await handler(name, arguments)
async def process_resource_read(self, uri: str, final_handler: Callable[[str], Any]) -> str:
"""Process resource read through middleware chain.
Args:
uri: Resource URI
final_handler: Final handler to call
Returns:
Resource content
"""
async def create_handler(index: int) -> Callable[[str], Any]:
if index >= len(self._middleware):
return final_handler
middleware = self._middleware[index]
if not middleware.enabled:
return await create_handler(index + 1)
async def handler(u: str) -> Any:
next_handler = await create_handler(index + 1)
return await middleware.process_resource_read(u, next_handler)
return handler
handler = await create_handler(0)
return await handler(uri)
async def process_prompt_get(
self, name: str, arguments: Dict[str, str] | None, final_handler: Callable[[str, Dict[str, str] | None], Any]
) -> types.GetPromptResult:
"""Process prompt get through middleware chain.
Args:
name: Prompt name
arguments: Prompt arguments
final_handler: Final handler to call
Returns:
Prompt result
"""
async def create_handler(index: int) -> Callable[[str, Dict[str, str] | None], Any]:
if index >= len(self._middleware):
return final_handler
middleware = self._middleware[index]
if not middleware.enabled:
return await create_handler(index + 1)
async def handler(n: str, a: Dict[str, str] | None) -> Any:
next_handler = await create_handler(index + 1)
return await middleware.process_prompt_get(n, a, next_handler)
return handler
handler = await create_handler(0)
return await handler(name, arguments)