"""Timing middleware for performance monitoring."""
import time
from typing import Any, Callable, Dict, List
import mcp.types as types
from .base import BaseMiddleware
from ..utils.logging import get_logger
logger = get_logger(__name__)
class TimingMiddleware(BaseMiddleware):
"""Middleware for timing and performance monitoring."""
def __init__(self, slow_threshold: float = 1.0, log_slow_requests: bool = True) -> None:
"""Initialize timing middleware.
Args:
slow_threshold: Threshold in seconds to consider a request slow
log_slow_requests: Whether to log slow requests
"""
super().__init__("timing")
self.slow_threshold = slow_threshold
self.log_slow_requests = log_slow_requests
self._logger = get_logger(f"{__name__}.TimingMiddleware")
# Performance metrics
self.metrics = {
"tool_calls": {"count": 0, "total_time": 0.0, "slow_count": 0},
"resource_reads": {"count": 0, "total_time": 0.0, "slow_count": 0},
"prompt_gets": {"count": 0, "total_time": 0.0, "slow_count": 0},
}
def _record_timing(self, operation_type: str, duration: float, name: str) -> None:
"""Record timing metrics.
Args:
operation_type: Type of operation (tool_calls, resource_reads, prompt_gets)
duration: Duration in seconds
name: Operation name
"""
if operation_type in self.metrics:
self.metrics[operation_type]["count"] += 1
self.metrics[operation_type]["total_time"] += duration
if duration > self.slow_threshold:
self.metrics[operation_type]["slow_count"] += 1
if self.log_slow_requests:
self._logger.warning(
f"Slow {operation_type[:-1]} '{name}': {duration:.3f}s (threshold: {self.slow_threshold}s)"
)
def get_metrics(self) -> Dict[str, Any]:
"""Get performance metrics.
Returns:
Performance metrics dictionary
"""
result = {}
for op_type, data in self.metrics.items():
if data["count"] > 0:
avg_time = data["total_time"] / data["count"]
result[op_type] = {
"total_requests": data["count"],
"total_time": data["total_time"],
"average_time": avg_time,
"slow_requests": data["slow_count"],
"slow_percentage": (data["slow_count"] / data["count"]) * 100,
}
return result
def reset_metrics(self) -> None:
"""Reset all performance metrics."""
for data in self.metrics.values():
data.update({"count": 0, "total_time": 0.0, "slow_count": 0})
self._logger.info("Performance metrics reset")
async def process_tool_call(
self, name: str, arguments: Dict[str, Any], next_handler: Callable[[str, Dict[str, Any]], Any]
) -> List[types.ContentBlock]:
"""Process tool call with timing."""
start_time = time.time()
try:
result = await next_handler(name, arguments)
return result
finally:
duration = time.time() - start_time
self._record_timing("tool_calls", duration, name)
async def process_resource_read(self, uri: Any, next_handler: Callable[[Any], Any]) -> str:
"""Process resource read with timing."""
start_time = time.time()
try:
result = await next_handler(uri)
return result
finally:
duration = time.time() - start_time
self._record_timing("resource_reads", duration, uri)
async def process_prompt_get(
self, name: str, arguments: Dict[str, str] | None, next_handler: Callable[[str, Dict[str, str] | None], Any]
) -> types.GetPromptResult:
"""Process prompt get with timing."""
start_time = time.time()
try:
result = await next_handler(name, arguments)
return result
finally:
duration = time.time() - start_time
self._record_timing("prompt_gets", duration, name)