"""Graphiti 客户端 - 用于与知识图谱交互."""
import os
from typing import Optional, Dict, Any, List
from datetime import datetime
from neo4j import GraphDatabase
from neo4j.graph import Node, Relationship
try:
from neo4j.time import DateTime as Neo4jDateTime
except ImportError:
Neo4jDateTime = None
from .config_manager import ConfigManager
from .logger import default_logger as logger
from .cache_manager import get_cache_manager
# 辅助函数:将 Neo4j 对象转换为可序列化的 Python 对象
def _neo4j_to_dict(obj: Any) -> Any:
"""将 Neo4j 对象(Node、Relationship、DateTime 等)转换为可序列化的 Python 对象."""
if isinstance(obj, Node):
# 处理 Node 对象
result = {}
for key, value in dict(obj).items():
result[key] = _neo4j_to_dict(value)
return result
elif isinstance(obj, Relationship):
# 处理 Relationship 对象
result = {
"type": obj.type,
"start_node": _neo4j_to_dict(obj.start_node),
"end_node": _neo4j_to_dict(obj.end_node)
}
# 添加关系属性
for key, value in dict(obj).items():
result[key] = _neo4j_to_dict(value)
return result
elif isinstance(obj, datetime):
# 处理 Python datetime 对象
return obj.isoformat()
elif Neo4jDateTime is not None and isinstance(obj, Neo4jDateTime):
# 处理 Neo4j DateTime 对象
try:
return obj.isoformat()
except (AttributeError, TypeError, ValueError):
return str(obj)
elif hasattr(obj, '__class__') and 'DateTime' in obj.__class__.__name__:
# 处理任何名称包含 DateTime 的对象(兜底方案)
try:
if hasattr(obj, 'isoformat'):
return obj.isoformat()
else:
return str(obj)
except (AttributeError, TypeError, ValueError):
return str(obj)
elif hasattr(obj, 'isoformat') and callable(getattr(obj, 'isoformat', None)):
# 处理其他有 isoformat 方法的对象
try:
return obj.isoformat()
except (AttributeError, TypeError, ValueError):
try:
return str(obj)
except:
return repr(obj)
elif isinstance(obj, (list, tuple)):
# 处理列表和元组
return [_neo4j_to_dict(item) for item in obj]
elif isinstance(obj, dict):
# 处理字典
return {key: _neo4j_to_dict(value) for key, value in obj.items()}
else:
# 其他类型直接返回
return obj
# 延迟导入 graphiti-core,避免模块导入时阻塞
# 只在真正需要时才导入
GRAPHITI_AVAILABLE = None
Graphiti = None
def _lazy_import_graphiti():
"""延迟导入 graphiti-core."""
global GRAPHITI_AVAILABLE, Graphiti
if GRAPHITI_AVAILABLE is None:
try:
from graphiti_core import Graphiti
GRAPHITI_AVAILABLE = True
except ImportError:
GRAPHITI_AVAILABLE = False
Graphiti = None
return GRAPHITI_AVAILABLE and Graphiti is not None
class GraphitiClient:
"""Graphiti 客户端,用于与 Neo4j 知识图谱交互."""
def __init__(self, config_manager: ConfigManager):
"""
初始化 Graphiti 客户端.
Args:
config_manager: 配置管理器
"""
self.config_manager = config_manager
self.driver: Optional[Any] = None
self.graphiti: Optional[Any] = None # Graphiti 实例
self._connected = False
self._graphiti_initialized = False
self._last_neo4j_config: Optional[Dict[str, Any]] = None # 记录上一次连接的配置
self._last_api_config: Optional[Dict[str, Any]] = None # 记录上一次的 API 配置
self.cache = get_cache_manager() # 缓存管理器
def connect(self) -> bool:
"""
连接到 Neo4j 数据库.
Returns:
连接是否成功
"""
# 测试模式:不连接真实数据库
if os.environ.get('GRAPHITIACE_TEST_MODE', '').lower() == 'true':
logger.info("测试模式:跳过数据库连接")
return False
neo4j_config = self.config_manager.get_neo4j_config()
api_config = self.config_manager.get_api_config()
if neo4j_config is None:
return False
# 记录当前配置指纹
current_config_dict = neo4j_config.model_dump()
if api_config:
self._last_api_config = api_config.model_dump()
try:
# 设置较短的连接超时,避免启动时阻塞
self.driver = GraphDatabase.driver(
neo4j_config.uri,
auth=(neo4j_config.username, neo4j_config.password),
connection_timeout=3.0, # 3秒连接超时
max_connection_lifetime=3600, # 1小时最大连接生命周期
connection_acquisition_timeout=3.0 # 3秒获取连接超时
)
# 测试连接(使用较短的超时)
with self.driver.session() as session:
session.run("RETURN 1")
self._connected = True
self._last_neo4j_config = current_config_dict
# 如果配置了 API key,初始化 Graphiti 实例用于实体抽取
self._initialize_graphiti()
return True
except Exception as e:
error_msg = f"连接 Neo4j 失败: {e}"
logger.error(error_msg, exc_info=True)
# 提供更详细的错误信息
if "authentication" in str(e).lower() or "password" in str(e).lower():
error_msg += "\n💡 提示:请检查用户名和密码是否正确"
elif "connection" in str(e).lower() or "refused" in str(e).lower():
error_msg += "\n💡 提示:请检查 Neo4j 是否正在运行,URI 是否正确"
self._connected = False
# 确保 driver 被清理
if self.driver:
try:
self.driver.close()
except:
pass
self.driver = None
return False
def check_reconnect(self) -> bool:
"""
检查配置是否变更,如果变更则重连.
Returns:
是否进行了重连(True=重连了,False=无需重连或重连失败)
"""
# 强制重新加载配置
self.config_manager.load_config()
neo4j_config = self.config_manager.get_neo4j_config()
api_config = self.config_manager.get_api_config()
if neo4j_config is None:
return False
current_config_dict = neo4j_config.model_dump()
current_api_config_dict = api_config.model_dump() if api_config else None
# 1. 检查 Neo4j 配置是否变更
if not self.is_connected() or current_config_dict != self._last_neo4j_config:
if self.is_connected():
logger.info("检测到 Neo4j 配置变更,正在重连...")
self.disconnect()
else:
logger.info("Neo4j 未连接,正在连接...")
return self.connect()
# 2. 检查 API 配置是否变更 (仅当 Neo4j 连接正常时)
if current_api_config_dict != self._last_api_config:
logger.info("检测到 API 配置变更,重新初始化 Graphiti Core...")
self._last_api_config = current_api_config_dict
self._graphiti_initialized = False # 强制重新初始化
self._initialize_graphiti()
return True
return False
def check_api_config_change(self) -> bool:
"""
检查 API 配置是否变更.
Returns:
是否有变更(True=有变更,False=无变更)
"""
api_config = self.config_manager.get_api_config()
current_api_config_dict = api_config.model_dump() if api_config else None
if current_api_config_dict != self._last_api_config:
self._last_api_config = current_api_config_dict
return True
return False
def _initialize_graphiti(self) -> bool:
"""
初始化 Graphiti 实例(如果配置了 API key).
Returns:
是否成功初始化
"""
if self._graphiti_initialized:
return self.graphiti is not None
# 延迟导入 graphiti-core
if not _lazy_import_graphiti():
self._graphiti_initialized = True
return False
api_config = self.config_manager.get_api_config()
neo4j_config = self.config_manager.get_neo4j_config()
# 如果没有配置 API key,不初始化 Graphiti(使用基础模式)
if not api_config or not api_config.api_key:
self._graphiti_initialized = True
return False
try:
# 初始化 Graphiti 实例
# 注意:这里需要根据实际的 API 配置来设置
import os
# 临时设置环境变量(如果使用 OpenAI 或兼容接口)
if api_config.provider == "openai":
os.environ["OPENAI_API_KEY"] = api_config.api_key
if api_config.base_url:
os.environ["OPENAI_API_BASE"] = api_config.base_url
os.environ["OPENAI_BASE_URL"] = api_config.base_url
if api_config.model:
os.environ["OPENAI_MODEL_NAME"] = api_config.model
os.environ["MODEL_NAME"] = api_config.model
elif api_config.provider == "anthropic":
os.environ["ANTHROPIC_API_KEY"] = api_config.api_key
if api_config.model:
os.environ["ANTHROPIC_MODEL"] = api_config.model
self.graphiti = Graphiti(
uri=neo4j_config.uri,
user=neo4j_config.username,
password=neo4j_config.password
)
self._graphiti_initialized = True
return True
except Exception as e:
# 使用 logger 记录错误(输出到 stderr,不会干扰 MCP 协议)
logger.error(f"初始化 Graphiti 失败: {e}")
self.graphiti = None
self._graphiti_initialized = True
return False
def is_connected(self) -> bool:
"""检查是否已连接."""
return self._connected and self.driver is not None
def disconnect(self):
"""断开连接."""
if self.graphiti:
try:
import asyncio
# graphiti.close() 是异步方法,需要正确处理
close_coro = self.graphiti.close()
try:
# 尝试获取当前运行的事件循环
loop = asyncio.get_running_loop()
# 如果在异步上下文中,创建任务但不等待(避免阻塞)
loop.create_task(close_coro)
except RuntimeError:
# 没有运行的事件循环,创建新的来执行
try:
asyncio.run(close_coro)
except RuntimeError:
# 如果仍然失败(如嵌套事件循环),直接跳过
pass
except Exception:
pass
self.graphiti = None
if self.driver:
self.driver.close()
self.driver = None
self._connected = False
self._graphiti_initialized = False
async def add_episode(
self,
content: str,
metadata: Optional[Dict[str, Any]] = None,
group_id: Optional[str] = None,
saga: Optional[str] = None,
saga_previous_episode_uuid: Optional[str] = None,
custom_extraction_instructions: Optional[str] = None
) -> Dict[str, Any]:
"""
添加交互记录(episode)到知识图谱.
Args:
content: 要记录的内容
metadata: 可选的元数据
group_id: 组 ID(用于数据隔离)
saga: Saga 名称(可选),用于将多个 episode 关联到同一个事务/会话
saga_previous_episode_uuid: 前一个 episode 的 UUID(可选),用于高效连接连续 episode
custom_extraction_instructions: 自定义抽取指令(可选,Graphiti 0.25.0+)
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库"
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
import json
with self.driver.session() as session:
# 创建 episode 节点
# 将 metadata 转换为 JSON 字符串,因为 Neo4j 不支持字典作为属性值
metadata_str = json.dumps(metadata, ensure_ascii=False) if metadata else "{}"
query = """
CREATE (e:Episode {
content: $content,
group_id: $group_id,
created_at: datetime(),
metadata: $metadata
})
RETURN id(e) as episode_id
"""
result = session.run(
query,
content=content,
group_id=group_id,
metadata=metadata_str
)
episode_id = result.single()["episode_id"]
# 如果 Graphiti 已初始化,使用它进行实体抽取
if self.graphiti:
try:
# 使用 Graphiti 进行实体抽取和关系构建
# Graphiti 0.26.0+ 支持 saga 参数
graphiti_kwargs = {
"name": f"Episode_{episode_id}",
"episode_body": content,
"source_description": metadata.get("source", "user_input") if metadata else "user_input",
"reference_time": datetime.now(),
"group_id": group_id
}
# 添加 saga 参数(如果提供)
if saga is not None:
graphiti_kwargs["saga"] = saga
if saga_previous_episode_uuid is not None:
graphiti_kwargs["saga_previous_episode_uuid"] = saga_previous_episode_uuid
# Graphiti 0.25.0+: 自定义抽取指令
if custom_extraction_instructions is not None:
graphiti_kwargs["custom_extraction_instructions"] = custom_extraction_instructions
result_obj = await self.graphiti.add_episode(**graphiti_kwargs)
# 获取 episode UUID 用于后续 saga 连接
episode_uuid = None
if hasattr(result_obj, 'episode') and hasattr(result_obj.episode, 'uuid'):
episode_uuid = result_obj.episode.uuid
return {
"success": True,
"message": "Episode 已添加到知识图谱,并自动抽取了实体和关系",
"episode_id": episode_id,
"episode_uuid": episode_uuid,
"entities_extracted": True,
"saga": saga
}
except Exception as e:
# 如果 Graphiti 处理失败,至少已经保存了原始 episode
logger.warning(f"Graphiti 实体抽取失败: {e},但 episode 已保存", exc_info=True)
# 清除相关缓存
if group_id:
self.cache.invalidate_pattern(f"statistics:{group_id}")
return {
"success": True,
"message": "Episode 已添加到知识图谱(实体抽取失败,已保存原始内容)",
"episode_id": episode_id,
"entities_extracted": False,
"warning": str(e)
}
else:
# 没有 Graphiti,只保存原始 episode
# 清除相关缓存
if group_id:
self.cache.invalidate_pattern(f"statistics:{group_id}")
return {
"success": True,
"message": "Episode 已添加到知识图谱(未配置 API key,未进行实体抽取)",
"episode_id": episode_id,
"entities_extracted": False,
"note": "配置 API key 后可启用自动实体抽取功能",
"suggestion": "你可以分析这段文本,识别实体和关系,然后使用 query_knowledge_graph 工具创建节点和关系,无需外部 API key。"
}
except Exception as e:
return {
"success": False,
"message": f"添加 episode 失败: {str(e)}"
}
def search_entities(
self,
query: str,
entity_type: Optional[str] = None,
limit: int = 10,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
搜索实体.
Args:
query: 搜索查询
entity_type: 实体类型过滤
limit: 结果数量限制
group_id: 组 ID
Returns:
搜索结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
if entity_type:
cypher_query = f"""
MATCH (e:{entity_type})
WHERE e.group_id = $group_id
AND (e.name CONTAINS $search_query OR e.description CONTAINS $search_query)
RETURN e
LIMIT $limit
"""
else:
cypher_query = """
MATCH (e)
WHERE e.group_id = $group_id
AND (
e.name CONTAINS $search_query OR
e.description CONTAINS $search_query OR
e.content CONTAINS $search_query
)
RETURN e
LIMIT $limit
"""
result = session.run(
cypher_query,
search_query=query,
group_id=group_id,
limit=limit
)
# 转换节点为字典,处理 Node、DateTime 等特殊类型
entities = []
for record in result:
node = record["e"]
entity_dict = _neo4j_to_dict(node)
entities.append(entity_dict)
return {
"success": True,
"message": f"找到 {len(entities)} 个实体",
"results": entities
}
except Exception as e:
return {
"success": False,
"message": f"搜索失败: {str(e)}",
"results": []
}
def search_relationships(
self,
query: str,
limit: int = 10,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
搜索关系.
Args:
query: 搜索查询
limit: 结果数量限制
group_id: 组 ID
Returns:
搜索结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
cypher_query = """
MATCH (a)-[r]->(b)
WHERE a.group_id = $group_id AND b.group_id = $group_id
AND type(r) CONTAINS $search_query
RETURN properties(a) as from_props, labels(a) as from_labels,
type(r) as rel_type,
properties(b) as to_props, labels(b) as to_labels
LIMIT $limit
"""
result = session.run(
cypher_query,
search_query=query,
group_id=group_id,
limit=limit
)
relationships = []
for record in result:
# 手动构建节点字典,将所有值转换为字符串格式
from_props = record.get("from_props", {})
from_labels = record.get("from_labels", [])
to_props = record.get("to_props", {})
to_labels = record.get("to_labels", [])
# 转换所有属性值,确保 DateTime 被转换为字符串
from_node = {}
for key, value in from_props.items():
from_node[key] = _neo4j_to_dict(value)
from_node["labels"] = list(from_labels)
to_node = {}
for key, value in to_props.items():
to_node[key] = _neo4j_to_dict(value)
to_node["labels"] = list(to_labels)
rel_dict = {
"from": from_node,
"relationship": record.get("rel_type", "UNKNOWN"),
"to": to_node
}
# 再次转换整个字典,确保所有嵌套的 DateTime 对象都被处理
relationships.append(_neo4j_to_dict(rel_dict))
# 最后再次转换整个结果列表,确保万无一失
final_results = _neo4j_to_dict(relationships)
return {
"success": True,
"message": f"找到 {len(final_results)} 个关系",
"results": final_results
}
except Exception as e:
return {
"success": False,
"message": f"搜索失败: {str(e)}",
"results": []
}
def query_knowledge_graph(
self,
cypher_query: str,
limit: int = 20
) -> Dict[str, Any]:
"""
执行 Cypher 查询.
Args:
cypher_query: Cypher 查询语句
limit: 结果数量限制(仅用于 SELECT 查询)
Returns:
查询结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
try:
with self.driver.session() as session:
# 只在 SELECT/RETURN 查询中添加 LIMIT,DELETE/UPDATE 等操作不需要
query_upper = cypher_query.upper().strip()
is_read_query = any(keyword in query_upper for keyword in ['RETURN', 'MATCH', 'WHERE'])
is_write_query = any(keyword in query_upper for keyword in ['DELETE', 'CREATE', 'SET', 'REMOVE', 'MERGE'])
# 如果是读取查询且没有 LIMIT,则添加 LIMIT
if is_read_query and not is_write_query and "LIMIT" not in query_upper:
cypher_query = f"{cypher_query} LIMIT {limit}"
result = session.run(cypher_query)
# 转换记录,处理 Node、Relationship、DateTime 等类型
records = []
for record in result:
record_dict = {}
for key, value in dict(record).items():
record_dict[key] = _neo4j_to_dict(value)
records.append(record_dict)
return {
"success": True,
"message": f"查询成功,返回 {len(records)} 条记录",
"results": records
}
except Exception as e:
return {
"success": False,
"message": f"查询失败: {str(e)}",
"results": []
}
def delete_episode(
self,
episode_id: Optional[int] = None,
content: Optional[str] = None,
delete_all: bool = False,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
删除 episode.
Args:
episode_id: Episode 的 ID
content: Episode 的内容
delete_all: 是否删除所有 episode
group_id: 组 ID
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库"
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
if delete_all:
# 删除所有 episode
query = "MATCH (e:Episode {group_id: $group_id}) DELETE e RETURN count(e) as deleted_count"
result = session.run(query, group_id=group_id)
deleted_count = result.single()["deleted_count"] if result.peek() else 0
return {
"success": True,
"message": f"已删除 {deleted_count} 个 episode"
}
elif episode_id is not None:
# 根据 ID 删除
query = "MATCH (e:Episode) WHERE id(e) = $episode_id AND e.group_id = $group_id DELETE e RETURN count(e) as deleted_count"
result = session.run(query, episode_id=episode_id, group_id=group_id)
deleted_count = result.single()["deleted_count"] if result.peek() else 0
if deleted_count > 0:
return {
"success": True,
"message": f"已删除 ID 为 {episode_id} 的 episode"
}
else:
return {
"success": False,
"message": f"未找到 ID 为 {episode_id} 的 episode"
}
elif content:
# 根据内容删除
query = "MATCH (e:Episode {content: $content, group_id: $group_id}) DELETE e RETURN count(e) as deleted_count"
result = session.run(query, content=content, group_id=group_id)
deleted_count = result.single()["deleted_count"] if result.peek() else 0
if deleted_count > 0:
return {
"success": True,
"message": f"已删除 {deleted_count} 个匹配内容的 episode"
}
else:
return {
"success": False,
"message": "未找到匹配内容的 episode"
}
else:
return {
"success": False,
"message": "请提供 episode_id、content 或设置 delete_all=true"
}
except Exception as e:
return {
"success": False,
"message": f"删除失败: {str(e)}"
}
def clear_graph(self, group_id: Optional[str] = None) -> Dict[str, Any]:
"""
清空整个知识图谱.
Args:
group_id: 组 ID(如果提供,只清空该组的数据)
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库"
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
if group_id:
# 只清空指定组的数据
query = """
MATCH (n {group_id: $group_id})
DETACH DELETE n
RETURN count(n) as deleted_count
"""
result = session.run(query, group_id=group_id)
deleted_count = result.single()["deleted_count"] if result.peek() else 0
return {
"success": True,
"message": f"已清空组 '{group_id}' 的知识图谱,删除了 {deleted_count} 个节点"
}
else:
# 清空所有数据
query = "MATCH (n) DETACH DELETE n RETURN count(n) as deleted_count"
result = session.run(query)
deleted_count = result.single()["deleted_count"] if result.peek() else 0
return {
"success": True,
"message": f"已清空整个知识图谱,删除了 {deleted_count} 个节点"
}
except Exception as e:
return {
"success": False,
"message": f"清空图谱失败: {str(e)}"
}
def query_by_time_range(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
days: Optional[int] = None,
entity_type: Optional[str] = None,
limit: int = 20,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
根据时间范围查询知识图谱.
Args:
start_time: 开始时间(ISO 格式)
end_time: 结束时间(ISO 格式)
days: 查询最近 N 天的数据
entity_type: 实体类型过滤
limit: 结果数量限制
group_id: 组 ID
Returns:
查询结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
from datetime import datetime, timedelta
# 处理时间参数
if days is not None:
# 查询最近 N 天
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=days)
start_time = start_dt.isoformat()
end_time = end_dt.isoformat()
elif start_time:
start_dt = datetime.fromisoformat(start_time.replace('Z', '+00:00'))
else:
# 如果没有提供时间,查询所有数据
start_dt = None
if end_time:
end_dt = datetime.fromisoformat(end_time.replace('Z', '+00:00'))
else:
end_dt = datetime.now()
with self.driver.session() as session:
if entity_type:
# 查询特定类型的实体
if start_dt:
query = f"""
MATCH (e:{entity_type})
WHERE e.group_id = $group_id
AND e.created_at >= $start_time
AND e.created_at <= $end_time
RETURN e
ORDER BY e.created_at DESC
LIMIT $limit
"""
result = session.run(
query,
group_id=group_id,
start_time=start_dt,
end_time=end_dt,
limit=limit
)
else:
query = f"""
MATCH (e:{entity_type})
WHERE e.group_id = $group_id
AND e.created_at <= $end_time
RETURN e
ORDER BY e.created_at DESC
LIMIT $limit
"""
result = session.run(
query,
group_id=group_id,
end_time=end_dt,
limit=limit
)
else:
# 查询所有节点(主要是 Episode)
if start_dt:
query = """
MATCH (e)
WHERE e.group_id = $group_id
AND e.created_at >= $start_time
AND e.created_at <= $end_time
RETURN e
ORDER BY e.created_at DESC
LIMIT $limit
"""
result = session.run(
query,
group_id=group_id,
start_time=start_dt,
end_time=end_dt,
limit=limit
)
else:
query = """
MATCH (e)
WHERE e.group_id = $group_id
AND e.created_at <= $end_time
RETURN e
ORDER BY e.created_at DESC
LIMIT $limit
"""
result = session.run(
query,
group_id=group_id,
end_time=end_dt,
limit=limit
)
entities = [dict(record["e"]) for record in result]
time_range_str = f"{start_time or '开始'} 到 {end_time}"
if days:
time_range_str = f"最近 {days} 天"
return {
"success": True,
"message": f"在 {time_range_str} 范围内找到 {len(entities)} 个节点",
"results": entities,
"time_range": {
"start": start_time,
"end": end_time,
"days": days
}
}
except Exception as e:
return {
"success": False,
"message": f"时间范围查询失败: {str(e)}",
"results": []
}
def export_graph_data(
self,
format: str = "json",
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
导出知识图谱数据.
Args:
format: 导出格式(json, cypher)
group_id: 组 ID
Returns:
导出结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"data": None
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
import json
with self.driver.session() as session:
if format == "json":
# 导出为 JSON 格式
nodes_query = """
MATCH (n {group_id: $group_id})
RETURN n, labels(n) as labels
"""
relationships_query = """
MATCH (a {group_id: $group_id})-[r]->(b {group_id: $group_id})
RETURN a, type(r) as relationship_type, b
"""
nodes_result = session.run(nodes_query, group_id=group_id)
rels_result = session.run(relationships_query, group_id=group_id)
nodes = []
for record in nodes_result:
node_dict = dict(record["n"])
node_dict["labels"] = record["labels"]
nodes.append(node_dict)
relationships = []
for record in rels_result:
relationships.append({
"from": dict(record["a"]),
"relationship": record["relationship_type"],
"to": dict(record["b"])
})
data = {
"nodes": nodes,
"relationships": relationships,
"export_time": datetime.now().isoformat(),
"group_id": group_id
}
return {
"success": True,
"message": f"导出成功:{len(nodes)} 个节点,{len(relationships)} 个关系",
"data": data,
"format": "json"
}
elif format == "cypher":
# 导出为 Cypher 语句
cypher_statements = []
nodes_query = """
MATCH (n {group_id: $group_id})
RETURN n, labels(n) as labels
"""
nodes_result = session.run(nodes_query, group_id=group_id)
for record in nodes_result:
node = record["n"]
labels = ":".join(record["labels"])
props = {}
for key, value in dict(node).items():
if key != "group_id":
props[key] = value
props_str = ", ".join([f"{k}: ${k}" for k in props.keys()])
cypher_statements.append(
f"CREATE (n:{labels} {{{props_str}}})"
)
return {
"success": True,
"message": f"导出成功:{len(cypher_statements)} 条 Cypher 语句",
"data": "\n".join(cypher_statements),
"format": "cypher"
}
else:
return {
"success": False,
"message": f"不支持的导出格式:{format}",
"data": None
}
except Exception as e:
return {
"success": False,
"message": f"导出失败: {str(e)}",
"data": None
}
def import_graph_data(
self,
data: Dict[str, Any],
format: str = "json",
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
导入知识图谱数据.
Args:
data: 要导入的数据
format: 数据格式(json)
group_id: 组 ID
Returns:
导入结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库"
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
if format == "json":
nodes = data.get("nodes", [])
relationships = data.get("relationships", [])
imported_nodes = 0
imported_rels = 0
# 导入节点
import json
for node in nodes:
node_dict = dict(node)
node_dict["group_id"] = group_id
labels = node_dict.pop("labels", ["Node"])
label_str = ":".join(labels)
# 处理所有字典类型的字段:序列化为 JSON 字符串
# Neo4j 不支持直接存储字典对象,需要序列化为字符串
for key, value in list(node_dict.items()):
if isinstance(value, dict):
node_dict[key] = json.dumps(value, ensure_ascii=False)
props_str = ", ".join([f"{k}: ${k}" for k in node_dict.keys()])
query = f"CREATE (n:{label_str} {{{props_str}}})"
session.run(query, **node_dict)
imported_nodes += 1
# 导入关系(简化版,实际需要更复杂的匹配逻辑)
# 注意:这里只是示例,实际导入关系需要先匹配节点
logger.info(f"导入完成:{imported_nodes} 个节点,{imported_rels} 个关系")
return {
"success": True,
"message": f"导入成功:{imported_nodes} 个节点,{imported_rels} 个关系",
"imported_nodes": imported_nodes,
"imported_relationships": imported_rels
}
else:
return {
"success": False,
"message": f"不支持的导入格式:{format}"
}
except Exception as e:
logger.error(f"导入失败: {e}", exc_info=True)
return {
"success": False,
"message": f"导入失败: {str(e)}"
}
def get_statistics(
self,
group_id: Optional[str] = None,
use_cache: bool = True
) -> Dict[str, Any]:
"""
获取知识图谱统计信息.
Args:
group_id: 组 ID
use_cache: 是否使用缓存(默认 True)
Returns:
统计信息
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"statistics": {}
}
if group_id is None:
group_id = self.config_manager.get_group_id()
# 尝试从缓存获取
cache_key = f"statistics:{group_id}"
if use_cache:
cached_result = self.cache.get(cache_key)
if cached_result is not None:
logger.debug("从缓存获取统计信息")
return cached_result
try:
with self.driver.session() as session:
# 节点统计
node_stats_query = """
MATCH (n {group_id: $group_id})
RETURN labels(n) as labels, count(n) as count
"""
node_result = session.run(node_stats_query, group_id=group_id)
node_stats = {}
total_nodes = 0
for record in node_result:
labels = record["labels"]
label = labels[0] if labels else "Unknown"
count = record["count"]
node_stats[label] = count
total_nodes += count
# 关系统计
rel_stats_query = """
MATCH ()-[r]->()
WHERE startNode(r).group_id = $group_id AND endNode(r).group_id = $group_id
RETURN type(r) as rel_type, count(r) as count
"""
rel_result = session.run(rel_stats_query, group_id=group_id)
rel_stats = {}
total_rels = 0
for record in rel_result:
rel_type = record["rel_type"]
count = record["count"]
rel_stats[rel_type] = count
total_rels += count
# Episode 统计
episode_stats_query = """
MATCH (e:Episode {group_id: $group_id})
RETURN count(e) as total_episodes,
min(e.created_at) as first_episode,
max(e.created_at) as last_episode
"""
episode_result = session.run(episode_stats_query, group_id=group_id)
episode_record = episode_result.single()
statistics = {
"nodes": {
"total": total_nodes,
"by_type": node_stats
},
"relationships": {
"total": total_rels,
"by_type": rel_stats
},
"episodes": {
"total": episode_record["total_episodes"] if episode_record else 0,
"first_episode": str(episode_record["first_episode"]) if episode_record and episode_record["first_episode"] else None,
"last_episode": str(episode_record["last_episode"]) if episode_record and episode_record["last_episode"] else None
},
"group_id": group_id,
"timestamp": datetime.now().isoformat()
}
result = {
"success": True,
"message": "统计信息获取成功",
"statistics": statistics
}
# 缓存结果(TTL: 60 秒)
if use_cache:
self.cache.set(cache_key, result, ttl=60)
return result
except Exception as e:
return {
"success": False,
"message": f"获取统计信息失败: {str(e)}",
"statistics": {}
}
async def add_episodes_bulk(
self,
episodes: List[Dict[str, Any]],
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
批量添加多个 episode.
Args:
episodes: Episode 列表,每个元素包含 content 和可选的 metadata
group_id: 组 ID
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
results = []
success_count = 0
fail_count = 0
for episode in episodes:
content = episode.get("content", "")
metadata = episode.get("metadata")
if not content:
fail_count += 1
results.append({
"success": False,
"message": "Episode 内容为空",
"content": content
})
continue
result = await self.add_episode(
content=content,
metadata=metadata,
group_id=group_id
)
if result.get("success"):
success_count += 1
else:
fail_count += 1
results.append(result)
return {
"success": True,
"message": f"批量添加完成:成功 {success_count} 个,失败 {fail_count} 个",
"total": len(episodes),
"success_count": success_count,
"fail_count": fail_count,
"results": results
}
async def semantic_search(
self,
query: str,
num_results: int = 10,
group_id: Optional[str] = None,
center_node_uuid: Optional[str] = None
) -> Dict[str, Any]:
"""
语义搜索 - 使用向量搜索和文本搜索的混合搜索.
Args:
query: 搜索查询
num_results: 返回结果数量
group_id: 组 ID
center_node_uuid: 中心节点 UUID(用于重新排序)
Returns:
搜索结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"results": []
}
# 语义搜索需要 Graphiti 实例(用于向量搜索)
if not self.graphiti:
# 如果没有 Graphiti,使用增强的关键词搜索
# 通过扩展查询词来模拟语义搜索
logger.info("Graphiti 未配置,使用增强的关键词搜索")
# 尝试扩展查询词(这里可以后续优化,使用 Cursor AI 来扩展)
# 目前先使用基础搜索,但返回更友好的提示
result = self.search_entities(query=query, limit=num_results, group_id=group_id)
if result['success']:
result['message'] = result['message'] + "(使用关键词搜索,配置 API key 后可启用向量搜索)"
result['search_type'] = 'enhanced_keyword'
return result
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
# 使用 Graphiti 的语义搜索 (await async method)
group_ids = [group_id] if group_id else None
edges = await self.graphiti.search(
query=query,
center_node_uuid=center_node_uuid,
group_ids=group_ids,
num_results=num_results
)
# 转换结果为字典格式
results = []
for edge in edges:
edge_dict = {
"from_node": {
"uuid": edge.from_node.uuid if hasattr(edge.from_node, 'uuid') else None,
"name": edge.from_node.name if hasattr(edge.from_node, 'name') else None,
"entity_type": edge.from_node.entity_type if hasattr(edge.from_node, 'entity_type') else None,
},
"to_node": {
"uuid": edge.to_node.uuid if hasattr(edge.to_node, 'uuid') else None,
"name": edge.to_node.name if hasattr(edge.to_node, 'name') else None,
"entity_type": edge.to_node.entity_type if hasattr(edge.to_node, 'entity_type') else None,
},
"relationship": edge.relationship_type if hasattr(edge, 'relationship_type') else None,
"description": edge.description if hasattr(edge, 'description') else None,
}
results.append(edge_dict)
return {
"success": True,
"message": f"语义搜索找到 {len(results)} 个相关关系",
"results": results,
"search_type": "semantic"
}
except Exception as e:
# 如果语义搜索失败,回退到基础搜索
logger.warning(f"语义搜索失败: {e},回退到基础搜索", exc_info=True)
return self.search_entities(query=query, limit=num_results, group_id=group_id)
def validate_data(
self,
check_orphaned: bool = True,
check_duplicates: bool = True,
check_integrity: bool = True,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
验证知识图谱数据的完整性.
Args:
check_orphaned: 是否检查孤立节点
check_duplicates: 是否检查重复节点
check_integrity: 是否检查数据完整性
group_id: 组 ID
Returns:
验证结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"issues": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
issues = []
try:
with self.driver.session() as session:
# 检查孤立节点
if check_orphaned:
orphaned_query = """
MATCH (n {group_id: $group_id})
WHERE NOT (n)--()
AND NOT n:Episode
RETURN count(n) as count, labels(n) as labels
"""
result = session.run(orphaned_query, group_id=group_id)
for record in result:
count = record["count"]
labels = record["labels"]
if count > 0:
label = labels[0] if labels else "Unknown"
issues.append(f"发现 {count} 个孤立的 {label} 节点(没有关系)")
# 检查重复节点(相同名称和类型)
if check_duplicates:
duplicates_query = """
MATCH (n {group_id: $group_id})
WHERE n.name IS NOT NULL
WITH labels(n) as labels, n.name as name, count(*) as count
WHERE count > 1
RETURN labels[0] as label, name, count
LIMIT 10
"""
result = session.run(duplicates_query, group_id=group_id)
for record in result:
label = record["label"]
name = record["name"]
count = record["count"]
issues.append(f"发现 {count} 个重复的 {label} 节点(名称: {name})")
# 检查数据完整性(缺少必要属性)
if check_integrity:
integrity_query = """
MATCH (n {group_id: $group_id})
WHERE n.group_id IS NULL OR n.group_id = ''
RETURN count(n) as count
"""
result = session.run(integrity_query, group_id=group_id)
record = result.single()
if record and record["count"] > 0:
issues.append(f"发现 {record['count']} 个节点缺少 group_id 属性")
return {
"success": True,
"message": f"数据验证完成,发现 {len(issues)} 个问题" if issues else "数据验证完成,未发现问题",
"issues": issues,
"issue_count": len(issues)
}
except Exception as e:
return {
"success": False,
"message": f"数据验证失败: {str(e)}",
"issues": []
}
def clean_orphaned_nodes(
self,
node_types: Optional[List[str]] = None,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
清理孤立节点(没有关系的节点).
Args:
node_types: 要清理的节点类型列表(可选)
group_id: 组 ID
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"deleted_count": 0
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
if node_types:
# 清理指定类型的孤立节点
deleted_count = 0
for node_type in node_types:
query = f"""
MATCH (n:{node_type} {{group_id: $group_id}})
WHERE NOT (n)--()
AND NOT n:Episode
WITH n LIMIT 1000
DETACH DELETE n
RETURN count(n) as count
"""
result = session.run(query, group_id=group_id)
record = result.single()
if record:
deleted_count += record["count"]
else:
# 清理所有类型的孤立节点(除了 Episode)
query = """
MATCH (n {group_id: $group_id})
WHERE NOT (n)--()
AND NOT n:Episode
WITH n LIMIT 1000
DETACH DELETE n
RETURN count(n) as count
"""
result = session.run(query, group_id=group_id)
record = result.single()
deleted_count = record["count"] if record else 0
# 清除相关缓存
self.cache.invalidate_pattern(f"statistics:{group_id}")
return {
"success": True,
"message": f"已清理 {deleted_count} 个孤立节点",
"deleted_count": deleted_count
}
except Exception as e:
return {
"success": False,
"message": f"清理孤立节点失败: {str(e)}",
"deleted_count": 0
}
def rebuild_indexes(
self,
index_types: Optional[List[str]] = None,
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
重建 Neo4j 数据库索引.
Args:
index_types: 要重建的索引类型(可选)
group_id: 组 ID(用于日志)
Returns:
操作结果
"""
if not self.is_connected():
return {
"success": False,
"message": "未连接到 Neo4j 数据库",
"indexes": []
}
if group_id is None:
group_id = self.config_manager.get_group_id()
try:
with self.driver.session() as session:
# 获取所有索引
indexes_query = "SHOW INDEXES"
result = session.run(indexes_query)
indexes = []
for record in result:
index_name = record.get("name", "")
index_type = record.get("type", "")
indexes.append(f"{index_name} ({index_type})")
# 注意:Neo4j 的索引重建通常不需要手动操作
# 这里主要是返回索引信息,实际重建由 Neo4j 自动管理
return {
"success": True,
"message": f"已检查 {len(indexes)} 个索引,Neo4j 会自动维护索引",
"indexes": indexes,
"index_count": len(indexes)
}
except Exception as e:
return {
"success": False,
"message": f"重建索引失败: {str(e)}",
"indexes": []
}