Skip to main content
Glama
test_timeouts.py18.2 kB
""" Tests for the timeout utilities module. Tests cover: - @with_timeout decorator - timeout triggering and successful execution - run_with_timeout() - functional timeout handling - spark_operation_timeout context manager - warning logging - Concurrent operations - thread pool isolation and parallel execution """ import logging import time import concurrent.futures from unittest.mock import patch import pytest from src.service.exceptions import SparkTimeoutError from src.service.timeouts import ( with_timeout, run_with_timeout, spark_operation_timeout, DEFAULT_SPARK_QUERY_TIMEOUT, DEFAULT_SPARK_COLLECT_TIMEOUT, _timeout_executor, ) # ============================================================================= # Test @with_timeout Decorator # ============================================================================= class TestWithTimeoutDecorator: """Tests for the @with_timeout decorator.""" def test_successful_execution_within_timeout(self): """Test that function executes successfully when within timeout.""" @with_timeout(timeout_seconds=5.0, operation_name="test_op") def fast_function(): return "success" result = fast_function() assert result == "success" def test_timeout_triggers_spark_timeout_error(self): """Test that timeout raises SparkTimeoutError.""" @with_timeout(timeout_seconds=0.1, operation_name="slow_op") def slow_function(): time.sleep(10) return "never reached" with pytest.raises(SparkTimeoutError) as exc_info: slow_function() assert exc_info.value.operation == "slow_op" assert exc_info.value.timeout == 0.1 assert "slow_op" in str(exc_info.value) assert "0.1" in str(exc_info.value) def test_decorator_preserves_function_arguments(self): """Test that decorator correctly passes arguments.""" @with_timeout(timeout_seconds=5.0, operation_name="args_test") def func_with_args(a, b, c=None): return f"{a}-{b}-{c}" result = func_with_args("x", "y", c="z") assert result == "x-y-z" def test_decorator_preserves_return_value(self): """Test that decorator preserves the return value.""" @with_timeout(timeout_seconds=5.0) def return_dict(): return {"key": "value", "count": 42} result = return_dict() assert result == {"key": "value", "count": 42} def test_decorator_uses_default_timeout_when_none(self): """Test that None timeout uses DEFAULT_SPARK_QUERY_TIMEOUT.""" @with_timeout(timeout_seconds=None, operation_name="default_test") def quick_function(): return True # This should work because default is 300 seconds result = quick_function() assert result is True def test_decorator_logs_error_on_timeout(self, caplog): """Test that timeout is logged as error.""" @with_timeout(timeout_seconds=0.05, operation_name="logging_test") def slow_func(): time.sleep(10) with caplog.at_level(logging.ERROR): with pytest.raises(SparkTimeoutError): slow_func() assert "logging_test" in caplog.text assert "timed out" in caplog.text def test_decorated_function_preserves_metadata(self): """Test that @wraps preserves function metadata.""" @with_timeout(timeout_seconds=5.0) def documented_func(): """This is the docstring.""" return True assert documented_func.__name__ == "documented_func" assert documented_func.__doc__ == "This is the docstring." # ============================================================================= # Test run_with_timeout Function # ============================================================================= class TestRunWithTimeout: """Tests for the run_with_timeout function.""" def test_successful_execution(self): """Test successful function execution.""" def simple_func(): return 42 result = run_with_timeout(simple_func, timeout_seconds=5.0) assert result == 42 def test_timeout_raises_error(self): """Test that timeout raises SparkTimeoutError.""" def slow_func(): time.sleep(10) with pytest.raises(SparkTimeoutError) as exc_info: run_with_timeout( slow_func, timeout_seconds=0.1, operation_name="timeout_test" ) assert exc_info.value.operation == "timeout_test" def test_passes_positional_args(self): """Test that positional args are passed correctly.""" def add(a, b): return a + b result = run_with_timeout(add, args=(3, 4), timeout_seconds=5.0) assert result == 7 def test_passes_keyword_args(self): """Test that keyword args are passed correctly.""" def greet(name, greeting="Hello"): return f"{greeting}, {name}!" result = run_with_timeout( greet, args=("World",), kwargs={"greeting": "Hi"}, timeout_seconds=5.0 ) assert result == "Hi, World!" def test_default_kwargs_is_empty_dict(self): """Test that kwargs defaults to empty dict.""" def func_with_kwargs(**kwargs): return len(kwargs) result = run_with_timeout(func_with_kwargs, timeout_seconds=5.0) assert result == 0 def test_uses_default_timeout_when_none(self): """Test that None timeout uses default.""" def quick_func(): return "done" result = run_with_timeout(quick_func, timeout_seconds=None) assert result == "done" def test_lambda_functions(self): """Test that lambdas work correctly.""" result = run_with_timeout(lambda: sum(range(100)), timeout_seconds=5.0) assert result == 4950 def test_exception_propagation(self): """Test that exceptions from the function are propagated.""" def raising_func(): raise ValueError("Test error") with pytest.raises(ValueError, match="Test error"): run_with_timeout(raising_func, timeout_seconds=5.0) # ============================================================================= # Test spark_operation_timeout Context Manager # ============================================================================= class TestSparkOperationTimeoutContextManager: """Tests for the spark_operation_timeout context manager.""" def test_fast_operation_no_warning(self, caplog): """Test that fast operations don't log warnings.""" with caplog.at_level(logging.WARNING): with spark_operation_timeout(timeout_seconds=5.0, operation_name="fast_op"): pass # Instant operation assert "fast_op" not in caplog.text def test_slow_operation_logs_warning(self, caplog): """Test that slow operations log a warning.""" with caplog.at_level(logging.WARNING): with spark_operation_timeout( timeout_seconds=0.05, operation_name="slow_op" ): time.sleep(0.1) assert "slow_op" in caplog.text assert "threshold" in caplog.text def test_approaching_threshold_logs_info(self, caplog): """Test that operations approaching threshold log info.""" with caplog.at_level(logging.INFO): # Sleep for 85% of timeout (> 80% threshold) with spark_operation_timeout(timeout_seconds=0.1, operation_name="near_op"): time.sleep(0.085) assert "near_op" in caplog.text assert "approaching" in caplog.text def test_context_manager_yields_control(self): """Test that code inside context manager executes.""" executed = False with spark_operation_timeout(timeout_seconds=5.0): executed = True assert executed is True def test_default_timeout_used_when_none(self): """Test that None timeout uses default.""" # Should not raise - just uses default timeout with spark_operation_timeout(timeout_seconds=None, operation_name="default_op"): pass def test_exception_in_context_still_logs(self, caplog): """Test that exceptions don't prevent logging.""" with caplog.at_level(logging.WARNING): with pytest.raises(ValueError): with spark_operation_timeout( timeout_seconds=0.01, operation_name="error_op" ): time.sleep(0.05) raise ValueError("Test error") # The warning should still be logged even though exception was raised # Note: the finally block logs based on elapsed time # ============================================================================= # Concurrency Tests # ============================================================================= class TestConcurrentTimeoutOperations: """Tests for concurrent operations using the timeout utilities.""" def test_multiple_concurrent_operations_succeed(self, concurrent_executor): """Test that multiple concurrent operations can execute.""" def quick_task(n): time.sleep(0.01) return n * 2 args_list = [(i,) for i in range(10)] results, exceptions = concurrent_executor(quick_task, args_list, max_workers=10) assert len(exceptions) == 0 assert sorted(results) == [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] def test_concurrent_timeout_operations_isolated(self, concurrent_executor): """Test that concurrent operations with timeouts are isolated.""" def task_with_timeout(n): return run_with_timeout( lambda: n * 3, timeout_seconds=5.0, operation_name=f"task_{n}" ) args_list = [(i,) for i in range(5)] results, exceptions = concurrent_executor( task_with_timeout, args_list, max_workers=5 ) assert len(exceptions) == 0 assert sorted(results) == [0, 3, 6, 9, 12] def test_mixed_fast_and_slow_concurrent_operations(self, concurrent_executor): """Test mix of fast and timing-out operations.""" def mixed_task(sleep_time): return run_with_timeout( lambda: time.sleep(sleep_time) or "done", timeout_seconds=0.2, operation_name="mixed_task", ) # Mix of fast (0.01s) and slow (1s) operations args_list = [(0.01,), (0.01,), (1.0,), (0.01,), (1.0,)] results, exceptions = concurrent_executor(mixed_task, args_list, max_workers=5) # 3 should succeed, 2 should timeout assert len(results) == 3 assert len(exceptions) == 2 assert all(isinstance(e, SparkTimeoutError) for e in exceptions) def test_thread_pool_handles_high_load(self, concurrent_executor): """Test that thread pool handles more tasks than workers.""" counter = {"value": 0} lock = __import__("threading").Lock() def counting_task(n): result = run_with_timeout( lambda: n, timeout_seconds=5.0, operation_name="count_task" ) with lock: counter["value"] += 1 return result # Submit more tasks than thread pool workers (10) args_list = [(i,) for i in range(25)] results, exceptions = concurrent_executor( counting_task, args_list, max_workers=5 ) assert len(exceptions) == 0 assert counter["value"] == 25 def test_decorated_function_concurrent_execution(self, concurrent_executor): """Test @with_timeout decorator with concurrent execution.""" @with_timeout(timeout_seconds=5.0, operation_name="decorated_concurrent") def decorated_task(n): return n**2 args_list = [(i,) for i in range(8)] results, exceptions = concurrent_executor( decorated_task, args_list, max_workers=8 ) assert len(exceptions) == 0 assert sorted(results) == [0, 1, 4, 9, 16, 25, 36, 49] def test_thread_pool_executor_exists(self): """Test that the module-level executor exists and is configured.""" assert _timeout_executor is not None assert isinstance(_timeout_executor, concurrent.futures.ThreadPoolExecutor) # ============================================================================= # Default Values Tests # ============================================================================= class TestDefaultValues: """Tests for default timeout values.""" def test_default_query_timeout_from_env(self): """Test DEFAULT_SPARK_QUERY_TIMEOUT reads from environment.""" # The default is 300 if not set assert DEFAULT_SPARK_QUERY_TIMEOUT == 300 or isinstance( DEFAULT_SPARK_QUERY_TIMEOUT, int ) def test_default_collect_timeout_from_env(self): """Test DEFAULT_SPARK_COLLECT_TIMEOUT reads from environment.""" # The default is 120 if not set assert DEFAULT_SPARK_COLLECT_TIMEOUT == 120 or isinstance( DEFAULT_SPARK_COLLECT_TIMEOUT, int ) def test_custom_env_timeout_values(self): """Test that environment variables can override defaults.""" # This is a design verification - the module reads from os.getenv with patch.dict( "os.environ", {"SPARK_QUERY_TIMEOUT": "600", "SPARK_COLLECT_TIMEOUT": "180"} ): # Re-import to get new values (note: this won't work because values # are set at module load time, but we're verifying the mechanism) pass # ============================================================================= # Error Handling Tests # ============================================================================= class TestErrorHandling: """Tests for error handling in timeout utilities.""" def test_spark_timeout_error_attributes(self): """Test SparkTimeoutError has correct attributes.""" @with_timeout(timeout_seconds=0.05, operation_name="attr_test") def slow_func(): time.sleep(10) with pytest.raises(SparkTimeoutError) as exc_info: slow_func() error = exc_info.value assert hasattr(error, "operation") assert hasattr(error, "timeout") assert error.operation == "attr_test" assert error.timeout == 0.05 def test_spark_timeout_error_message(self): """Test SparkTimeoutError has descriptive message.""" error = SparkTimeoutError(operation="test_op", timeout=30.0) message = str(error) assert "test_op" in message assert "30" in message assert "timed out" in message.lower() or "timeout" in message.lower() def test_future_cancellation_attempted_on_timeout(self): """Test that future.cancel() is called on timeout.""" # We can't easily verify cancellation worked (it may not stop the thread), # but we can verify the timeout mechanism works correctly call_count = {"value": 0} def slow_incrementer(): time.sleep(0.5) call_count["value"] += 1 return call_count["value"] with pytest.raises(SparkTimeoutError): run_with_timeout( slow_incrementer, timeout_seconds=0.05, operation_name="cancel_test" ) # Even after timeout, the function might still complete in background # This test verifies the timeout mechanism, not actual cancellation # ============================================================================= # Edge Cases # ============================================================================= class TestEdgeCases: """Tests for edge cases and boundary conditions.""" def test_zero_timeout_immediate_failure(self): """Test that zero timeout causes immediate failure for non-instant operations.""" def instant_func(): return "instant" # Very small timeout - may or may not succeed depending on system load # Using a very small positive value instead of 0 with pytest.raises(SparkTimeoutError): run_with_timeout( lambda: time.sleep(1), timeout_seconds=0.001, operation_name="zero_timeout", ) def test_negative_timeout_behavior(self): """Test behavior with negative timeout (should raise ValueError).""" # Negative timeout should be rejected with ValueError with pytest.raises(ValueError, match="timeout_seconds must be positive"): run_with_timeout( lambda: time.sleep(1), timeout_seconds=-1, operation_name="negative_timeout", ) def test_very_long_operation_name(self): """Test that very long operation names work correctly.""" long_name = "a" * 1000 @with_timeout(timeout_seconds=5.0, operation_name=long_name) def quick_func(): return True result = quick_func() assert result is True def test_special_characters_in_operation_name(self): """Test operation names with special characters.""" @with_timeout(timeout_seconds=5.0, operation_name="test:op/with<special>chars") def quick_func(): return True result = quick_func() assert result is True def test_none_return_value(self): """Test that None return values are handled correctly.""" @with_timeout(timeout_seconds=5.0) def returns_none(): return None result = returns_none() assert result is None def test_large_return_value(self): """Test handling of large return values.""" @with_timeout(timeout_seconds=5.0) def returns_large_list(): return list(range(100000)) result = returns_large_list() assert len(result) == 100000

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server