"""Unit tests for Auth Middleware.
Tests authentication middleware decorators and utilities.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from mcp.shared.exceptions import McpError
from sso_mcp_server.auth import middleware
from sso_mcp_server.auth.middleware import (
check_auth,
ensure_authenticated,
get_auth_manager,
require_auth,
set_auth_manager,
)
class TestSetAuthManager:
"""Tests for set_auth_manager function."""
def test_set_auth_manager_stores_manager(self) -> None:
"""Test that set_auth_manager stores the manager globally."""
# Reset state
middleware._auth_manager = None
mock_manager = MagicMock()
set_auth_manager(mock_manager)
assert get_auth_manager() is mock_manager
# Cleanup
middleware._auth_manager = None
class TestGetAuthManager:
"""Tests for get_auth_manager function."""
def test_get_auth_manager_returns_none_when_not_set(self) -> None:
"""Test that get_auth_manager returns None when not set."""
# Reset state
middleware._auth_manager = None
assert get_auth_manager() is None
def test_get_auth_manager_returns_manager_when_set(self) -> None:
"""Test that get_auth_manager returns the manager when set."""
mock_manager = MagicMock()
middleware._auth_manager = mock_manager
assert get_auth_manager() is mock_manager
# Cleanup
middleware._auth_manager = None
class TestCheckAuth:
"""Tests for check_auth function."""
def test_check_auth_returns_false_when_manager_not_set(self) -> None:
"""Test that check_auth returns False when manager is not set."""
middleware._auth_manager = None
assert check_auth() is False
def test_check_auth_returns_false_when_not_authenticated(self) -> None:
"""Test that check_auth returns False when not authenticated."""
mock_manager = MagicMock()
mock_manager.is_authenticated.return_value = False
middleware._auth_manager = mock_manager
assert check_auth() is False
# Cleanup
middleware._auth_manager = None
def test_check_auth_returns_true_when_authenticated(self) -> None:
"""Test that check_auth returns True when authenticated."""
mock_manager = MagicMock()
mock_manager.is_authenticated.return_value = True
middleware._auth_manager = mock_manager
assert check_auth() is True
# Cleanup
middleware._auth_manager = None
class TestEnsureAuthenticated:
"""Tests for ensure_authenticated function."""
def test_ensure_authenticated_returns_false_when_manager_not_set(self) -> None:
"""Test that ensure_authenticated returns False when manager is not set."""
middleware._auth_manager = None
assert ensure_authenticated() is False
def test_ensure_authenticated_calls_manager(self) -> None:
"""Test that ensure_authenticated delegates to manager."""
mock_manager = MagicMock()
mock_manager.ensure_authenticated.return_value = True
middleware._auth_manager = mock_manager
result = ensure_authenticated()
assert result is True
mock_manager.ensure_authenticated.assert_called_once()
# Cleanup
middleware._auth_manager = None
class TestRequireAuthDecorator:
"""Tests for require_auth decorator."""
@pytest.mark.asyncio
async def test_require_auth_raises_when_manager_not_configured(self) -> None:
"""Test that require_auth raises McpError when manager is not configured."""
middleware._auth_manager = None
@require_auth
async def my_tool() -> str:
return "result"
with pytest.raises(McpError) as exc_info:
await my_tool()
assert "not configured" in exc_info.value.error.message.lower()
@pytest.mark.asyncio
async def test_require_auth_raises_when_not_authenticated(self) -> None:
"""Test that require_auth raises McpError when authentication fails."""
mock_manager = MagicMock()
mock_manager.ensure_authenticated.return_value = False
middleware._auth_manager = mock_manager
@require_auth
async def my_tool() -> str:
return "result"
with pytest.raises(McpError) as exc_info:
await my_tool()
assert "required" in exc_info.value.error.message.lower()
# Cleanup
middleware._auth_manager = None
@pytest.mark.asyncio
async def test_require_auth_allows_execution_when_authenticated(self) -> None:
"""Test that require_auth allows function execution when authenticated."""
mock_manager = MagicMock()
mock_manager.ensure_authenticated.return_value = True
middleware._auth_manager = mock_manager
@require_auth
async def my_tool() -> str:
return "success"
result = await my_tool()
assert result == "success"
# Cleanup
middleware._auth_manager = None
@pytest.mark.asyncio
async def test_require_auth_passes_arguments(self) -> None:
"""Test that require_auth passes arguments to wrapped function."""
mock_manager = MagicMock()
mock_manager.ensure_authenticated.return_value = True
middleware._auth_manager = mock_manager
@require_auth
async def my_tool(name: str, count: int = 1) -> str:
return f"{name}:{count}"
result = await my_tool("test", count=5)
assert result == "test:5"
# Cleanup
middleware._auth_manager = None