"""Graphiti 客户端所有方法完整测试."""
import pytest
from unittest.mock import Mock, MagicMock, patch
from src.config_manager import ConfigManager
from src.graphiti_client import GraphitiClient
class TestGraphitiClientAllMethods:
"""Graphiti 客户端所有方法完整测试类."""
@pytest.fixture
def config_manager(self, temp_config_dir):
"""创建配置管理器."""
return ConfigManager(config_path=temp_config_dir / ".graphitiace" / "config.json")
@pytest.fixture
def connected_client(self, config_manager, mock_neo4j_config, mock_driver):
"""创建已连接的客户端."""
config_manager.configure_neo4j(**mock_neo4j_config)
client = GraphitiClient(config_manager)
driver, session = mock_driver
client.driver = driver
client._connected = True
return client, session
def test_is_connected_true(self, connected_client):
"""测试连接状态(已连接)."""
client, session = connected_client
assert client.is_connected() is True
def test_is_connected_false(self, config_manager):
"""测试连接状态(未连接)."""
client = GraphitiClient(config_manager)
assert client.is_connected() is False
def test_disconnect_when_connected(self, connected_client):
"""测试断开连接(已连接时)."""
client, session = connected_client
client.disconnect()
assert client.is_connected() is False
@pytest.mark.asyncio
async def test_add_episode_with_metadata(self, connected_client):
"""测试添加 Episode(带元数据)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1
mock_result.single.return_value = mock_record
session.run.return_value = mock_result
result = await client.add_episode(
content="Test",
metadata={"type": "test", "category": "programming"}
)
assert result['success'] is True
def test_search_entities_with_type_filter(self, connected_client):
"""测试搜索实体(带类型过滤)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"name": "Test", "type": "Entity"}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.search_entities(
query="test",
entity_type="Entity",
limit=5
)
assert result['success'] is True
def test_query_knowledge_graph_with_limit(self, connected_client):
"""测试查询知识图谱(带限制)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"n": {"name": "Node"}}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.query_knowledge_graph(
cypher_query="MATCH (n) RETURN n",
limit=10
)
assert result['success'] is True
def test_delete_episode_by_content(self, connected_client):
"""测试根据内容删除 Episode."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 1
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.delete_episode(content="Test")
assert result['success'] is True
def test_export_graph_data_cypher_format(self, connected_client):
"""测试导出图谱数据(Cypher 格式)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = {"n": MagicMock(), "labels": ["Entity"]}
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.export_graph_data(format="cypher")
assert result['success'] is True
def test_import_graph_data_cypher_format(self, connected_client):
"""测试导入图谱数据(Cypher 格式)."""
client, session = connected_client
mock_result = MagicMock()
session.run.return_value = mock_result
data = {
"nodes": [{"name": "Node1"}],
"relationships": []
}
result = client.import_graph_data(data=data, format="cypher")
# Cypher 格式可能不支持,应该返回错误或处理
assert result is not None
def test_get_statistics_with_group_id(self, connected_client):
"""测试获取统计信息(指定组 ID)."""
client, session = connected_client
node_result = MagicMock()
node_result.__iter__.return_value = []
rel_result = MagicMock()
rel_result.__iter__.return_value = []
episode_result = MagicMock()
episode_record = MagicMock()
episode_record.__getitem__.side_effect = lambda key: {
"total_episodes": 0,
"first_episode": None,
"last_episode": None
}[key]
episode_result.single.return_value = episode_record
session.run.side_effect = [node_result, rel_result, episode_result]
result = client.get_statistics(group_id="test_group")
assert result['success'] is True
def test_validate_data_with_options(self, connected_client):
"""测试数据验证(带选项)."""
client, session = connected_client
orphaned_result = MagicMock()
orphaned_result.__iter__.return_value = []
duplicates_result = MagicMock()
duplicates_result.__iter__.return_value = []
integrity_result = MagicMock()
integrity_record = MagicMock()
integrity_record.__getitem__.return_value = 0
integrity_result.single.return_value = integrity_record
session.run.side_effect = [orphaned_result, duplicates_result, integrity_result]
result = client.validate_data(
check_orphaned=True,
check_duplicates=True,
check_integrity=True
)
assert result['success'] is True
def test_clean_orphaned_nodes_with_types(self, connected_client):
"""测试清理孤立节点(指定类型)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.__getitem__.return_value = 3
mock_result.single.return_value = mock_record
mock_result.peek.return_value = True
session.run.return_value = mock_result
result = client.clean_orphaned_nodes(node_types=["Entity"])
assert result['success'] is True
def test_rebuild_indexes_with_types(self, connected_client):
"""测试重建索引(指定类型)."""
client, session = connected_client
mock_result = MagicMock()
mock_record = MagicMock()
mock_record.get.side_effect = lambda key, default: {
"name": "index_name",
"type": "BTREE"
}.get(key, default)
mock_result.__iter__.return_value = [mock_record]
session.run.return_value = mock_result
result = client.rebuild_indexes(index_types=["BTREE"])
assert result['success'] is True