"""Unit tests for StoreQueue embedding methods.
Tests the embedding storage functionality added to StoreQueue, including:
- update_embedding: Store computed embeddings
- get_embedding: Retrieve embeddings by queue ID
- Embedding serialization/deserialization via struct.pack/unpack
- Float precision preservation
- Large embedding (1024-dim) support
"""
import struct
import sys
import tempfile
from pathlib import Path
import pytest
# Add the hooks directory to Python path to import recall_queue
hooks_path = Path.home() / ".claude" / "hooks"
if str(hooks_path) not in sys.path:
sys.path.insert(0, str(hooks_path))
from recall_queue import QueuedStore, StoreQueue
class TestStoreQueueEmbedding:
"""Tests for StoreQueue embedding storage and retrieval."""
@pytest.fixture
def temp_queue(self, tmp_path: Path) -> StoreQueue:
"""Create a temporary StoreQueue for testing."""
db_path = tmp_path / "test_queue.db"
queue = StoreQueue(db_path=db_path)
yield queue
queue.close()
@pytest.fixture
def sample_embedding(self) -> list[float]:
"""Create a sample 1024-dim embedding like mxbai-embed-large produces."""
return [0.1 * (i % 10) + 0.001 * i for i in range(1024)]
def test_update_embedding_stores_correctly(self, temp_queue: StoreQueue, sample_embedding: list[float]):
"""Test that update_embedding stores embedding correctly and roundtrips."""
# Create an entry without embedding
entry = QueuedStore(
content="Test memory for embedding update",
namespace="test:embedding",
memory_type="preference",
importance=0.7,
)
queue_id = temp_queue.enqueue(entry)
# Verify no embedding initially
initial = temp_queue.get_embedding(queue_id)
assert initial is None, "Expected no embedding before update"
# Update with embedding
success = temp_queue.update_embedding(queue_id, sample_embedding)
assert success is True, "Expected update_embedding to return True"
# Retrieve and verify
retrieved = temp_queue.get_embedding(queue_id)
assert retrieved is not None, "Expected embedding to be retrieved"
assert len(retrieved) == len(sample_embedding), f"Expected {len(sample_embedding)} floats, got {len(retrieved)}"
# Compare values
for i, (expected, actual) in enumerate(zip(sample_embedding, retrieved)):
assert abs(expected - actual) < 1e-6, f"Mismatch at index {i}: expected {expected}, got {actual}"
def test_get_embedding_retrieves_correct_embedding(self, temp_queue: StoreQueue):
"""Test that get_embedding retrieves the correct embedding for a given ID."""
# Create two entries with different embeddings
entry1 = QueuedStore(
content="First memory",
namespace="test:multi",
memory_type="pattern",
importance=0.5,
)
entry2 = QueuedStore(
content="Second memory",
namespace="test:multi",
memory_type="decision",
importance=0.6,
)
id1 = temp_queue.enqueue(entry1)
id2 = temp_queue.enqueue(entry2)
# Set different embeddings
embedding1 = [1.0] * 1024
embedding2 = [2.0] * 1024
temp_queue.update_embedding(id1, embedding1)
temp_queue.update_embedding(id2, embedding2)
# Retrieve and verify each has correct embedding
retrieved1 = temp_queue.get_embedding(id1)
retrieved2 = temp_queue.get_embedding(id2)
assert retrieved1 is not None
assert retrieved2 is not None
# Check that they're distinct
assert abs(retrieved1[0] - 1.0) < 1e-6, "Entry 1 should have embedding [1.0, ...]"
assert abs(retrieved2[0] - 2.0) < 1e-6, "Entry 2 should have embedding [2.0, ...]"
def test_get_embedding_returns_none_for_missing_entry(self, temp_queue: StoreQueue):
"""Test that get_embedding returns None for non-existent queue ID."""
result = temp_queue.get_embedding(99999)
assert result is None, "Expected None for non-existent queue ID"
result = temp_queue.get_embedding(-1)
assert result is None, "Expected None for invalid queue ID"
def test_get_embedding_returns_none_for_entry_without_embedding(self, temp_queue: StoreQueue):
"""Test that get_embedding returns None for entry that has no embedding set."""
entry = QueuedStore(
content="Memory without embedding",
namespace="test:none",
memory_type="session",
importance=0.3,
)
queue_id = temp_queue.enqueue(entry)
result = temp_queue.get_embedding(queue_id)
assert result is None, "Expected None for entry without embedding"
def test_dequeue_batch_returns_embedding_field(self, temp_queue: StoreQueue, sample_embedding: list[float]):
"""Test that dequeue_batch returns entries with embedding field populated."""
# Create entry with pre-computed embedding
entry_with_embedding = QueuedStore(
content="Memory with pre-computed embedding",
namespace="test:dequeue",
memory_type="preference",
importance=0.8,
embedding=sample_embedding,
)
queue_id = temp_queue.enqueue(entry_with_embedding)
# Dequeue and verify embedding is included
batch = temp_queue.dequeue_batch(batch_size=10)
assert len(batch) >= 1
# Find our entry
our_entry = next((e for e in batch if e.id == queue_id), None)
assert our_entry is not None, "Expected our entry in the batch"
assert our_entry.embedding is not None, "Expected embedding to be present"
assert len(our_entry.embedding) == 1024
# Verify values
for i in range(10):
assert abs(our_entry.embedding[i] - sample_embedding[i]) < 1e-6
def test_embedding_serialization_preserves_float_precision(self, temp_queue: StoreQueue):
"""Test that embedding serialization preserves float precision correctly."""
# Use values that test precision edge cases
precision_test_embedding = [
0.123456789, # Many decimal places
-0.987654321, # Negative value
1e-7, # Very small positive
-1e-7, # Very small negative
3.14159265358979, # Pi approximation
2.71828182845904, # e approximation
0.0, # Zero
1.0, # One
-1.0, # Negative one
float('inf') if False else 999999.0, # Large value (avoid actual inf)
]
# Pad to 1024 dimensions
precision_test_embedding.extend([0.0] * (1024 - len(precision_test_embedding)))
entry = QueuedStore(
content="Precision test memory",
namespace="test:precision",
memory_type="pattern",
importance=0.5,
)
queue_id = temp_queue.enqueue(entry)
temp_queue.update_embedding(queue_id, precision_test_embedding)
retrieved = temp_queue.get_embedding(queue_id)
assert retrieved is not None
# Check precision for the first 10 significant values
for i in range(10):
# struct.pack with 'f' format gives single precision (32-bit float)
# so we should expect ~6-7 significant digits of precision
if abs(precision_test_embedding[i]) > 1e-10:
relative_error = abs(retrieved[i] - precision_test_embedding[i]) / abs(precision_test_embedding[i])
assert relative_error < 1e-5, f"Precision loss at index {i}: expected {precision_test_embedding[i]}, got {retrieved[i]}"
else:
# For very small values, use absolute comparison
assert abs(retrieved[i] - precision_test_embedding[i]) < 1e-6
def test_large_embedding_1024_floats_works_correctly(self, temp_queue: StoreQueue):
"""Test that full 1024-dimension embedding (mxbai-embed-large size) works correctly."""
# Create a realistic 1024-dim embedding with varied values
large_embedding = []
for i in range(1024):
# Create varied values similar to real embedding outputs
value = (i / 1024.0) * 2.0 - 1.0 # Values from -1 to 1
value += (i % 7) * 0.01 # Add some variation
large_embedding.append(value)
# Test via enqueue with embedding
entry_enqueue = QueuedStore(
content="Large embedding via enqueue",
namespace="test:large",
memory_type="preference",
importance=0.9,
embedding=large_embedding,
)
id_enqueue = temp_queue.enqueue(entry_enqueue)
# Retrieve via dequeue_batch
batch = temp_queue.dequeue_batch(batch_size=10)
entry_from_batch = next((e for e in batch if e.id == id_enqueue), None)
assert entry_from_batch is not None
assert entry_from_batch.embedding is not None
assert len(entry_from_batch.embedding) == 1024
# Verify all 1024 values
for i in range(1024):
assert abs(entry_from_batch.embedding[i] - large_embedding[i]) < 1e-5, f"Mismatch at index {i}"
# Test via update_embedding
entry_update = QueuedStore(
content="Large embedding via update",
namespace="test:large",
memory_type="decision",
importance=0.8,
)
id_update = temp_queue.enqueue(entry_update)
temp_queue.update_embedding(id_update, large_embedding)
retrieved = temp_queue.get_embedding(id_update)
assert retrieved is not None
assert len(retrieved) == 1024
for i in range(1024):
assert abs(retrieved[i] - large_embedding[i]) < 1e-5, f"Mismatch at index {i}"
def test_update_embedding_returns_false_for_nonexistent_id(self, temp_queue: StoreQueue):
"""Test that update_embedding returns False for non-existent queue ID."""
fake_embedding = [0.0] * 1024
result = temp_queue.update_embedding(99999, fake_embedding)
assert result is False, "Expected False for non-existent queue ID"
def test_embedding_overwrites_previous_value(self, temp_queue: StoreQueue):
"""Test that calling update_embedding twice overwrites the previous embedding."""
entry = QueuedStore(
content="Overwrite test memory",
namespace="test:overwrite",
memory_type="pattern",
importance=0.5,
)
queue_id = temp_queue.enqueue(entry)
# First embedding
first_embedding = [1.0] * 1024
temp_queue.update_embedding(queue_id, first_embedding)
retrieved_first = temp_queue.get_embedding(queue_id)
assert retrieved_first is not None
assert abs(retrieved_first[0] - 1.0) < 1e-6
# Second embedding (overwrite)
second_embedding = [2.0] * 1024
temp_queue.update_embedding(queue_id, second_embedding)
retrieved_second = temp_queue.get_embedding(queue_id)
assert retrieved_second is not None
assert abs(retrieved_second[0] - 2.0) < 1e-6, "Expected embedding to be overwritten"
class TestQueuedStoreDataclass:
"""Tests for QueuedStore dataclass with embedding field."""
def test_queued_store_with_embedding(self):
"""Test QueuedStore creation with embedding field."""
embedding = [0.1] * 1024
entry = QueuedStore(
content="Test content",
namespace="test",
memory_type="preference",
importance=0.5,
embedding=embedding,
)
assert entry.embedding is not None
assert len(entry.embedding) == 1024
assert entry.embedding[0] == 0.1
def test_queued_store_without_embedding(self):
"""Test QueuedStore creation without embedding (default None)."""
entry = QueuedStore(
content="Test content",
namespace="test",
memory_type="preference",
importance=0.5,
)
assert entry.embedding is None
def test_queued_store_to_dict_includes_embedding(self):
"""Test that to_dict includes embedding field."""
embedding = [0.5] * 10
entry = QueuedStore(
content="Test content",
namespace="test",
memory_type="preference",
importance=0.5,
embedding=embedding,
)
d = entry.to_dict()
assert "embedding" in d
assert d["embedding"] == embedding
def test_queued_store_to_dict_with_none_embedding(self):
"""Test that to_dict includes None for embedding when not set."""
entry = QueuedStore(
content="Test content",
namespace="test",
memory_type="preference",
importance=0.5,
)
d = entry.to_dict()
assert "embedding" in d
assert d["embedding"] is None
class TestStoreQueueEnqueueWithEmbedding:
"""Tests for enqueue with pre-computed embedding."""
@pytest.fixture
def temp_queue(self, tmp_path: Path) -> StoreQueue:
"""Create a temporary StoreQueue for testing."""
db_path = tmp_path / "test_queue.db"
queue = StoreQueue(db_path=db_path)
yield queue
queue.close()
def test_enqueue_with_embedding_stores_embedding(self, temp_queue: StoreQueue):
"""Test that enqueue with embedding stores it in the database."""
embedding = [0.25] * 1024
entry = QueuedStore(
content="Pre-embedded memory",
namespace="test:pre",
memory_type="preference",
importance=0.6,
embedding=embedding,
)
queue_id = temp_queue.enqueue(entry)
# Verify via get_embedding
retrieved = temp_queue.get_embedding(queue_id)
assert retrieved is not None
assert len(retrieved) == 1024
assert abs(retrieved[0] - 0.25) < 1e-6
def test_enqueue_without_embedding_stores_null(self, temp_queue: StoreQueue):
"""Test that enqueue without embedding stores NULL in database."""
entry = QueuedStore(
content="Non-embedded memory",
namespace="test:none",
memory_type="preference",
importance=0.6,
)
queue_id = temp_queue.enqueue(entry)
# Verify via get_embedding
retrieved = temp_queue.get_embedding(queue_id)
assert retrieved is None
class TestEmbeddingBLOBFormat:
"""Tests for the BLOB serialization format of embeddings."""
def test_struct_pack_format_matches_expected(self):
"""Test that struct.pack format produces expected BLOB size."""
embedding = [0.0] * 1024
blob = struct.pack(f"{len(embedding)}f", *embedding)
# Each float is 4 bytes, so 1024 floats = 4096 bytes
assert len(blob) == 4096, f"Expected 4096 bytes, got {len(blob)}"
def test_struct_unpack_roundtrip(self):
"""Test that struct pack/unpack roundtrip preserves values."""
original = [0.1 * i for i in range(1024)]
# Pack
blob = struct.pack(f"{len(original)}f", *original)
# Unpack
num_floats = len(blob) // 4
unpacked = list(struct.unpack(f"{num_floats}f", blob))
assert len(unpacked) == len(original)
for i in range(len(original)):
assert abs(unpacked[i] - original[i]) < 1e-5
def test_empty_embedding_not_stored_as_blob(self, tmp_path: Path):
"""Test behavior with empty embedding list."""
queue = StoreQueue(db_path=tmp_path / "test.db")
entry = QueuedStore(
content="Test",
namespace="test",
memory_type="preference",
importance=0.5,
)
queue_id = queue.enqueue(entry)
# Empty list should be stored but handled correctly
empty_embedding: list[float] = []
queue.update_embedding(queue_id, empty_embedding)
# Retrieve - should get empty list
retrieved = queue.get_embedding(queue_id)
# Empty blob means 0 bytes, which when deserialized gives empty list
# The get_embedding checks for embedding IS NOT NULL, so we may get None or empty list
# depending on implementation
queue.close()