"""Unit tests for OllamaClient retry logic.
Tests verify that transient errors (503, 429) are retried with exponential backoff.
"""
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from recall.embedding.ollama import EmbeddingError, OllamaClient
def make_response(status_code: int, json_data: dict) -> httpx.Response:
"""Create a properly formed httpx Response for testing."""
request = httpx.Request("POST", "http://localhost:11434/api/embed")
return httpx.Response(status_code, json=json_data, request=request)
class TestOllamaRetryLogic:
"""Tests for OllamaClient retry behavior on transient errors."""
@pytest.mark.asyncio
async def test_retry_on_503_service_unavailable(self):
"""Test that 503 errors trigger retry with eventual success."""
client = OllamaClient()
# Create mock responses: 503, 503, then success
mock_responses = [
make_response(503, {"error": "server busy"}),
make_response(503, {"error": "server busy"}),
make_response(200, {"embeddings": [[0.1, 0.2, 0.3]]}),
]
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
response = mock_responses[call_count]
call_count += 1
return response
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
# Should succeed after retries
with patch("recall.embedding.ollama.asyncio.sleep", new_callable=AsyncMock):
result = await client.embed("test text")
assert result == [0.1, 0.2, 0.3]
assert call_count == 3 # 2 failures + 1 success
@pytest.mark.asyncio
async def test_retry_on_429_too_many_requests(self):
"""Test that 429 errors trigger retry with eventual success."""
client = OllamaClient()
mock_responses = [
make_response(429, {"error": "rate limited"}),
make_response(200, {"embeddings": [[0.4, 0.5, 0.6]]}),
]
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
response = mock_responses[call_count]
call_count += 1
return response
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with patch("recall.embedding.ollama.asyncio.sleep", new_callable=AsyncMock):
result = await client.embed("test text")
assert result == [0.4, 0.5, 0.6]
assert call_count == 2
@pytest.mark.asyncio
async def test_max_retries_exceeded_raises_error(self):
"""Test that exceeding max retries raises EmbeddingError."""
client = OllamaClient()
# Always return 503
async def mock_post(*args, **kwargs):
return make_response(503, {"error": "server busy"})
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with patch("recall.embedding.ollama.asyncio.sleep", new_callable=AsyncMock):
with pytest.raises(EmbeddingError) as exc_info:
await client.embed("test text")
assert "503" in str(exc_info.value)
assert "3 attempts" in str(exc_info.value)
@pytest.mark.asyncio
async def test_non_retryable_error_fails_immediately(self):
"""Test that 4xx errors (except 429) fail immediately without retry."""
client = OllamaClient()
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
return make_response(400, {"error": "bad request"})
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with pytest.raises(EmbeddingError) as exc_info:
await client.embed("test text")
assert call_count == 1 # No retries for 400
assert "400" in str(exc_info.value)
@pytest.mark.asyncio
async def test_exponential_backoff_delays(self):
"""Test that retry delays follow exponential backoff pattern."""
client = OllamaClient()
sleep_calls = []
mock_responses = [
make_response(503, {"error": "busy"}),
make_response(503, {"error": "busy"}),
make_response(200, {"embeddings": [[0.1]]}),
]
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
response = mock_responses[call_count]
call_count += 1
return response
async def mock_sleep(delay):
sleep_calls.append(delay)
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with patch("recall.embedding.ollama.asyncio.sleep", side_effect=mock_sleep):
await client.embed("test text")
# Should have exponential delays: 1s, 2s
assert len(sleep_calls) == 2
assert sleep_calls[0] == 1.0 # base_delay * 2^0
assert sleep_calls[1] == 2.0 # base_delay * 2^1
@pytest.mark.asyncio
async def test_embed_batch_retries_on_503(self):
"""Test that embed_batch also retries on 503."""
client = OllamaClient()
mock_responses = [
make_response(503, {"error": "busy"}),
make_response(200, {"embeddings": [[0.1], [0.2]]}),
]
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
response = mock_responses[call_count]
call_count += 1
return response
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with patch("recall.embedding.ollama.asyncio.sleep", new_callable=AsyncMock):
result = await client.embed_batch(["text1", "text2"])
assert result == [[0.1], [0.2]]
assert call_count == 2
@pytest.mark.asyncio
async def test_retry_on_502_bad_gateway(self):
"""Test that 502 errors also trigger retry."""
client = OllamaClient()
mock_responses = [
make_response(502, {"error": "bad gateway"}),
make_response(200, {"embeddings": [[0.7, 0.8]]}),
]
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
response = mock_responses[call_count]
call_count += 1
return response
with patch.object(client, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.post = mock_post
mock_get_client.return_value = mock_client
with patch("recall.embedding.ollama.asyncio.sleep", new_callable=AsyncMock):
result = await client.embed("test")
assert result == [0.7, 0.8]
assert call_count == 2