Skip to main content
Glama
juanqui
by juanqui
test_local_embeddings.py11.8 kB
"""Tests for local embedding service.""" from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from pdfkb.config import ServerConfig from pdfkb.embeddings_factory import create_embedding_service from pdfkb.embeddings_local import LocalEmbeddingService, LRUCache from pdfkb.exceptions import ConfigurationError class TestLRUCache: """Test cases for LRU cache.""" def test_cache_put_and_get(self): """Test basic cache operations.""" cache = LRUCache(maxsize=3) # Add items cache.put("key1", [1.0, 2.0]) cache.put("key2", [3.0, 4.0]) # Get items assert cache.get("key1") == [1.0, 2.0] assert cache.get("key2") == [3.0, 4.0] assert cache.get("key3") is None def test_cache_lru_eviction(self): """Test LRU eviction when cache is full.""" cache = LRUCache(maxsize=2) # Add items cache.put("key1", [1.0]) cache.put("key2", [2.0]) cache.put("key3", [3.0]) # Should evict key1 # key1 should be evicted assert cache.get("key1") is None assert cache.get("key2") == [2.0] assert cache.get("key3") == [3.0] def test_cache_move_to_end(self): """Test that accessing an item moves it to end (most recently used).""" cache = LRUCache(maxsize=3) # Add items cache.put("key1", [1.0]) cache.put("key2", [2.0]) cache.put("key3", [3.0]) # Access key1 to move it to end _ = cache.get("key1") # Add new item - should evict key2 (least recently used) cache.put("key4", [4.0]) assert cache.get("key1") == [1.0] # Still present assert cache.get("key2") is None # Evicted assert cache.get("key3") == [3.0] assert cache.get("key4") == [4.0] def test_cache_clear(self): """Test cache clearing.""" cache = LRUCache(maxsize=3) cache.put("key1", [1.0]) cache.put("key2", [2.0]) cache.clear() assert cache.get("key1") is None assert cache.get("key2") is None class TestLocalEmbeddingService: """Test cases for local embedding service.""" @pytest.fixture def config(self): """Create test configuration.""" return ServerConfig( openai_api_key="sk-test-key", embedding_provider="local", local_embedding_model="Qwen/Qwen3-Embedding-0.6B", local_embedding_batch_size=2, embedding_device="cpu", embedding_cache_size=100, max_sequence_length=512, use_model_optimization=False, model_cache_dir="/tmp/test-models", ) def test_service_creation(self, config): """Test service creation.""" service = LocalEmbeddingService(config) assert service.model_name == "Qwen/Qwen3-Embedding-0.6B" assert service.batch_size == 2 assert not service._initialized def test_device_selection(self, config): """Test device selection logic.""" service = LocalEmbeddingService(config) # Device selection happens during initialization assert service.device is None # Not yet initialized # Test CPU preference is respected assert service._select_device("cpu") == "cpu" # Test that auto-detect returns at least CPU result = service._select_device(None) assert result in ["cpu", "mps", "cuda"] @pytest.mark.asyncio async def test_initialize(self, config): """Test service initialization.""" service = LocalEmbeddingService(config) # Mock the transformers imports that happen inside initialize with patch("transformers.AutoModel") as mock_model_cls: with patch("transformers.AutoTokenizer") as mock_tokenizer_cls: # Mock model and tokenizer mock_model = MagicMock() mock_tokenizer = MagicMock() mock_model_cls.from_pretrained.return_value = mock_model mock_tokenizer_cls.from_pretrained.return_value = mock_tokenizer mock_model.to.return_value = mock_model mock_model.eval.return_value = mock_model await service.initialize() assert service._initialized assert service.model is not None assert service.tokenizer is not None @pytest.mark.asyncio async def test_generate_embeddings_with_cache(self, config): """Test embedding generation with caching.""" service = LocalEmbeddingService(config) # Mock the synchronous batch generation mock_embeddings = [[0.1, 0.2], [0.3, 0.4]] service._generate_batch_sync = Mock(return_value=mock_embeddings) service._initialized = True # Generate embeddings texts = ["text1", "text2"] embeddings = await service.generate_embeddings(texts) assert embeddings == mock_embeddings service._generate_batch_sync.assert_called_once() # Second call should use cache embeddings2 = await service.generate_embeddings(texts) assert embeddings2 == mock_embeddings @pytest.mark.asyncio async def test_generate_single_embedding(self, config): """Test single embedding generation.""" service = LocalEmbeddingService(config) # Mock batch generation mock_embedding = [0.1, 0.2, 0.3] service.generate_embeddings = AsyncMock(return_value=[mock_embedding]) # Generate single embedding embedding = await service.generate_embedding("test text") assert embedding == mock_embedding service.generate_embeddings.assert_called_once_with(["test text"]) def test_get_embedding_dimension(self, config): """Test getting embedding dimension.""" service = LocalEmbeddingService(config) # Test known model service.model_name = "Qwen/Qwen3-Embedding-0.6B" assert service.get_embedding_dimension() == 1024 # Test with model config - need to temporarily change model name to unknown service.model_name = "unknown-model" mock_model = MagicMock() mock_model.config.hidden_size = 768 service.model = mock_model assert service.get_embedding_dimension() == 768 # Test fallback service.model = None service.model_name = "unknown-model" assert service.get_embedding_dimension() == 768 # Default @pytest.mark.asyncio async def test_test_connection(self, config): """Test connection testing.""" service = LocalEmbeddingService(config) # Mock successful embedding generation service.generate_embedding = AsyncMock(return_value=[0.1, 0.2]) result = await service.test_connection() assert result is True # Mock failed embedding generation service.generate_embedding = AsyncMock(side_effect=Exception("Test error")) service._initialized = True # Skip initialization result = await service.test_connection() assert result is False def test_get_model_info(self, config): """Test getting model information.""" service = LocalEmbeddingService(config) service.device = "mps" info = service.get_model_info() assert info["provider"] == "local" assert info["model"] == "Qwen/Qwen3-Embedding-0.6B" assert info["dimension"] == 1024 assert info["max_sequence_length"] == 32000 assert info["batch_size"] == 2 assert info["device"] == "mps" assert info["cache_size"] == 100 @pytest.mark.asyncio async def test_oom_handling(self, config): """Test out-of-memory error handling.""" service = LocalEmbeddingService(config) service._initialized = True # Create a mock that fails on first large batch, then succeeds on smaller batches call_count = [0] batch_sizes_received = [] def mock_batch_sync(texts): call_count[0] += 1 batch_sizes_received.append(len(texts)) if call_count[0] == 1 and len(texts) > 1: raise RuntimeError("CUDA out of memory") return [[0.1] * 384 for _ in texts] service._generate_batch_sync = mock_batch_sync # Should retry with smaller batch texts = ["text1", "text2", "text3", "text4"] embeddings = await service.generate_embeddings(texts) # Due to the batch reduction, we should get all embeddings but in smaller batches assert len(embeddings) >= 2 # At least some embeddings returned assert call_count[0] > 1 # Should have retried assert 1 in batch_sizes_received # Should have reduced to batch size 1 class TestEmbeddingFactory: """Test cases for embedding factory.""" @pytest.fixture def local_config(self): """Create local embedding configuration.""" return ServerConfig( openai_api_key="sk-test-key", embedding_provider="local", local_embedding_model="Qwen/Qwen3-Embedding-0.6B", ) @pytest.fixture def openai_config(self): """Create OpenAI embedding configuration.""" return ServerConfig( openai_api_key="sk-test-key", embedding_provider="openai", ) def test_create_local_service(self, local_config): """Test creating local embedding service.""" with patch("pdfkb.embeddings_factory.LocalEmbeddingService") as mock_local: mock_instance = Mock() mock_local.return_value = mock_instance service = create_embedding_service(local_config) assert service == mock_instance mock_local.assert_called_once_with(local_config) def test_create_openai_service(self, openai_config): """Test creating OpenAI embedding service.""" from pdfkb.embeddings_openai import OpenAIEmbeddingService service = create_embedding_service(openai_config) assert isinstance(service, OpenAIEmbeddingService) def test_fallback_to_openai(self, local_config): """Test fallback to OpenAI when local fails.""" local_config.fallback_to_openai = True with patch("pdfkb.embeddings_factory.LocalEmbeddingService") as mock_local: mock_local.side_effect = Exception("Local init failed") with patch("pdfkb.embeddings_factory.OpenAIEmbeddingService") as mock_openai: mock_instance = Mock() mock_openai.return_value = mock_instance service = create_embedding_service(local_config, fallback=True) assert service == mock_instance mock_openai.assert_called_once() def test_no_fallback_raises_error(self, local_config): """Test that error is raised when fallback is disabled.""" local_config.fallback_to_openai = False with patch("pdfkb.embeddings_factory.LocalEmbeddingService") as mock_local: mock_local.side_effect = Exception("Local init failed") with pytest.raises(ConfigurationError): create_embedding_service(local_config, fallback=True) def test_invalid_provider(self): """Test invalid provider raises error.""" # ConfigurationError should be raised during config validation, # not during factory creation with pytest.raises(ConfigurationError) as exc_info: ServerConfig( openai_api_key="sk-test-key", embedding_provider="invalid", ) assert "Invalid embedding_provider" in str(exc_info.value)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/juanqui/pdfkb-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server