Skip to main content
Glama
test_context_propagation.py13.6 kB
# Copyright (C) 2023 the project owner # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. #!/usr/bin/env python3 """ Test suite for context variable propagation in batch operations. Verifies that user context (client_id, username) set by middleware properly propagates to child tasks in batch_process(). Issues without this fix: - Context variables are task-local in asyncio - asyncio.gather() spawns tasks in a fresh context - Child tasks can't see parent context - User tracking fails silently in batch mode - Quota enforcement doesn't apply to batch tasks """ import asyncio import sys from contextvars import ContextVar from pathlib import Path from typing import Optional, Tuple from unittest.mock import AsyncMock, MagicMock, patch import pytest # Add project root to path sys.path.insert(0, str(Path(__file__).parent)) # Import context variables and mock implementations from delia.mcp_server import current_client_id, current_username # Mock _delegate_impl to capture context during execution context_captures = [] async def mock_delegate_impl( task: str, content: str, file: Optional[str] = None, model: Optional[str] = None, language: Optional[str] = None, context: Optional[str] = None, symbols: Optional[str] = None, include_references: bool = False, backend: Optional[str] = None, backend_obj=None, ) -> str: """Mock _delegate_impl that captures context.""" # Capture the context variables at the time this task executes client_id = current_client_id.get() username = current_username.get() context_captures.append({ "task": task, "client_id": client_id, "username": username, "content_len": len(content), }) # Simulate some async work await asyncio.sleep(0.01) return f"Task {task} completed" @pytest.mark.asyncio async def test_context_propagation_single_task(): """Test that context is available in a single batch task.""" print("\n=== Test 1: Context Propagation - Single Task ===") context_captures.clear() # Set context (as middleware would) current_client_id.set("client-123") current_username.set("testuser") # Simulate batch_process's context capture pattern captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_task(task_name: str, client_id: Optional[str], username: Optional[str]) -> str: # Re-set context in child task (CRITICAL FIX) current_client_id.set(client_id) current_username.set(username) # Call delegate return await mock_delegate_impl( task=task_name, content="test content", ) # Run the task result = await run_task("analyze", captured_client_id, captured_username) # Verify context was captured assert len(context_captures) == 1 assert context_captures[0]["client_id"] == "client-123" assert context_captures[0]["username"] == "testuser" print("✓ Single task captured correct context") @pytest.mark.asyncio async def test_context_propagation_batch_tasks(): """Test that context propagates to all tasks in a batch.""" print("\n=== Test 2: Context Propagation - Batch Tasks ===") context_captures.clear() # Set context (as middleware would) current_client_id.set("client-456") current_username.set("batchuser") # Simulate batch_process's context capture pattern captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_task( task_id: int, task_name: str, client_id: Optional[str], username: Optional[str], ) -> str: # Re-set context in child task (CRITICAL FIX) current_client_id.set(client_id) current_username.set(username) # Call delegate return await mock_delegate_impl( task=task_name, content=f"task {task_id} content", ) # Run multiple tasks in parallel (simulating batch) tasks = [ ("analyze", "Analyzing code"), ("review", "Reviewing changes"), ("summarize", "Summarizing"), ] results = await asyncio.gather(*[ run_task(i, task_type, captured_client_id, captured_username) for i, (task_type, _) in enumerate(tasks) ]) # Verify all tasks captured correct context assert len(context_captures) == 3 for i, capture in enumerate(context_captures): assert capture["client_id"] == "client-456", f"Task {i} lost context" assert capture["username"] == "batchuser", f"Task {i} lost username" print(f"✓ Task {i} captured correct context") assert len(results) == 3 print(f"✓ All {len(results)} batch tasks captured correct context") @pytest.mark.asyncio async def test_concurrent_batches_isolated(): """Test that concurrent batches don't mix user contexts.""" print("\n=== Test 3: Concurrent Batches - Context Isolation ===") context_captures.clear() async def run_batch(batch_id: str, user_id: str, username: str, num_tasks: int): # Each batch sets its own context current_client_id.set(user_id) current_username.set(username) # Capture context for this batch captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_task(task_id: int, client_id, user): # Re-set context in child task current_client_id.set(client_id) current_username.set(user) return await mock_delegate_impl( task=f"task_{task_id}", content=f"batch {batch_id} task {task_id}", ) # Run batch tasks in parallel results = await asyncio.gather(*[ run_task(i, captured_client_id, captured_username) for i in range(num_tasks) ]) return results # Run multiple batches concurrently batch_results = await asyncio.gather( run_batch("batch-a", "client-a", "user-a", 2), run_batch("batch-b", "client-b", "user-b", 2), run_batch("batch-c", "client-c", "user-c", 2), ) # Verify each batch's tasks have correct context assert len(context_captures) == 6 # 3 batches × 2 tasks each # Batch A tasks should have client-a batch_a_captures = context_captures[0:2] for capture in batch_a_captures: assert capture["client_id"] == "client-a", f"Batch A task lost context: {capture}" assert capture["username"] == "user-a" print("✓ Batch A isolated correctly") # Batch B tasks should have client-b batch_b_captures = context_captures[2:4] for capture in batch_b_captures: assert capture["client_id"] == "client-b" assert capture["username"] == "user-b" print("✓ Batch B isolated correctly") # Batch C tasks should have client-c batch_c_captures = context_captures[4:6] for capture in batch_c_captures: assert capture["client_id"] == "client-c" assert capture["username"] == "user-c" print("✓ Batch C isolated correctly") print("✓ All concurrent batches maintained context isolation") @pytest.mark.asyncio async def test_context_with_none_values(): """Test that None context values are properly propagated.""" print("\n=== Test 4: Context Propagation - None Values ===") context_captures.clear() # Clear context (simulate unauthenticated user) current_client_id.set(None) current_username.set(None) captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_task(client_id, username): current_client_id.set(client_id) current_username.set(username) return await mock_delegate_impl( task="analyze", content="test", ) # Run task with None context result = await run_task(captured_client_id, captured_username) # Verify None context was preserved assert len(context_captures) == 1 assert context_captures[0]["client_id"] is None assert context_captures[0]["username"] is None print("✓ None context values properly propagated") @pytest.mark.asyncio async def test_nested_batch_operations(): """Test context propagation in nested batch-like operations.""" print("\n=== Test 5: Nested Batch Operations ===") context_captures.clear() # Set context current_client_id.set("client-nested") current_username.set("nested-user") captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_subtask(subtask_id: int, client_id, username): # Re-set context current_client_id.set(client_id) current_username.set(username) return await mock_delegate_impl( task=f"subtask_{subtask_id}", content=f"subtask {subtask_id}", ) async def run_main_task(task_id: int, client_id, username): # Re-set context for main task current_client_id.set(client_id) current_username.set(username) # This task spawns sub-tasks subtasks = await asyncio.gather(*[ run_subtask(i, client_id, username) for i in range(2) ]) # Main task also records itself await mock_delegate_impl( task=f"main_task_{task_id}", content=f"main task {task_id}", ) return subtasks # Run nested batch operations results = await asyncio.gather(*[ run_main_task(i, captured_client_id, captured_username) for i in range(2) ]) # Verify all nested tasks captured context correctly # 2 main tasks × (2 subtasks + 1 main recording) = 6 captures assert len(context_captures) == 6 for capture in context_captures: assert capture["client_id"] == "client-nested" assert capture["username"] == "nested-user" print("✓ Nested batch operations maintained context throughout") @pytest.mark.asyncio async def test_context_error_handling(): """Test context is maintained even when tasks raise errors.""" print("\n=== Test 6: Context in Error Scenarios ===") context_captures.clear() current_client_id.set("client-error") current_username.set("error-user") captured_client_id = current_client_id.get() captured_username = current_username.get() async def run_task_with_error( task_id: int, should_fail: bool, client_id, username, ): # Re-set context even if task will fail current_client_id.set(client_id) current_username.set(username) # Capture was done, now fail if requested if should_fail: raise ValueError(f"Intentional error in task {task_id}") return await mock_delegate_impl( task=f"task_{task_id}", content=f"task {task_id}", ) # Run mix of successful and failing tasks results = await asyncio.gather( run_task_with_error(0, False, captured_client_id, captured_username), run_task_with_error(1, True, captured_client_id, captured_username), run_task_with_error(2, False, captured_client_id, captured_username), return_exceptions=True, ) # Verify context was captured even in failing task # Task 0 succeeds, Task 1 fails (but still captured context), Task 2 succeeds # Actually, the failing task doesn't call mock_delegate_impl, so only 2 captures # But the important point: if an error happens before context is used, # it shouldn't lose the context value assert len(context_captures) == 2 # Only successful tasks called delegate for capture in context_captures: assert capture["client_id"] == "client-error" assert capture["username"] == "error-user" # Verify error was raised assert isinstance(results[1], ValueError) print("✓ Context properly maintained even with task failures") async def main(): """Run all tests.""" print("=" * 70) print("CONTEXT PROPAGATION IN BATCH OPERATIONS TEST SUITE") print("=" * 70) tests = [ test_context_propagation_single_task, test_context_propagation_batch_tasks, test_concurrent_batches_isolated, test_context_with_none_values, test_nested_batch_operations, test_context_error_handling, ] passed = 0 failed = 0 for test in tests: try: await test() passed += 1 except Exception as e: failed += 1 print(f"✗ Test failed: {e}") import traceback traceback.print_exc() print("\n" + "=" * 70) print(f"RESULTS: {passed} passed, {failed} failed") print("=" * 70) return 0 if failed == 0 else 1 if __name__ == "__main__": exit_code = asyncio.run(main()) sys.exit(exit_code)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/zbrdc/delia'

If you have feedback or need assistance with the MCP directory API, please join our Discord server