"""Tests for correlation ID tracking in workflow execution"""
import pytest
from src.utils.context import (
clear_correlation_id,
generate_correlation_id,
get_correlation_id,
set_correlation_id,
)
def test_generate_correlation_id():
"""Test correlation ID generation"""
cid1 = generate_correlation_id()
cid2 = generate_correlation_id()
# Should be valid UUIDs
assert len(cid1) == 36 # UUID4 format: 8-4-4-4-12
assert len(cid2) == 36
# Should be unique
assert cid1 != cid2
def test_set_and_get_correlation_id():
"""Test setting and getting correlation ID"""
test_id = "test-correlation-id-123"
set_correlation_id(test_id)
assert get_correlation_id() == test_id
# Clean up
clear_correlation_id()
assert get_correlation_id() is None
def test_correlation_id_isolation():
"""Test that correlation IDs are isolated between contexts"""
import asyncio
results = []
async def task_with_correlation_id(task_id: str):
"""Task that sets its own correlation ID"""
correlation_id = generate_correlation_id()
set_correlation_id(correlation_id)
# Simulate some work
await asyncio.sleep(0.01)
# Should still have the same correlation ID
retrieved_id = get_correlation_id()
results.append((task_id, correlation_id, retrieved_id))
async def run_parallel_tasks():
"""Run multiple tasks in parallel"""
await asyncio.gather(
task_with_correlation_id("task1"),
task_with_correlation_id("task2"),
task_with_correlation_id("task3"),
)
# Run the test
asyncio.run(run_parallel_tasks())
# Verify each task maintained its own correlation ID
assert len(results) == 3
for task_id, set_id, retrieved_id in results:
assert set_id == retrieved_id, f"{task_id} lost its correlation ID"
# Verify all IDs are different
all_ids = [retrieved_id for _, _, retrieved_id in results]
assert len(set(all_ids)) == 3, "Correlation IDs leaked between tasks"
def test_log_call_with_correlation_id(caplog):
"""Test that log_call decorator includes correlation ID"""
from src.utils.logging import log_call
@log_call(action_name="test_function", level_name="test")
def test_sync_function():
return "success"
@log_call(action_name="test_async_function", level_name="test")
async def test_async_function():
return "success"
# Test without correlation ID
clear_correlation_id()
caplog.clear()
test_sync_function()
# Should not have correlation ID prefix
assert "[test] test_function - ENTER" in caplog.text
assert any("[" not in line or "[test]" in line for line in caplog.text.split("\n"))
# Test with correlation ID
caplog.clear()
test_correlation_id = "test-correlation-123"
set_correlation_id(test_correlation_id)
test_sync_function()
# Should have correlation ID prefix
assert f"[{test_correlation_id}] [test] test_function - ENTER" in caplog.text
assert f"[{test_correlation_id}] [test] test_function - SUCCESS" in caplog.text
# Clean up
clear_correlation_id()
@pytest.mark.asyncio
async def test_log_call_async_with_correlation_id(caplog):
"""Test that log_call decorator includes correlation ID for async functions"""
from src.utils.logging import log_call
@log_call(action_name="test_async_function", level_name="test")
async def test_async_function():
return "success"
# Test with correlation ID
caplog.clear()
test_correlation_id = "test-async-correlation-123"
set_correlation_id(test_correlation_id)
await test_async_function()
# Should have correlation ID prefix
assert f"[{test_correlation_id}] [test] test_async_function - ENTER" in caplog.text
assert f"[{test_correlation_id}] [test] test_async_function - SUCCESS" in caplog.text
# Clean up
clear_correlation_id()
if __name__ == "__main__":
pytest.main([__file__, "-v"])