"""Unit tests for dual-mode authentication middleware."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mcp.shared.exceptions import McpError
from sso_mcp_server.auth.cloud.claims import TokenClaims
from sso_mcp_server.auth.exceptions import (
CloudTokenExpiredError,
InvalidTokenError,
)
from sso_mcp_server.auth.middleware import (
_extract_bearer_token,
get_auth_manager,
get_current_claims,
get_token_validator,
require_auth,
set_auth_manager,
set_authorization_header,
set_token_validator,
)
from sso_mcp_server.config import AuthMode
class TestExtractBearerToken:
"""Tests for _extract_bearer_token helper."""
def test_extracts_valid_bearer_token(self):
"""Test extracting valid Bearer token."""
token = _extract_bearer_token("Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9")
assert token == "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
def test_handles_lowercase_bearer(self):
"""Test handles lowercase 'bearer'."""
token = _extract_bearer_token("bearer eyJtoken")
assert token == "eyJtoken"
def test_returns_none_for_none_header(self):
"""Test returns None when header is None."""
assert _extract_bearer_token(None) is None
def test_returns_none_for_empty_header(self):
"""Test returns None for empty header."""
assert _extract_bearer_token("") is None
def test_returns_none_for_non_bearer_scheme(self):
"""Test returns None for non-Bearer scheme."""
assert _extract_bearer_token("Basic dXNlcjpwYXNz") is None
def test_returns_none_for_malformed_header(self):
"""Test returns None for malformed header."""
assert _extract_bearer_token("BearerWithNoSpace") is None
def test_returns_none_for_empty_token(self):
"""Test returns None when token part is empty."""
assert _extract_bearer_token("Bearer ") is None
assert _extract_bearer_token("Bearer ") is None
class TestSetGetAuthManager:
"""Tests for auth manager get/set functions."""
def test_set_and_get_auth_manager(self):
"""Test setting and getting auth manager."""
mock_manager = MagicMock()
set_auth_manager(mock_manager)
assert get_auth_manager() is mock_manager
class TestSetGetTokenValidator:
"""Tests for token validator get/set functions."""
def test_set_and_get_token_validator(self):
"""Test setting and getting token validator."""
mock_validator = MagicMock()
set_token_validator(mock_validator)
assert get_token_validator() is mock_validator
class TestSetAuthorizationHeader:
"""Tests for authorization header context variable."""
def test_set_authorization_header(self):
"""Test setting authorization header."""
# This sets the context variable
set_authorization_header("Bearer test-token")
# The value is checked in middleware when processing requests
class TestGetCurrentClaims:
"""Tests for getting current claims."""
def test_returns_none_by_default(self):
"""Test returns None when no claims set."""
# After clearing context
from sso_mcp_server.auth.middleware import _current_claims
_current_claims.set(None)
assert get_current_claims() is None
class TestRequireAuthLocalMode:
"""Tests for require_auth decorator in LOCAL mode."""
@pytest.fixture
def mock_settings_local(self):
"""Create mock settings for LOCAL mode."""
settings = MagicMock()
settings.auth_mode = AuthMode.LOCAL
return settings
@pytest.fixture
def mock_auth_manager(self):
"""Create mock auth manager."""
manager = MagicMock()
manager.ensure_authenticated = MagicMock(return_value=True)
return manager
@pytest.mark.asyncio
async def test_calls_function_when_authenticated(
self, mock_settings_local, mock_auth_manager
):
"""Test calls wrapped function when authenticated."""
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_local,
):
set_auth_manager(mock_auth_manager)
@require_auth
async def test_tool():
return {"result": "success"}
result = await test_tool()
assert result == {"result": "success"}
mock_auth_manager.ensure_authenticated.assert_called_once()
@pytest.mark.asyncio
async def test_raises_error_when_not_authenticated(self, mock_settings_local):
"""Test raises McpError when not authenticated."""
mock_manager = MagicMock()
mock_manager.ensure_authenticated = MagicMock(return_value=False)
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_local,
):
set_auth_manager(mock_manager)
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32002
@pytest.mark.asyncio
async def test_raises_error_when_auth_manager_not_set(self, mock_settings_local):
"""Test raises McpError when auth manager not configured."""
# Clear auth manager
from sso_mcp_server.auth import middleware
middleware._auth_manager = None
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_local,
):
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32001
class TestRequireAuthCloudMode:
"""Tests for require_auth decorator in CLOUD mode."""
@pytest.fixture
def mock_settings_cloud(self):
"""Create mock settings for CLOUD mode."""
settings = MagicMock()
settings.auth_mode = AuthMode.CLOUD
return settings
@pytest.fixture
def mock_token_validator(self):
"""Create mock token validator."""
validator = MagicMock()
validator.validate = AsyncMock()
return validator
@pytest.fixture
def mock_claims(self):
"""Create mock token claims."""
from datetime import datetime
return TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read", "write"],
)
@pytest.mark.asyncio
async def test_validates_bearer_token(
self, mock_settings_cloud, mock_token_validator, mock_claims
):
"""Test validates Bearer token and calls function."""
mock_token_validator.validate = AsyncMock(return_value=mock_claims)
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_cloud,
):
set_token_validator(mock_token_validator)
set_authorization_header("Bearer valid-token")
@require_auth
async def test_tool():
return {"result": "success"}
result = await test_tool()
assert result == {"result": "success"}
mock_token_validator.validate.assert_called_once_with("valid-token")
@pytest.mark.asyncio
async def test_raises_error_when_no_auth_header(self, mock_settings_cloud):
"""Test raises McpError when Authorization header missing."""
mock_validator = MagicMock()
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_cloud,
):
set_token_validator(mock_validator)
set_authorization_header(None)
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32002
@pytest.mark.asyncio
async def test_raises_error_when_token_invalid(
self, mock_settings_cloud, mock_token_validator
):
"""Test raises McpError when token validation fails."""
mock_token_validator.validate = AsyncMock(
side_effect=InvalidTokenError("Token is malformed")
)
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_cloud,
):
set_token_validator(mock_token_validator)
set_authorization_header("Bearer invalid-token")
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32002
@pytest.mark.asyncio
async def test_raises_error_when_token_expired(
self, mock_settings_cloud, mock_token_validator
):
"""Test raises McpError when token is expired."""
mock_token_validator.validate = AsyncMock(side_effect=CloudTokenExpiredError())
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_cloud,
):
set_token_validator(mock_token_validator)
set_authorization_header("Bearer expired-token")
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32002
@pytest.mark.asyncio
async def test_raises_error_when_validator_not_set(self, mock_settings_cloud):
"""Test raises McpError when token validator not configured."""
# Clear token validator
from sso_mcp_server.auth import middleware
middleware._token_validator = None
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_cloud,
):
set_authorization_header("Bearer some-token")
@require_auth
async def test_tool():
return {"result": "success"}
with pytest.raises(McpError) as exc_info:
await test_tool()
assert exc_info.value.error.code == -32001
class TestRequireAuthAutoMode:
"""Tests for require_auth decorator in AUTO mode."""
@pytest.fixture
def mock_settings_auto(self):
"""Create mock settings for AUTO mode."""
settings = MagicMock()
settings.auth_mode = AuthMode.AUTO
return settings
@pytest.fixture
def mock_auth_manager(self):
"""Create mock auth manager."""
manager = MagicMock()
manager.ensure_authenticated = MagicMock(return_value=True)
return manager
@pytest.fixture
def mock_token_validator(self):
"""Create mock token validator."""
validator = MagicMock()
validator.validate = AsyncMock()
return validator
@pytest.fixture
def mock_claims(self):
"""Create mock token claims."""
from datetime import datetime
return TokenClaims(
sub="user123",
iss="https://issuer.example.com",
aud="https://resource.example.com",
exp=datetime.now(),
iat=datetime.now(),
scopes=["read"],
)
@pytest.mark.asyncio
async def test_uses_cloud_flow_when_bearer_present(
self, mock_settings_auto, mock_token_validator, mock_claims
):
"""Test uses CLOUD flow when Bearer token is present."""
mock_token_validator.validate = AsyncMock(return_value=mock_claims)
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_auto,
):
set_token_validator(mock_token_validator)
set_authorization_header("Bearer cloud-token")
@require_auth
async def test_tool():
return {"mode": "cloud"}
result = await test_tool()
assert result == {"mode": "cloud"}
mock_token_validator.validate.assert_called_once_with("cloud-token")
@pytest.mark.asyncio
async def test_uses_local_flow_when_no_bearer(
self, mock_settings_auto, mock_auth_manager
):
"""Test uses LOCAL flow when no Bearer token present."""
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_auto,
):
set_auth_manager(mock_auth_manager)
set_authorization_header(None)
@require_auth
async def test_tool():
return {"mode": "local"}
result = await test_tool()
assert result == {"mode": "local"}
mock_auth_manager.ensure_authenticated.assert_called_once()
@pytest.mark.asyncio
async def test_uses_local_flow_for_non_bearer_auth(
self, mock_settings_auto, mock_auth_manager
):
"""Test uses LOCAL flow for non-Bearer auth header."""
with patch(
"sso_mcp_server.auth.middleware.get_settings",
return_value=mock_settings_auto,
):
set_auth_manager(mock_auth_manager)
set_authorization_header("Basic dXNlcjpwYXNz")
@require_auth
async def test_tool():
return {"mode": "local"}
result = await test_tool()
assert result == {"mode": "local"}
mock_auth_manager.ensure_authenticated.assert_called_once()