"""
日志中间件
提供结构化日志记录、性能监控和审计功能。
"""
import time
import logging
logger = logging.getLogger(__name__)
import traceback
from typing import Optional, Dict, Any
from datetime import datetime
from contextlib import asynccontextmanager
import structlog
from functools import wraps
class LoggingMiddleware:
"""日志中间件类"""
def __init__(self, log_level: str = "INFO"):
"""
初始化日志中间件
Args:
log_level: 日志级别
"""
self.setup_logging(log_level)
self.logger = structlog.get_logger("mcp_server")
self.performance_stats = {} # 性能统计
def setup_logging(self, log_level: str):
"""设置结构化日志"""
# 配置structlog
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# 配置标准库logging
logging.basicConfig(
format="%(message)s",
stream=open("mcp_server.log", "a", encoding="utf-8"),
level=getattr(logging, log_level.upper())
)
def log_request(self, tool_name: str, params: Dict[str, Any]):
"""记录请求日志"""
self.logger.info(
"工具调用开始",
tool_name=tool_name,
params={k: str(v) for k, v in params.items()},
timestamp=datetime.now().isoformat()
)
def log_response(self, tool_name: str, result: Any, duration: float):
"""记录响应日志"""
self.logger.info(
"工具调用完成",
tool_name=tool_name,
duration_ms=duration * 1000,
result_type=type(result).__name__,
timestamp=datetime.now().isoformat()
)
# 更新性能统计
if tool_name not in self.performance_stats:
self.performance_stats[tool_name] = {
"call_count": 0,
"total_duration": 0,
"avg_duration": 0,
"min_duration": float('inf'),
"max_duration": 0
}
stats = self.performance_stats[tool_name]
stats["call_count"] += 1
stats["total_duration"] += duration
stats["avg_duration"] = stats["total_duration"] / stats["call_count"]
stats["min_duration"] = min(stats["min_duration"], duration)
stats["max_duration"] = max(stats["max_duration"], duration)
def log_error(self, tool_name: str, error: Exception, params: Dict[str, Any]):
"""记录错误日志"""
self.logger.error(
"工具调用出错",
tool_name=tool_name,
error_type=type(error).__name__,
error_message=str(error),
error_traceback=traceback.format_exc(),
params={k: str(v) for k, v in params.items()},
timestamp=datetime.now().isoformat()
)
def log_security_event(self, event_type: str, details: Dict[str, Any]):
"""记录安全事件"""
self.logger.warning(
"安全事件",
event_type=event_type,
details=details,
timestamp=datetime.now().isoformat()
)
def get_performance_stats(self) -> Dict[str, Any]:
"""获取性能统计信息"""
return self.performance_stats.copy()
@asynccontextmanager
async def performance_monitor(self, tool_name: str):
"""性能监控上下文管理器"""
start_time = time.time()
try:
yield
finally:
duration = time.time() - start_time
self.log_response(tool_name, None, duration)
def log_performance(tool_name: str = None):
"""性能监控装饰器"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
nonlocal tool_name
if tool_name is None:
tool_name = func.__name__
start_time = time.time()
try:
result = await func(*args, **kwargs)
duration = time.time() - start_time
logging_middleware.log_response(tool_name, result, duration)
return result
except (RuntimeError, ValueError) as e:
duration = time.time() - start_time
logging_middleware.log_error(tool_name, e, kwargs)
raise
@wraps(func)
def sync_wrapper(*args, **kwargs):
nonlocal tool_name
if tool_name is None:
tool_name = func.__name__
start_time = time.time()
try:
result = func(*args, **kwargs)
duration = time.time() - start_time
logging_middleware.log_response(tool_name, result, duration)
return result
except (RuntimeError, ValueError) as e:
duration = time.time() - start_time
logging_middleware.log_error(tool_name, e, kwargs)
raise
# 根据函数是否为协程函数返回相应的装饰器
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
# 全局日志中间件实例
logging_middleware = LoggingMiddleware()