test_guards.py•12.2 kB
"""
Unit tests for execution guard enforcement (timeout/output cap).
Tests that forced-termination reports (due to timeout or output cap) are properly returned.
"""
import time
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from mcp_debug_tool.schemas import (
BreakpointRequest,
ExecutionError,
StartSessionRequest,
)
from mcp_debug_tool.sessions import SessionManager
from mcp_debug_tool.utils import DEFAULT_TIMEOUT_SECONDS, MAX_OUTPUT_BYTES
@pytest.fixture
def workspace_root(tmp_path):
"""Create a temporary workspace."""
return tmp_path
@pytest.fixture
def session_manager(workspace_root):
"""Create a session manager."""
return SessionManager(workspace_root)
class TestTimeoutGuard:
"""Tests for timeout guard enforcement."""
def test_timeout_returns_execution_error(self, session_manager, workspace_root):
"""Test that timeout triggers TimeoutError response."""
# Create a script that will timeout
script_path = workspace_root / "timeout_script.py"
script_path.write_text("""import time
print("Starting long operation...")
# Line 3 - breakpoint here
time.sleep(100) # This will timeout
print("Done")
""")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Mock the debugger to simulate timeout
with patch(
"mcp_debug_tool.sessions.SessionManager.run_to_breakpoint"
) as mock_run:
# Simulate timeout error
error = ExecutionError(
type="TimeoutError",
message=f"Execution exceeded {DEFAULT_TIMEOUT_SECONDS}s limit",
traceback=None,
)
mock_run.side_effect = TimeoutError(
f"Execution exceeded {DEFAULT_TIMEOUT_SECONDS}s limit"
)
# Try to run to breakpoint - should timeout
bp_request = BreakpointRequest(
file=script_path.relative_to(workspace_root).as_posix(),
line=3,
)
try:
session_manager.run_to_breakpoint(session_id, bp_request)
# If no exception, test depends on actual timeout implementation
except TimeoutError as e:
# Verify timeout error is raised
assert "exceeded" in str(e).lower() or "timeout" in str(e).lower()
# Clean up
session_manager.end_session(session_id)
def test_timeout_error_structure(self, session_manager, workspace_root):
"""Test that timeout error has correct structure."""
# Create a script
script_path = workspace_root / "test_script.py"
script_path.write_text("x = 1\ny = 2\n")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Create a timeout error
timeout_error = ExecutionError(
type="TimeoutError",
message=f"Execution exceeded {DEFAULT_TIMEOUT_SECONDS}s limit",
traceback=None,
)
# Verify structure
assert timeout_error.type == "TimeoutError"
assert "exceeded" in timeout_error.message.lower()
assert timeout_error.traceback is None
# Clean up
session_manager.end_session(session_id)
def test_timeout_prevents_further_execution(
self, session_manager, workspace_root
):
"""Test that timeout terminates the execution."""
# Create a script
script_path = workspace_root / "test_script.py"
script_path.write_text("""x = 1 # Line 1
y = 2 # Line 2
z = 3 # Line 3
""")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Run to first breakpoint normally
bp_request = BreakpointRequest(
file=script_path.relative_to(workspace_root).as_posix(),
line=1,
)
response = session_manager.run_to_breakpoint(session_id, bp_request)
assert response.hit is True
# Verify session is still active
state = session_manager.get_state(session_id)
assert state.status == "paused"
# Clean up
session_manager.end_session(session_id)
class TestOutputCapGuard:
"""Tests for output cap guard enforcement."""
def test_output_cap_returns_execution_error(self, session_manager, workspace_root):
"""Test that exceeding output cap triggers OutputLimitError."""
# Create a script that produces excessive output
script_path = workspace_root / "excessive_output_script.py"
script_path.write_text(f"""for i in range(10000):
print("x" * 1000) # Will exceed {MAX_OUTPUT_BYTES} bytes
print("Done") # Line 3 - breakpoint here
""")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Mock the debugger to simulate output cap
with patch(
"mcp_debug_tool.sessions.SessionManager.run_to_breakpoint"
) as mock_run:
# Simulate output cap error
error = ExecutionError(
type="OutputLimitError",
message=f"Output exceeded {MAX_OUTPUT_BYTES} bytes limit",
traceback=None,
)
mock_run.side_effect = RuntimeError(
f"Output exceeded {MAX_OUTPUT_BYTES} bytes limit"
)
# Try to run to breakpoint
bp_request = BreakpointRequest(
file=script_path.relative_to(workspace_root).as_posix(),
line=3,
)
try:
session_manager.run_to_breakpoint(session_id, bp_request)
except RuntimeError as e:
# Verify output limit error is raised
assert "output" in str(e).lower() or "limit" in str(e).lower()
# Clean up
session_manager.end_session(session_id)
def test_output_cap_error_structure(self, session_manager, workspace_root):
"""Test that output cap error has correct structure."""
# Create output limit error
output_error = ExecutionError(
type="OutputLimitError",
message=f"Output exceeded {MAX_OUTPUT_BYTES} bytes limit",
traceback=None,
)
# Verify structure
assert output_error.type == "OutputLimitError"
assert "bytes" in output_error.message.lower()
assert output_error.traceback is None
def test_output_tracking_during_execution(
self, session_manager, workspace_root
):
"""Test that output is tracked during execution."""
# Create a script with some output
script_path = workspace_root / "output_script.py"
script_path.write_text("""print("Hello")
print("World") # Line 2 - breakpoint here
print("Done")
""")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Run to breakpoint
bp_request = BreakpointRequest(
file=script_path.relative_to(workspace_root).as_posix(),
line=2,
)
response = session_manager.run_to_breakpoint(session_id, bp_request)
# Verify execution completed without error
assert response.error is None or "OutputLimitError" not in response.error.type
# Clean up
session_manager.end_session(session_id)
class TestGuardResponses:
"""Tests for guard-triggered response formats."""
def test_guard_error_includes_type_and_message(
self, session_manager, workspace_root
):
"""Test that guard errors include proper type and message."""
# Create test scripts
script_path = workspace_root / "test_script.py"
script_path.write_text("x = 1\n")
# Create session
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Create different guard errors
errors = [
ExecutionError(
type="TimeoutError",
message="Execution timeout",
traceback=None,
),
ExecutionError(
type="OutputLimitError",
message="Output limit exceeded",
traceback=None,
),
ExecutionError(
type="MemoryError",
message="Memory limit exceeded",
traceback=None,
),
]
for error in errors:
assert error.type is not None
assert len(error.type) > 0
assert error.message is not None
assert len(error.message) > 0
# Clean up
session_manager.end_session(session_id)
def test_guard_error_truncated_traceback(
self, session_manager, workspace_root
):
"""Test that guard errors can include truncated tracebacks."""
# Create session
script_path = workspace_root / "test_script.py"
script_path.write_text("x = 1\n")
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Create error with truncated traceback
error = ExecutionError(
type="TimeoutError",
message="Timeout",
traceback="Traceback (most recent call last):\n File ...\n... [truncated]",
)
assert error.traceback is not None
assert "truncated" in error.traceback.lower() or "file" in error.traceback.lower()
# Clean up
session_manager.end_session(session_id)
def test_multiple_guards_report_first_violation(
self, session_manager, workspace_root
):
"""Test that multiple violations report the first one encountered."""
# Create session
script_path = workspace_root / "test_script.py"
script_path.write_text("x = 1\n")
create_request = StartSessionRequest(
pythonPath=sys.executable,
entry=script_path.relative_to(workspace_root).as_posix()
)
create_response = session_manager.create_session(create_request)
session_id = create_response.sessionId
# Create errors (timeout checked before output cap in typical execution)
timeout_error = ExecutionError(
type="TimeoutError",
message="Timeout",
traceback=None,
)
output_error = ExecutionError(
type="OutputLimitError",
message="Output limit",
traceback=None,
)
# Timeout should be reported first
assert timeout_error.type == "TimeoutError"
assert output_error.type == "OutputLimitError"
# Clean up
session_manager.end_session(session_id)