"""工具输入模型单元测试."""
import pytest
from src.tools import (
ConfigureNeo4jInput,
ConfigureAPIInput,
SetGroupIdInput,
AddEpisodeInput,
SearchEntitiesInput,
SearchRelationshipsInput,
QueryKnowledgeGraphInput,
DeleteEpisodeInput,
ClearGraphInput,
QueryByTimeRangeInput,
SemanticSearchInput,
BulkEpisodeInput,
ExportGraphDataInput,
ImportGraphDataInput,
ValidateDataInput,
CleanOrphanedNodesInput,
RebuildIndexesInput
)
class TestToolInputs:
"""工具输入模型测试类."""
def test_configure_neo4j_input(self):
"""测试配置 Neo4j 输入."""
input_data = ConfigureNeo4jInput(
uri="bolt://localhost:7687",
username="neo4j",
password="test"
)
assert input_data.uri == "bolt://localhost:7687"
assert input_data.username == "neo4j"
assert input_data.database == "neo4j" # 默认值
def test_configure_api_input(self):
"""测试配置 API 输入."""
input_data = ConfigureAPIInput(
provider="openai",
api_key="test_key"
)
assert input_data.provider == "openai"
assert input_data.api_key == "test_key"
def test_configure_api_input_no_key(self):
"""测试配置 API 输入(无 key)."""
input_data = ConfigureAPIInput(provider="openai")
assert input_data.provider == "openai"
assert input_data.api_key is None
def test_set_group_id_input(self):
"""测试设置组 ID 输入."""
input_data = SetGroupIdInput(group_id="test_group")
assert input_data.group_id == "test_group"
def test_add_episode_input(self):
"""测试添加 Episode 输入."""
input_data = AddEpisodeInput(
content="Test content",
metadata={"type": "test"}
)
assert input_data.content == "Test content"
assert input_data.metadata == {"type": "test"}
def test_add_episode_input_no_metadata(self):
"""测试添加 Episode 输入(无元数据)."""
input_data = AddEpisodeInput(content="Test content")
assert input_data.content == "Test content"
assert input_data.metadata is None
def test_search_entities_input(self):
"""测试搜索实体输入."""
input_data = SearchEntitiesInput(
query="test",
entity_type="Preference",
limit=10
)
assert input_data.query == "test"
assert input_data.entity_type == "Preference"
assert input_data.limit == 10
def test_search_relationships_input(self):
"""测试搜索关系输入."""
input_data = SearchRelationshipsInput(query="test", limit=5)
assert input_data.query == "test"
assert input_data.limit == 5
def test_query_knowledge_graph_input(self):
"""测试查询知识图谱输入."""
input_data = QueryKnowledgeGraphInput(
query="MATCH (n) RETURN n",
limit=20
)
assert input_data.query == "MATCH (n) RETURN n"
assert input_data.limit == 20
def test_delete_episode_input(self):
"""测试删除 Episode 输入."""
input_data = DeleteEpisodeInput(episode_id=1)
assert input_data.episode_id == 1
assert input_data.delete_all is False
def test_delete_episode_input_all(self):
"""测试删除所有 Episode 输入."""
input_data = DeleteEpisodeInput(delete_all=True)
assert input_data.delete_all is True
def test_clear_graph_input(self):
"""测试清空图谱输入."""
input_data = ClearGraphInput(confirm=True)
assert input_data.confirm is True
def test_query_by_time_range_input(self):
"""测试时间范围查询输入."""
input_data = QueryByTimeRangeInput(days=7, limit=10)
assert input_data.days == 7
assert input_data.limit == 10
def test_semantic_search_input(self):
"""测试语义搜索输入."""
input_data = SemanticSearchInput(
query="test",
num_results=10
)
assert input_data.query == "test"
assert input_data.num_results == 10
def test_bulk_episode_input(self):
"""测试批量 Episode 输入."""
input_data = BulkEpisodeInput(
episodes=[
{"content": "Test 1"},
{"content": "Test 2"}
]
)
assert len(input_data.episodes) == 2
def test_export_graph_data_input(self):
"""测试导出图谱数据输入."""
input_data = ExportGraphDataInput(format="json")
assert input_data.format == "json"
def test_import_graph_data_input(self):
"""测试导入图谱数据输入."""
input_data = ImportGraphDataInput(
data={"nodes": [], "relationships": []},
format="json"
)
assert input_data.format == "json"
assert "nodes" in input_data.data
def test_validate_data_input(self):
"""测试数据验证输入."""
input_data = ValidateDataInput(
check_orphaned=True,
check_duplicates=True,
check_integrity=True
)
assert input_data.check_orphaned is True
assert input_data.check_duplicates is True
assert input_data.check_integrity is True
def test_clean_orphaned_nodes_input(self):
"""测试清理孤立节点输入."""
input_data = CleanOrphanedNodesInput(
confirm=True,
node_types=["Entity"]
)
assert input_data.confirm is True
assert input_data.node_types == ["Entity"]
def test_rebuild_indexes_input(self):
"""测试重建索引输入."""
input_data = RebuildIndexesInput(index_types=["node"])
assert input_data.index_types == ["node"]