Skip to main content
Glama
juanqui
by juanqui
test_reranker_deepinfra.py12.4 kB
"""Tests for DeepInfra reranker service.""" from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from pdfkb.config import ServerConfig from pdfkb.exceptions import EmbeddingError from pdfkb.reranker_deepinfra import DeepInfraRerankerService @pytest.fixture def test_config(): """Create test configuration for DeepInfra reranker.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", enable_reranker=True, reranker_provider="deepinfra", deepinfra_api_key="test-deepinfra-key", ) return config @pytest.fixture def deepinfra_reranker(test_config): """Create DeepInfra reranker service instance.""" return DeepInfraRerankerService(test_config) class TestDeepInfraRerankerService: """Test DeepInfra reranker service.""" async def test_initialize(self, deepinfra_reranker): """Test service initialization.""" await deepinfra_reranker.initialize() assert deepinfra_reranker._initialized assert deepinfra_reranker.api_key == "test-deepinfra-key" async def test_initialize_missing_api_key(self): """Test initialization fails without API key.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", deepinfra_api_key="", # Empty key enable_reranker=True, reranker_provider="deepinfra", ) reranker = DeepInfraRerankerService(config) with pytest.raises(EmbeddingError) as exc_info: await reranker.initialize() assert "DeepInfra API key required" in str(exc_info.value) async def test_initialize_with_dummy_key(self): """Test initialization fails with dummy key.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", deepinfra_api_key="sk-local-embeddings-dummy-key", # Dummy key enable_reranker=True, reranker_provider="deepinfra", ) reranker = DeepInfraRerankerService(config) with pytest.raises(EmbeddingError) as exc_info: await reranker.initialize() assert "DeepInfra API key required" in str(exc_info.value) @patch("aiohttp.ClientSession") async def test_rerank(self, mock_session_class, deepinfra_reranker): """Test document reranking.""" # Mock the API response mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={ "scores": [0.9, 0.3, 0.7], "input_tokens": 42, "inference_status": { "status": "success", "runtime_ms": 100, "cost": 0.001, }, } ) mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session # Test reranking documents = [ "Machine learning is a subset of AI", "Database optimization techniques", "Neural networks and deep learning", ] results = await deepinfra_reranker.rerank("machine learning", documents) # Check results are sorted by score assert len(results) == 3 assert results[0] == (0, 0.9) # Highest score assert results[1] == (2, 0.7) assert results[2] == (1, 0.3) # Lowest score # Verify API was called correctly mock_session.post.assert_called_once() call_args = mock_session.post.call_args assert call_args[0][0] == deepinfra_reranker.model_endpoint # Check that query was duplicated for each document assert call_args[1]["json"]["queries"] == ["machine learning"] * 3 assert call_args[1]["json"]["documents"] == documents @patch("aiohttp.ClientSession") async def test_rerank_api_error(self, mock_session_class, deepinfra_reranker): """Test fallback behavior on API error.""" mock_response = MagicMock() mock_response.status = 500 mock_response.text = AsyncMock(return_value="Internal Server Error") mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session documents = ["Doc 1", "Doc 2", "Doc 3"] results = await deepinfra_reranker.rerank("test query", documents) # Should return fallback equal scores assert len(results) == 3 assert all(score == 1.0 for _, score in results) @patch("aiohttp.ClientSession") async def test_rerank_missing_scores(self, mock_session_class, deepinfra_reranker): """Test error handling when response missing scores.""" mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={ "input_tokens": 42, # Missing "scores" field } ) mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session documents = ["Doc 1", "Doc 2", "Doc 3"] results = await deepinfra_reranker.rerank("test query", documents) # Should return fallback equal scores due to missing scores field assert len(results) == 3 assert all(score == 1.0 for _, score in results) @patch("aiohttp.ClientSession") async def test_rerank_with_token_logging(self, mock_session_class, deepinfra_reranker, caplog): """Test that token usage is logged when available.""" import logging caplog.set_level(logging.DEBUG) mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={ "scores": [0.5, 0.5], "input_tokens": 100, } ) mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session await deepinfra_reranker.rerank("test", ["doc1", "doc2"]) # Check that token usage was logged assert "DeepInfra token usage: 100 input tokens" in caplog.text async def test_get_model_info(self, deepinfra_reranker): """Test getting model information.""" info = deepinfra_reranker.get_model_info() assert info["provider"] == "deepinfra" assert info["model"] == "Qwen/Qwen3-Reranker-8B" assert info["endpoint"] == deepinfra_reranker.model_endpoint assert "description" in info assert "capabilities" in info assert "available_models" in info assert len(info["available_models"]) == 3 @patch("aiohttp.ClientSession") async def test_test_connection_success(self, mock_session_class, deepinfra_reranker): """Test connection testing success.""" mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={ "scores": [0.5], "input_tokens": 10, } ) mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session result = await deepinfra_reranker.test_connection() assert result is True @patch("pdfkb.reranker_deepinfra.DeepInfraRerankerService.rerank") async def test_test_connection_failure(self, mock_rerank, deepinfra_reranker): """Test connection testing failure.""" # Make rerank raise an exception mock_rerank.side_effect = Exception("Connection failed") result = await deepinfra_reranker.test_connection() assert result is False async def test_rerank_empty_documents(self, deepinfra_reranker): """Test reranking with empty document list.""" results = await deepinfra_reranker.rerank("test query", []) assert results == [] @patch("aiohttp.ClientSession") async def test_score_normalization(self, mock_session_class, deepinfra_reranker): """Test that scores are normalized to [0, 1] range.""" mock_response = MagicMock() mock_response.status = 200 mock_response.json = AsyncMock( return_value={ "scores": [1.5, -0.3, 0.7], # Out of range scores "input_tokens": 42, } ) mock_session = MagicMock() mock_session.post.return_value.__aenter__.return_value = mock_response mock_session.__aenter__.return_value = mock_session mock_session.__aexit__.return_value = None mock_session_class.return_value = mock_session documents = ["Doc 1", "Doc 2", "Doc 3"] results = await deepinfra_reranker.rerank("test query", documents) # Check that scores are clamped to [0, 1] assert len(results) == 3 assert results[0] == (0, 1.0) # 1.5 clamped to 1.0 assert results[1] == (2, 0.7) assert results[2] == (1, 0.0) # -0.3 clamped to 0.0 async def test_model_0_6b(self): """Test using the 0.6B model.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", enable_reranker=True, reranker_provider="deepinfra", deepinfra_api_key="test-deepinfra-key", deepinfra_reranker_model="Qwen/Qwen3-Reranker-0.6B", ) reranker = DeepInfraRerankerService(config) await reranker.initialize() assert reranker.model_name == "Qwen/Qwen3-Reranker-0.6B" assert reranker.model_endpoint == DeepInfraRerankerService.MODEL_ENDPOINTS["Qwen/Qwen3-Reranker-0.6B"] info = reranker.get_model_info() assert info["model"] == "Qwen/Qwen3-Reranker-0.6B" assert "0.6B" in info["description"] async def test_model_4b(self): """Test using the 4B model.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", enable_reranker=True, reranker_provider="deepinfra", deepinfra_api_key="test-deepinfra-key", deepinfra_reranker_model="Qwen/Qwen3-Reranker-4B", ) reranker = DeepInfraRerankerService(config) await reranker.initialize() assert reranker.model_name == "Qwen/Qwen3-Reranker-4B" assert reranker.model_endpoint == DeepInfraRerankerService.MODEL_ENDPOINTS["Qwen/Qwen3-Reranker-4B"] info = reranker.get_model_info() assert info["model"] == "Qwen/Qwen3-Reranker-4B" assert "4B" in info["description"] async def test_invalid_model_fallback(self): """Test that invalid model falls back to default.""" config = ServerConfig( knowledgebase_path=Path("/tmp/test"), openai_api_key="sk-local-embeddings-dummy-key", enable_reranker=True, reranker_provider="deepinfra", deepinfra_api_key="test-deepinfra-key", deepinfra_reranker_model="Qwen/Invalid-Model", ) reranker = DeepInfraRerankerService(config) # Should fall back to default model assert reranker.model_name == DeepInfraRerankerService.DEFAULT_MODEL assert ( reranker.model_endpoint == DeepInfraRerankerService.MODEL_ENDPOINTS[DeepInfraRerankerService.DEFAULT_MODEL] )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/juanqui/pdfkb-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server