"""路由配置模块 - 统一管理所有 API 路由"""
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
from starlette.responses import JSONResponse, RedirectResponse
from starlette.requests import Request
from pathlib import Path
from config import (
STATIC_DIR,
PROMPT_FILE_EXTENSIONS,
SUPPORTED_ENCODINGS,
)
from src.novel_processor import NovelProcessor
from src.workspace_manager import workspace_manager
from src.logger import get_logger
# 统一的日志记录器
logger = get_logger(__name__)
# ==================== API 路由处理器 ====================
async def api_split_novel(request: Request):
"""API: 分割小说"""
try:
# 处理文件上传
form = await request.form()
if "file" in form:
# 文件上传模式
file = form["file"]
file_bytes = await file.read()
logger.info(f"接收文件上传: {file.filename}, 大小: {len(file_bytes)} 字节")
# 尝试多种编码方式
content = None
used_encoding = None
for encoding in SUPPORTED_ENCODINGS:
try:
content = file_bytes.decode(encoding)
used_encoding = encoding
break
except (UnicodeDecodeError, LookupError):
continue
if content is None:
logger.error(f"文件编码识别失败: {file.filename}")
return JSONResponse({
"success": False,
"error": "无法识别文件编码,请使用 UTF-8、GBK 或 GB2312 编码保存文件"
}, status_code=400)
logger.info(f"文件编码识别成功: {used_encoding}")
source_file = file.filename
elif "content" in form:
# 直接输入模式
content = form["content"]
source_file = "直接输入"
used_encoding = "utf-8"
logger.info(f"接收直接输入内容, 长度: {len(content)} 字符")
else:
logger.warning("分割请求缺少文件或内容")
return JSONResponse({
"success": False,
"error": "缺少文件或内容"
}, status_code=400)
max_chars = int(form.get("max_chars", 500))
project_name = form.get("project_name", Path(
source_file).stem if source_file != "直接输入" else "untitled")
logger.info(f"开始分割小说: 项目={project_name}, 最大字符={max_chars}")
# 分割小说
result = NovelProcessor.split_novel_with_metadata(
content=content,
max_chars=max_chars,
source_file=source_file
)
# 直接保存到项目的 source 目录
project_output_dir = workspace_manager.get_output_dir() / project_name / "source"
base_name = Path(source_file).stem if source_file != "直接输入" else "text"
count = NovelProcessor.save_split_files(
paragraphs=result["paragraphs"],
output_dir=project_output_dir,
base_name=base_name
)
result["separate_files_saved"] = count
result["output_dir"] = str(project_output_dir)
result["project_name"] = project_name
result["success"] = True
result["encoding"] = used_encoding # 返回使用的编码
logger.info(
f"分割完成: 项目={project_name}, 段落数={count}, 输出目录={project_output_dir}")
return JSONResponse(result)
except Exception as e:
logger.exception(f"分割小说异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_get_prompt(request: Request):
"""API: 获取提示词"""
try:
default_prompt_file = workspace_manager.get_default_prompt_file()
file_name = request.query_params.get(
"file_name", default_prompt_file.name)
prompt_file = workspace_manager.get_prompts_dir() / file_name
logger.info(f"获取提示词: {file_name}")
if not prompt_file.exists():
logger.warning(f"提示词文件不存在: {file_name}")
return JSONResponse({
"success": False,
"error": "提示词文件不存在"
}, status_code=404)
content = prompt_file.read_text(encoding='utf-8')
logger.info(f"提示词读取成功: {file_name}, 长度: {len(content)} 字符")
return JSONResponse({
"success": True,
"fileName": file_name,
"content": content
})
except Exception as e:
logger.exception(f"获取提示词异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_save_prompt(request: Request):
"""API: 保存提示词"""
try:
import shutil
from config import DEFAULT_PROMPT_FILE_NAME
data = await request.json()
default_prompt_file = workspace_manager.get_default_prompt_file()
file_name = data.get("fileName", default_prompt_file.name)
content = data.get("content", "")
logger.info(f"保存提示词: {file_name}, 长度: {len(content)} 字符")
if not content:
logger.warning(f"保存提示词失败: 内容为空, 文件: {file_name}")
return JSONResponse({
"success": False,
"error": "内容不能为空"
}, status_code=400)
prompts_dir = workspace_manager.get_prompts_dir()
prompt_file = prompts_dir / file_name
default_file = prompts_dir / DEFAULT_PROMPT_FILE_NAME
# 保存文件前,检查是否是当前默认提示词
is_current_default = False
if prompt_file.exists() and default_file.exists():
try:
for encoding in SUPPORTED_ENCODINGS:
try:
existing_content = prompt_file.read_text(
encoding=encoding)
default_content = default_file.read_text(
encoding=encoding)
if existing_content == default_content:
is_current_default = True
break
except (UnicodeDecodeError, LookupError):
continue
except Exception:
pass
# 保存文件
prompt_file.write_text(content, encoding='utf-8')
# 如果保存的文件是当前默认提示词,同步更新默认文件
if is_current_default:
shutil.copy2(prompt_file, default_file)
return JSONResponse({
"success": True,
"message": f"提示词已保存到: {prompt_file.name},并已同步更新默认提示词"
})
return JSONResponse({
"success": True,
"message": f"提示词已保存到: {prompt_file.name}"
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_list_prompts(request: Request):
"""API: 列出所有提示词"""
try:
from config import DEFAULT_PROMPT_FILE_NAME
# 获取查询参数,判断是否包含默认文件
include_default = request.query_params.get(
"include_default", "false").lower() == "true"
prompts_dir = workspace_manager.get_prompts_dir()
default_prompt_file = prompts_dir / DEFAULT_PROMPT_FILE_NAME
# 读取默认文件的内容(如果存在)
default_content = None
if default_prompt_file.exists():
try:
for encoding in SUPPORTED_ENCODINGS:
try:
default_content = default_prompt_file.read_text(
encoding=encoding)
break
except (UnicodeDecodeError, LookupError):
continue
except Exception:
pass
files = []
for ext in PROMPT_FILE_EXTENSIONS:
for file_path in prompts_dir.glob(ext):
# 如果不包含默认文件,则跳过
if not include_default and file_path.name == DEFAULT_PROMPT_FILE_NAME:
continue
files.append({
"name": file_path.name,
"size": file_path.stat().st_size,
"path": file_path,
"is_default": file_path.name == DEFAULT_PROMPT_FILE_NAME
})
# 按文件名排序
files.sort(key=lambda x: x["name"])
# 如果不包含默认文件,找到第一个内容与默认文件相同的文件并标记为默认
if not include_default and default_content is not None:
for file_info in files:
try:
for encoding in SUPPORTED_ENCODINGS:
try:
file_content = file_info["path"].read_text(
encoding=encoding)
if file_content == default_content:
file_info["is_default"] = True
break # 找到第一个匹配的文件后立即退出
break # 成功读取文件后退出编码循环
except (UnicodeDecodeError, LookupError):
continue
if file_info["is_default"]:
break # 已找到默认文件,退出整个循环
except Exception:
pass
# 移除临时的 path 字段
for file_info in files:
file_info.pop("path", None)
return JSONResponse({
"success": True,
"files": files,
"default_prompt": DEFAULT_PROMPT_FILE_NAME,
"has_default_file": default_prompt_file.exists()
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def redirect_root(request: Request):
"""重定向到首页"""
return RedirectResponse(url="/static/index.html")
async def api_init_tasks(request: Request):
"""API: 初始化任务"""
try:
from src.task_manager import TaskManager
data = await request.json()
project_name = data.get("project_name")
source_dir = data.get("source_dir") or data.get("temp_dir") # 兼容旧参数名
force = data.get("force", False)
logger.info(f"初始化任务: 项目={project_name}, 源目录={source_dir}, 强制={force}")
if not project_name or not source_dir:
logger.warning("初始化任务失败: 缺少必要参数")
return JSONResponse({
"success": False,
"error": "缺少项目名称或源文件目录"
}, status_code=400)
task_manager = TaskManager()
result = task_manager.init_tasks(
project_name=project_name,
source_dir=Path(source_dir),
force=force
)
if result.get("success"):
logger.info(
f"任务初始化成功: 项目={project_name}, 任务数={result.get('count', 0)}")
else:
logger.error(f"任务初始化失败: {result.get('message', '未知错误')}")
return JSONResponse(result)
except Exception as e:
logger.exception(f"初始化任务异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_list_projects(request: Request):
"""API: 列出所有项目"""
try:
from src.task_manager import TaskManager
task_manager = TaskManager()
projects = task_manager.list_projects()
logger.info(f"查询项目列表: 共 {len(projects)} 个项目")
return JSONResponse({
"success": True,
"projects": projects
})
except Exception as e:
logger.exception(f"查询项目列表异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_list_tasks(request: Request):
"""API: 列出任务"""
try:
from src.task_manager import TaskManager
project_name = request.query_params.get("project_name")
status = request.query_params.get("status", "all")
if not project_name:
return JSONResponse({
"success": False,
"error": "缺少项目名称"
}, status_code=400)
task_manager = TaskManager()
result = task_manager.list_tasks(project_name, status)
logger.info(
f"查询任务列表: 项目={project_name}, 状态={status}, 数量={len(result.get('tasks', []))}")
return JSONResponse({
"success": True,
**result
})
except Exception as e:
logger.exception(f"查询任务列表异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_get_task_status(request: Request):
"""API: 获取任务状态"""
try:
from src.task_manager import TaskManager
project_name = request.query_params.get("project_name")
if not project_name:
logger.warning("获取任务状态失败: 缺少项目名称")
return JSONResponse({
"success": False,
"error": "缺少项目名称"
}, status_code=400)
task_manager = TaskManager()
result = task_manager.get_status(project_name)
logger.info(
f"获取任务状态: 项目={project_name}, 总数={result.get('metadata', {}).get('total', 0)}")
return JSONResponse({
"success": True,
**result
})
except Exception as e:
logger.exception(f"获取任务状态异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_update_task_status(request: Request):
"""API: 更新任务状态"""
try:
from src.task_manager import TaskManager
data = await request.json()
project_name = data.get("project_name")
task_id = data.get("task_id")
new_status = data.get("status")
logger.info(
f"更新任务状态: 项目={project_name}, 任务={task_id}, 状态={new_status}")
if not all([project_name, task_id, new_status]):
return JSONResponse({
"success": False,
"error": "缺少必要参数"
}, status_code=400)
task_manager = TaskManager()
if new_status == "pending":
result = task_manager.reset_task(task_id, project_name)
logger.info(f"任务已重置: {task_id}")
elif new_status == "failed":
error_msg = data.get("error_message", "手动标记为失败")
result = task_manager.fail_task(task_id, project_name, error_msg)
logger.warning(f"任务标记为失败: {task_id}, 原因: {error_msg}")
else:
logger.warning(f"不支持的状态: {new_status}")
return JSONResponse({
"success": False,
"error": f"不支持的状态: {new_status}"
}, status_code=400)
return JSONResponse(result)
except Exception as e:
logger.exception(f"更新任务状态异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_get_task_content(request: Request):
"""API: 获取任务的原文和改写内容"""
try:
from src.task_manager import TaskManager
project_name = request.query_params.get("project_name")
task_id = request.query_params.get("task_id")
logger.info(f"获取任务内容: 项目={project_name}, 任务={task_id}")
if not all([project_name, task_id]):
logger.warning("获取任务内容失败: 缺少必要参数")
return JSONResponse({
"success": False,
"error": "缺少必要参数"
}, status_code=400)
task_manager = TaskManager()
data = task_manager.load_tasks(project_name)
task = next((t for t in data["tasks"] if t["id"] == task_id), None)
if not task:
logger.warning(f"任务不存在: 项目={project_name}, 任务={task_id}")
return JSONResponse({
"success": False,
"error": "任务不存在"
}, status_code=404)
work_dir = workspace_manager.get_work_dir()
# 读取原文
source_path = work_dir / task["source_path"]
source_content = source_path.read_text(
encoding='utf-8') if source_path.exists() else ""
# 读取改写内容
output_path = work_dir / task["output_path"]
rewrite_content = output_path.read_text(
encoding='utf-8') if output_path.exists() else ""
logger.info(
f"任务内容获取成功: 项目={project_name}, 任务={task_id}, 原文={len(source_content)}字, 改写={len(rewrite_content)}字")
return JSONResponse({
"success": True,
"task": task,
"source_content": source_content,
"rewrite_content": rewrite_content
})
except Exception as e:
logger.exception(f"获取任务内容异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_check_timeout(request: Request):
"""API: 手动检查超时任务"""
try:
from src.task_manager import TaskManager
from config import TASK_TIMEOUT_MINUTES
timeout_minutes = int(request.query_params.get(
"timeout_minutes", TASK_TIMEOUT_MINUTES))
logger.info(f"开始检查超时任务: 超时阈值={timeout_minutes}分钟")
task_manager = TaskManager()
result = task_manager.check_all_projects_timeout(
timeout_minutes=timeout_minutes)
logger.info(f"超时检查完成: 检查项目数={result.get('projects_count', 0)}, " +
f"超时任务数={result.get('checked_count', 0)}, " +
f"已完成={result.get('completed_count', 0)}, " +
f"已重置={result.get('recovered_count', 0)}")
return JSONResponse({
"success": True,
**result
})
except Exception as e:
logger.exception(f"检查超时任务异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_set_default_prompt(request: Request):
"""API: 设置默认提示词"""
try:
import shutil
from config import DEFAULT_PROMPT_FILE_NAME
data = await request.json()
file_name = data.get("fileName")
if not file_name:
return JSONResponse({
"success": False,
"error": "缺少文件名"
}, status_code=400)
prompts_dir = workspace_manager.get_prompts_dir()
source_file = prompts_dir / file_name
default_file = prompts_dir / DEFAULT_PROMPT_FILE_NAME
if not source_file.exists():
return JSONResponse({
"success": False,
"error": "源文件不存在"
}, status_code=404)
# 如果源文件就是默认文件,直接返回成功
if source_file == default_file:
return JSONResponse({
"success": True,
"message": f"{file_name} 已经是默认提示词"
})
# 复制文件内容到默认提示词文件
shutil.copy2(source_file, default_file)
return JSONResponse({
"success": True,
"message": f"已将 {file_name} 设置为默认提示词"
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_delete_prompt(request: Request):
"""API: 删除提示词"""
try:
import shutil
from config import DEFAULT_PROMPT_FILE_NAME
data = await request.json()
file_name = data.get("fileName")
if not file_name:
return JSONResponse({
"success": False,
"error": "缺少文件名"
}, status_code=400)
prompts_dir = workspace_manager.get_prompts_dir()
file_to_delete = prompts_dir / file_name
default_file = prompts_dir / DEFAULT_PROMPT_FILE_NAME
if not file_to_delete.exists():
return JSONResponse({
"success": False,
"error": "文件不存在"
}, status_code=404)
# 检查是否是默认文件,不允许直接删除默认文件
if file_to_delete == default_file:
return JSONResponse({
"success": False,
"error": "不能删除默认提示词文件"
}, status_code=400)
# 获取所有提示词文件(不包括默认文件)
files = []
for ext in PROMPT_FILE_EXTENSIONS:
for file_path in prompts_dir.glob(ext):
if file_path.name != DEFAULT_PROMPT_FILE_NAME:
files.append(file_path)
# 如果只有一个文件,不允许删除
if len(files) <= 1:
return JSONResponse({
"success": False,
"error": "至少需要保留一个提示词文件"
}, status_code=400)
# 读取默认文件内容和要删除的文件内容
default_content = None
file_to_delete_content = None
if default_file.exists():
try:
for encoding in SUPPORTED_ENCODINGS:
try:
default_content = default_file.read_text(
encoding=encoding)
break
except (UnicodeDecodeError, LookupError):
continue
except Exception:
pass
try:
for encoding in SUPPORTED_ENCODINGS:
try:
file_to_delete_content = file_to_delete.read_text(
encoding=encoding)
break
except (UnicodeDecodeError, LookupError):
continue
except Exception:
pass
# 检查要删除的文件是否是当前默认提示词(内容相同)
is_current_default = (default_content is not None and
file_to_delete_content is not None and
default_content == file_to_delete_content)
# 删除文件
file_to_delete.unlink()
# 如果删除的是当前默认提示词,设置第一个可用文件为默认
if is_current_default:
# 重新获取文件列表(已删除当前文件)
remaining_files = []
for ext in PROMPT_FILE_EXTENSIONS:
for file_path in prompts_dir.glob(ext):
if file_path.name != DEFAULT_PROMPT_FILE_NAME:
remaining_files.append(file_path)
if remaining_files:
# 按文件名排序,选择第一个
remaining_files.sort(key=lambda x: x.name)
new_default = remaining_files[0]
shutil.copy2(new_default, default_file)
return JSONResponse({
"success": True,
"message": f"已删除 {file_name},并将 {new_default.name} 设置为新的默认提示词"
})
return JSONResponse({
"success": True,
"message": f"已删除 {file_name}"
})
except Exception as e:
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_update_project_prompt(request: Request):
"""API: 更新项目的提示词配置"""
try:
from src.task_manager import TaskManager
data = await request.json()
project_name = data.get("project_name")
prompt_file = data.get("prompt_file")
logger.info(f"更新项目提示词: 项目={project_name}, 提示词={prompt_file}")
if not all([project_name, prompt_file]):
logger.warning("更新项目提示词失败: 缺少必要参数")
return JSONResponse({
"success": False,
"error": "缺少必要参数"
}, status_code=400)
task_manager = TaskManager()
result = task_manager.update_project_prompt(project_name, prompt_file)
if result["success"]:
logger.info(f"项目提示词更新成功: 项目={project_name}, 提示词={prompt_file}")
return JSONResponse(result)
else:
logger.warning(f"项目提示词更新失败: {result.get('message', '未知错误')}")
return JSONResponse(result, status_code=404)
except Exception as e:
logger.exception(f"更新项目提示词异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
async def api_export_combined(request: Request):
"""API: 导出合并的改写文件"""
try:
from starlette.responses import Response
from datetime import datetime
from urllib.parse import quote
project_name = request.query_params.get("project_name")
logger.info(f"收到导出请求,项目: {project_name}")
if not project_name:
logger.warning("导出失败: 缺少项目名称")
return JSONResponse({
"success": False,
"error": "缺少项目名称"
}, status_code=400)
# 获取项目的改写输出目录
output_dir = workspace_manager.get_output_dir()
project_output_dir = output_dir / project_name / "rewrite"
logger.info(f"导出目录: {project_output_dir}")
if not project_output_dir.exists():
logger.error(f"导出失败: 目录不存在 {project_output_dir}")
return JSONResponse({
"success": False,
"error": f"项目 {project_name} 的改写目录不存在"
}, status_code=404)
# 合并文件
logger.info(f"开始合并文件...")
result = NovelProcessor.combine_novels(project_output_dir)
if not result["success"]:
logger.error(f"合并失败: {result.get('error', '未知错误')}")
return JSONResponse({
"success": False,
"error": result.get("error", "合并失败")
}, status_code=400)
logger.info(
f"合并成功: 找到 {result['filesFound']} 个文件, 总字数 {result['totalWords']}")
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{project_name}_合并_{timestamp}.txt"
# URL 编码文件名(用于 Content-Disposition)
encoded_filename = quote(filename)
# 返回文件
content = result["content"]
# 处理格式:将 [[角色名]]"对话内容" 转换为 "<<角色名>>对话内容"
import re
content = re.sub(r'\[\[([^\]]+)\]\]"', r'"<<\1>>', content)
logger.info(f"开始发送文件: {filename} ({len(content)} 字符)")
return Response(
content=content.encode('utf-8'),
media_type='text/plain; charset=utf-8',
headers={
'Content-Disposition': f'attachment; filename*=UTF-8\'\'{encoded_filename}',
'Content-Type': 'text/plain; charset=utf-8'
}
)
except Exception as e:
logger.exception(f"导出异常: {str(e)}")
return JSONResponse({
"success": False,
"error": str(e)
}, status_code=500)
# ==================== 路由配置 ====================
def get_api_routes():
"""获取所有 API 路由"""
return [
Route("/api/split", api_split_novel, methods=["POST"]),
Route("/api/init_tasks", api_init_tasks, methods=["POST"]),
Route("/api/projects", api_list_projects, methods=["GET"]),
Route("/api/tasks", api_list_tasks, methods=["GET"]),
Route("/api/tasks/status", api_get_task_status, methods=["GET"]),
Route("/api/tasks/update", api_update_task_status, methods=["POST"]),
Route("/api/tasks/content", api_get_task_content, methods=["GET"]),
Route("/api/tasks/check_timeout", api_check_timeout, methods=["GET"]),
Route("/api/projects/update_prompt",
api_update_project_prompt, methods=["POST"]),
Route("/api/projects/export", api_export_combined, methods=["GET"]),
Route("/api/prompts", api_list_prompts, methods=["GET"]),
Route("/api/prompts/get", api_get_prompt, methods=["GET"]),
Route("/api/prompts/save", api_save_prompt, methods=["POST"]),
Route("/api/prompts/set_default",
api_set_default_prompt, methods=["POST"]),
Route("/api/prompts/delete", api_delete_prompt, methods=["POST"]),
]
def get_static_routes():
"""获取静态文件路由"""
routes = []
if STATIC_DIR.exists():
routes.append(
Mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
)
routes.append(Route("/", redirect_root))
return routes
def get_all_routes(mcp_app, mcp_mount_path: str = "/mcp"):
"""
获取所有路由配置
Args:
mcp_app: MCP 应用实例
mcp_mount_path: MCP 挂载路径
Returns:
所有路由列表
"""
routes = []
# 添加 API 路由
routes.extend(get_api_routes())
# 添加静态文件路由
routes.extend(get_static_routes())
# 添加 MCP 路由
routes.append(Mount(mcp_mount_path, app=mcp_app))
return routes
def add_middleware(app):
"""添加中间件到应用"""
from .middleware import WorkspaceHeaderMiddleware
app.add_middleware(WorkspaceHeaderMiddleware)