"""中间件 - 处理请求头参数"""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from .workspace_manager import workspace_manager
from .logger import get_logger
logger = get_logger(__name__)
class WorkspaceHeaderMiddleware(BaseHTTPMiddleware):
"""从请求头中读取工作目录参数的中间件"""
async def dispatch(self, request: Request, call_next):
"""处理请求"""
# 记录请求信息
logger.debug(f"收到请求: {request.method} {request.url.path}")
# 从 headers 中读取工作目录参数
work_dir_header = request.headers.get(
'X-Work-Dir') or request.headers.get('WORK_DIR')
if work_dir_header:
logger.info(f"检测到工作目录 header: {work_dir_header}")
# 设置工作目录
result = workspace_manager.set_work_dir(work_dir_header)
if result["success"]:
logger.info(f"从 header 设置工作目录成功: {work_dir_header}")
logger.debug(f"工作目录详情: {result.get('batch_dir', '')}")
else:
logger.warning(f"从 header 设置工作目录失败: {result['message']}")
# 读取其他可能的配置参数
batch_task_dir = request.headers.get(
'X-Batch-Task-Dir') or request.headers.get('BATCH_TASK_DIR')
if batch_task_dir:
logger.info(f"检测到批处理目录名 header: {batch_task_dir}")
# 更新批处理任务目录名
workspace_manager.batch_task_dir_name = batch_task_dir
logger.info(f"批处理目录名已更新为: {batch_task_dir}")
# 处理请求
try:
response = await call_next(request)
logger.debug(
f"请求处理完成: {request.method} {request.url.path}, 状态码: {response.status_code}")
return response
except Exception as e:
logger.error(
f"请求处理失败: {request.method} {request.url.path}, 错误: {e}", exc_info=True)
raise
__all__ = ['WorkspaceHeaderMiddleware']