test_llm_client.py•9.13 kB
import pytest
import httpx
from unittest.mock import AsyncMock, patch, Mock
from llm.client import LLMClient, LLMTimeoutError, LLMHttpError, LLMNetworkError
class TestLLMClient:
"""Test LLM client functionality"""
@pytest.fixture
def client(self):
"""Create LLM client for testing"""
return LLMClient(
endpoint="https://api.test.com/v1/chat/completions",
api_key="test-key",
model="gpt-4",
timeout=30.0,
max_retries=2,
)
@pytest.mark.asyncio
async def test_successful_request(self, client):
"""Test successful API call"""
mock_response = {
"choices": [
{"message": {"content": '{"cleaned": "test response", "notes": []}'}}
]
}
with patch("httpx.AsyncClient") as mock_client:
mock_response_obj = Mock() # Use regular Mock for response object
mock_response_obj.json.return_value = (
mock_response # Make json() return the data
)
mock_response_obj.raise_for_status.return_value = None
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response_obj
)
result = await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
assert result == '{"cleaned": "test response", "notes": []}'
@pytest.mark.asyncio
async def test_timeout_error(self, client):
"""Test timeout handling"""
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post.side_effect = (
httpx.TimeoutException("Timeout")
)
with pytest.raises(LLMTimeoutError):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_http_4xx_error(self, client):
"""Test 4xx HTTP error (no retry)"""
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
http_error = httpx.HTTPStatusError(
"Bad Request", request=AsyncMock(), response=mock_response
)
mock_client.return_value.__aenter__.return_value.post.side_effect = (
http_error
)
with pytest.raises(LLMHttpError, match="HTTP 400"):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_http_5xx_error_with_retry(self, client):
"""Test 5xx HTTP error with retry logic"""
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
http_error = httpx.HTTPStatusError(
"Server Error", request=AsyncMock(), response=mock_response
)
mock_client.return_value.__aenter__.return_value.post.side_effect = (
http_error
)
with pytest.raises(LLMHttpError, match="HTTP 500"):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_network_error(self, client):
"""Test network error handling"""
with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post.side_effect = (
httpx.RequestError("Network error")
)
with pytest.raises(LLMNetworkError):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_invalid_response_format(self, client):
"""Test handling of invalid response format"""
mock_response = {"invalid": "response"}
with patch("httpx.AsyncClient") as mock_client:
mock_response_obj = Mock() # Use regular Mock for response object
mock_response_obj.json.return_value = (
mock_response # Make json() return the data
)
mock_response_obj.raise_for_status.return_value = None
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response_obj
)
with pytest.raises(LLMHttpError, match="Invalid response format"):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_empty_response_content(self, client):
"""Test handling of empty response content"""
mock_response = {"choices": [{"message": {"content": ""}}]}
with patch("httpx.AsyncClient") as mock_client:
mock_response_obj = Mock() # Use regular Mock for response object
mock_response_obj.json.return_value = (
mock_response # Make json() return the data
)
mock_response_obj.raise_for_status.return_value = None
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response_obj
)
with pytest.raises(LLMHttpError, match="Empty response content"):
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
@pytest.mark.asyncio
async def test_headers_inclusion(self, client):
"""Test that proper headers are included in requests"""
with patch("httpx.AsyncClient") as mock_client:
mock_response = {
"choices": [{"message": {"content": '{"cleaned": "test"}'}}]
}
mock_response_obj = Mock() # Use regular Mock for response object
mock_response_obj.json.return_value = mock_response
mock_response_obj.raise_for_status.return_value = None
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response_obj
)
await client.chat_completions(
messages=[{"role": "user", "content": "test"}],
temperature=0.2,
max_tokens=100,
request_id="test-123",
)
# Verify the request was made with correct headers
call_args = mock_client.return_value.__aenter__.return_value.post.call_args
headers = call_args[1]["headers"]
assert headers["Content-Type"] == "application/json"
assert headers["Authorization"] == "Bearer test-key"
assert headers["X-Request-ID"] == "test-123"
@pytest.mark.asyncio
async def test_payload_structure(self, client):
"""Test that request payload has correct structure"""
with patch("httpx.AsyncClient") as mock_client:
mock_response = {
"choices": [{"message": {"content": '{"cleaned": "test"}'}}]
}
mock_response_obj = Mock() # Use regular Mock for response object
mock_response_obj.json.return_value = mock_response
mock_response_obj.raise_for_status.return_value = None
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response_obj
)
messages = [{"role": "user", "content": "test prompt"}]
await client.chat_completions(
messages=messages,
temperature=0.5,
max_tokens=200,
request_id="test-456",
)
# Verify the payload structure
call_args = mock_client.return_value.__aenter__.return_value.post.call_args
payload = call_args[1]["json"]
assert payload["model"] == "gpt-4"
assert payload["temperature"] == 0.5
assert payload["max_tokens"] == 200
assert payload["messages"] == messages