rag_interactive_test.py•17.2 kB
#!/usr/bin/env python3
"""
多文档RAG引擎交互式测试工具
支持添加文档、查询、删除文档等操作
"""
import asyncio
import logging
import os
import sys
from pathlib import Path
# 添加src路径到系统路径
current_dir = Path(__file__).parent
src_dir = current_dir / "src"
sys.path.insert(0, str(src_dir))
try:
from multi_doc_rag_engine import MultiDocRAGEngine
except ImportError as e:
print(f"导入错误: {e}")
print("请确保multi_doc_rag_engine.py文件存在于src目录中")
sys.exit(1)
# 配置日志
logging.basicConfig(level=logging.INFO) # 减少日志输出
logger = logging.getLogger(__name__)
class InteractiveRAGTester:
"""交互式RAG测试器"""
def __init__(self):
self.engine = None
self.initialized = False
def check_environment(self):
"""检查环境变量"""
if not os.getenv("OPENAI_API_KEY"):
print("❌ 请设置 OPENAI_API_KEY 环境变量")
return False
if not os.getenv("ARK_API_KEY"):
print("❌ 请设置 ARK_API_KEY 环境变量")
return False
return True
def initialize_engine(self):
"""初始化RAG引擎"""
if self.initialized:
return True
try:
print("🚀 正在初始化多文档RAG引擎...")
self.engine = MultiDocRAGEngine(
persist_dir="./interactive_test_db",
cache_dir="./interactive_test_cache",
collection_name="interactive_test_docs",
top_k=3
)
self.initialized = True
print("✅ RAG引擎初始化成功!")
return True
except Exception as e:
print(f"❌ 初始化失败: {e}")
return False
def show_menu(self):
"""显示菜单"""
print("\n" + "="*60)
print("🎯 多文档RAG引擎交互式测试")
print("="*60)
print("1. 📄 添加PPT文档")
print("2. 🔍 查询文档 (全文档)")
print("3. 🎯 查询特定文档")
print("4. 📊 查看索引状态")
print("5. 📋 列出所有文档")
print("6. 🗑️ 删除文档")
print("7. 🧹 清空所有文档")
print("8. 📝 查看Markdown缓存")
print("9. 🔄 重新初始化引擎")
print("0. 🚪 退出")
print("-"*60)
async def add_document(self):
"""添加文档"""
print("\n📄 添加PPT文档")
print("-"*30)
# 显示可用的PPT文件
data_dir = Path("./data")
if data_dir.exists():
ppt_files = list(data_dir.glob("*.ppt*"))
if ppt_files:
print("📁 在data目录中找到以下PPT文件:")
for i, ppt_file in enumerate(ppt_files, 1):
print(f" {i}. {ppt_file.name}")
print(f" {len(ppt_files) + 1}. 手动输入路径")
try:
choice = input(f"\n请选择文件 (1-{len(ppt_files) + 1}): ").strip()
choice_num = int(choice)
if 1 <= choice_num <= len(ppt_files):
ppt_path = str(ppt_files[choice_num - 1])
elif choice_num == len(ppt_files) + 1:
ppt_path = input("请输入PPT文件路径: ").strip()
else:
print("❌ 无效选择")
return
except ValueError:
print("❌ 请输入有效的数字")
return
else:
ppt_path = input("请输入PPT文件路径: ").strip()
else:
ppt_path = input("请输入PPT文件路径: ").strip()
if not ppt_path:
print("❌ 路径不能为空")
return
if not Path(ppt_path).exists():
print(f"❌ 文件不存在: {ppt_path}")
return
# 询问是否强制重新处理
force_reprocess = input("是否强制重新处理? (y/N): ").lower().strip() == 'y'
print(f"\n🔄 正在添加文档: {Path(ppt_path).name}")
try:
result = await self.engine.add_ppt_document(ppt_path, force_reprocess=force_reprocess)
if result["status"] == "success":
print(f"✅ 文档添加成功!")
print(f" 页数: {result['pages']}")
elif result["status"] == "skipped":
print(f"⏭️ 文档已存在,跳过添加")
print(f" 文档ID: {result.get('doc_id', 'N/A')}")
else:
print(f"❌ 添加失败: {result['message']}")
except Exception as e:
print(f"❌ 添加过程中发生错误: {e}")
async def query_all_documents(self):
"""查询所有文档"""
print("\n🔍 查询所有文档")
print("-"*30)
query = input("请输入查询问题: ").strip()
if not query:
print("❌ 查询不能为空")
return
print(f"\n🔄 正在查询: {query}")
try:
result = await self.engine.query(query)
if result["status"] == "success":
print(f"\n✅ 查询成功!")
print(f"📝 回答:")
print("-" * 40)
print(result["answer"])
print("-" * 40)
if result["sources"]:
print(f"\n📚 来源信息 ({len(result['sources'])} 个):")
for i, source in enumerate(result["sources"], 1):
print(f" {i}. {source['doc_name']} - 第{source['page_num']}页")
else:
print("🤷 没有找到相关来源")
else:
print(f"❌ 查询失败: {result['message']}")
except Exception as e:
print(f"❌ 查询过程中发生错误: {e}")
async def query_specific_document(self):
"""查询特定文档"""
print("\n🎯 查询特定文档")
print("-"*30)
# 获取所有文档
docs_info = self.engine.get_document_info()
if not docs_info:
print("❌ 没有已索引的文档")
return
print("📚 已索引的文档:")
doc_list = list(docs_info.items())
for i, (doc_id, info) in enumerate(doc_list, 1):
print(f" {i}. {info['doc_name']} ({info['pages']}页)")
try:
choice = input(f"\n请选择文档 (1-{len(doc_list)}): ").strip()
choice_num = int(choice)
if not (1 <= choice_num <= len(doc_list)):
print("❌ 无效选择")
return
selected_doc_id, selected_doc_info = doc_list[choice_num - 1]
except ValueError:
print("❌ 请输入有效的数字")
return
query = input("请输入查询问题: ").strip()
if not query:
print("❌ 查询不能为空")
return
print(f"\n🔄 正在查询文档 '{selected_doc_info['doc_name']}': {query}")
try:
result = await self.engine.query(query, doc_id=selected_doc_id)
if result["status"] == "success":
print(f"\n✅ 查询成功!")
print(f"📝 回答:")
print("-" * 40)
print(result["answer"])
print("-" * 40)
if result["sources"]:
print(f"\n📚 来源信息 ({len(result['sources'])} 个):")
for i, source in enumerate(result["sources"], 1):
print(f" {i}. 第{source['page_num']}页")
else:
print("🤷 在该文档中没有找到相关内容")
else:
print(f"❌ 查询失败: {result['message']}")
except Exception as e:
print(f"❌ 查询过程中发生错误: {e}")
def show_index_status(self):
"""显示索引状态"""
print("\n📊 索引状态")
print("-"*30)
try:
status = self.engine.get_index_status()
print(f"状态: {status['status']}")
print(f"文档数量: {status['total_documents']}")
print(f"总页数: {status['total_pages']}")
print(f"集合名称: {status['collection_name']}")
print(f"索引路径: {status['index_path']}")
if status.get('documents'):
print(f"\n📄 已索引文档:")
for i, doc_path in enumerate(status['documents'], 1):
print(f" {i}. {Path(doc_path).name}")
except Exception as e:
print(f"❌ 获取状态失败: {e}")
def list_documents(self):
"""列出所有文档"""
print("\n📋 所有文档列表")
print("-"*30)
try:
docs_info = self.engine.get_document_info()
if not docs_info:
print("📭 没有已索引的文档")
return
for i, (doc_id, info) in enumerate(docs_info.items(), 1):
print(f"\n{i}. 📄 {info['doc_name']}")
print(f" ID: {doc_id[:8]}...")
print(f" 页数: {info['pages']}")
print(f" 大小: {info['file_size']:,} 字节")
print(f" 路径: {info['file_path']}")
except Exception as e:
print(f"❌ 获取文档列表失败: {e}")
def remove_document(self):
"""删除文档"""
print("\n🗑️ 删除文档")
print("-"*30)
# 获取所有文档
docs_info = self.engine.get_document_info()
if not docs_info:
print("❌ 没有已索引的文档")
return
print("📚 已索引的文档:")
doc_list = list(docs_info.items())
for i, (doc_id, info) in enumerate(doc_list, 1):
print(f" {i}. {info['doc_name']} ({info['pages']}页)")
try:
choice = input(f"\n请选择要删除的文档 (1-{len(doc_list)}): ").strip()
choice_num = int(choice)
if not (1 <= choice_num <= len(doc_list)):
print("❌ 无效选择")
return
selected_doc_id, selected_doc_info = doc_list[choice_num - 1]
except ValueError:
print("❌ 请输入有效的数字")
return
# 确认删除
confirm = input(f"确认删除 '{selected_doc_info['doc_name']}'? (y/N): ").lower().strip()
if confirm != 'y':
print("🚫 删除取消")
return
print(f"\n🔄 正在删除文档: {selected_doc_info['doc_name']}")
try:
result = self.engine.remove_ppt_document(selected_doc_info['file_path'])
if result["status"] == "success":
print(f"✅ 文档删除成功!")
else:
print(f"❌ 删除失败: {result['message']}")
except Exception as e:
print(f"❌ 删除过程中发生错误: {e}")
def clear_all_documents(self):
"""清空所有文档"""
print("\n🧹 清空所有文档")
print("-"*30)
# 确认操作
confirm = input("⚠️ 确认清空所有文档和缓存? 此操作不可逆! (yes/N): ").strip()
if confirm.lower() != 'yes':
print("🚫 操作取消")
return
print(f"\n🔄 正在清空所有文档...")
try:
result = self.engine.clear_all_documents()
if result["status"] == "success":
print(f"✅ 所有文档已清空!")
print("💡 您可以重新添加文档")
else:
print(f"❌ 清空失败: {result['message']}")
except Exception as e:
print(f"❌ 清空过程中发生错误: {e}")
def view_markdown_cache(self):
"""查看Markdown缓存"""
print("\n📝 查看Markdown缓存")
print("-"*30)
try:
markdown_dir = Path("./interactive_test_cache/parsed_markdown")
if not markdown_dir.exists():
print("📭 没有找到Markdown缓存目录")
return
markdown_files = list(markdown_dir.glob("*.md"))
if not markdown_files:
print("📭 没有找到Markdown缓存文件")
return
print("📁 Markdown缓存文件:")
for i, md_file in enumerate(markdown_files, 1):
file_size = md_file.stat().st_size
print(f" {i}. {md_file.name} ({file_size:,} 字节)")
try:
choice = input(f"\n选择要查看的文件 (1-{len(markdown_files)}, 回车跳过): ").strip()
if not choice:
return
choice_num = int(choice)
if not (1 <= choice_num <= len(markdown_files)):
print("❌ 无效选择")
return
selected_file = markdown_files[choice_num - 1]
# 读取并显示文件内容预览
with open(selected_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
print(f"\n📖 文件预览: {selected_file.name}")
print("=" * 50)
# 显示前20行
for i, line in enumerate(lines[:20], 1):
print(f"{i:2d}: {line.rstrip()}")
if len(lines) > 20:
print(f"... ({len(lines) - 20} 行省略)")
print("=" * 50)
print(f"完整路径: {selected_file.absolute()}")
except ValueError:
print("❌ 请输入有效的数字")
except Exception as e:
print(f"❌ 查看缓存失败: {e}")
def reinitialize_engine(self):
"""重新初始化引擎"""
print("\n🔄 重新初始化引擎")
print("-"*30)
self.engine = None
self.initialized = False
if self.initialize_engine():
print("✅ 引擎重新初始化成功!")
else:
print("❌ 引擎重新初始化失败!")
async def run(self):
"""运行交互式测试"""
print("🎉 欢迎使用多文档RAG引擎交互式测试工具!")
print("本工具可以帮助您测试文档添加、查询、删除等功能")
# 检查环境
if not self.check_environment():
return
# 初始化引擎
if not self.initialize_engine():
return
while True:
try:
self.show_menu()
choice = input("请选择操作 (0-9): ").strip()
if choice == "0":
print("\n👋 感谢使用,再见!")
break
elif choice == "1":
await self.add_document()
elif choice == "2":
await self.query_all_documents()
elif choice == "3":
await self.query_specific_document()
elif choice == "4":
self.show_index_status()
elif choice == "5":
self.list_documents()
elif choice == "6":
self.remove_document()
elif choice == "7":
self.clear_all_documents()
elif choice == "8":
self.view_markdown_cache()
elif choice == "9":
self.reinitialize_engine()
else:
print("❌ 无效选择,请输入0-9之间的数字")
input("\n按回车键继续...")
except KeyboardInterrupt:
print("\n\n👋 用户中断,退出程序")
break
except Exception as e:
print(f"\n❌ 发生错误: {e}")
input("按回车键继续...")
async def main():
"""主函数"""
tester = InteractiveRAGTester()
await tester.run()
if __name__ == "__main__":
asyncio.run(main())