"""
Integration tests for context-aware MCP safety.
These tests verify that the system correctly enforces context safety
and fails closed under various edge cases and error conditions.
"""
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from src.core.tool_context import PageType, EntityType, PageContext
from src.core.context_validation import ValidatedPageContext, ValidationError
from src.web_chat.context import set_current_context, clear_current_context
from src.core.tool_registry import ContextAwareToolRegistry
from src.core.domain_policies import ALL_POLICIES
@pytest.fixture
async def test_client():
"""Create test client."""
from src.web_chat.main import app
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
@pytest.fixture
def registry():
"""Create test tool registry."""
app = FastAPI()
registry = ContextAwareToolRegistry(app)
return registry
@pytest.mark.asyncio
async def test_tool_without_policy_fails():
"""Tools without policies should never be exposed."""
registry = ContextAwareToolRegistry(FastAPI())
# Attempt to register tool without policy
with pytest.raises(ValueError) as exc:
@registry.register_tool(
name="bad_tool",
policy=None,
description="This should fail"
)
async def bad_tool():
pass
assert "must have a ToolContextPolicy" in str(exc.value)
@pytest.mark.asyncio
async def test_late_registration_fails():
"""Tools cannot be registered after startup."""
registry = ContextAwareToolRegistry(FastAPI())
registry.complete_startup()
# Attempt late registration
with pytest.raises(RuntimeError) as exc:
@registry.register_tool(
name="late_tool",
policy=ALL_POLICIES["list_vendors"],
description="This should fail"
)
async def late_tool():
pass
assert "registered after startup" in str(exc.value)
@pytest.mark.asyncio
async def test_context_validation():
"""PageContext validation should catch invalid states."""
# Table page with entity
with pytest.raises(ValidationError):
ValidatedPageContext(
page_type=PageType.INVOICE_TABLE,
entity_type=EntityType.INVOICE,
entity_id="123"
)
# Detail page without entity
with pytest.raises(ValidationError):
ValidatedPageContext(
page_type=PageType.INVOICE_DETAIL,
entity_type=None,
entity_id=None
)
# Mismatched entity type
with pytest.raises(ValidationError):
ValidatedPageContext(
page_type=PageType.INVOICE_DETAIL,
entity_type=EntityType.VENDOR,
entity_id="123"
)
@pytest.mark.asyncio
async def test_entity_id_binding():
"""Entity IDs from context should override tool parameters."""
registry = ContextAwareToolRegistry(FastAPI())
# Register test tool
@registry.register_tool(
name="test_tool",
policy=ALL_POLICIES["get_invoice_details"],
entity_param="invoice_id"
)
async def test_tool(invoice_id: str):
return {"invoice_id": invoice_id}
# Set context with entity
context = PageContext(
page_type=PageType.INVOICE_DETAIL,
entity_type=EntityType.INVOICE,
entity_id="context_123"
)
set_current_context(context)
try:
# Call with different ID - should be overridden
result = await test_tool(invoice_id="param_456")
assert result["invoice_id"] == "context_123"
finally:
clear_current_context()
@pytest.mark.asyncio
async def test_tool_context_mismatch():
"""Tools should fail if called in wrong context."""
registry = ContextAwareToolRegistry(FastAPI())
# Register detail page tool
@registry.register_tool(
name="detail_tool",
policy=ALL_POLICIES["get_invoice_details"]
)
async def detail_tool():
return {"status": "ok"}
# Set table page context
context = PageContext(
page_type=PageType.INVOICE_TABLE,
entity_type=None,
entity_id=None
)
set_current_context(context)
try:
# Attempt to call detail tool on table page
with pytest.raises(RuntimeError) as exc:
await detail_tool()
assert "cannot be used in current context" in str(exc.value)
finally:
clear_current_context()
@pytest.mark.asyncio
async def test_missing_context():
"""Tools should fail if called without context."""
registry = ContextAwareToolRegistry(FastAPI())
@registry.register_tool(
name="test_tool",
policy=ALL_POLICIES["list_vendors"]
)
async def test_tool():
return {"status": "ok"}
# Clear any existing context
clear_current_context()
# Attempt to call without context
with pytest.raises(RuntimeError) as exc:
await test_tool()
assert "without PageContext" in str(exc.value)
@pytest.mark.asyncio
async def test_stale_entity():
"""Entity should be validated on every request."""
registry = ContextAwareToolRegistry(FastAPI())
@registry.register_tool(
name="test_tool",
policy=ALL_POLICIES["get_invoice_details"],
entity_param="invoice_id"
)
async def test_tool(invoice_id: str):
return {"invoice_id": invoice_id}
# Set initial context
context1 = PageContext(
page_type=PageType.INVOICE_DETAIL,
entity_type=EntityType.INVOICE,
entity_id="invoice_123"
)
set_current_context(context1)
try:
# First call should work
result1 = await test_tool(invoice_id="ignored")
assert result1["invoice_id"] == "invoice_123"
# Change to table view (no entity)
context2 = PageContext(
page_type=PageType.INVOICE_TABLE,
entity_type=None,
entity_id=None
)
set_current_context(context2)
# Second call should fail
with pytest.raises(RuntimeError):
await test_tool(invoice_id="invoice_123")
finally:
clear_current_context()