"""Unit tests for MLX embedding provider."""
import sys
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from recall.embedding.mlx_provider import (
EMBED_PREFIX,
EmbeddingError,
MLXNotAvailableError,
MLXProvider,
)
from recall.embedding.provider import EmbeddingProvider
# Create a fake mlx_embeddings module for testing
@pytest.fixture
def mock_mlx_module():
"""Create a mock mlx_embeddings module and inject it into sys.modules."""
mock_module = MagicMock()
mock_module.load = MagicMock()
mock_module.generate = MagicMock()
# Store original state
original_module = sys.modules.get("mlx_embeddings")
# Inject mock module
sys.modules["mlx_embeddings"] = mock_module
yield mock_module
# Restore original state
if original_module is not None:
sys.modules["mlx_embeddings"] = original_module
else:
sys.modules.pop("mlx_embeddings", None)
def create_mlx_output(embeddings_list: List[List[float]]) -> MagicMock:
"""Create a mock MLX output object with text_embeds attribute.
The real mlx_embeddings.generate() returns a BaseModelOutput-like object
with a text_embeds attribute containing the embeddings.
"""
mock_output = MagicMock()
mock_embeddings = MagicMock()
mock_embeddings.tolist.return_value = embeddings_list
mock_output.text_embeds = mock_embeddings
return mock_output
class TestMLXProviderInit:
"""Tests for MLXProvider initialization."""
def test_default_values(self):
"""Test default initialization values."""
provider = MLXProvider()
assert provider.model == "mlx-community/mxbai-embed-large-v1"
assert provider._is_mxbai is True
assert provider._model_instance is None
assert provider._tokenizer is None
def test_custom_model(self):
"""Test custom model initialization."""
provider = MLXProvider(model="custom/model")
assert provider.model == "custom/model"
assert provider._is_mxbai is False
def test_mxbai_detection_case_insensitive(self):
"""Test mxbai detection is case-insensitive."""
provider = MLXProvider(model="MxBai-LARGE-Model")
assert provider._is_mxbai is True
class TestEmbedPrefix:
"""Tests for query prefix handling."""
def test_prefix_constant(self):
"""Test EMBED_PREFIX constant value matches OllamaClient."""
from recall.embedding.ollama import EMBED_PREFIX as OLLAMA_PREFIX
assert EMBED_PREFIX == OLLAMA_PREFIX
assert EMBED_PREFIX == "Represent this sentence for searching relevant passages: "
class TestMLXProviderProtocol:
"""Tests for EmbeddingProvider protocol conformance."""
def test_provider_conforms_to_protocol(self):
"""Verify MLXProvider is recognized as an EmbeddingProvider."""
provider = MLXProvider()
assert isinstance(provider, EmbeddingProvider)
def test_provider_has_required_methods(self):
"""Verify MLXProvider has all required methods."""
provider = MLXProvider()
assert hasattr(provider, "embed")
assert hasattr(provider, "embed_batch")
assert hasattr(provider, "close")
assert hasattr(provider, "__aenter__")
assert hasattr(provider, "__aexit__")
class TestMLXNotAvailable:
"""Tests for handling missing mlx-embeddings dependency."""
@pytest.mark.asyncio
async def test_embed_raises_when_mlx_not_installed(self):
"""Test that embed raises MLXNotAvailableError when mlx-embeddings not installed."""
provider = MLXProvider()
with patch(
"recall.embedding.mlx_provider._check_mlx_available",
side_effect=MLXNotAvailableError("mlx-embeddings is not installed"),
):
with pytest.raises(MLXNotAvailableError, match="mlx-embeddings is not installed"):
await provider.embed("test")
def test_mlx_not_available_error_is_exception(self):
"""Verify MLXNotAvailableError is an Exception subclass."""
assert issubclass(MLXNotAvailableError, Exception)
class TestEmbed:
"""Tests for single text embedding with mocked MLX."""
@pytest.mark.asyncio
async def test_embed_document_no_prefix(self, mock_mlx_module):
"""Test document embedding does not add prefix."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1, 0.2, 0.3]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
result = await provider.embed("test document", is_query=False)
assert result == [0.1, 0.2, 0.3]
mock_mlx_module.generate.assert_called_once()
call_args = mock_mlx_module.generate.call_args
assert call_args.kwargs["texts"] == ["test document"]
@pytest.mark.asyncio
async def test_embed_query_with_prefix(self, mock_mlx_module):
"""Test query embedding adds mxbai prefix."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1, 0.2, 0.3]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
result = await provider.embed("test query", is_query=True)
assert result == [0.1, 0.2, 0.3]
call_args = mock_mlx_module.generate.call_args
expected_input = f"{EMBED_PREFIX}test query"
assert call_args.kwargs["texts"] == [expected_input]
@pytest.mark.asyncio
async def test_embed_query_no_prefix_non_mxbai(self, mock_mlx_module):
"""Test query embedding without prefix for non-mxbai models."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1, 0.2, 0.3]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider(model="nomic-embed-text") as provider:
result = await provider.embed("test query", is_query=True)
assert result == [0.1, 0.2, 0.3]
call_args = mock_mlx_module.generate.call_args
assert call_args.kwargs["texts"] == ["test query"] # No prefix
@pytest.mark.asyncio
async def test_embed_empty_text_raises(self):
"""Test that empty text raises ValueError."""
async with MLXProvider() as provider:
with pytest.raises(ValueError, match="Text cannot be empty"):
await provider.embed("")
@pytest.mark.asyncio
async def test_embed_whitespace_only_raises(self):
"""Test that whitespace-only text raises ValueError."""
async with MLXProvider() as provider:
with pytest.raises(ValueError, match="Text cannot be empty"):
await provider.embed(" ")
@pytest.mark.asyncio
async def test_embed_no_embeddings_returned(self, mock_mlx_module):
"""Test error when no embeddings returned."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
with pytest.raises(EmbeddingError, match="No embedding returned"):
await provider.embed("test")
class TestEmbedBatch:
"""Tests for batch embedding with mocked MLX."""
@pytest.mark.asyncio
async def test_embed_batch_single_batch(self, mock_mlx_module):
"""Test batch embedding with single batch."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output(
[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
)
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
texts = ["text1", "text2", "text3"]
result = await provider.embed_batch(texts)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
@pytest.mark.asyncio
async def test_embed_batch_multiple_batches(self, mock_mlx_module):
"""Test batch embedding with multiple batches."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.side_effect = [
create_mlx_output([[0.1, 0.2], [0.3, 0.4]]),
create_mlx_output([[0.5, 0.6]]),
]
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
texts = ["text1", "text2", "text3"]
result = await provider.embed_batch(texts, batch_size=2)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
@pytest.mark.asyncio
async def test_embed_batch_with_query_prefix(self, mock_mlx_module):
"""Test batch embedding with query prefix."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1, 0.2], [0.3, 0.4]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
texts = ["query1", "query2"]
result = await provider.embed_batch(texts, is_query=True)
assert result == [[0.1, 0.2], [0.3, 0.4]]
call_args = mock_mlx_module.generate.call_args
expected_inputs = [f"{EMBED_PREFIX}query1", f"{EMBED_PREFIX}query2"]
assert call_args.kwargs["texts"] == expected_inputs
@pytest.mark.asyncio
async def test_embed_batch_empty_list_raises(self):
"""Test that empty texts list raises ValueError."""
async with MLXProvider() as provider:
with pytest.raises(ValueError, match="Texts list cannot be empty"):
await provider.embed_batch([])
@pytest.mark.asyncio
async def test_embed_batch_wrong_count(self, mock_mlx_module):
"""Test error when wrong number of embeddings returned."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1, 0.2]]) # Only 1, expected 3
with patch("recall.embedding.mlx_provider._check_mlx_available"):
async with MLXProvider() as provider:
with pytest.raises(EmbeddingError, match="Expected 3 embeddings"):
await provider.embed_batch(["t1", "t2", "t3"])
class TestContextManager:
"""Tests for async context manager."""
@pytest.mark.asyncio
async def test_context_manager_closes_provider(self, mock_mlx_module):
"""Test that context manager properly closes provider."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
provider = MLXProvider()
async with provider:
await provider.embed("test")
assert provider._model_instance is not None
# After exiting, model should be cleared
assert provider._model_instance is None
assert provider._tokenizer is None
@pytest.mark.asyncio
async def test_manual_close(self, mock_mlx_module):
"""Test manual close method."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
provider = MLXProvider()
await provider.embed("test")
assert provider._model_instance is not None
await provider.close()
assert provider._model_instance is None
assert provider._tokenizer is None
class TestLazyLoading:
"""Tests for lazy model loading behavior."""
@pytest.mark.asyncio
async def test_model_not_loaded_on_init(self):
"""Test that model is not loaded during initialization."""
provider = MLXProvider()
assert provider._model_instance is None
assert provider._tokenizer is None
@pytest.mark.asyncio
async def test_model_loaded_on_first_embed(self, mock_mlx_module):
"""Test that model is loaded on first embed call."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
provider = MLXProvider()
assert provider._model_instance is None
await provider.embed("test")
assert provider._model_instance is mock_model
assert provider._tokenizer is mock_tokenizer
mock_mlx_module.load.assert_called_once()
@pytest.mark.asyncio
async def test_model_not_reloaded_on_second_embed(self, mock_mlx_module):
"""Test that model is not reloaded on subsequent embed calls."""
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_mlx_module.load.return_value = (mock_model, mock_tokenizer)
mock_mlx_module.generate.return_value = create_mlx_output([[0.1]])
with patch("recall.embedding.mlx_provider._check_mlx_available"):
provider = MLXProvider()
await provider.embed("test1")
await provider.embed("test2")
mock_mlx_module.load.assert_called_once()
class TestModuleExports:
"""Tests for module exports from __init__.py."""
def test_mlx_provider_exported(self):
"""Test MLXProvider is exported from embedding package."""
from recall.embedding import MLXProvider as ExportedMLXProvider
assert ExportedMLXProvider is MLXProvider
def test_mlx_not_available_error_exported(self):
"""Test MLXNotAvailableError is exported from embedding package."""
from recall.embedding import MLXNotAvailableError as ExportedError
assert ExportedError is MLXNotAvailableError