base.py•9.94 kB
"""Base OCC client with shared HTTP behaviour."""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Any, Literal, Protocol
import httpx
from tenacity import AsyncRetrying, retry_if_exception, stop_after_attempt, wait_exponential
from app.models import JsonObject
from app.settings import settings
from app.utils import OccApiError
class SupportsOccHttp(Protocol):
async def get_json(self, path: str, *, params: dict[str, Any]) -> JsonObject: ...
async def head_metadata(self, path: str, *, params: dict[str, Any]) -> dict[str, Any]: ...
async def post_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject: ...
async def put_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject: ...
async def patch_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject: ...
async def delete(
self,
path: str,
*,
params: dict[str, Any],
expect_body: bool = False,
) -> JsonObject | dict[str, Any]: ...
RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504}
def _is_retryable_exception(exc: BaseException) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code in RETRY_STATUS_CODES
return isinstance(exc, httpx.TransportError)
class OccClientBase:
"""Shared HTTP plumbing for OCC API calls."""
def __init__(self) -> None:
self._client = self._create_async_client()
self._token_lock = asyncio.Lock()
self._cached_token: tuple[str, datetime] | None = None
async def aclose(self) -> None:
"""Close the underlying HTTP client."""
await self._client.aclose()
async def get_json(self, path: str, *, params: dict[str, Any]) -> JsonObject:
"""Perform a GET request and return the decoded JSON payload."""
return await self._request("GET", path, params=params)
async def head_metadata(self, path: str, *, params: dict[str, Any]) -> dict[str, Any]:
"""Perform a HEAD request and return response metadata."""
return await self._request("HEAD", path, params=params, parse_response=False)
async def post_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject:
"""Perform a POST request with JSON body and return the decoded payload."""
return await self._request(
"POST",
path,
params=params,
json_data=json_data,
)
async def put_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject:
"""Perform a PUT request with JSON body and return the decoded payload."""
return await self._request(
"PUT",
path,
params=params,
json_data=json_data,
)
async def patch_json(
self,
path: str,
*,
params: dict[str, Any],
json_data: dict[str, Any] | list[Any] | None,
) -> JsonObject:
"""Perform a PATCH request with JSON body and return the decoded payload."""
return await self._request(
"PATCH",
path,
params=params,
json_data=json_data,
)
async def delete(
self,
path: str,
*,
params: dict[str, Any],
expect_body: bool = False,
) -> JsonObject | dict[str, Any]:
"""Perform a DELETE request, optionally parsing a response body."""
return await self._request(
"DELETE",
path,
params=params,
parse_response=expect_body,
)
async def _get(self, path: str, *, params: dict[str, Any]) -> JsonObject:
return await self.get_json(path, params=params)
async def _head(self, path: str, *, params: dict[str, Any]) -> dict[str, Any]:
return await self.head_metadata(path, params=params)
async def _request(
self,
method: Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
params: dict[str, Any],
parse_response: bool = True,
json_data: dict[str, Any] | list[Any] | None = None,
) -> JsonObject | dict[str, Any]:
self._ensure_client()
retrying = AsyncRetrying(
stop=stop_after_attempt(settings.occ_retry_attempts),
wait=wait_exponential(
multiplier=1, min=1, max=settings.occ_retry_backoff_max_seconds
),
retry=retry_if_exception(_is_retryable_exception),
reraise=True,
)
try:
async for attempt in retrying:
with attempt:
headers = await self._build_headers()
response = await self._client.request(
method,
path,
params=params,
headers=headers,
json=json_data,
)
if response.status_code >= 400:
if response.status_code in RETRY_STATUS_CODES:
response.raise_for_status()
raise OccApiError(
f"OCC request failed with status {response.status_code}",
status_code=response.status_code,
payload=self._extract_error_payload(response),
)
if parse_response:
return self._parse_response(response)
return {
"status_code": response.status_code,
"headers": dict(response.headers),
}
except httpx.HTTPStatusError as exc:
response = exc.response
raise OccApiError(
"OCC request failed after retries",
status_code=response.status_code,
payload=self._extract_error_payload(response),
) from exc
except httpx.TransportError as exc:
raise OccApiError(
f"Transport error while communicating with OCC: {exc}"
) from exc
raise OccApiError("OCC request did not complete successfully")
async def _build_headers(self) -> dict[str, str]:
headers: dict[str, str] = {"Accept": settings.occ_accept}
if settings.use_oauth:
token = await self._ensure_token()
headers["Authorization"] = f"Bearer {token}"
return headers
def _parse_response(self, response: httpx.Response) -> JsonObject:
if settings.accepts_json:
try:
return response.json()
except ValueError as exc:
raise OccApiError(
"Failed to decode JSON response from OCC",
status_code=response.status_code,
payload=response.text,
) from exc
return {
"raw": response.text,
"content_type": response.headers.get("content-type"),
}
def _extract_error_payload(self, response: httpx.Response) -> Any:
if "application/json" in response.headers.get("content-type", ""):
try:
return response.json()
except Exception: # noqa: BLE001 - fallback to raw text
return response.text
return response.text
async def _ensure_token(self) -> str:
if not settings.use_oauth:
raise OccApiError("OAuth credentials are not configured.")
async with self._token_lock:
if self._cached_token and not self._token_expired(self._cached_token):
return self._cached_token[0]
data = {"grant_type": "client_credentials"}
if settings.oauth_scope:
data["scope"] = settings.oauth_scope
auth = (settings.oauth_client_id or "", settings.oauth_client_secret or "")
response = await self._client.post(
str(settings.oauth_token_url),
data=data,
auth=auth,
)
if response.status_code >= 400:
raise OccApiError(
"Failed to obtain OAuth token",
status_code=response.status_code,
payload=self._extract_error_payload(response),
)
payload = response.json()
access_token = payload.get("access_token")
expires_in = payload.get("expires_in", 0)
if not access_token:
raise OccApiError("OAuth token response missing access_token", payload=payload)
expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in or 0))
self._cached_token = (access_token, expires_at)
return access_token
def _token_expired(self, token_info: tuple[str, datetime]) -> bool:
_, expires_at = token_info
return datetime.now(UTC) + timedelta(seconds=60) >= expires_at
def _create_async_client(self) -> httpx.AsyncClient:
timeout = httpx.Timeout(settings.occ_timeout_seconds)
return httpx.AsyncClient(
base_url=str(settings.occ_base_url),
timeout=timeout,
headers={"Accept": settings.occ_accept},
)
def _ensure_client(self) -> None:
if getattr(self._client, "is_closed", False):
self._client = self._create_async_client()