"""知识库管理 API 路由"""
import asyncio
import os
import traceback
import uuid
import aiofiles
from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from ..utils import generate_tsne_visualization
from .route import Response, Route, RouteContext
class KnowledgeBaseRoute(Route):
"""知识库管理路由
提供知识库、文档、检索、会话配置等 API 接口
"""
def __init__(
self,
context: RouteContext,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.kb_manager = None # 延迟初始化
self.kb_db = None
self.session_config_db = None # 会话配置数据库
self.retrieval_manager = None
self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}}
self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}}
# 注册路由
self.routes = {
# 知识库管理
"/kb/list": ("GET", self.list_kbs),
"/kb/create": ("POST", self.create_kb),
"/kb/get": ("GET", self.get_kb),
"/kb/update": ("POST", self.update_kb),
"/kb/delete": ("POST", self.delete_kb),
"/kb/stats": ("GET", self.get_kb_stats),
# 文档管理
"/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document),
"/kb/document/import": ("POST", self.import_documents),
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document),
"/kb/document/delete": ("POST", self.delete_document),
# # 块管理
"/kb/chunk/list": ("GET", self.list_chunks),
"/kb/chunk/delete": ("POST", self.delete_chunk),
# # 多媒体管理
# "/kb/media/list": ("GET", self.list_media),
# "/kb/media/delete": ("POST", self.delete_media),
# 检索
"/kb/retrieve": ("POST", self.retrieve),
}
self.register_routes()
def _get_kb_manager(self):
return self.core_lifecycle.kb_manager
def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def _set_task_result(
self, task_id: str, status: str, result: any = None, error: str | None = None
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
p = self.upload_progress[task_id]
if status is not None:
p["status"] = status
if file_index is not None:
p["file_index"] = file_index
if file_name is not None:
p["file_name"] = file_name
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int):
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
async def _background_upload_task(
self,
task_id: str,
kb_helper,
files_to_upload: list,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台上传任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(files_to_upload),
"stage": "waiting",
"current": 0,
"total": 100,
}
uploaded_docs = []
failed_docs = []
for file_idx, file_info in enumerate(files_to_upload):
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_info["file_name"],
stage="parsing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_info["file_name"]
)
doc = await kb_helper.upload_document(
file_name=file_info["file_name"],
file_content=file_info["file_content"],
file_type=file_info["file_type"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
failed_docs.append(
{"file_name": file_info["file_name"], "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(files_to_upload),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def _background_import_task(
self,
task_id: str,
kb_helper,
documents: list,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台导入预切片文档任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
}
uploaded_docs = []
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_name
)
# 调用 upload_document,传入 pre_chunked_text
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None, # 预切片模式下不需要原始内容
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"导入文档 {file_name} 失败: {e}")
failed_docs.append(
{"file_name": file_name, "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self):
"""获取知识库列表
Query 参数:
- page: 页码 (默认 1)
- page_size: 每页数量 (默认 20)
- refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true)
"""
try:
kb_manager = self._get_kb_manager()
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
kbs = await kb_manager.list_kbs()
# 转换为字典列表
kb_list = []
for kb in kbs:
kb_list.append(kb.model_dump())
return (
Response()
.ok({"items": kb_list, "page": page, "page_size": page_size})
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取知识库列表失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取知识库列表失败: {e!s}").__dict__
async def create_kb(self):
"""创建知识库
Body:
- kb_name: 知识库名称 (必填)
- description: 描述 (可选)
- emoji: 图标 (可选)
- embedding_provider_id: 嵌入模型提供商ID (可选)
- rerank_provider_id: 重排序模型提供商ID (可选)
- chunk_size: 分块大小 (可选, 默认512)
- chunk_overlap: 块重叠大小 (可选, 默认50)
- top_k_dense: 密集检索数量 (可选, 默认50)
- top_k_sparse: 稀疏检索数量 (可选, 默认50)
- top_m_final: 最终返回数量 (可选, 默认5)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_name = data.get("kb_name")
if not kb_name:
return Response().error("知识库名称不能为空").__dict__
description = data.get("description")
emoji = data.get("emoji")
embedding_provider_id = data.get("embedding_provider_id")
rerank_provider_id = data.get("rerank_provider_id")
chunk_size = data.get("chunk_size")
chunk_overlap = data.get("chunk_overlap")
top_k_dense = data.get("top_k_dense")
top_k_sparse = data.get("top_k_sparse")
top_m_final = data.get("top_m_final")
# pre-check embedding dim
if not embedding_provider_id:
return Response().error("缺少参数 embedding_provider_id").__dict__
prv = await kb_manager.provider_manager.get_provider_by_id(
embedding_provider_id,
) # type: ignore
if not prv or not isinstance(prv, EmbeddingProvider):
return (
Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__
)
try:
vec = await prv.get_embedding("astrbot")
if len(vec) != prv.get_dim():
raise ValueError(
f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}",
)
except Exception as e:
return Response().error(f"测试嵌入模型失败: {e!s}").__dict__
# pre-check rerank
if rerank_provider_id:
rerank_prv: RerankProvider = (
await kb_manager.provider_manager.get_provider_by_id(
rerank_provider_id,
)
) # type: ignore
if not rerank_prv:
return Response().error("重排序模型不存在").__dict__
# 检查重排序模型可用性
try:
res = await rerank_prv.rerank(
query="astrbot",
documents=["astrbot knowledge base"],
)
if not res:
raise ValueError("重排序模型返回结果异常")
except Exception as e:
return (
Response()
.error(f"测试重排序模型失败: {e!s},请检查平台日志输出。")
.__dict__
)
kb_helper = await kb_manager.create_kb(
kb_name=kb_name,
description=description,
emoji=emoji,
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
top_k_dense=top_k_dense,
top_k_sparse=top_k_sparse,
top_m_final=top_m_final,
)
kb = kb_helper.kb
return Response().ok(kb.model_dump(), "创建知识库成功").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"创建知识库失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"创建知识库失败: {e!s}").__dict__
async def get_kb(self):
"""获取知识库详情
Query 参数:
- kb_id: 知识库 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
kb_id = request.args.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
kb = kb_helper.kb
return Response().ok(kb.model_dump()).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取知识库详情失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取知识库详情失败: {e!s}").__dict__
async def update_kb(self):
"""更新知识库
Body:
- kb_id: 知识库 ID (必填)
- kb_name: 新的知识库名称 (可选)
- description: 新的描述 (可选)
- emoji: 新的图标 (可选)
- embedding_provider_id: 新的嵌入模型提供商ID (可选)
- rerank_provider_id: 新的重排序模型提供商ID (可选)
- chunk_size: 分块大小 (可选)
- chunk_overlap: 块重叠大小 (可选)
- top_k_dense: 密集检索数量 (可选)
- top_k_sparse: 稀疏检索数量 (可选)
- top_m_final: 最终返回数量 (可选)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id = data.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
kb_name = data.get("kb_name")
description = data.get("description")
emoji = data.get("emoji")
embedding_provider_id = data.get("embedding_provider_id")
rerank_provider_id = data.get("rerank_provider_id")
chunk_size = data.get("chunk_size")
chunk_overlap = data.get("chunk_overlap")
top_k_dense = data.get("top_k_dense")
top_k_sparse = data.get("top_k_sparse")
top_m_final = data.get("top_m_final")
# 检查是否至少提供了一个更新字段
if all(
v is None
for v in [
kb_name,
description,
emoji,
embedding_provider_id,
rerank_provider_id,
chunk_size,
chunk_overlap,
top_k_dense,
top_k_sparse,
top_m_final,
]
):
return Response().error("至少需要提供一个更新字段").__dict__
kb_helper = await kb_manager.update_kb(
kb_id=kb_id,
kb_name=kb_name,
description=description,
emoji=emoji,
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
top_k_dense=top_k_dense,
top_k_sparse=top_k_sparse,
top_m_final=top_m_final,
)
if not kb_helper:
return Response().error("知识库不存在").__dict__
kb = kb_helper.kb
return Response().ok(kb.model_dump(), "更新知识库成功").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"更新知识库失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"更新知识库失败: {e!s}").__dict__
async def delete_kb(self):
"""删除知识库
Body:
- kb_id: 知识库 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id = data.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
success = await kb_manager.delete_kb(kb_id)
if not success:
return Response().error("知识库不存在").__dict__
return Response().ok(message="删除知识库成功").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除知识库失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"删除知识库失败: {e!s}").__dict__
async def get_kb_stats(self):
"""获取知识库统计信息
Query 参数:
- kb_id: 知识库 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
kb_id = request.args.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
kb = kb_helper.kb
stats = {
"kb_id": kb.kb_id,
"kb_name": kb.kb_name,
"doc_count": kb.doc_count,
"chunk_count": kb.chunk_count,
"created_at": kb.created_at.isoformat(),
"updated_at": kb.updated_at.isoformat(),
}
return Response().ok(stats).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取知识库统计失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取知识库统计失败: {e!s}").__dict__
# ===== 文档管理 API =====
async def list_documents(self):
"""获取文档列表
Query 参数:
- kb_id: 知识库 ID (必填)
- page: 页码 (默认 1)
- page_size: 每页数量 (默认 20)
"""
try:
kb_manager = self._get_kb_manager()
kb_id = request.args.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 100, type=int)
offset = (page - 1) * page_size
limit = page_size
doc_list = await kb_helper.list_documents(offset=offset, limit=limit)
doc_list = [doc.model_dump() for doc in doc_list]
return (
Response()
.ok({"items": doc_list, "page": page, "page_size": page_size})
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取文档列表失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取文档列表失败: {e!s}").__dict__
async def upload_document(self):
"""上传文档
支持两种方式:
1. multipart/form-data 文件上传(支持多文件,最多10个)
2. JSON 格式 base64 编码上传(支持多文件,最多10个)
Form Data (multipart/form-data):
- kb_id: 知识库 ID (必填)
- file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[])
JSON Body (application/json):
- kb_id: 知识库 ID (必填)
- files: 文件数组 (必填)
- file_name: 文件名 (必填)
- file_content: base64 编码的文件内容 (必填)
返回:
- task_id: 任务ID,用于查询上传进度和结果
"""
try:
kb_manager = self._get_kb_manager()
# 检查 Content-Type
content_type = request.content_type
kb_id = None
chunk_size = None
chunk_overlap = None
batch_size = 32
tasks_limit = 3
max_retries = 3
files_to_upload = [] # 存储待上传的文件信息列表
if content_type and "multipart/form-data" not in content_type:
return (
Response().error("Content-Type 须为 multipart/form-data").__dict__
)
form_data = await request.form
files = await request.files
kb_id = form_data.get("kb_id")
chunk_size = int(form_data.get("chunk_size", 512))
chunk_overlap = int(form_data.get("chunk_overlap", 50))
batch_size = int(form_data.get("batch_size", 32))
tasks_limit = int(form_data.get("tasks_limit", 3))
max_retries = int(form_data.get("max_retries", 3))
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
# 收集所有文件
file_list = []
# 支持 file, file1, file2, ... 或 files[] 格式
for key in files.keys():
if key == "file" or key.startswith("file") or key == "files[]":
file_items = files.getlist(key)
file_list.extend(file_items)
if not file_list:
return Response().error("缺少文件").__dict__
# 限制文件数量
if len(file_list) > 10:
return Response().error("最多只能上传10个文件").__dict__
# 处理每个文件
for file in file_list:
file_name = file.filename
# 保存到临时文件
temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
await file.save(temp_file_path)
try:
# 异步读取文件内容
async with aiofiles.open(temp_file_path, "rb") as f:
file_content = await f.read()
# 提取文件类型
file_type = (
file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
)
files_to_upload.append(
{
"file_name": file_name,
"file_content": file_content,
"file_type": file_type,
},
)
finally:
# 清理临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_upload_task(
task_id=task_id,
kb_helper=kb_helper,
files_to_upload=files_to_upload,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"file_count": len(files_to_upload),
"message": "task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"上传文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__
def _validate_import_request(self, data: dict):
kb_id = data.get("kb_id")
if not kb_id:
raise ValueError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise ValueError("缺少参数 documents 或格式错误")
for doc in documents:
if "file_name" not in doc or "chunks" not in doc:
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
if not isinstance(doc["chunks"], list):
raise ValueError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise ValueError("chunks 必须是非空字符串列表")
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
return kb_id, documents, batch_size, tasks_limit, max_retries
async def import_documents(self):
"""导入预切片文档
Body:
- kb_id: 知识库 ID (必填)
- documents: 文档列表 (必填)
- file_name: 文件名 (必填)
- chunks: 切片列表 (必填, list[str])
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id, documents, batch_size, tasks_limit, max_retries = (
self._validate_import_request(data)
)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"导入文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入文档失败: {e!s}").__dict__
async def get_upload_progress(self):
"""获取上传进度和结果
Query 参数:
- task_id: 任务 ID (必填)
返回状态:
- pending: 任务待处理
- processing: 任务处理中
- completed: 任务完成
- failed: 任务失败
"""
try:
task_id = request.args.get("task_id")
if not task_id:
return Response().error("缺少参数 task_id").__dict__
# 检查任务是否存在
if task_id not in self.upload_tasks:
return Response().error("找不到该任务").__dict__
task_info = self.upload_tasks[task_id]
status = task_info["status"]
# 构建返回数据
response_data = {
"task_id": task_id,
"status": status,
}
# 如果任务正在处理,返回进度信息
if status == "processing" and task_id in self.upload_progress:
response_data["progress"] = self.upload_progress[task_id]
# 如果任务完成,返回结果
if status == "completed":
response_data["result"] = task_info["result"]
# 清理已完成的任务
# del self.upload_tasks[task_id]
# if task_id in self.upload_progress:
# del self.upload_progress[task_id]
# 如果任务失败,返回错误信息
if status == "failed":
response_data["error"] = task_info["error"]
return Response().ok(response_data).__dict__
except Exception as e:
logger.error(f"获取上传进度失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取上传进度失败: {e!s}").__dict__
async def get_document(self):
"""获取文档详情
Query 参数:
- doc_id: 文档 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
kb_id = request.args.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
doc_id = request.args.get("doc_id")
if not doc_id:
return Response().error("缺少参数 doc_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
doc = await kb_helper.get_document(doc_id)
if not doc:
return Response().error("文档不存在").__dict__
return Response().ok(doc.model_dump()).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取文档详情失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取文档详情失败: {e!s}").__dict__
async def delete_document(self):
"""删除文档
Body:
- kb_id: 知识库 ID (必填)
- doc_id: 文档 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id = data.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
doc_id = data.get("doc_id")
if not doc_id:
return Response().error("缺少参数 doc_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
await kb_helper.delete_document(doc_id)
return Response().ok(message="删除文档成功").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"删除文档失败: {e!s}").__dict__
async def delete_chunk(self):
"""删除文本块
Body:
- kb_id: 知识库 ID (必填)
- chunk_id: 块 ID (必填)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id = data.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
chunk_id = data.get("chunk_id")
if not chunk_id:
return Response().error("缺少参数 chunk_id").__dict__
doc_id = data.get("doc_id")
if not doc_id:
return Response().error("缺少参数 doc_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
await kb_helper.delete_chunk(chunk_id, doc_id)
return Response().ok(message="删除文本块成功").__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除文本块失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"删除文本块失败: {e!s}").__dict__
async def list_chunks(self):
"""获取块列表
Query 参数:
- kb_id: 知识库 ID (必填)
- page: 页码 (默认 1)
- page_size: 每页数量 (默认 20)
"""
try:
kb_manager = self._get_kb_manager()
kb_id = request.args.get("kb_id")
doc_id = request.args.get("doc_id")
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 100, type=int)
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
if not doc_id:
return Response().error("缺少参数 doc_id").__dict__
kb_helper = await kb_manager.get_kb(kb_id)
offset = (page - 1) * page_size
limit = page_size
if not kb_helper:
return Response().error("知识库不存在").__dict__
chunk_list = await kb_helper.get_chunks_by_doc_id(
doc_id=doc_id,
offset=offset,
limit=limit,
)
return (
Response()
.ok(
data={
"items": chunk_list,
"page": page,
"page_size": page_size,
"total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"获取块列表失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"获取块列表失败: {e!s}").__dict__
# ===== 检索 API =====
async def retrieve(self):
"""检索知识库
Body:
- query: 查询文本 (必填)
- kb_ids: 知识库 ID 列表 (必填)
- top_k: 返回结果数量 (可选, 默认 5)
- debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
query = data.get("query")
kb_names = data.get("kb_names")
debug = data.get("debug", False)
if not query:
return Response().error("缺少参数 query").__dict__
if not kb_names or not isinstance(kb_names, list):
return Response().error("缺少参数 kb_names 或格式错误").__dict__
top_k = data.get("top_k", 5)
results = await kb_manager.retrieve(
query=query,
kb_names=kb_names,
top_m_final=top_k,
)
result_list = []
if results:
result_list = results["results"]
response_data = {
"results": result_list,
"total": len(result_list),
"query": query,
}
# Debug 模式:生成 t-SNE 可视化
if debug:
try:
img_base64 = await generate_tsne_visualization(
query,
kb_names,
kb_manager,
)
if img_base64:
response_data["visualization"] = img_base64
except Exception as e:
logger.error(f"生成 t-SNE 可视化失败: {e}")
logger.error(traceback.format_exc())
response_data["visualization_error"] = str(e)
return Response().ok(response_data).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"检索失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"检索失败: {e!s}").__dict__
async def upload_document_from_url(self):
"""从 URL 上传文档
Body:
- kb_id: 知识库 ID (必填)
- url: 要提取内容的网页 URL (必填)
- chunk_size: 分块大小 (可选, 默认512)
- chunk_overlap: 块重叠大小 (可选, 默认50)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
返回:
- task_id: 任务ID,用于查询上传进度和结果
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id = data.get("kb_id")
if not kb_id:
return Response().error("缺少参数 kb_id").__dict__
url = data.get("url")
if not url:
return Response().error("缺少参数 url").__dict__
chunk_size = data.get("chunk_size", 512)
chunk_overlap = data.get("chunk_overlap", 50)
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
enable_cleaning = data.get("enable_cleaning", False)
cleaning_provider_id = data.get("cleaning_provider_id")
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_upload_from_url_task(
task_id=task_id,
kb_helper=kb_helper,
url=url,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
enable_cleaning=enable_cleaning,
cleaning_provider_id=cleaning_provider_id,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"url": url,
"message": "URL upload task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"从URL上传文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"从URL上传文档失败: {e!s}").__dict__
async def _background_upload_from_url_task(
self,
task_id: str,
kb_helper,
url: str,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
tasks_limit: int,
max_retries: int,
enable_cleaning: bool,
cleaning_provider_id: str | None,
):
"""后台上传URL任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": 1,
"file_name": f"URL: {url}",
"stage": "extracting",
"current": 0,
"total": 100,
}
# 创建进度回调函数
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
# 上传文档
doc = await kb_helper.upload_from_url(
url=url,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
enable_cleaning=enable_cleaning,
cleaning_provider_id=cleaning_provider_id,
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": [doc.model_dump()],
"failed": [],
"total": 1,
"success_count": 1,
"failed_count": 0,
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))