"""
多文档PPT RAG引擎的MCP服务器
"""
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, AsyncIterator
import logging
from contextlib import asynccontextmanager
import argparse
# 添加src目录到Python路径
current_dir = Path(__file__).parent
src_dir = current_dir / "src"
sys.path.insert(0, str(src_dir))
# MCP FastMCP导入
from mcp.server.fastmcp.server import FastMCP, Context
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
# 本地导入
from multi_doc_rag_engine import MultiDocRAGEngine
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 定义应用程序上下文类
class AppContext:
def __init__(self):
self.rag_engine: Optional[MultiDocRAGEngine] = None
# 创建生命周期管理器
@asynccontextmanager
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
"""管理应用程序生命周期,包括多文档RAG引擎初始化
Args:
server: FastMCP服务器实例
Returns:
应用程序上下文对象
"""
# 启动初始化
logger.info("应用程序正在启动,初始化资源...")
app_ctx = AppContext()
try:
# 从环境变量获取配置
cache_dir = os.getenv("CACHE_DIRECTORY", "./cache")
chroma_dir = os.getenv("CHROMA_PERSIST_DIRECTORY", "./chroma_db")
doubao_model = os.getenv("DOUBAO_MODEL", "ep-20250205153642-hzqpj")
# 初始化多文档RAG引擎
app_ctx.rag_engine = MultiDocRAGEngine(
persist_dir=chroma_dir,
cache_dir=cache_dir,
doubao_model=doubao_model
)
logger.info("多文档RAG引擎初始化成功")
# 将上下文传递给应用程序
yield app_ctx
finally:
# 关闭时清理资源
logger.info("应用程序正在关闭,清理资源...")
# 创建MCP服务器实例
app = FastMCP("multi-doc-ppt-rag", port=5053, lifespan=app_lifespan)
@app.tool()
async def add_ppt(
ctx: Context,
file_path: str,
force_reprocess: bool = False
) -> str:
"""将指定的PPT文档添加到RAG索引中
Args:
ctx: 上下文对象
file_path: 要添加的PPT文件的绝对或相对路径
force_reprocess: 是否强制重新处理,即使文档已存在于索引中
Returns:
操作结果的JSON字符串
"""
try:
if not file_path:
return json.dumps({"status": "error", "message": "需要提供file_path参数"}, ensure_ascii=False)
# 将相对路径转换为绝对路径
abs_path = str(Path(file_path).resolve())
if not os.path.exists(abs_path):
return json.dumps({"status": "error", "message": f"文件未找到: {abs_path}"}, ensure_ascii=False)
# 从上下文获取RAG引擎
rag_engine = ctx.request_context.lifespan_context.rag_engine
logger.info(f"开始添加PPT文档: {abs_path}")
await ctx.info(f"开始添加PPT文档: {abs_path}")
result = await rag_engine.add_ppt_document(abs_path, force_reprocess=force_reprocess)
logger.info(f"PPT文档添加完成: {result['status']}")
await ctx.info(f"PPT文档添加完成: {result['status']}")
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"添加PPT文档时出错: {str(e)}"
logger.error(error_msg)
return json.dumps({"status": "error", "message": error_msg}, ensure_ascii=False)
@app.tool()
def delete_ppt(
ctx: Context,
file_path: str
) -> str:
"""从RAG索引中删除指定的PPT文档
Args:
ctx: 上下文对象
file_path: 要删除的PPT文件的绝对或相对路径
Returns:
操作结果的JSON字符串
"""
try:
if not file_path:
return json.dumps({"status": "error", "message": "需要提供file_path参数"}, ensure_ascii=False)
# 将相对路径转换为绝对路径
abs_path = str(Path(file_path).resolve())
# 从上下文获取RAG引擎
rag_engine = ctx.request_context.lifespan_context.rag_engine
logger.info(f"开始删除PPT文档: {abs_path}")
if not rag_engine.is_document_indexed(abs_path):
return json.dumps({"status": "error", "message": f"文档不在索引中: {abs_path}"}, ensure_ascii=False)
result = rag_engine.remove_ppt_document(abs_path)
logger.info(f"PPT文档删除完成: {result['status']}")
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"删除PPT文档时出错: {str(e)}"
logger.error(error_msg)
return json.dumps({"status": "error", "message": error_msg}, ensure_ascii=False)
@app.tool()
async def chat_with_ppt(
ctx: Context,
query: str,
file_path: Optional[str] = None,
doc_id: Optional[str] = None
) -> str:
"""与一个或所有PPT文档进行对话
Args:
ctx: 上下文对象
query: 用户提出的问题
file_path: (可选) 指定要查询的PPT文件路径。如果提供,则只在该文档中搜索
doc_id: (可选) 指定要查询的文档ID。如果提供,则只在该文档中搜索。如果同时提供了file_path和doc_id,则优先使用doc_id
Returns:
查询结果的JSON字符串
"""
try:
if not query:
return json.dumps({"status": "error", "message": "需要提供query参数"}, ensure_ascii=False)
# 从上下文获取RAG引擎
rag_engine = ctx.request_context.lifespan_context.rag_engine
logger.info(f"开始查询: {query}")
await ctx.info(f"开始查询: {query}")
result = await rag_engine.query(query, file_path=file_path, doc_id=doc_id)
logger.info(f"查询完成: {result['status']}")
await ctx.info(f"查询完成: {result['status']}")
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"查询时出错: {str(e)}"
logger.error(error_msg)
return json.dumps({"status": "error", "message": error_msg}, ensure_ascii=False)
@app.tool()
def index_status(
ctx: Context
) -> str:
"""获取当前RAG索引的状态和统计信息
Args:
ctx: 上下文对象
Returns:
索引状态的JSON字符串
"""
try:
# 从上下文获取RAG引擎
rag_engine = ctx.request_context.lifespan_context.rag_engine
logger.info("获取索引状态")
result = rag_engine.get_index_status()
logger.info(f"索引状态: {result['status']}")
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"获取索引状态时出错: {str(e)}"
logger.error(error_msg)
return json.dumps({"status": "error", "message": error_msg}, ensure_ascii=False)
def get_available_tools():
"""获取当前服务器提供的所有工具"""
tools = app._tool_manager.list_tools()
return tools
def main():
"""MCP服务器的主入口点"""
# 命令行参数
parser = argparse.ArgumentParser(description="启动多文档PPT RAG MCP服务器")
parser.add_argument("--transport", choices=["sse", "stdio"], default="stdio",
help="选择传输方式: sse 或 stdio")
parser.add_argument("--port", type=int, default=5053, help="指定服务器端口号(仅用于SSE)")
args = parser.parse_args()
# 更新端口设置
if args.transport == "sse":
current_settings = app.settings
current_settings.port = args.port
app.settings = current_settings
# 列出可用工具
logger.info("\n")
logger.info("==== 多文档PPT RAG Server可用工具 ====")
tools = app._tool_manager.list_tools()
if tools:
for i, tool in enumerate(tools, 1):
description = tool.description.strip().split("\n")[0] if tool.description else "无描述"
logger.info(f"{i}. {tool.name}: {description}")
else:
logger.info("未发现工具函数")
logger.info("=================================\n")
# 启动服务器
logger.info(f"启动多文档PPT RAG服务器,传输方式: {args.transport}, 端口: {args.port if args.transport == 'sse' else 'N/A'}")
app.run(transport=args.transport)
if __name__ == "__main__":
main()