test_request_batcher.py.disabled•7.3 kB
"""Tests for request batching functionality."""
import asyncio
from unittest.mock import AsyncMock
import pytest
from biomcp.request_batcher import RequestBatcher
@pytest.fixture
def mock_batch_func():
"""Create a mock batch processing function."""
async def batch_func(params_list):
# Simulate processing - return results in same order
return [{"result": f"processed_{p['id']}"} for p in params_list]
return AsyncMock(side_effect=batch_func)
@pytest.mark.asyncio
async def test_single_request(mock_batch_func):
"""Test that a single request is processed correctly."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=10,
batch_timeout=0.1
)
result = await batcher.request({"id": 1})
assert result == {"result": "processed_1"}
mock_batch_func.assert_called_once_with([{"id": 1}])
@pytest.mark.asyncio
async def test_batch_size_trigger(mock_batch_func):
"""Test that batch is processed when size threshold is reached."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=3,
batch_timeout=10 # Long timeout to ensure size triggers first
)
# Send requests concurrently
tasks = [
batcher.request({"id": 1}),
batcher.request({"id": 2}),
batcher.request({"id": 3})
]
results = await asyncio.gather(*tasks)
assert results == [
{"result": "processed_1"},
{"result": "processed_2"},
{"result": "processed_3"}
]
mock_batch_func.assert_called_once_with([
{"id": 1}, {"id": 2}, {"id": 3}
])
@pytest.mark.asyncio
async def test_batch_timeout_trigger(mock_batch_func):
"""Test that batch is processed when timeout expires."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=10, # Large size to ensure timeout triggers first
batch_timeout=0.05
)
# Send two requests (less than batch size)
task1 = asyncio.create_task(batcher.request({"id": 1}))
task2 = asyncio.create_task(batcher.request({"id": 2}))
results = await asyncio.gather(task1, task2)
assert results == [
{"result": "processed_1"},
{"result": "processed_2"}
]
mock_batch_func.assert_called_once()
@pytest.mark.asyncio
async def test_multiple_batches(mock_batch_func):
"""Test that multiple batches are processed correctly."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=2,
batch_timeout=0.05
)
# Send 5 requests - should create 3 batches
tasks = []
for i in range(5):
tasks.append(batcher.request({"id": i}))
if i < 4: # Small delay between requests
await asyncio.sleep(0.01)
results = await asyncio.gather(*tasks)
assert len(results) == 5
assert all(results[i] == {"result": f"processed_{i}"} for i in range(5))
assert mock_batch_func.call_count >= 2 # At least 2 batches
@pytest.mark.asyncio
async def test_error_propagation(mock_batch_func):
"""Test that errors are properly propagated to individual requests."""
# Override mock to raise an error
mock_batch_func.side_effect = Exception("Batch processing failed")
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=2,
batch_timeout=0.05
)
tasks = [
batcher.request({"id": 1}),
batcher.request({"id": 2})
]
# Both requests should receive the error
with pytest.raises(Exception, match="Batch processing failed"):
await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_partial_error_handling():
"""Test handling when batch function returns fewer results."""
async def partial_batch_func(params_list):
# Return fewer results than requests
return [{"result": "only_one"}]
batcher = RequestBatcher(
batch_func=partial_batch_func,
batch_size=2,
batch_timeout=0.05
)
task1 = asyncio.create_task(batcher.request({"id": 1}))
task2 = asyncio.create_task(batcher.request({"id": 2}))
# First request gets result, second gets error
result1 = await task1
assert result1 == {"result": "only_one"}
with pytest.raises(Exception, match="No result for request"):
await task2
@pytest.mark.asyncio
async def test_concurrent_batches(mock_batch_func):
"""Test that multiple batches can be processed concurrently."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=2,
batch_timeout=0.05
)
# Create multiple waves of requests
all_tasks = []
for wave in range(3):
wave_tasks = [
batcher.request({"id": wave * 10 + i})
for i in range(2)
]
all_tasks.extend(wave_tasks)
await asyncio.sleep(0.01) # Small delay between waves
results = await asyncio.gather(*all_tasks)
assert len(results) == 6
assert mock_batch_func.call_count >= 3
@pytest.mark.asyncio
async def test_empty_batch_handling(mock_batch_func):
"""Test that empty batches are handled correctly."""
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=5,
batch_timeout=0.01
)
# Start the timer but don't send any requests
timer_task = asyncio.create_task(batcher._batch_timer())
await asyncio.sleep(0.02)
timer_task.cancel() # Clean up the task
# Should not have called the batch function
mock_batch_func.assert_not_called()
@pytest.mark.asyncio
async def test_request_ordering(mock_batch_func):
"""Test that results are returned in the correct order."""
# Custom batch function that reverses the order
async def reverse_batch_func(params_list):
results = [{"result": f"processed_{p['id']}"} for p in params_list]
return results # Keep original order
mock_batch_func.side_effect = reverse_batch_func
batcher = RequestBatcher(
batch_func=mock_batch_func,
batch_size=4,
batch_timeout=0.05
)
tasks = [batcher.request({"id": i}) for i in range(4)]
results = await asyncio.gather(*tasks)
# Results should match the request order
for i in range(4):
assert results[i] == {"result": f"processed_{i}"}
@pytest.mark.asyncio
async def test_batch_func_exception_cleanup():
"""Test that batcher state is cleaned up after exceptions."""
call_count = 0
async def failing_batch_func(params_list):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("First batch fails")
return [{"result": f"processed_{p['id']}"} for p in params_list]
batcher = RequestBatcher(
batch_func=failing_batch_func,
batch_size=2,
batch_timeout=0.05
)
# First batch should fail
with pytest.raises(Exception, match="First batch fails"):
await asyncio.gather(
batcher.request({"id": 1}),
batcher.request({"id": 2})
)
# Second batch should succeed
results = await asyncio.gather(
batcher.request({"id": 3}),
batcher.request({"id": 4})
)
assert results == [
{"result": "processed_3"},
{"result": "processed_4"}
]