from __future__ import annotations
import base64
from pathlib import Path
from typing import Any, Dict
import httpx
import pytest
from openapi2mcpserver.client import OpenAPIMCPClient
from openapi2mcpserver.config import AuthConfig, ServerConfig
from openapi2mcpserver.openapi_loader import OperationSpec
def _make_operation() -> OperationSpec:
return OperationSpec(
operation_id="TestOperation",
method="get",
path="/test",
summary=None,
description=None,
input_schema={"type": "object", "properties": {}, "additionalProperties": False},
path_params=[],
query_params=[],
header_params=[],
request_body_schema=None,
request_body_required=False,
request_body_media_type=None,
)
pytestmark = pytest.mark.anyio("asyncio")
@pytest.fixture
def anyio_backend() -> str:
return "asyncio"
async def _invoke_with_transport(
username: str,
password: str,
*,
headers: Dict[str, str] | None = None,
) -> httpx.Headers:
captured: httpx.Headers | None = None
def handler(request: httpx.Request) -> httpx.Response:
nonlocal captured
captured = request.headers
return httpx.Response(200, json={"ok": True})
transport = httpx.MockTransport(handler)
config = ServerConfig(
base_url="https://example.com",
verify_ssl=False,
openapi_path=Path(__file__),
auth=AuthConfig(mode="basic", basic_username=username, basic_password=password),
)
client = OpenAPIMCPClient(config)
client._client = httpx.AsyncClient(transport=transport, base_url=config.base_url, verify=False) # type: ignore[attr-defined]
try:
await client.invoke(_make_operation(), headers=headers)
finally:
await client.close()
assert captured is not None
return captured # type: ignore[return-value]
async def test_client_sets_basic_auth_header() -> None:
headers = await _invoke_with_transport("alice", "secret")
expected = base64.b64encode(b"alice:secret").decode("ascii")
assert headers.get("authorization") == f"Basic {expected}"
async def test_client_does_not_override_explicit_authorization_header() -> None:
custom_auth = "Bearer custom-token"
headers = await _invoke_with_transport("alice", "secret", headers={"Authorization": custom_auth})
assert headers.get("authorization") == custom_auth