"""Tests for client_secret_basic (HTTP Basic Auth) on the OAuth token endpoint."""
import base64
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock
import pytest
def _make_request(headers=None, body=None, content_type="application/x-www-form-urlencoded"):
"""Create a mock Starlette Request for oauth_token."""
request = AsyncMock()
request.headers = MagicMock()
_headers = {"content-type": content_type}
if headers:
_headers.update(headers)
request.headers.get = lambda key, default="": _headers.get(key.lower(), default)
form_data = body or {}
form = AsyncMock(return_value=form_data)
request.form = form
return request
def _encode_basic(client_id: str, client_secret: str) -> str:
"""Encode client_id:client_secret as Basic Auth header value."""
raw = f"{client_id}:{client_secret}"
encoded = base64.b64encode(raw.encode("utf-8")).decode("utf-8")
return f"Basic {encoded}"
@pytest.fixture
def temp_storage():
"""Create temporary storage for OAuth tests."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["OAUTH_STORAGE_PATH"] = tmpdir
os.environ["OAUTH_JWT_SECRET_KEY"] = "test_secret_key_for_basic_auth_tests"
from core.oauth import client_registry, server, storage, token_manager
client_registry._client_registry = None
token_manager._token_manager = None
storage._storage = None
server._oauth_server = None
yield tmpdir
os.environ.pop("OAUTH_STORAGE_PATH", None)
os.environ.pop("OAUTH_JWT_SECRET_KEY", None)
client_registry._client_registry = None
token_manager._token_manager = None
storage._storage = None
server._oauth_server = None
async def _call_oauth_token(request):
"""Import and call the oauth_token endpoint from server.py."""
from server import oauth_token
return await oauth_token(request)
@pytest.mark.unit
class TestClientSecretBasicAuth:
"""Tests for client_secret_basic authentication on the token endpoint."""
async def test_basic_auth_header_parsed(self, temp_storage):
"""Basic Auth header is correctly parsed (base64 encoded client_id:client_secret)."""
request = _make_request(
headers={"authorization": _encode_basic("my_client_id", "my_client_secret")},
body={"grant_type": "client_credentials"},
)
response = await _call_oauth_token(request)
# The request will likely fail on actual client validation,
# but we're testing that it gets past the credential parsing stage.
# If it returned invalid_request with "Missing client_id", Basic Auth parsing failed.
import json
data = json.loads(response.body)
assert data.get("error") != "invalid_request" or "Missing client_id" not in data.get(
"error_description", ""
)
async def test_body_params_take_priority_over_basic_auth(self, temp_storage):
"""Body params take priority over Basic Auth (setdefault behavior)."""
request = _make_request(
headers={"authorization": _encode_basic("basic_id", "basic_secret")},
body={
"grant_type": "client_credentials",
"client_id": "body_id",
"client_secret": "body_secret",
},
)
response = await _call_oauth_token(request)
import json
data = json.loads(response.body)
# The body params should be used, not the Basic Auth ones.
# If the error mentions "body_id", body params were used.
# If it mentions "basic_id", Basic Auth overrode body params (wrong).
# Since the client won't exist, we expect invalid_client error.
# The key check: it should NOT have replaced body_id with basic_id.
if data.get("error") == "invalid_client":
# The error came from actual client validation, meaning
# credentials were parsed. The body params were used.
pass
else:
# Should not get "Missing client_id" error
assert "Missing client_id" not in data.get("error_description", "")
async def test_malformed_basic_auth_returns_invalid_client(self, temp_storage):
"""Malformed Basic Auth header returns invalid_client with 401 status."""
request = _make_request(
headers={"authorization": "Basic !!!not-valid-base64!!!"},
body={"grant_type": "client_credentials"},
)
response = await _call_oauth_token(request)
import json
data = json.loads(response.body)
assert data["error"] == "invalid_client"
assert "Invalid Basic authentication header" in data["error_description"]
assert response.status_code == 401
async def test_missing_credentials_returns_invalid_request(self, temp_storage):
"""Missing credentials (no body, no Basic header) returns invalid_request."""
request = _make_request(
body={"grant_type": "client_credentials"},
)
response = await _call_oauth_token(request)
import json
data = json.loads(response.body)
assert data["error"] == "invalid_request"
assert "Missing client_id or client_secret" in data["error_description"]
async def test_client_secret_with_colon(self, temp_storage):
"""client_secret containing ':' is handled correctly (split on first ':' only)."""
secret_with_colon = "my:secret:with:colons"
request = _make_request(
headers={"authorization": _encode_basic("my_client_id", secret_with_colon)},
body={"grant_type": "client_credentials"},
)
response = await _call_oauth_token(request)
import json
data = json.loads(response.body)
# Should not get invalid_request about missing credentials
assert data.get("error") != "invalid_request" or "Missing client_id" not in data.get(
"error_description", ""
)
# Should not get invalid_client from Basic Auth parsing
assert data.get("error_description") != "Invalid Basic authentication header"