test_graphiti_mock.py•64.9 kB
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timedelta
from unittest.mock import Mock
import numpy as np
import pytest
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
get_communities_by_nodes,
get_edge_invalidation_candidates,
get_embeddings_for_communities,
get_embeddings_for_edges,
get_embeddings_for_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
)
from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
from graphiti_core.utils.maintenance.community_operations import (
determine_entity_community,
get_community_clusters,
remove_communities,
)
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
from tests.helpers_test import (
GraphProvider,
assert_entity_edge_equals,
assert_entity_node_equals,
assert_episodic_edge_equals,
assert_episodic_node_equals,
get_edge_count,
get_node_count,
group_id,
group_id_2,
)
pytest_plugins = ('pytest_asyncio',)
@pytest.fixture
def mock_llm_client():
"""Create a mock LLM"""
mock_llm = Mock(spec=LLMClient)
mock_llm.config = Mock()
mock_llm.model = 'test-model'
mock_llm.small_model = 'test-small-model'
mock_llm.temperature = 0.0
mock_llm.max_tokens = 1000
mock_llm.cache_enabled = False
mock_llm.cache_dir = None
# Mock the public method that's actually called
mock_llm.generate_response = Mock()
mock_llm.generate_response.return_value = {
'tool_calls': [
{
'name': 'extract_entities',
'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
}
]
}
return mock_llm
@pytest.fixture
def mock_cross_encoder_client():
"""Create a mock LLM"""
mock_llm = Mock(spec=CrossEncoderClient)
mock_llm.config = Mock()
# Mock the public method that's actually called
mock_llm.rerank = Mock()
mock_llm.rerank.return_value = {
'tool_calls': [
{
'name': 'extract_entities',
'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
}
]
}
return mock_llm
@pytest.mark.asyncio
async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
# Create episodic nodes
episode_node_1 = EpisodicNode(
name='test_episode',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
)
episode_node_2 = EpisodicNode(
name='test_episode_2',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Bob adores Alice',
valid_at=now,
entity_edges=[], # Filled in later
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=now,
summary='test_entity_1 summary',
attributes={'age': 30, 'location': 'New York'},
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
group_id=group_id,
labels=['Entity', 'Person2'],
created_at=now,
summary='test_entity_2 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
group_id=group_id,
labels=['Entity', 'City', 'Location'],
created_at=now,
summary='test_entity_3 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
group_id=group_id,
labels=['Entity'],
created_at=now,
summary='test_entity_4 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=now,
name='likes',
fact='test_entity_1 relates to test_entity_2',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=now,
name='relates_to',
fact='test_entity_3 relates to test_entity_4',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_3 = EpisodicEdge(
source_node_uuid=episode_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_4 = EpisodicEdge(
source_node_uuid=episode_node_2.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=now,
group_id=group_id,
)
# Cross reference the ids
episode_node_1.entity_edges = [entity_edge_1.uuid]
episode_node_2.entity_edges = [entity_edge_2.uuid]
entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
entity_edge_2.episodes = [episode_node_2.uuid]
# Test add bulk
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node_1, episode_node_2],
[episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
[entity_node_1, entity_node_2, entity_node_3, entity_node_4],
[entity_edge_1, entity_edge_2],
mock_embedder,
)
node_ids = [
episode_node_1.uuid,
episode_node_2.uuid,
entity_node_1.uuid,
entity_node_2.uuid,
entity_node_3.uuid,
entity_node_4.uuid,
]
edge_ids = [
episodic_edge_1.uuid,
episodic_edge_2.uuid,
episodic_edge_3.uuid,
episodic_edge_4.uuid,
entity_edge_1.uuid,
entity_edge_2.uuid,
]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == len(node_ids)
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == len(edge_ids)
# Test episodic nodes
retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
await assert_episodic_node_equals(retrieved_episode, episode_node_1)
retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
await assert_episodic_node_equals(retrieved_episode, episode_node_2)
# Test entity nodes
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)
# Test episodic edges
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)
# Test entity edges
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
@pytest.mark.asyncio
async def test_remove_episode(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
# Create episodic nodes
episode_node = EpisodicNode(
name='test_episode',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
)
# Create entity nodes
alice_node = EntityNode(
name='Alice',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=now,
summary='Alice summary',
attributes={'age': 30, 'location': 'New York'},
)
await alice_node.generate_name_embedding(mock_embedder)
bob_node = EntityNode(
name='Bob',
group_id=group_id,
labels=['Entity', 'Person2'],
created_at=now,
summary='Bob summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await bob_node.generate_name_embedding(mock_embedder)
# Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_alice_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
episodic_bob_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
group_id=group_id,
)
# Cross reference the ids
episode_node.entity_edges = [entity_edge.uuid]
entity_edge.episodes = [episode_node.uuid]
# Test add bulk
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node],
[episodic_alice_edge, episodic_bob_edge],
[alice_node, bob_node],
[entity_edge],
mock_embedder,
)
node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 3
# Test remove episode
await graphiti.remove_episode(episode_node.uuid)
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 0
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 0
# Test add bulk again
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node],
[episodic_alice_edge, episodic_bob_edge],
[alice_node, bob_node],
[entity_edge],
mock_embedder,
)
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 3
@pytest.mark.asyncio
async def test_graphiti_retrieve_episodes(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
valid_at_1 = now - timedelta(days=2)
valid_at_2 = now - timedelta(days=4)
valid_at_3 = now - timedelta(days=6)
# Create episodic nodes
episode_node_1 = EpisodicNode(
name='test_episode_1',
labels=[],
created_at=now,
valid_at=valid_at_1,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 1',
entity_edges=[],
group_id=group_id,
)
episode_node_2 = EpisodicNode(
name='test_episode_2',
labels=[],
created_at=now,
valid_at=valid_at_2,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 2',
entity_edges=[],
group_id=group_id,
)
episode_node_3 = EpisodicNode(
name='test_episode_3',
labels=[],
created_at=now,
valid_at=valid_at_3,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 3',
entity_edges=[],
group_id=group_id,
)
# Save the nodes
await episode_node_1.save(graph_driver)
await episode_node_2.save(graph_driver)
await episode_node_3.save(graph_driver)
node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
# Retrieve episodes
query_time = now - timedelta(days=3)
episodes = await graphiti.retrieve_episodes(
query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
)
assert len(episodes) == 2
assert episodes[0].name == episode_node_3.name
assert episodes[1].name == episode_node_2.name
@pytest.mark.asyncio
async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Save the nodes
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 4
# Create duplicate entity edge
entity_edge = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='IS_DUPLICATE_OF',
fact='test_entity_1 is a duplicate of test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge.generate_embedding(mock_embedder)
await entity_edge.save(graph_driver)
# Filter duplicate entity edges
duplicate_node_tuples = [
(entity_node_1, entity_node_2),
(entity_node_3, entity_node_4),
]
node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
assert len(node_tuples) == 1
assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]
@pytest.mark.asyncio
async def test_determine_entity_community(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
entity_edge_3 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_3 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_3.generate_embedding(mock_embedder)
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='test_community_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Create community to entity edges
community_edge_1 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
community_edge_2 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=datetime.now(),
group_id=group_id,
)
community_edge_3 = CommunityEdge(
source_node_uuid=community_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await entity_edge_3.save(graph_driver)
await community_edge_1.save(graph_driver)
await community_edge_2.save(graph_driver)
await community_edge_3.save(graph_driver)
node_ids = [
entity_node_1.uuid,
entity_node_2.uuid,
entity_node_3.uuid,
entity_node_4.uuid,
community_node_1.uuid,
community_node_2.uuid,
]
edge_ids = [
entity_edge_1.uuid,
entity_edge_2.uuid,
entity_edge_3.uuid,
community_edge_1.uuid,
community_edge_2.uuid,
community_edge_3.uuid,
]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 6
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 6
# Determine entity community
community, is_new = await determine_entity_community(graph_driver, entity_node_4)
assert community.name == community_node_1.name
assert is_new
# Add entity to community edge
community_edge_4 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=datetime.now(),
group_id=group_id,
)
await community_edge_4.save(graph_driver)
# Determine entity community again
community, is_new = await determine_entity_community(graph_driver, entity_node_4)
assert community.name == community_node_1.name
assert not is_new
await remove_communities(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
assert node_count == 0
@pytest.mark.asyncio
async def test_get_community_clusters(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_3 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 4
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 2
# Get community clusters
clusters = await get_community_clusters(graph_driver, group_ids=None)
assert len(clusters) == 2
assert len(clusters[0]) == 2
assert len(clusters[1]) == 2
entities_1 = set([node.name for node in clusters[0]])
entities_2 = set([node.name for node in clusters[1]])
assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
['test_entity_1', 'test_entity_2']
)
assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
['test_entity_3', 'test_entity_4']
)
@pytest.mark.asyncio
async def test_get_mentioned_nodes(graph_driver, mock_embedder):
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Get mentioned nodes
mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
assert len(mentioned_nodes) == 1
assert mentioned_nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_get_communities_by_nodes(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
# Create community to entity edges
community_edge_1 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await community_node_1.save(graph_driver)
await community_edge_1.save(graph_driver)
# Get communities by nodes
communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
assert len(communities) == 1
assert communities[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_edge_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = await edge_fulltext_search(
graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
)
assert len(edges) == 1
assert edges[0].name == entity_edge_1.name
@pytest.mark.asyncio
async def test_edge_similarity_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = await edge_similarity_search(
graph_driver,
entity_edge_1.fact_embedding,
entity_node_1.uuid,
entity_node_2.uuid,
search_filters,
group_ids=[group_id],
)
assert len(edges) == 1
assert edges[0].name == entity_edge_1.name
@pytest.mark.asyncio
async def test_edge_bfs_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_3',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
# Test bfs from episodic node
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
1,
search_filters,
group_ids=[group_id],
)
assert len(edges) == 0
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
2,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 1
assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
3,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 2
assert edges_deduplicated == {
'test_entity_1 relates to test_entity_2',
'test_entity_2 relates to test_entity_3',
}
# Test bfs from entity node
edges = await edge_bfs_search(
graph_driver,
[entity_node_1.uuid],
1,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 1
assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
edges = await edge_bfs_search(
graph_driver,
[entity_node_1.uuid],
2,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 2
assert edges_deduplicated == {
'test_entity_1 relates to test_entity_2',
'test_entity_2 relates to test_entity_3',
}
@pytest.mark.asyncio
async def test_node_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
summary='Summary about Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
summary='Summary about Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(node_labels=['Entity'])
nodes = await node_fulltext_search(
graph_driver,
'Alice',
search_filters,
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_node_similarity_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_alice',
summary='Summary about Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_bob',
summary='Summary about Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(node_labels=['Entity'])
nodes = await node_similarity_search(
graph_driver,
entity_node_1.name_embedding,
search_filters,
group_ids=[group_id],
min_score=0.9,
)
assert len(nodes) == 1
assert nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_node_bfs_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_3',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(
node_labels=['Entity'],
)
# Test bfs from episodic node
nodes = await node_bfs_search(
graph_driver,
[episodic_node_1.uuid],
search_filters,
1,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 1
assert nodes_deduplicated == {'test_entity_1'}
nodes = await node_bfs_search(
graph_driver,
[episodic_node_1.uuid],
search_filters,
2,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 2
assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}
# Test bfs from entity node
nodes = await node_bfs_search(
graph_driver,
[entity_node_1.uuid],
search_filters,
1,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 1
assert nodes_deduplicated == {'test_entity_2'}
@pytest.mark.asyncio
async def test_episode_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
content='test_content',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Alice',
)
episodic_node_2 = EpisodicNode(
name='test_episodic_2',
content='test_content_2',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Bob',
)
# Save the graph
await episodic_node_1.save(graph_driver)
await episodic_node_2.save(graph_driver)
# Search for episodic nodes
search_filters = SearchFilters(node_labels=['Episodic'])
nodes = await episode_fulltext_search(
graph_driver,
'Alice',
search_filters,
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == episodic_node_1.name
@pytest.mark.asyncio
async def test_community_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create community nodes
community_node_1 = CommunityNode(
name='Alice',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='Bob',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
# Search for community nodes
nodes = await community_fulltext_search(
graph_driver,
'Alice',
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_community_similarity_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create community nodes
community_node_1 = CommunityNode(
name='Alice',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='Bob',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
# Search for community nodes
nodes = await community_similarity_search(
graph_driver,
community_node_1.name_embedding,
group_ids=[group_id],
min_score=0.9,
)
assert len(nodes) == 1
assert nodes[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_get_relevant_nodes(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as tests fail on Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='Alice',
summary='Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='Bob',
summary='Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='Alice Smith',
summary='Alice Smith',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(node_labels=['Entity'])
nodes = (
await get_relevant_nodes(
graph_driver,
[entity_node_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(nodes) == 2
assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}
@pytest.mark.asyncio
async def test_get_relevant_edges_and_invalidation_candidates(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
summary='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
summary='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
summary='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='Alice',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='Bob',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
entity_edge_3 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='Alice',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_3.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await entity_edge_3.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = (
await get_relevant_edges(
graph_driver,
[entity_edge_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(edges) == 1
assert set({edge.name for edge in edges}) == {entity_edge_1.name}
edges = (
await get_edge_invalidation_candidates(
graph_driver,
[entity_edge_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(edges) == 2
assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}
@pytest.mark.asyncio
async def test_node_distance_reranker(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Test reranker
reranked_uuids, reranked_scores = await node_distance_reranker(
graph_driver,
[entity_node_2.uuid, entity_node_3.uuid],
entity_node_1.uuid,
)
uuid_to_name = {
entity_node_1.uuid: entity_node_1.name,
entity_node_2.uuid: entity_node_2.name,
entity_node_3.uuid: entity_node_3.name,
}
names = [uuid_to_name[uuid] for uuid in reranked_uuids]
assert names == [entity_node_2.name, entity_node_3.name]
assert np.allclose(reranked_scores, [1.0, 0.0])
@pytest.mark.asyncio
async def test_episode_mentions_reranker(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
content='test_content',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Alice',
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Create entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await episodic_node_1.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Test reranker
reranked_uuids, reranked_scores = await episode_mentions_reranker(
graph_driver,
[[entity_node_1.uuid, entity_node_2.uuid]],
)
uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
names = [uuid_to_name[uuid] for uuid in reranked_uuids]
assert names == [entity_node_1.name, entity_node_2.name]
assert np.allclose(reranked_scores, [1.0, float('inf')])
@pytest.mark.asyncio
async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Get embeddings for edges
embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
assert len(embeddings) == 1
assert entity_edge_1.uuid in embeddings
assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)
@pytest.mark.asyncio
async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
# Get embeddings for edges
embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
assert len(embeddings) == 1
assert entity_node_1.uuid in embeddings
assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)
@pytest.mark.asyncio
async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
# Get embeddings for communities
embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
assert len(embeddings) == 1
assert community_node_1.uuid in embeddings
assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)