"""Test the migration from ContextVar to FastMCP Context."""
from unittest.mock import Mock, patch
import pytest
from fastmcp import Context, FastMCP
# Import the migrated context functions
from auth.context import (
_get_pending_service_requests,
_set_injected_service,
_set_service_error,
clear_all_context,
clear_session,
clear_session_context,
clear_user_email_context,
delete_session_data,
get_injected_service,
get_session_context,
get_session_count,
get_session_data,
get_user_email_context,
list_sessions,
request_google_service,
set_session_context,
set_user_email_context,
store_session_data,
)
@pytest.fixture
def mock_context():
"""Create a mock FastMCP Context for testing."""
context = Mock(spec=Context)
context._state = {}
def get_state(key):
return context._state.get(key)
def set_state(key, value):
context._state[key] = value
context.get_state = Mock(side_effect=get_state)
context.set_state = Mock(side_effect=set_state)
return context
def test_session_context_functions(mock_context):
"""Test session context functions with FastMCP Context."""
with patch("auth.context.get_context", return_value=mock_context):
# Test setting session context
set_session_context("test-session-123")
assert mock_context._state["session_id"] == "test-session-123"
# Test getting session context
session_id = get_session_context()
assert session_id == "test-session-123"
# Test clearing session context
clear_session_context()
assert mock_context._state["session_id"] is None
def test_user_email_context_functions(mock_context):
"""Test user email context functions with FastMCP Context."""
with patch("auth.context.get_context", return_value=mock_context):
# Test setting user email context
set_user_email_context("test@example.com")
assert mock_context._state["user_email"] == "test@example.com"
# Test getting user email context
email = get_user_email_context()
assert email == "test@example.com"
# Test clearing user email context
clear_user_email_context()
assert mock_context._state["user_email"] is None
def test_service_injection_functions(mock_context):
"""Test service injection functions with FastMCP Context."""
with patch("auth.context.get_context", return_value=mock_context):
# Test requesting a service
service_key = request_google_service(
service_type="drive",
scopes=["drive.readonly"],
version="v3",
cache_enabled=True,
)
assert service_key == "drive"
assert "service_requests" in mock_context._state
assert "drive" in mock_context._state["service_requests"]
# Test getting pending service requests
pending = _get_pending_service_requests()
assert "drive" in pending
assert pending["drive"]["requested"] is True
assert pending["drive"]["fulfilled"] is False
# Test injecting a service
mock_service = Mock()
_set_injected_service("drive", mock_service)
assert mock_context._state["service_requests"]["drive"]["fulfilled"] is True
assert (
mock_context._state["service_requests"]["drive"]["service"] == mock_service
)
# Test getting injected service
service = get_injected_service("drive")
assert service == mock_service
# Test setting service error
_set_service_error("drive", "Test error")
assert mock_context._state["service_requests"]["drive"]["error"] == "Test error"
assert mock_context._state["service_requests"]["drive"]["fulfilled"] is False
def test_clear_all_context(mock_context):
"""Test clearing all context data."""
with patch("auth.context.get_context", return_value=mock_context):
# Set some context data
set_session_context("test-session")
set_user_email_context("test@example.com")
mock_context._state["service_requests"] = {"test": "data"}
# Clear all context
clear_all_context()
# Verify all context is cleared
assert mock_context._state["session_id"] is None
assert mock_context._state["user_email"] is None
assert mock_context._state["service_requests"] == {}
def test_session_data_storage():
"""Test session data storage functions (these don't use FastMCP Context)."""
session_id = "test-session-456"
# Store session data
store_session_data(session_id, "test_key", "test_value")
# Retrieve session data
value = get_session_data(session_id, "test_key")
assert value == "test_value"
# Test default value
default_value = get_session_data(session_id, "nonexistent", "default")
assert default_value == "default"
# Delete session data
deleted = delete_session_data(session_id, "test_key")
assert deleted is True
# Try to delete again
deleted = delete_session_data(session_id, "test_key")
assert deleted is False
# Clear session
store_session_data(session_id, "another_key", "another_value")
cleared = clear_session(session_id)
assert cleared is True
# Verify session is cleared
value = get_session_data(session_id, "another_key")
assert value is None
def test_session_management():
"""Test session count and listing functions."""
# Clear any existing sessions
for session_id in list_sessions():
clear_session(session_id)
initial_count = get_session_count()
assert initial_count == 0
# Add some sessions
store_session_data("session1", "key", "value1")
store_session_data("session2", "key", "value2")
store_session_data("session3", "key", "value3")
# Check session count
count = get_session_count()
assert count == 3
# List sessions
sessions = list_sessions()
assert len(sessions) == 3
assert "session1" in sessions
assert "session2" in sessions
assert "session3" in sessions
# Clean up
for session_id in sessions:
clear_session(session_id)
def test_context_outside_request():
"""Test that functions handle being called outside a FastMCP request context."""
with patch(
"auth.context.get_context",
side_effect=RuntimeError("not in a FastMCP request context"),
):
# These should handle the error gracefully
result = get_session_context()
assert result is None
result = get_user_email_context()
assert result is None
# These should not raise but log warnings
set_session_context("test")
set_user_email_context("test@example.com")
clear_session_context()
clear_user_email_context()
# Service requests should raise RuntimeError with a clear message
with pytest.raises(RuntimeError) as exc_info:
request_google_service("drive")
assert "requires an active FastMCP request context" in str(exc_info.value)
# Getting pending requests should return empty dict
pending = _get_pending_service_requests()
assert pending == {}
@pytest.mark.asyncio
async def test_integration_with_fastmcp():
"""Test that the context works with actual FastMCP tools."""
mcp = FastMCP("test-server")
@mcp.tool
async def test_tool(ctx: Context) -> str:
"""Test tool that uses context."""
# The tool should be able to access context state
session_id = ctx.get_state("session_id")
user_email = ctx.get_state("user_email")
# Set some state
ctx.set_state("test_value", "Hello from tool")
return f"Session: {session_id}, Email: {user_email}"
# Simulate middleware setting context before tool execution
@mcp.tool
async def setup_and_call_tool(ctx: Context) -> str:
"""Tool that sets up context and calls another tool."""
# Simulate what middleware would do
ctx.set_state("session_id", "test-session-789")
ctx.set_state("user_email", "user@example.com")
# Now the context functions should work within this request
with patch("auth.context.get_context", return_value=ctx):
# These should now work
session = get_session_context()
email = get_user_email_context()
return f"Got session: {session}, email: {email}"
# Note: Actually running these tools would require a full FastMCP server setup
# This test just verifies the structure is correct
assert len(mcp.tools) == 2
assert any(tool.name == "test_tool" for tool in mcp.tools)
assert any(tool.name == "setup_and_call_tool" for tool in mcp.tools)
if __name__ == "__main__":
# Run basic tests
print("Testing FastMCP Context migration...")
# Create mock context
mock_ctx = Mock(spec=Context)
mock_ctx._state = {}
mock_ctx.get_state = lambda key: mock_ctx._state.get(key)
mock_ctx.set_state = lambda key, value: mock_ctx._state.__setitem__(key, value)
# Test session context
test_session_context_functions(mock_ctx)
print("✅ Session context functions work")
# Test user email context
test_user_email_context_functions(mock_ctx)
print("✅ User email context functions work")
# Test service injection
test_service_injection_functions(mock_ctx)
print("✅ Service injection functions work")
# Test session data storage
test_session_data_storage()
print("✅ Session data storage works")
# Test session management
test_session_management()
print("✅ Session management works")
print(
"\n🎉 All tests passed! Migration from ContextVar to FastMCP Context is successful."
)