"""Unit tests for embedding provider factory function.
Validates that:
1. create_embedding_provider creates correct provider types
2. Parameters are passed through correctly
3. ImportError is raised with helpful message for missing mlx-embeddings
4. ValueError is raised for unknown backends
"""
import sys
from unittest.mock import MagicMock, patch
import pytest
from recall.config import EmbeddingBackend
from recall.embedding.factory import create_embedding_provider
from recall.embedding.ollama import OllamaProvider
from recall.embedding.provider import EmbeddingProvider
class TestCreateOllamaProvider:
"""Tests for creating OllamaProvider via factory."""
def test_create_ollama_provider_default(self):
"""Test creating OllamaProvider with default settings."""
provider = create_embedding_provider("ollama")
assert isinstance(provider, OllamaProvider)
assert isinstance(provider, EmbeddingProvider)
assert provider.host == "http://localhost:11434"
assert provider.model == "mxbai-embed-large"
assert provider.timeout == 30.0
def test_create_ollama_provider_custom_host(self):
"""Test creating OllamaProvider with custom host."""
provider = create_embedding_provider(
"ollama",
host="http://custom:11434",
)
assert provider.host == "http://custom:11434"
def test_create_ollama_provider_custom_model(self):
"""Test creating OllamaProvider with custom model."""
provider = create_embedding_provider(
"ollama",
model="nomic-embed-text",
)
assert provider.model == "nomic-embed-text"
def test_create_ollama_provider_custom_timeout(self):
"""Test creating OllamaProvider with custom timeout."""
provider = create_embedding_provider(
"ollama",
timeout=60.0,
)
assert provider.timeout == 60.0
def test_create_ollama_provider_all_custom(self):
"""Test creating OllamaProvider with all custom settings."""
provider = create_embedding_provider(
"ollama",
host="http://remote:8080",
model="custom-embed",
timeout=120.0,
)
assert provider.host == "http://remote:8080"
assert provider.model == "custom-embed"
assert provider.timeout == 120.0
class TestCreateMLXProvider:
"""Tests for creating MLXProvider via factory."""
def test_create_mlx_provider_default(self):
"""Test creating MLXProvider with default settings."""
# Mock the lazy import of MLXProvider
mock_instance = MagicMock()
mock_mlx_provider_class = MagicMock(return_value=mock_instance)
mock_module = MagicMock()
mock_module.MLXProvider = mock_mlx_provider_class
with patch.dict("sys.modules", {"recall.embedding.mlx_provider": mock_module}):
provider = create_embedding_provider("mlx")
mock_mlx_provider_class.assert_called_once_with(
model="mlx-community/mxbai-embed-large-v1"
)
assert provider is mock_instance
def test_create_mlx_provider_custom_model(self):
"""Test creating MLXProvider with custom mlx_model."""
mock_instance = MagicMock()
mock_mlx_provider_class = MagicMock(return_value=mock_instance)
mock_module = MagicMock()
mock_module.MLXProvider = mock_mlx_provider_class
with patch.dict("sys.modules", {"recall.embedding.mlx_provider": mock_module}):
provider = create_embedding_provider(
"mlx",
mlx_model="custom/mlx-model",
)
mock_mlx_provider_class.assert_called_once_with(
model="custom/mlx-model"
)
assert provider is mock_instance
def test_create_mlx_provider_raises_import_error(self):
"""Test that helpful ImportError is raised when mlx-embeddings missing."""
# We can't easily mock the import inside the match block since
# mlx_provider is already imported. Instead, test the error message
# format by mocking the MLXProvider to raise ImportError
import builtins
original_import = builtins.__import__
def mock_import(name, *args, **kwargs):
if name == "recall.embedding.mlx_provider":
raise ImportError("No module named 'mlx_embeddings'")
return original_import(name, *args, **kwargs)
# Need to remove from sys.modules to force re-import
original_module = sys.modules.pop("recall.embedding.mlx_provider", None)
try:
with patch.object(builtins, "__import__", mock_import):
with pytest.raises(ImportError) as exc_info:
create_embedding_provider("mlx")
error_message = str(exc_info.value)
assert "mlx-embeddings" in error_message
assert "pip install mlx-embeddings" in error_message
assert "Apple Silicon" in error_message
finally:
# Restore original module
if original_module is not None:
sys.modules["recall.embedding.mlx_provider"] = original_module
class TestUnknownBackend:
"""Tests for unknown backend handling."""
def test_unknown_backend_raises_value_error(self):
"""Test that ValueError is raised for unknown backend."""
# Use type: ignore since we're testing invalid input
with pytest.raises(ValueError, match="Unknown embedding backend"):
create_embedding_provider("unknown") # type: ignore
def test_unknown_backend_error_message_helpful(self):
"""Test that error message lists valid options."""
with pytest.raises(ValueError) as exc_info:
create_embedding_provider("invalid") # type: ignore
error_msg = str(exc_info.value)
assert "ollama" in error_msg
assert "mlx" in error_msg
class TestProviderConformsToProtocol:
"""Tests that factory returns EmbeddingProvider conformant instances."""
def test_ollama_provider_conforms(self):
"""Test OllamaProvider from factory conforms to protocol."""
provider = create_embedding_provider("ollama")
assert isinstance(provider, EmbeddingProvider)
def test_mlx_provider_conforms(self):
"""Test MLXProvider from factory conforms to protocol."""
# Create a mock that has the protocol methods
mock_instance = MagicMock(spec=EmbeddingProvider)
mock_mlx_provider_class = MagicMock(return_value=mock_instance)
mock_module = MagicMock()
mock_module.MLXProvider = mock_mlx_provider_class
with patch.dict("sys.modules", {"recall.embedding.mlx_provider": mock_module}):
provider = create_embedding_provider("mlx")
# Verify the returned instance
assert provider is mock_instance
class TestFactoryExport:
"""Tests for factory function export from embedding package."""
def test_factory_exported_from_package(self):
"""Test create_embedding_provider is exported from recall.embedding."""
from recall.embedding import create_embedding_provider as exported_factory
assert exported_factory is create_embedding_provider
def test_factory_in_all(self):
"""Test create_embedding_provider is in __all__."""
from recall import embedding
assert "create_embedding_provider" in embedding.__all__
class TestBackendTypeHint:
"""Tests for EmbeddingBackend type handling."""
def test_accepts_literal_ollama(self):
"""Test factory accepts 'ollama' literal."""
backend: EmbeddingBackend = "ollama"
provider = create_embedding_provider(backend)
assert isinstance(provider, OllamaProvider)
def test_accepts_literal_mlx(self):
"""Test factory accepts 'mlx' literal."""
mock_instance = MagicMock()
mock_mlx_provider_class = MagicMock(return_value=mock_instance)
mock_module = MagicMock()
mock_module.MLXProvider = mock_mlx_provider_class
with patch.dict("sys.modules", {"recall.embedding.mlx_provider": mock_module}):
backend: EmbeddingBackend = "mlx"
provider = create_embedding_provider(backend)
assert provider is mock_instance