"""Integration tests for embedding backend switching.
This module verifies that both embedding backends (Ollama and MLX) work correctly
and produce compatible embeddings. Tests gracefully skip MLX tests if mlx-embeddings
is not installed.
Test categories:
1. Provider creation via factory with both backends
2. Single and batch embedding generation
3. Embedding dimensionality consistency (1024-dim for mxbai model)
4. HybridStore integration with both backends
5. Backend switching at runtime
"""
import sys
import uuid
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from recall.config import EmbeddingBackend
from recall.embedding import (
EmbeddingProvider,
MLXNotAvailableError,
OllamaProvider,
create_embedding_provider,
)
from recall.embedding.ollama import EmbeddingError
from recall.storage.chromadb import ChromaStore
from recall.storage.hybrid import HybridStore
from recall.storage.sqlite import SQLiteStore
# Check if MLX is available for conditional test skipping
def _mlx_available() -> bool:
"""Check if mlx-embeddings is installed."""
try:
import mlx_embeddings # noqa: F401
return True
except ImportError:
return False
MLX_AVAILABLE = _mlx_available()
MLX_SKIP_REASON = "mlx-embeddings not installed (requires Apple Silicon)"
def unique_collection_name() -> str:
"""Generate a unique collection name for test isolation."""
return f"test_backends_{uuid.uuid4().hex[:8]}"
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_ollama_response():
"""Create mock httpx response for Ollama API."""
def _create_response(embedding: List[float]):
response = MagicMock()
response.json.return_value = {"embeddings": [embedding]}
response.raise_for_status = MagicMock()
return response
return _create_response
@pytest.fixture
def mock_ollama_provider():
"""Create a mocked OllamaProvider for testing without actual Ollama."""
provider = AsyncMock(spec=OllamaProvider)
# Return 1024-dim embeddings for mxbai model
provider.embed.return_value = [0.1] * 1024
provider.embed_batch.return_value = [[0.1] * 1024, [0.2] * 1024]
provider.model = "mxbai-embed-large"
provider.host = "http://localhost:11434"
provider.timeout = 30.0
return provider
def create_mlx_output(embeddings_list: List[List[float]]) -> MagicMock:
"""Create a mock MLX output object with text_embeds attribute.
The real mlx_embeddings.generate() returns a BaseModelOutput-like object
with a text_embeds attribute containing the embeddings.
"""
mock_output = MagicMock()
mock_embeddings = MagicMock()
mock_embeddings.tolist.return_value = embeddings_list
mock_output.text_embeds = mock_embeddings
return mock_output
@pytest.fixture
def mock_mlx_module():
"""Create a mock mlx_embeddings module for testing without actual MLX."""
mock_module = MagicMock()
mock_module.load = MagicMock()
mock_module.generate = MagicMock()
# Store original state
original_module = sys.modules.get("mlx_embeddings")
# Inject mock module
sys.modules["mlx_embeddings"] = mock_module
yield mock_module
# Restore original state
if original_module is not None:
sys.modules["mlx_embeddings"] = original_module
else:
sys.modules.pop("mlx_embeddings", None)
@pytest.fixture
def ephemeral_stores():
"""Create ephemeral SQLite and ChromaDB stores for testing."""
sqlite = SQLiteStore(ephemeral=True)
chroma = ChromaStore(ephemeral=True, collection_name=unique_collection_name())
yield {"sqlite": sqlite, "chroma": chroma}
sqlite.close()
# =============================================================================
# Factory Dispatch Tests
# =============================================================================
class TestFactoryDispatch:
"""Tests for create_embedding_provider factory function."""
def test_factory_creates_ollama_provider(self):
"""Test factory creates OllamaProvider for 'ollama' backend."""
provider = create_embedding_provider("ollama")
assert isinstance(provider, OllamaProvider)
assert isinstance(provider, EmbeddingProvider)
def test_factory_ollama_with_custom_params(self):
"""Test factory passes parameters to OllamaProvider."""
provider = create_embedding_provider(
"ollama",
host="http://custom:11434",
model="custom-model",
timeout=60.0,
)
assert provider.host == "http://custom:11434"
assert provider.model == "custom-model"
assert provider.timeout == 60.0
@pytest.mark.skipif(not MLX_AVAILABLE, reason=MLX_SKIP_REASON)
def test_factory_creates_mlx_provider(self):
"""Test factory creates MLXProvider for 'mlx' backend."""
from recall.embedding import MLXProvider
provider = create_embedding_provider("mlx")
assert isinstance(provider, MLXProvider)
assert isinstance(provider, EmbeddingProvider)
@pytest.mark.skipif(not MLX_AVAILABLE, reason=MLX_SKIP_REASON)
def test_factory_mlx_with_custom_model(self):
"""Test factory passes mlx_model parameter to MLXProvider."""
from recall.embedding import MLXProvider
provider = create_embedding_provider(
"mlx",
mlx_model="custom/mlx-model",
)
assert provider.model == "custom/mlx-model"
def test_factory_mlx_raises_when_not_available(self):
"""Test factory raises helpful error when mlx-embeddings not installed."""
import builtins
import sys
# Force-simulate MLX not being available by temporarily manipulating imports
original_import = builtins.__import__
def mock_import(name, *args, **kwargs):
if name == "recall.embedding.mlx_provider" or name == "mlx_embeddings":
raise ImportError("No module named 'mlx_embeddings'")
return original_import(name, *args, **kwargs)
# Remove cached module to force re-import
original_module = sys.modules.pop("recall.embedding.mlx_provider", None)
try:
with patch.object(builtins, "__import__", mock_import):
with pytest.raises(ImportError) as exc_info:
create_embedding_provider("mlx")
error_msg = str(exc_info.value)
assert "mlx-embeddings" in error_msg
assert "pip install mlx-embeddings" in error_msg
finally:
# Restore original module
if original_module is not None:
sys.modules["recall.embedding.mlx_provider"] = original_module
def test_factory_unknown_backend_raises_value_error(self):
"""Test factory raises ValueError for unknown backend."""
with pytest.raises(ValueError, match="Unknown embedding backend"):
create_embedding_provider("unknown") # type: ignore
# =============================================================================
# Ollama Provider Tests (Mocked)
# =============================================================================
class TestOllamaProviderMocked:
"""Tests for OllamaProvider using mocked HTTP responses."""
@pytest.mark.asyncio
async def test_embed_single_text(self, mock_ollama_response):
"""Test embedding a single text."""
expected_embedding = [0.1] * 1024
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.post.return_value = mock_ollama_response(expected_embedding)
mock_client.is_closed = False
mock_client_class.return_value = mock_client
provider = create_embedding_provider("ollama")
result = await provider.embed("test text")
assert len(result) == 1024
assert result == expected_embedding
@pytest.mark.asyncio
async def test_embed_batch_texts(self, mock_ollama_response):
"""Test embedding multiple texts in batch."""
embeddings = [[0.1] * 1024, [0.2] * 1024, [0.3] * 1024]
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
# Mock returns all embeddings in single response
response = MagicMock()
response.json.return_value = {"embeddings": embeddings}
response.raise_for_status = MagicMock()
mock_client.post.return_value = response
mock_client.is_closed = False
mock_client_class.return_value = mock_client
provider = create_embedding_provider("ollama")
result = await provider.embed_batch(["text1", "text2", "text3"])
assert len(result) == 3
for i, emb in enumerate(result):
assert len(emb) == 1024
# =============================================================================
# MLX Provider Tests (Mocked)
# =============================================================================
class TestMLXProviderMocked:
"""Tests for MLXProvider using mocked mlx_embeddings module."""
@pytest.mark.asyncio
async def test_embed_single_text(self, mock_mlx_module):
"""Test embedding a single text with mocked MLX."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
async with MLXProvider() as provider:
result = await provider.embed("test text")
assert len(result) == 1024
@pytest.mark.asyncio
async def test_embed_batch_texts(self, mock_mlx_module):
"""Test batch embedding with mocked MLX."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output(
[[0.1] * 1024, [0.2] * 1024, [0.3] * 1024]
)
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
async with MLXProvider() as provider:
result = await provider.embed_batch(["text1", "text2", "text3"])
assert len(result) == 3
for emb in result:
assert len(emb) == 1024
# =============================================================================
# Embedding Dimensionality Consistency Tests
# =============================================================================
class TestEmbeddingDimensionality:
"""Tests verifying both providers produce consistent embedding dimensions."""
@pytest.mark.asyncio
async def test_ollama_mxbai_produces_1024_dim(self, mock_ollama_response):
"""Test OllamaProvider with mxbai model produces 1024-dim vectors."""
expected_embedding = [0.1] * 1024
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.post.return_value = mock_ollama_response(expected_embedding)
mock_client.is_closed = False
mock_client_class.return_value = mock_client
provider = create_embedding_provider("ollama", model="mxbai-embed-large")
result = await provider.embed("test")
assert len(result) == 1024
@pytest.mark.asyncio
async def test_mlx_mxbai_produces_1024_dim(self, mock_mlx_module):
"""Test MLXProvider with mxbai model produces 1024-dim vectors."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
provider = MLXProvider(model="mlx-community/mxbai-embed-large-v1")
async with provider:
result = await provider.embed("test")
assert len(result) == 1024
@pytest.mark.asyncio
async def test_both_backends_same_dimensionality(self, mock_mlx_module, mock_ollama_response):
"""Test both backends produce same embedding dimensionality."""
# Mock Ollama
ollama_embedding = [0.1] * 1024
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.post.return_value = mock_ollama_response(ollama_embedding)
mock_client.is_closed = False
mock_client_class.return_value = mock_client
ollama_provider = create_embedding_provider("ollama")
ollama_result = await ollama_provider.embed("test text")
# Mock MLX
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.2] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
async with MLXProvider() as mlx_provider:
mlx_result = await mlx_provider.embed("test text")
# Both should produce 1024-dim vectors
assert len(ollama_result) == len(mlx_result) == 1024
# =============================================================================
# HybridStore Integration Tests
# =============================================================================
class TestHybridStoreWithBothBackends:
"""Integration tests for HybridStore with different embedding backends."""
@pytest.mark.asyncio
async def test_hybrid_store_with_ollama_backend(self, ephemeral_stores, mock_ollama_provider):
"""Test HybridStore full store/recall cycle with Ollama backend."""
store = HybridStore(
sqlite_store=ephemeral_stores["sqlite"],
chroma_store=ephemeral_stores["chroma"],
embedding_client=mock_ollama_provider,
)
# Store a memory
memory_id = await store.add_memory(
content="Test memory with Ollama backend",
memory_type="preference",
namespace="test:ollama",
)
assert memory_id is not None
# Verify embedding was generated
mock_ollama_provider.embed.assert_called()
# Recall should work
results = await store.search("Test memory", n_results=5)
assert len(results) >= 1
assert any(r["id"] == memory_id for r in results)
await store.close()
@pytest.mark.asyncio
async def test_hybrid_store_with_mlx_backend(self, ephemeral_stores, mock_mlx_module):
"""Test HybridStore full store/recall cycle with MLX backend."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
mlx_provider = MLXProvider()
store = HybridStore(
sqlite_store=ephemeral_stores["sqlite"],
chroma_store=ephemeral_stores["chroma"],
embedding_client=mlx_provider,
)
# Store a memory
memory_id = await store.add_memory(
content="Test memory with MLX backend",
memory_type="preference",
namespace="test:mlx",
)
assert memory_id is not None
# Verify MLX generate was called for embedding
mock_mlx_module.generate.assert_called()
# Recall should work
results = await store.search("Test memory", n_results=5)
assert len(results) >= 1
assert any(r["id"] == memory_id for r in results)
await store.close()
@pytest.mark.asyncio
async def test_hybrid_store_create_with_ollama_backend(self):
"""Test HybridStore.create() factory with ollama backend."""
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
response = MagicMock()
response.json.return_value = {"embeddings": [[0.1] * 1024]}
response.raise_for_status = MagicMock()
mock_client.post.return_value = response
mock_client.is_closed = False
mock_client_class.return_value = mock_client
store = await HybridStore.create(
ephemeral=True,
collection_name=unique_collection_name(),
embedding_backend="ollama",
)
# Store and recall
memory_id = await store.add_memory(
content="Test via factory",
memory_type="preference",
namespace="test:factory:ollama",
)
assert memory_id is not None
await store.close()
@pytest.mark.asyncio
async def test_hybrid_store_create_with_mlx_backend(self, mock_mlx_module):
"""Test HybridStore.create() factory with mlx backend."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
store = await HybridStore.create(
ephemeral=True,
collection_name=unique_collection_name(),
embedding_backend="mlx",
)
# Store and recall
memory_id = await store.add_memory(
content="Test via factory with MLX",
memory_type="preference",
namespace="test:factory:mlx",
)
assert memory_id is not None
await store.close()
# =============================================================================
# Backend Switching Tests
# =============================================================================
class TestBackendSwitching:
"""Tests for switching between backends at runtime."""
@pytest.mark.asyncio
async def test_memories_searchable_after_backend_switch(
self, ephemeral_stores, mock_ollama_provider, mock_mlx_module
):
"""Test that memories stored with one backend are searchable with another.
This validates that both backends produce compatible embeddings that work
with ChromaDB's vector search.
"""
sqlite = ephemeral_stores["sqlite"]
chroma = ephemeral_stores["chroma"]
# Store memory with Ollama backend
ollama_store = HybridStore(
sqlite_store=sqlite,
chroma_store=chroma,
embedding_client=mock_ollama_provider,
)
memory_id = await ollama_store.add_memory(
content="Important preference stored with Ollama",
memory_type="preference",
namespace="test:switch",
)
# Create MLX provider
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1] * 1024])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
mlx_provider = MLXProvider()
# Create new store instance with MLX backend (same underlying stores)
mlx_store = HybridStore(
sqlite_store=sqlite,
chroma_store=chroma,
embedding_client=mlx_provider,
)
# Search should find the memory
results = await mlx_store.search("Important preference", n_results=5)
# Memory should be found (SQLite has the metadata)
assert len(results) >= 1
assert any(r["id"] == memory_id for r in results)
await mlx_provider.close()
def test_different_backends_same_protocol(self, mock_mlx_module):
"""Test that both backends implement EmbeddingProvider protocol."""
# Ollama
ollama_provider = create_embedding_provider("ollama")
assert isinstance(ollama_provider, EmbeddingProvider)
# MLX (mocked)
with patch("recall.embedding.mlx_provider._check_mlx_available"):
from recall.embedding.mlx_provider import MLXProvider
mlx_provider = MLXProvider()
assert isinstance(mlx_provider, EmbeddingProvider)
# =============================================================================
# Real Backend Tests (Optional - Skip if services unavailable)
# =============================================================================
@pytest.mark.skipif(
not MLX_AVAILABLE,
reason=MLX_SKIP_REASON
)
class TestMLXRealBackend:
"""Real integration tests for MLX backend.
These tests require mlx-embeddings to be installed and will
actually load and run the model. Skip if not on Apple Silicon.
"""
@pytest.mark.asyncio
async def test_real_mlx_embed(self):
"""Test real MLX embedding generation."""
from recall.embedding import MLXProvider
async with MLXProvider() as provider:
result = await provider.embed("Hello world", is_query=False)
assert isinstance(result, list)
assert len(result) == 1024
assert all(isinstance(x, float) for x in result)
@pytest.mark.asyncio
async def test_real_mlx_embed_batch(self):
"""Test real MLX batch embedding generation."""
from recall.embedding import MLXProvider
async with MLXProvider() as provider:
texts = ["Hello", "World", "Test"]
result = await provider.embed_batch(texts)
assert len(result) == 3
for emb in result:
assert len(emb) == 1024
assert all(isinstance(x, float) for x in emb)
@pytest.mark.asyncio
async def test_real_mlx_query_prefix(self):
"""Test MLX applies query prefix for mxbai model."""
from recall.embedding import MLXProvider
async with MLXProvider() as provider:
# Query embedding should have prefix applied
query_result = await provider.embed("search query", is_query=True)
# Document embedding should not have prefix
doc_result = await provider.embed("document text", is_query=False)
# Both should produce valid 1024-dim embeddings
assert len(query_result) == 1024
assert len(doc_result) == 1024
# They should be different (prefix changes the embedding)
assert query_result != doc_result