"""Graphiti 客户端连接后功能测试(使用 mock)."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from src.config_manager import ConfigManager
from src.graphiti_client import GraphitiClient
class TestGraphitiClientConnected:
"""Graphiti 客户端连接后功能测试类."""
@pytest.fixture
def mock_driver(self):
"""创建模拟的 Neo4j 驱动."""
driver = MagicMock()
session = MagicMock()
driver.session.return_value.__enter__.return_value = session
driver.session.return_value.__exit__.return_value = None
return driver, session
@pytest.fixture
def connected_client(self, config_manager, mock_neo4j_config, mock_driver):
"""创建已连接的客户端."""
config_manager.configure_neo4j(**mock_neo4j_config)
client = GraphitiClient(config_manager)
driver, session = mock_driver
client.driver = driver
client._connected = True
return client, session
@pytest.mark.asyncio
async def test_add_episode_connected(self, connected_client):
"""测试连接后添加 Episode."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1 # episode_id
mock_result.single.return_value = mock_record
session.run.return_value = mock_result
result = await client.add_episode(
content="Test episode",
metadata={"type": "test"}
)
assert result['success'] is True
assert 'episode_id' in result
def test_search_entities_connected(self, connected_client):
"""测试连接后搜索实体."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"name": "Test", "type": "Entity"}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.search_entities(query="Test")
assert result['success'] is True
assert 'results' in result
def test_search_relationships_connected(self, connected_client):
"""测试连接后搜索关系."""
client, session = connected_client
# Mock 查询结果 - 直接使用字典来模拟节点
mock_result = MagicMock()
mock_record = MagicMock()
# 直接使用字典作为节点,因为 dict() 需要可迭代对象
node_a_dict = {"name": "A", "group_id": "default"}
node_b_dict = {"name": "B", "group_id": "default"}
mock_rel = MagicMock()
mock_rel.__class__.__name__ = "RELATES_TO"
# 设置记录的 __getitem__ 方法(需要接受 self 参数)
def getitem(self, key):
if key == "a":
return node_a_dict
elif key == "b":
return node_b_dict
elif key == "r":
return mock_rel
return None
mock_record.__getitem__ = getitem
# 确保 __iter__ 返回可迭代对象
mock_result.__iter__ = lambda self: iter([mock_record])
session.run.return_value = mock_result
result = client.search_relationships(query="test")
# 如果失败,打印错误信息以便调试
if not result.get('success'):
print(f"Error: {result.get('message')}")
# 检查是否是异常导致的
import traceback
traceback.print_exc()
assert result['success'] is True, f"Search failed: {result.get('message', 'Unknown error')}"
assert 'results' in result
assert len(result['results']) == 1
def test_query_knowledge_graph_connected(self, connected_client):
"""测试连接后查询知识图谱."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"n": {"name": "Node"}}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.query_knowledge_graph(cypher_query="MATCH (n) RETURN n")
assert result['success'] is True
assert 'results' in result
def test_delete_episode_by_id(self, connected_client):
"""测试根据 ID 删除 Episode."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1 # deleted_count
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.delete_episode(episode_id=1)
assert result['success'] is True
def test_delete_episode_by_content(self, connected_client):
"""测试根据内容删除 Episode."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1 # deleted_count
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.delete_episode(content="Test")
assert result['success'] is True
def test_delete_episode_all(self, connected_client):
"""测试删除所有 Episode."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 5 # deleted_count
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.delete_episode(delete_all=True)
assert result['success'] is True
def test_clear_graph(self, connected_client):
"""测试清空图谱."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 10 # deleted_count
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.clear_graph()
assert result['success'] is True
def test_query_by_time_range(self, connected_client):
"""测试时间范围查询."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"created_at": "2025-01-01", "name": "Test"}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.query_by_time_range(days=7)
assert result['success'] is True
assert 'results' in result
def test_get_statistics(self, connected_client):
"""测试获取统计信息."""
client, session = connected_client
# Mock 节点统计结果
node_result = MagicMock()
node_record1 = MagicMock()
node_record1.__getitem__.side_effect = lambda key: {
"labels": ["Entity"],
"count": 5
}[key]
node_record2 = MagicMock()
node_record2.__getitem__.side_effect = lambda key: {
"labels": ["Episode"],
"count": 10
}[key]
node_result.__iter__.return_value = [node_record1, node_record2]
# Mock 关系统计结果
rel_result = MagicMock()
rel_record = MagicMock()
rel_record.__getitem__.side_effect = lambda key: {
"rel_type": "RELATES_TO",
"count": 3
}[key]
rel_result.__iter__.return_value = [rel_record]
# Mock Episode 统计结果
episode_result = MagicMock()
episode_record = MagicMock()
episode_record.__getitem__.side_effect = lambda key: {
"total_episodes": 10,
"first_episode": "2025-01-01",
"last_episode": "2025-01-10"
}[key]
episode_result.single.return_value = episode_record
# 设置 session.run 的返回值
session.run.side_effect = [node_result, rel_result, episode_result]
result = client.get_statistics()
assert result['success'] is True
assert 'statistics' in result
def test_export_graph_data(self, connected_client):
"""测试导出图谱数据."""
client, session = connected_client
# 创建简单的字典对象来模拟节点
class MockNode:
def __init__(self, data):
self._data = data
def __getitem__(self, key):
return self._data[key]
def keys(self):
return self._data.keys()
def items(self):
return self._data.items()
def __iter__(self):
return iter(self._data)
# Mock 节点查询结果
node_result = MagicMock()
node_record = MagicMock()
node_dict = {"name": "Node", "type": "Entity", "group_id": "default"}
mock_node = MockNode(node_dict)
node_record.__getitem__ = lambda self, key: {
"n": mock_node,
"labels": ["Entity"]
}[key]
node_result.__iter__.return_value = [node_record]
# Mock 关系查询结果
rel_result = MagicMock()
rel_record = MagicMock()
node_a_dict = {"name": "A", "group_id": "default"}
mock_node_a = MockNode(node_a_dict)
node_b_dict = {"name": "B", "group_id": "default"}
mock_node_b = MockNode(node_b_dict)
rel_record.__getitem__ = lambda self, key: {
"a": mock_node_a,
"b": mock_node_b,
"relationship_type": "RELATES_TO"
}[key]
rel_result.__iter__.return_value = [rel_record]
session.run.side_effect = [node_result, rel_result]
# 使用 patch 来模拟 dict() 函数
with patch('src.graphiti_client.dict') as mock_dict_func:
def dict_wrapper(obj):
if isinstance(obj, MockNode):
return obj._data
else:
return dict(obj)
mock_dict_func.side_effect = dict_wrapper
result = client.export_graph_data(format="json")
assert result['success'] is True
assert 'data' in result
def test_import_graph_data(self, connected_client):
"""测试导入图谱数据."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
session.run.return_value = mock_result
data = {
"nodes": [{"name": "Node1", "type": "Entity"}],
"relationships": [{"from": "Node1", "to": "Node2", "type": "RELATES_TO"}]
}
result = client.import_graph_data(data=data)
assert result['success'] is True
@pytest.mark.asyncio
async def test_add_episodes_bulk(self, connected_client):
"""测试批量添加 Episode."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1 # episode_id
mock_result.single.return_value = mock_record
session.run.return_value = mock_result
episodes = [
{"content": "Test 1"},
{"content": "Test 2"}
]
result = await client.add_episodes_bulk(episodes=episodes)
assert result['success'] is True
assert result['total'] == 2
def test_validate_data(self, connected_client):
"""测试数据验证."""
client, session = connected_client
# Mock 查询结果
orphaned_result = MagicMock()
orphaned_record = MagicMock()
orphaned_record.__getitem__.side_effect = lambda key: {
"count": 0,
"labels": []
}[key]
orphaned_result.__iter__.return_value = []
duplicates_result = MagicMock()
duplicates_result.__iter__.return_value = []
integrity_result = MagicMock()
integrity_record = MagicMock()
integrity_record.__getitem__.return_value = 0
integrity_result.single.return_value = integrity_record
session.run.side_effect = [orphaned_result, duplicates_result, integrity_result]
result = client.validate_data()
assert result['success'] is True
assert 'issues' in result
def test_clean_orphaned_nodes(self, connected_client):
"""测试清理孤立节点."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 5 # deleted_count
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.clean_orphaned_nodes()
assert result['success'] is True
assert result['deleted_count'] == 5
def test_rebuild_indexes(self, connected_client):
"""测试重建索引."""
client, session = connected_client
# Mock 查询结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.get.side_effect = lambda key, default: {
"name": "index_name",
"type": "BTREE"
}.get(key, default)
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.rebuild_indexes()
assert result['success'] is True
assert 'indexes' in result
@pytest.mark.asyncio
async def test_semantic_search_without_graphiti(self, connected_client):
"""测试语义搜索(无 Graphiti)."""
client, session = connected_client
client.graphiti = None # 确保没有 Graphiti
# Mock 搜索实体结果
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"name": "Test"}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = await client.semantic_search(query="test")
assert result['success'] is True
assert result.get('search_type') == 'enhanced_keyword'