import json
import asyncio
import os
import webbrowser
from typing import Optional
from contextlib import AsyncExitStack
import logging
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
load_dotenv()
logging.basicConfig(level=logging.INFO)
class EnhancedKGClient:
"""
增强版知识图谱客户端
- 专注于“构建 + 自动内容增强”的核心功能
"""
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
async def connect_to_server(self):
"""连接到增强版知识图谱服务器"""
server_params = StdioServerParameters(
command='python',
args=['kg_server_enhanced.py'],
env=os.environ
)
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params))
stdio, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
ClientSession(stdio, write))
await self.session.initialize()
# 验证工具是否存在
tools_response = await self.session.list_tools()
tool_names = [t.name for t in tools_response.tools]
if "build_and_analyze_kg" not in tool_names:
raise RuntimeError("错误:服务器未提供 'build_and_analyze_kg' 工具。")
async def build_and_enhance_kg(self, text: str) -> dict:
"""
调用服务器的 build_and_analyze_kg 工具来构建并自动增强知识图谱。
"""
try:
result = await self.session.call_tool("build_and_analyze_kg", {
"text": text,
"auto_enhance": True # 始终开启自动增强
})
result_text = result.content[0].text
return json.loads(result_text)
except Exception as e:
logging.error(f"调用 build_and_analyze_kg 工具时出错: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def process_file_to_cypher(self, file_path: str) -> dict:
"""
调用服务器的 process_text_file_to_cypher 工具来批量处理文件。
"""
try:
result = await self.session.call_tool("process_text_file_to_cypher", {
"input_file": file_path
})
result_text = result.content[0].text
return json.loads(result_text)
except Exception as e:
logging.error(f"调用 process_text_file_to_cypher 工具时出错: {e}", exc_info=True)
return {"success": False, "error": str(e)}
def display_result(self, result: dict):
"""显示增强后的知识图谱构建结果"""
if not result.get("success"):
print(f"\n❌ 处理失败: {result.get('error', '未知错误')}")
if 'error_details' in result:
print("\n--- 错误详情 ---")
print(result['error_details'])
print("-----------------")
return
print("\n✅ 知识图谱构建与增强成功!")
print(f"⏱️ 处理时间: {result.get('processing_time', 0):.3f} 秒")
summary = result.get("summary", {})
enhancement_summary = result.get("stages", {}).get("enhancement_results", {}).get("enhancement_summary", {})
print("\n--- 增强摘要 ---")
print(f" 原始实体数: {enhancement_summary.get('original_entity_count', 'N/A')}")
print(f" 增强后实体数: {enhancement_summary.get('enhanced_entity_count', 'N/A')}")
print(f" 原始关系数: {enhancement_summary.get('original_relation_count', 'N/A')}")
print(f" 增强后关系数: {enhancement_summary.get('enhanced_relation_count', 'N/A')}")
viz = result.get("stages", {}).get("visualization", {})
file_path = viz.get("file_path")
if file_path and os.path.exists(file_path):
print(f"\n🎨 可视化文件已生成: {file_path}")
try:
webbrowser.open(f"file:///{os.path.abspath(file_path)}")
print(" 已在默认浏览器中打开。")
except Exception as e:
print(f" 无法自动打开浏览器: {e}")
else:
print("\n🎨 未生成可视化文件。")
def display_batch_result(self, result: dict):
"""显示批量处理的结果"""
if not result.get("success"):
print(f"\n❌ 批量处理失败: {result.get('error', '未知错误')}")
if 'error_details' in result:
print("\n--- 错误详情 ---")
print(result['error_details'])
print("-----------------")
return
print("\n✅ 批量处理成功!")
print(f"⏱️ 处理时间: {result.get('processing_time', 'N/A'):.3f} 秒")
print("\n--- 处理摘要 ---")
print(f" 总行数: {result.get('total_lines', 'N/A')}")
print(f" 成功处理行数: {result.get('processed_lines', 'N/A')}")
print(f" 失败行数: {result.get('failed_lines', 'N/A')}")
print(f" 生成三元组总数: {result.get('total_triples_generated', 'N/A')}")
cypher_file = result.get('cypher_script_file')
if cypher_file:
print(f"\n🚀 Cypher 脚本已生成: {cypher_file}")
print(" 您可以将此文件内容复制到 Neo4j Browser 中运行以导入图谱。")
else:
print("\n❌ 未能生成 Cypher 脚本文件。")
async def interactive_mode(self):
"""交互式模式"""
print("\n🎯 增强版知识图谱客户端")
print(" - 输入任意文本,构建并增强知识图谱。")
print(" - 输入 .txt 文件路径 (例如: data/processed_dataset/news_sports.txt),批量处理并生成Cypher脚本。")
print(" - 输入 'quit' 退出。")
print("=" * 50)
while True:
try:
user_input = input("\n📝 请输入文本或文件路径: ").strip()
if user_input.lower() == 'quit':
break
if not user_input:
continue
# 判断是文件路径还是普通文本
if user_input.lower().endswith('.txt') and os.path.exists(user_input):
print(f"\n🔄 检测到文件路径,开始批量处理 '{user_input}'...")
result = await self.process_file_to_cypher(user_input)
self.display_batch_result(result)
else:
print("\n🔄 正在构建并增强知识图谱,请稍候...")
result = await self.build_and_enhance_kg(user_input)
self.display_result(result)
except (KeyboardInterrupt, EOFError):
print("\n\n👋 再见!")
break
except Exception as e:
logging.error(f"交互模式中发生未知错误: {e}", exc_info=True)
async def cleanup(self):
if self.exit_stack:
await self.exit_stack.aclose()
async def main():
"""主函数"""
client = EnhancedKGClient()
try:
print("🔗 正在连接到增强版知识图谱服务器...")
await client.connect_to_server()
print("✅ 连接成功!")
await client.interactive_mode()
except Exception as e:
logging.error(f"启动或连接服务器时发生致命错误: {e}", exc_info=True)
finally:
print("\n shutting down...")
await client.cleanup()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n程序已中断。")