import json
import traceback
from datetime import datetime
from io import BytesIO
from quart import request, send_file
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from .route import Response, Route, RouteContext
class ConversationRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/conversation/list": ("GET", self.list_conversations),
"/conversation/detail": (
"POST",
self.get_conv_detail,
),
"/conversation/update": ("POST", self.upd_conv),
"/conversation/delete": ("POST", self.del_conv),
"/conversation/update_history": (
"POST",
self.update_history,
),
"/conversation/export": ("POST", self.export_conversations),
}
self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
self.register_routes()
async def list_conversations(self):
"""获取对话列表,支持分页、排序和筛选"""
try:
# 获取分页参数
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 获取筛选参数
platforms = request.args.get("platforms", "")
message_types = request.args.get("message_types", "")
search_query = request.args.get("search", "")
exclude_ids = request.args.get("exclude_ids", "")
exclude_platforms = request.args.get("exclude_platforms", "")
# 转换为列表
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
page = max(page, 1)
if page_size < 1:
page_size = 20
page_size = min(page_size, 100)
try:
(
conversations,
total_count,
) = await self.conv_mgr.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
except Exception as e:
logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}")
return Response().error(f"数据库查询出错: {e!s}").__dict__
# 计算总页数
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
result = {
"conversations": conversations,
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
return Response().ok(result).__dict__
except Exception as e:
error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取对话列表失败: {e!s}").__dict__
async def get_conv_detail(self):
"""获取指定对话详情(通过POST请求)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
return (
Response()
.ok(
{
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
},
)
.__dict__
)
except Exception as e:
logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"获取对话详情失败: {e!s}").__dict__
async def upd_conv(self):
"""更新对话信息(标题和角色ID)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
persona_id = data.get("persona_id", "")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
if title is not None or persona_id is not None:
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
title=title,
persona_id=persona_id,
)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话信息失败: {e!s}").__dict__
async def del_conv(self):
"""删除对话"""
try:
data = await request.get_json()
# 检查是否是批量删除
if "conversations" in data:
# 批量删除
conversations = data.get("conversations", [])
if not conversations:
return (
Response().error("批量删除时conversations参数不能为空").__dict__
)
deleted_count = 0
failed_items = []
for conv in conversations:
user_id = conv.get("user_id")
cid = conv.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
deleted_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
message = f"成功删除 {deleted_count} 个对话"
if failed_items:
message += f",失败 {len(failed_items)} 个"
return (
Response()
.ok(
{
"message": message,
"deleted_count": deleted_count,
"failed_count": len(failed_items),
"failed_items": failed_items,
},
)
.__dict__
)
# 单个删除
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
await self.core_lifecycle.conversation_manager.delete_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"删除对话失败: {e!s}").__dict__
async def update_history(self):
"""更新对话历史内容"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
history = data.get("history")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
if history is None:
return Response().error("缺少必要参数: history").__dict__
# 历史记录必须是合法的 JSON 字符串
try:
if isinstance(history, list):
history = json.dumps(history)
else:
# 验证是否为有效的 JSON 字符串
json.loads(history)
except json.JSONDecodeError:
return (
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
return Response().error("对话不存在").__dict__
history = json.loads(history) if isinstance(history, str) else history
await self.conv_mgr.update_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
history=history,
)
return Response().ok({"message": "对话历史更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__
async def export_conversations(self):
"""批量导出对话为 JSONL 格式"""
try:
data = await request.get_json()
conversations_to_export = data.get("conversations", [])
if not conversations_to_export:
return Response().error("导出列表不能为空").__dict__
# 收集所有对话的内容
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 对话不存在"
)
continue
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
content = json.loads(conversation.history)
# 创建导出记录
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
# 将记录转换为 JSON 字符串并添加到 JSONL
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
)
if exported_count == 0:
return Response().error("没有成功导出任何对话").__dict__
# 创建 JSONL 内容
jsonl_content = "\n".join(jsonl_lines)
# 创建一个内存文件对象
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
# 返回文件流
return await send_file(
file_obj,
mimetype="application/jsonl",
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"批量导出对话失败: {e!s}").__dict__