"""Tests for EmbeddingProvider Protocol.
Validates that:
1. The Protocol is runtime_checkable (isinstance works)
2. OllamaClient conforms to the Protocol
3. Classes implementing the interface are recognized as providers
4. Classes missing methods are not recognized as providers
"""
from typing import List
import pytest
from recall.embedding.provider import EmbeddingError, EmbeddingProvider
from recall.embedding.ollama import OllamaClient
class TestEmbeddingProviderProtocol:
"""Test suite for EmbeddingProvider Protocol behavior."""
def test_protocol_is_runtime_checkable(self):
"""Verify Protocol can be used with isinstance()."""
# Should not raise - Protocol is decorated with @runtime_checkable
assert hasattr(EmbeddingProvider, "__protocol_attrs__") or hasattr(
EmbeddingProvider, "_is_protocol"
)
def test_ollama_client_conforms_to_protocol(self):
"""Verify OllamaClient is recognized as an EmbeddingProvider."""
client = OllamaClient()
assert isinstance(client, EmbeddingProvider)
def test_conforming_class_recognized(self):
"""Verify a class implementing all methods is recognized."""
class ConformingProvider:
"""A provider that implements all required methods."""
async def embed(self, text: str, is_query: bool = False) -> List[float]:
return [0.0] * 768
async def embed_batch(
self,
texts: List[str],
is_query: bool = False,
batch_size: int = 32,
) -> List[List[float]]:
return [[0.0] * 768 for _ in texts]
async def close(self) -> None:
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
provider = ConformingProvider()
assert isinstance(provider, EmbeddingProvider)
def test_non_conforming_class_missing_embed(self):
"""Verify a class missing embed() is not recognized."""
class MissingEmbed:
async def embed_batch(
self, texts: List[str], is_query: bool = False, batch_size: int = 32
) -> List[List[float]]:
return []
async def close(self) -> None:
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
provider = MissingEmbed()
assert not isinstance(provider, EmbeddingProvider)
def test_non_conforming_class_missing_close(self):
"""Verify a class missing close() is not recognized."""
class MissingClose:
async def embed(self, text: str, is_query: bool = False) -> List[float]:
return []
async def embed_batch(
self, texts: List[str], is_query: bool = False, batch_size: int = 32
) -> List[List[float]]:
return []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
provider = MissingClose()
assert not isinstance(provider, EmbeddingProvider)
def test_non_conforming_class_missing_context_manager(self):
"""Verify a class missing __aenter__ is not recognized."""
class MissingContextManager:
async def embed(self, text: str, is_query: bool = False) -> List[float]:
return []
async def embed_batch(
self, texts: List[str], is_query: bool = False, batch_size: int = 32
) -> List[List[float]]:
return []
async def close(self) -> None:
pass
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
provider = MissingContextManager()
assert not isinstance(provider, EmbeddingProvider)
class TestEmbeddingErrorExport:
"""Test that EmbeddingError is properly exported from provider module."""
def test_embedding_error_is_exception(self):
"""Verify EmbeddingError is an Exception subclass."""
assert issubclass(EmbeddingError, Exception)
def test_embedding_error_can_be_raised(self):
"""Verify EmbeddingError can be raised and caught."""
with pytest.raises(EmbeddingError, match="test error"):
raise EmbeddingError("test error")
def test_embedding_error_same_as_ollama_module(self):
"""Verify EmbeddingError is the same class in both modules."""
from recall.embedding.ollama import EmbeddingError as OllamaEmbeddingError
assert EmbeddingError is OllamaEmbeddingError
class TestProtocolMethodSignatures:
"""Verify Protocol methods have correct signatures."""
def test_embed_signature(self):
"""Verify embed method signature matches expected interface."""
import inspect
sig = inspect.signature(EmbeddingProvider.embed)
params = list(sig.parameters.keys())
assert "self" in params
assert "text" in params
assert "is_query" in params
# Check is_query has default value
assert sig.parameters["is_query"].default is False
def test_embed_batch_signature(self):
"""Verify embed_batch method signature matches expected interface."""
import inspect
sig = inspect.signature(EmbeddingProvider.embed_batch)
params = list(sig.parameters.keys())
assert "self" in params
assert "texts" in params
assert "is_query" in params
assert "batch_size" in params
# Check defaults
assert sig.parameters["is_query"].default is False
assert sig.parameters["batch_size"].default == 32
def test_close_signature(self):
"""Verify close method signature matches expected interface."""
import inspect
sig = inspect.signature(EmbeddingProvider.close)
params = list(sig.parameters.keys())
# Should only have self
assert params == ["self"]