"""
多文档RAG引擎 - 支持多个PPT文档的统一索引管理
整体架构:创建空index -> 不断增加多个ppt到index -> index缓存 -> 引擎初始化自动从缓存加载index用于查询
"""
import os
import chromadb
import pickle
from pathlib import Path
from typing import List, Dict, Optional, Any, Set
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import hashlib
# LlamaIndex 导入
from llama_index.core import Settings, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.core.schema import TextNode, NodeWithScore, MetadataMode
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.prompts import PromptTemplate
from llama_index.core.base.response.schema import Response
from llama_index.core.vector_stores.types import MetadataFilter, MetadataFilters, FilterOperator
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.openai import OpenAIEmbedding
from doubao import DoubaoVisionLLM
from ppt_utils import PPTUtils
logger = logging.getLogger(__name__)
class MultimodalQueryEngine(CustomQueryEngine):
"""结合文本检索和图像分析的自定义查询引擎。"""
def __init__(
self,
retriever: BaseRetriever,
doubao_llm: DoubaoVisionLLM,
qa_prompt: Optional[PromptTemplate] = None
):
# 默认问答提示模板
default_prompt = """以下是PPT幻灯片中解析的Markdown文本和图片信息。Markdown文本已经尝试将相关图表转换为表格。
优先使用图片信息来回答问题。在无法理解图像时才使用Markdown文本信息。
---------------------
{context_str}
---------------------
-- 根据上下文信息并且不依赖先验知识, 回答查询。
-- 解释你是从解析的markdown、还是图片中得到答案的, 如果有差异, 请说明最终答案的理由。
-- 尽可能详细的回答问题。
-- 给出你重点参考的图片路径和页码。
查询: {query_str}
答案: """
final_qa_prompt = qa_prompt or PromptTemplate(default_prompt)
# 调用父类构造函数
super().__init__()
# 设置属性
self._retriever = retriever
self._doubao_llm = doubao_llm
self._qa_prompt = final_qa_prompt
def custom_query(self, query_str: str) -> Response:
"""执行具有多模态理解的查询。"""
# 检索相关节点
nodes = self._retriever.retrieve(query_str)
if not nodes:
return Response(
response="抱歉,没有找到相关的PPT内容来回答您的问题。",
source_nodes=[],
metadata={}
)
# 从文本节点创建上下文字符串,包含文档名称
context_str = "\n\n".join([
f"文档: {Path(node.metadata['source']).name}, 页面 {node.metadata['page_num']}: {node.get_content(metadata_mode=MetadataMode.LLM)}\n"
f"来源图片: {node.metadata['image_path']}"
for node in nodes
])
# 格式化提示
fmt_prompt = self._qa_prompt.format(
context_str=context_str,
query_str=query_str
)
# 获取图片路径用于视觉分析
image_paths = [node.metadata["image_path"] for node in nodes]
# 使用豆包视觉LLM生成回答
try:
response_text = self._doubao_llm.generate_response(
prompt=fmt_prompt,
image_paths=image_paths
)
except Exception as e:
response_text = f"生成回答时出现错误: {str(e)}"
return Response(
response=response_text,
source_nodes=nodes,
metadata={
"num_sources": len(nodes),
"image_paths": image_paths,
"source_documents": list(set(node.metadata["source"] for node in nodes))
}
)
class MultiDocRAGEngine:
"""
多文档RAG引擎:
- 使用单个统一的向量索引管理多个PPT文档
- 支持增量添加文档到索引
- 支持从索引中删除特定文档
- 自动索引缓存和加载
- 支持全文档检索和文档特定检索
"""
def __init__(
self,
persist_dir: str = "./.multi_doc_chroma_db",
cache_dir: str = "./.cache",
collection_name: str = "multi_ppt_documents",
embedding_model: str = "text-embedding-3-small",
doubao_model: str = "ep-20250205153642-hzqpj",
top_k: int = 3
):
self.persist_dir = Path(persist_dir)
self.persist_dir.mkdir(parents=True, exist_ok=True)
self.collection_name = collection_name
self.top_k = top_k
# 索引存储路径
self.index_storage_dir = self.persist_dir / "unified_index"
self.index_storage_dir.mkdir(parents=True, exist_ok=True)
# 节点缓存目录
self.node_cache_dir = Path(cache_dir) / "multi_doc_nodes"
self.node_cache_dir.mkdir(parents=True, exist_ok=True)
# Markdown缓存目录
self.markdown_cache_dir = Path(cache_dir) / "parsed_markdown"
self.markdown_cache_dir.mkdir(parents=True, exist_ok=True)
# 初始化PPT处理工具
self.ppt_utils = PPTUtils(cache_dir=cache_dir)
# 初始化嵌入模型(始终使用OpenAI)
self.embed_model = OpenAIEmbedding(model=embedding_model)
Settings.embed_model = self.embed_model
logger.info(f"正在使用OpenAI嵌入模型: {embedding_model}")
# 初始化豆包视觉LLM
self.doubao_llm = DoubaoVisionLLM(model_name=doubao_model)
# 文档元数据缓存路径
self.docs_metadata_path = self.persist_dir / "docs_metadata.pkl"
# 向量存储和索引
self._vector_store = None
self._index = None
self._query_engine = None
# 图像解析提示
self.parse_prompt = """用中文提取图片中的详细信息,并使用Markdown格式化输出。
-- 对于其中的文字,使用OCR识别,并尽量保持原格式或类似格式输出。
-- 对于其中的表格与统计图表信息,选择表格结合文字的方式进行描述。
-- 对于其中的图形、图表、流程图等视觉元素,请用文字详细描述其内容和布局。
-- 对于其他有意义的图像部分,请使用文字描述。
-- 合理排版,使得输出内容清晰易懂。"""
# 初始化时自动加载索引
self._load_index()
def _initialize_vector_store(self):
"""初始化ChromaDB向量存储。"""
if self._vector_store is None:
# 初始化ChromaDB客户端
chroma_client = chromadb.PersistentClient(path=str(self.persist_dir))
chroma_collection = chroma_client.get_or_create_collection(self.collection_name)
# 创建向量存储
self._vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
logger.info(f"已初始化ChromaDB向量存储: {self.collection_name}")
return self._vector_store
def _load_index(self):
"""自动加载缓存的索引(如果存在)。"""
try:
if self.index_storage_dir.exists() and (self.index_storage_dir / "docstore.json").exists():
logger.info("正在加载缓存的统一索引")
# 初始化向量存储
vector_store = self._initialize_vector_store()
# 加载索引
storage_context = StorageContext.from_defaults(
persist_dir=str(self.index_storage_dir),
vector_store=vector_store
)
self._index = load_index_from_storage(storage_context=storage_context)
logger.info("成功加载缓存的统一索引")
else:
logger.info("未找到缓存索引,将按需创建空索引")
except Exception as e:
logger.warning(f"加载缓存索引失败: {e},将创建新索引")
self._index = None
def _ensure_index_exists(self):
"""确保索引存在,如果不存在则创建空索引。"""
if self._index is None:
logger.info("正在创建新的空索引")
# 初始化向量存储
vector_store = self._initialize_vector_store()
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# 创建空索引(没有节点)
self._index = VectorStoreIndex(
[],
storage_context=storage_context,
show_progress=False
)
# 立即持久化空索引
self._persist_index()
logger.info("已创建并持久化空索引")
def _persist_index(self):
"""持久化索引到存储。"""
if self._index is not None:
self._index.storage_context.persist(persist_dir=str(self.index_storage_dir))
logger.info(f"索引已持久化到: {self.index_storage_dir}")
def _get_node_cache_path(self, ppt_path: str) -> Path:
"""获取PPT文件的文本节点缓存路径。"""
file_hash = self._get_file_hash(ppt_path)
return self.node_cache_dir / f"{file_hash}_nodes.pkl"
def _get_markdown_cache_path(self, ppt_path: str) -> Path:
"""获取PPT文件的Markdown缓存路径。"""
file_hash = self._get_file_hash(ppt_path)
ppt_name = Path(ppt_path).stem
return self.markdown_cache_dir / f"{ppt_name}_{file_hash}.md"
def _get_file_hash(self, file_path: str) -> str:
"""计算文件的哈希值用于缓存。"""
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def _save_docs_metadata(self, docs_info: Dict[str, Dict]):
"""保存文档元数据(以doc_id为key)。"""
with open(self.docs_metadata_path, 'wb') as f:
pickle.dump(docs_info, f)
def _load_docs_metadata(self) -> Dict[str, Dict]:
"""加载文档元数据(以doc_id为key)。"""
if self.docs_metadata_path.exists():
try:
with open(self.docs_metadata_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载文档元数据失败: {e}")
return {}
def _get_doc_id_from_path(self, ppt_path: str) -> str:
"""从文件路径获取文档ID(hash值)。"""
return self._get_file_hash(ppt_path)
def _find_doc_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""根据doc_id查找文档信息。"""
docs_info = self._load_docs_metadata()
return docs_info.get(doc_id)
def is_document_indexed(self, ppt_path: str) -> bool:
"""
检查文档是否已经在索引中(推荐方法)。
Args:
ppt_path: PPT文件路径
Returns:
文档是否已索引
"""
try:
doc_id = self._get_doc_id_from_path(ppt_path)
return self._find_doc_by_id(doc_id) is not None
except Exception:
return False
def is_document_id_indexed(self, doc_id: str) -> bool:
"""
检查文档ID是否已经在索引中。
Args:
doc_id: 文档ID
Returns:
文档是否已索引
"""
return self._find_doc_by_id(doc_id) is not None
async def _parse_image_async(self, image_path: str, page_num: int) -> str:
"""使用豆包视觉LLM异步解析单张图片。"""
try:
content = await asyncio.to_thread(
self.doubao_llm.generate_response,
self.parse_prompt,
[image_path]
)
logger.info(f"已解析第{page_num}页图片: {len(content)} 个字符")
return content
except Exception as e:
logger.error(f"解析图片失败 {image_path}: {e}")
return f"解析图片失败: {str(e)}"
async def _save_parsed_markdown(self, ppt_path: str, image_paths: List[str], parsed_contents: List[str]) -> None:
"""将解析的内容保存到Markdown文件中。"""
try:
markdown_path = self._get_markdown_cache_path(ppt_path)
ppt_name = Path(ppt_path).name
# 创建Markdown内容
markdown_content = []
markdown_content.append(f"# {ppt_name} - 视觉解析结果")
markdown_content.append(f"\n生成时间: {Path(ppt_path).stat().st_mtime}")
markdown_content.append(f"总页数: {len(image_paths)}")
markdown_content.append(f"文档路径: {ppt_path}")
markdown_content.append("\n" + "="*80 + "\n")
# 添加每页的解析结果
for i, (image_path, content) in enumerate(zip(image_paths, parsed_contents), 1):
markdown_content.append(f"## 第 {i} 页")
markdown_content.append(f"**图片路径:** `{image_path}`")
markdown_content.append(f"**内容长度:** {len(content)} 字符")
markdown_content.append("\n### 解析内容:")
markdown_content.append(content)
markdown_content.append("\n" + "-"*50 + "\n")
# 写入文件
full_content = "\n".join(markdown_content)
with open(markdown_path, 'w', encoding='utf-8') as f:
f.write(full_content)
logger.info(f"已保存解析的Markdown到: {markdown_path}")
except Exception as e:
logger.error(f"保存Markdown解析结果失败 {ppt_path}: {e}")
async def _create_text_nodes(self, ppt_path: str, image_paths: List[str]) -> List[TextNode]:
"""从图片创建文本节点,包含关联的元数据。"""
logger.info(f"正在为{len(image_paths)}张图片创建文本节点")
# 检查API密钥是否配置
openai_key = os.getenv("OPENAI_API_KEY", "")
ark_key = os.getenv("ARK_API_KEY", "")
if not openai_key or not ark_key:
raise RuntimeError("必须配置OPENAI_API_KEY和ARK_API_KEY环境变量")
if openai_key.startswith(("test", "sk-test")) or ark_key.startswith("test"):
raise RuntimeError("API密钥不能是测试密钥,必须使用真实的API密钥")
# 并发解析所有图片
parse_tasks = [
self._parse_image_async(image_path, i + 1)
for i, image_path in enumerate(image_paths)
]
parsed_contents = await asyncio.gather(*parse_tasks)
# 保存解析结果到Markdown文件
await self._save_parsed_markdown(ppt_path, image_paths, parsed_contents)
# 创建文本节点,添加文档唯一标识符
doc_id = self._get_file_hash(ppt_path)
nodes = []
for i, (image_path, content) in enumerate(zip(image_paths, parsed_contents)):
node = TextNode(
text=content,
metadata={
"source": ppt_path,
"source_file_id": doc_id, # 文档唯一标识符
"doc_name": Path(ppt_path).name, # 文档名称
"page_num": i + 1,
"image_path": image_path,
"doc_type": "ppt_slide"
}
)
nodes.append(node)
logger.info(f"已创建{len(nodes)}个文本节点,文档ID: {doc_id}")
return nodes
async def add_ppt_document(self, ppt_path: str, force_reprocess: bool = False) -> Dict[str, Any]:
"""
向统一索引中添加PPT文档。
Args:
ppt_path: PPT文件路径
force_reprocess: 强制重新处理,即使已缓存
Returns:
添加结果字典
"""
ppt_path = str(Path(ppt_path).resolve())
if not Path(ppt_path).exists():
raise FileNotFoundError(f"PPT文件未找到:{ppt_path}")
# 检查文档是否已经在索引中 - 直接用doc_id检查
doc_id = self._get_doc_id_from_path(ppt_path)
existing_doc_info = self._find_doc_by_id(doc_id)
if not force_reprocess and existing_doc_info is not None:
return {
"status": "skipped",
"message": "文档已存在于索引中",
"source": ppt_path,
"doc_id": doc_id
}
logger.info(f"正在添加PPT文档到统一索引: {ppt_path}")
# 确保索引存在
self._ensure_index_exists()
# 获取节点缓存路径
node_cache_path = self._get_node_cache_path(ppt_path)
try:
# 检查是否有缓存的节点
if not force_reprocess and node_cache_path.exists():
logger.info(f"正在使用缓存的节点:{ppt_path}")
with open(node_cache_path, 'rb') as f:
nodes = pickle.load(f)
else:
# 完整处理流程
logger.info("正在通过完整流程处理PPT")
# 阶段1:PPT → 图片
image_paths = self.ppt_utils.ppt_to_images(ppt_path)
if not image_paths:
raise RuntimeError("PPT中未生成任何图片")
logger.info(f"从PPT生成了{len(image_paths)}张图片")
# 阶段2:图片 → 文本节点
nodes = await self._create_text_nodes(ppt_path, image_paths)
# 缓存节点
with open(node_cache_path, 'wb') as f:
pickle.dump(nodes, f)
logger.info(f"已缓存{len(nodes)}个节点")
logger.info(f"{nodes}")
# 阶段3:将节点添加到统一索引
logger.info(f"正在将{len(nodes)}个节点添加到统一索引")
self._index.insert_nodes(nodes)
# 持久化索引
self._persist_index()
# 更新文档元数据(使用doc_id作为key)
docs_info = self._load_docs_metadata()
docs_info[doc_id] = {
"file_path": ppt_path, # 文件路径作为字段
"doc_name": Path(ppt_path).name,
"pages": len(nodes),
"added_at": str(Path(ppt_path).stat().st_mtime),
"file_size": Path(ppt_path).stat().st_size
}
self._save_docs_metadata(docs_info)
# 重置查询引擎
self._query_engine = None
logger.info(f"成功将PPT添加到统一索引:{ppt_path}")
return {
"status": "success",
"message": "PPT已成功添加到统一索引",
"pages": len(nodes),
"source": ppt_path
}
except Exception as e:
logger.error(f"添加PPT失败 {ppt_path}: {e}")
return {
"status": "error",
"message": f"添加PPT失败:{str(e)}",
"source": ppt_path
}
def remove_ppt_document(self, ppt_path: str) -> Dict[str, Any]:
"""
从统一索引中删除PPT文档。
Args:
ppt_path: PPT文件路径
Returns:
删除结果字典
"""
ppt_path = str(Path(ppt_path).resolve())
if self._index is None:
return {
"status": "error",
"message": "索引不存在",
"source": ppt_path
}
try:
# 通过文件路径计算doc_id,然后查找文档
doc_id = self._get_doc_id_from_path(ppt_path)
doc_info = self._find_doc_by_id(doc_id)
if doc_info is None:
return {
"status": "error",
"message": "文档未在索引中找到",
"source": ppt_path,
"doc_id": doc_id
}
# 从向量存储中删除具有指定doc_id的节点
vector_store = self._index.vector_store
if hasattr(vector_store, '_collection'):
# 直接从ChromaDB集合中删除
collection = vector_store._collection
try:
collection.delete(where={"source_file_id": doc_id})
deleted_count = "unknown" # ChromaDB delete不返回计数
except Exception as e:
logger.error(f"从ChromaDB删除失败:{e}")
return {
"status": "error",
"message": f"从向量存储删除失败:{str(e)}",
"source": ppt_path
}
else:
logger.warning("无法从向量存储删除 - 不支持的操作")
return {
"status": "error",
"message": "向量存储不支持删除操作",
"source": ppt_path
}
# 更新文档元数据(删除对应的doc_id)
docs_info = self._load_docs_metadata()
if doc_id in docs_info:
del docs_info[doc_id]
self._save_docs_metadata(docs_info)
# 删除节点缓存
node_cache_path = self._get_node_cache_path(ppt_path)
if node_cache_path.exists():
node_cache_path.unlink()
logger.info(f"已删除节点缓存:{node_cache_path}")
# 删除Markdown缓存
markdown_cache_path = self._get_markdown_cache_path(ppt_path)
if markdown_cache_path.exists():
markdown_cache_path.unlink()
logger.info(f"已删除Markdown缓存:{markdown_cache_path}")
# 重新加载索引以确保一致性
logger.info("正在重新加载索引以确保一致性")
self._index = None
self._load_index()
# 重置查询引擎
self._query_engine = None
logger.info(f"成功从统一索引中删除PPT:{ppt_path}")
return {
"status": "success",
"message": f"PPT已成功从统一索引中移除",
"source": ppt_path,
"deleted_nodes": deleted_count
}
except Exception as e:
logger.error(f"删除PPT失败 {ppt_path}: {e}")
return {
"status": "error",
"message": f"删除PPT失败:{str(e)}",
"source": ppt_path
}
def get_document_info(self, doc_id: Optional[str] = None) -> Dict[str, Any]:
"""
获取文档信息。
Args:
doc_id: 特定文档ID
如果两为None则返回所有文档信息
Returns:
文档信息字典(以doc_id为key)
"""
docs_info = self._load_docs_metadata()
if doc_id is not None:
# 通过doc_id查询
if doc_id in docs_info:
return {doc_id: docs_info[doc_id]}
else:
return {}
# 返回所有文档信息
return docs_info
def _get_query_engine(self, file_path: Optional[str] = None, doc_id: Optional[str] = None) -> MultimodalQueryEngine:
"""获取或创建多模态查询引擎,支持文档过滤。
Args:
file_path: 通过文件路径过滤文档
doc_id: 通过文档ID过滤文档
注意:file_path和doc_id只能传入其中一个,如果都传入则优先使用doc_id
Returns:
配置好的多模态查询引擎
"""
if self._index is None:
raise RuntimeError("尚未索引任何文档。请先添加PPT文档。")
# 确定过滤的doc_id
filter_doc_id = None
if doc_id is not None and file_path is not None:
logger.warning("同时提供了doc_id和file_path,将使用doc_id")
if doc_id is not None:
# 直接使用提供的doc_id
filter_doc_id = doc_id
# 验证doc_id是否存在
if self._find_doc_by_id(filter_doc_id) is None:
raise ValueError(f"未找到文档ID:{filter_doc_id}")
elif file_path is not None:
# 通过文件路径计算doc_id
try:
ppt_path = str(Path(file_path).resolve())
filter_doc_id = self._get_doc_id_from_path(ppt_path)
# 验证doc_id是否存在
if self._find_doc_by_id(filter_doc_id) is None:
raise ValueError(f"文件路径对应的文档未找到:{file_path}")
except Exception as e:
raise ValueError(f"无效的文件路径:{file_path},错误:{str(e)}")
# 创建检索器
if filter_doc_id is not None:
# 创建带文档过滤的检索器
filters = MetadataFilters(
filters=[
MetadataFilter(
key="source_file_id",
value=filter_doc_id,
operator=FilterOperator.EQ
)
]
)
retriever = VectorIndexRetriever(
index=self._index,
similarity_top_k=self.top_k,
filters=filters
)
logger.info(f"已创建过滤检索器,文档ID:{filter_doc_id}")
else:
# 创建通用检索器(无过滤)
retriever = self._index.as_retriever(similarity_top_k=self.top_k)
logger.info("已创建通用检索器(无文档过滤)")
# 创建多模态查询引擎
query_engine = MultimodalQueryEngine(
retriever=retriever,
doubao_llm=self.doubao_llm
)
return query_engine
async def query(self, query: str, file_path: Optional[str] = None, doc_id: Optional[str] = None) -> Dict[str, Any]:
"""
查询所有文档或特定文档。
Args:
query: 用户查询
file_path: 可选的文档文件路径过滤器,如果指定则只在该文档中搜索
doc_id: 可选的文档ID过滤器,如果指定则只在该文档中搜索
注意:file_path和doc_id只能传入其中一个,如果都传入则优先使用doc_id
Returns:
查询结果字典
"""
try:
# 记录过滤参数
filter_info = []
if doc_id:
filter_info.append(f"doc_id={doc_id}")
if file_path:
filter_info.append(f"file_path={file_path}")
filter_desc = ", ".join(filter_info) if filter_info else "无过滤器"
logger.info(f"正在使用{filter_desc}进行查询")
# 获取查询引擎
query_engine = self._get_query_engine(file_path=file_path, doc_id=doc_id)
# 执行查询
response = await asyncio.to_thread(query_engine.query, query)
# 提取信息
result = {
"status": "success",
"query": query,
"file_path": file_path,
"doc_id": doc_id,
"answer": response.response,
"sources": [
{
"doc_name": node.metadata.get("doc_name"),
"doc_id": node.metadata.get("source_file_id"),
"page_num": node.metadata.get("page_num"),
"image_path": node.metadata.get("image_path"),
"source": node.metadata.get("source")
}
for node in response.source_nodes
],
"metadata": response.metadata or {}
}
logger.info(f"使用{len(response.source_nodes)}个来源生成了答案")
return result
except Exception as e:
logger.error(f"查询失败:{e}")
return {
"status": "error",
"query": query,
"file_path": file_path,
"doc_id": doc_id,
"message": f"查询失败:{str(e)}"
}
def get_index_status(self) -> Dict[str, Any]:
"""获取索引状态。"""
try:
docs_info = self.get_document_info()
total_pages = sum(info.get("pages", 0) for info in docs_info.values())
indexed_docs = [info["file_path"] for info in docs_info.values()]
if self._index is None:
status = "empty"
elif not docs_info:
status = "empty"
else:
status = "ready"
return {
"status": status,
"total_documents": len(docs_info),
"total_pages": total_pages,
"documents": indexed_docs,
"document_details": docs_info,
"collection_name": self.collection_name,
"index_path": str(self.index_storage_dir)
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
def print_vectorstore_info(self) -> None:
"""
打印向量存储中集合的详细信息。
包括节点数量、文档数量、文档ID和路径等信息。
"""
try:
print("\n🔍 向量存储集合信息")
print("=" * 60)
# 确保向量存储已初始化
if self._vector_store is None:
vector_store = self._initialize_vector_store()
else:
vector_store = self._vector_store
# 获取ChromaDB集合
if hasattr(vector_store, '_collection'):
collection = vector_store._collection
# 获取集合基本信息
collection_name = collection.name
print(f"📚 集合名称: {collection_name}")
# 获取所有文档
try:
result = collection.get(include=['metadatas', 'documents'])
if not result['ids']:
print("📭 集合为空,没有任何节点")
return
total_nodes = len(result['ids'])
print(f"📄 总节点数: {total_nodes}")
# 分析文档信息
doc_stats = {}
for i, metadata in enumerate(result['metadatas']):
if metadata:
doc_id = metadata.get('source_file_id', 'unknown')
doc_name = metadata.get('doc_name', 'unknown')
page_num = metadata.get('page_num', 0)
print(f"🔍 处理文档 {i + 1}/{len(result['metadatas'])}: {doc_name} (ID: {doc_id}, 页面: {page_num})")
if doc_id not in doc_stats:
doc_stats[doc_id] = {
'doc_name': doc_name,
'pages': set(),
'node_count': 0
}
doc_stats[doc_id]['pages'].add(page_num)
doc_stats[doc_id]['node_count'] += 1
print(f"📊 包含文档数: {len(doc_stats)}")
print("\n📋 文档详情:")
print("-" * 60)
# 按文档显示详细信息
for i, (doc_id, stats) in enumerate(doc_stats.items(), 1):
print(f"\n{i}. 📄 {stats['doc_name']}")
print(f" 🆔 文档ID: {doc_id}")
print(f" 📍 节点数: {stats['node_count']}")
print(f" 📃 页面数: {len(stats['pages'])}")
if stats['pages']:
page_range = f"{min(stats['pages'])}-{max(stats['pages'])}"
print(f" 📖 页面范围: {page_range}")
print("\n" + "=" * 60)
# 验证与元数据的一致性
metadata_docs = self.get_document_info()
if len(metadata_docs) != len(doc_stats):
print("⚠️ 警告: 向量存储中的文档数与元数据不一致!")
print(f" 向量存储: {len(doc_stats)} 个文档")
print(f" 元数据: {len(metadata_docs)} 个文档")
else:
print("✅ 向量存储与元数据一致")
except Exception as e:
print(f"❌ 无法获取集合数据: {e}")
else:
print("❌ 向量存储不支持此操作")
except Exception as e:
print(f"❌ 获取向量存储信息失败: {e}")
def get_vectorstore_stats(self) -> Dict[str, Any]:
"""
获取向量存储统计信息(返回字典格式)。
Returns:
包含向量存储统计信息的字典
"""
try:
# 确保向量存储已初始化
if self._vector_store is None:
vector_store = self._initialize_vector_store()
else:
vector_store = self._vector_store
# 获取ChromaDB集合
if hasattr(vector_store, '_collection'):
collection = vector_store._collection
# 获取所有文档
result = collection.get(include=['metadatas', 'documents'])
if not result['ids']:
return {
"status": "empty",
"collection_name": collection.name,
"total_nodes": 0,
"total_documents": 0,
"documents": []
}
total_nodes = len(result['ids'])
# 分析文档信息
doc_stats = {}
for metadata in result['metadatas']:
if metadata:
doc_id = metadata.get('doc_id', 'unknown')
doc_name = metadata.get('doc_name', 'unknown')
source = metadata.get('source', 'unknown')
page_num = metadata.get('page_num', 0)
if doc_id not in doc_stats:
doc_stats[doc_id] = {
'doc_id': doc_id,
'doc_name': doc_name,
'source': source,
'pages': set(),
'node_count': 0
}
doc_stats[doc_id]['pages'].add(page_num)
doc_stats[doc_id]['node_count'] += 1
# 转换页面集合为列表并排序
documents = []
for stats in doc_stats.values():
pages_list = sorted(list(stats['pages']))
documents.append({
'doc_id': stats['doc_id'],
'doc_name': stats['doc_name'],
'source': stats['source'],
'node_count': stats['node_count'],
'page_count': len(pages_list),
'page_range': f"{min(pages_list)}-{max(pages_list)}" if pages_list else "0-0"
})
return {
"status": "success",
"collection_name": collection.name,
"total_nodes": total_nodes,
"total_documents": len(doc_stats),
"documents": documents
}
else:
return {
"status": "error",
"message": "向量存储不支持此操作"
}
except Exception as e:
return {
"status": "error",
"message": f"获取向量存储统计信息失败:{str(e)}"
}
def clear_all_documents(self) -> Dict[str, Any]:
"""
清除所有文档和索引,通过循环调用 remove_ppt_document 来实现。
Returns:
清理结果
"""
try:
logger.info("开始逐个删除所有文档以清空索引")
docs_info = self.get_document_info()
if not docs_info:
logger.info("未找到需要清除的文档")
return {
"status": "success",
"message": "未找到需要清除的文档"
}
# 创建要删除的文档路径列表的副本,以避免在迭代时修改字典
doc_paths_to_remove = [doc['file_path'] for doc in docs_info.values()]
total_docs = len(doc_paths_to_remove)
logger.info(f"找到{total_docs}个文档需要删除")
all_successful = True
errors = []
for i, ppt_path in enumerate(doc_paths_to_remove):
print(f"正在删除文档 ({i + 1}/{total_docs}): {Path(ppt_path).name}")
logger.info(f"正在删除文档:{ppt_path}")
result = self.remove_ppt_document(ppt_path)
if result['status'] == 'error':
all_successful = False
error_message = f"删除失败 {ppt_path}: {result['message']}"
logger.warning(error_message)
errors.append(error_message)
# 最终状态检查
final_docs_info = self.get_document_info()
if not final_docs_info and all_successful:
message = "所有文档清除成功"
logger.info(message)
return {"status": "success", "message": message}
else:
message = f"清除所有文档完成。成功:{all_successful}。剩余文档:{len(final_docs_info)}。错误:{errors}"
logger.warning(message)
return {"status": "error" if not all_successful else "success", "message": message}
except Exception as e:
logger.error(f"清除所有文档失败:{e}", exc_info=True)
return {
"status": "error",
"message": f"清空所有文档时发生意外错误:{str(e)}"
}