_helpers.py•69.1 kB
from __future__ import annotations
import asyncio
import os
import re
import ssl
import string
import sys
from abc import ABC, abstractmethod
from base64 import b64decode, urlsafe_b64encode
from collections.abc import Iterable, Iterator, Mapping
from contextlib import AbstractContextManager, contextmanager, nullcontext
from contextvars import ContextVar
from dataclasses import dataclass, replace
from datetime import datetime, timezone
from email.message import Message
from functools import cached_property
from io import BytesIO
from itertools import chain
from random import random
from secrets import randbits, token_hex
from subprocess import PIPE, STDOUT
from threading import Lock, Thread
from time import sleep, time
from types import MappingProxyType, TracebackType
from typing import (
Any,
Awaitable,
Callable,
Generator,
Generic,
Literal,
NamedTuple,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse
from urllib.request import urlopen
import bs4
import httpx
import jwt
import pytest
import smtpdfix
from fastapi import FastAPI
from httpx import Headers, HTTPStatusError
from jwt import DecodeError
from openinference.semconv.resource import ResourceAttributes
from opentelemetry.exporter.otlp.proto.grpc.exporter import _load_credentials
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.sdk.trace.id_generator import IdGenerator
from opentelemetry.trace import Span, Tracer, format_span_id
from opentelemetry.util.types import AttributeValue
from psutil import STATUS_ZOMBIE, Popen
from sqlalchemy import URL, create_engine, text
from sqlalchemy.exc import OperationalError
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from strawberry.relay import GlobalID
from typing_extensions import Self, TypeAlias, assert_never, override
from phoenix.auth import (
DEFAULT_ADMIN_EMAIL,
DEFAULT_ADMIN_PASSWORD,
DEFAULT_ADMIN_USERNAME,
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
sanitize_email,
)
from phoenix.config import (
ENV_PHOENIX_SQL_DATABASE_SCHEMA,
ENV_PHOENIX_SQL_DATABASE_URL,
)
from phoenix.server.api.auth import IsAdmin
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from phoenix.server.thread_server import ThreadServer
_DB_BACKEND: TypeAlias = Literal["sqlite", "postgresql"]
_ADMIN = UserRoleInput.ADMIN
_MEMBER = UserRoleInput.MEMBER
_VIEWER = UserRoleInput.VIEWER
_ProjectName: TypeAlias = str
_SpanName: TypeAlias = str
_Headers: TypeAlias = dict[str, Any]
_Name: TypeAlias = str
_Secret: TypeAlias = str
_Email: TypeAlias = str
_Password: TypeAlias = str
_Username: TypeAlias = str
@dataclass(frozen=True)
class _Profile:
email: _Email
password: _Password
username: _Username
class _GqlId(str): ...
_AnyT = TypeVar("_AnyT")
class _CanLogOut(Generic[_AnyT], ABC):
@abstractmethod
def log_out(self, app: _AppInfo) -> _AnyT: ...
@dataclass(frozen=True)
class _User:
gid: _GqlId
role: UserRoleInput
profile: _Profile
def log_in(self, app: _AppInfo) -> _LoggedInUser:
tokens = _log_in(app, self.password, email=self.email)
return _LoggedInUser(self.gid, self.role, self.profile, tokens)
@cached_property
def password(self) -> _Password:
return self.profile.password
@cached_property
def email(self) -> _Email:
return self.profile.email
@cached_property
def username(self) -> Optional[_Username]:
return self.profile.username
def gql(
self,
app: _AppInfo,
query: str,
variables: Optional[Mapping[str, Any]] = None,
operation_name: Optional[str] = None,
) -> tuple[dict[str, Any], Headers]:
return _gql(app, self, query=query, variables=variables, operation_name=operation_name)
def create_user(
self,
app: _AppInfo,
role: UserRoleInput = _MEMBER,
/,
*,
profile: _Profile,
send_welcome_email: bool = False,
local: bool = True,
) -> _User:
return _create_user(
app,
self,
role=role,
profile=profile,
send_welcome_email=send_welcome_email,
local=local,
)
def delete_users(self, app: _AppInfo, *users: Union[_GqlId, _User]) -> None:
return _delete_users(app, self, users=users)
def list_users(self, app: _AppInfo) -> list[_User]:
return _list_users(app, self)
def patch_user_gid(
self,
app: _AppInfo,
gid: _GqlId,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
new_role: Optional[UserRoleInput] = None,
) -> None:
return _patch_user_gid(
app,
gid,
self,
new_username=new_username,
new_password=new_password,
new_role=new_role,
)
def patch_user(
self,
app: _AppInfo,
user: _User,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
new_role: Optional[UserRoleInput] = None,
) -> _User:
return _patch_user(
app,
user,
self,
new_username=new_username,
new_password=new_password,
new_role=new_role,
)
def patch_viewer(
self,
app: _AppInfo,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
) -> None:
return _patch_viewer(
app,
self,
self.password,
new_username=new_username,
new_password=new_password,
)
def create_api_key(
self,
app: _AppInfo,
kind: _ApiKeyKind = "User",
/,
*,
name: Optional[_Name] = None,
expires_at: Optional[datetime] = None,
) -> _ApiKey:
return _create_api_key(app, self, kind, name=name, expires_at=expires_at)
def delete_api_key(self, app: _AppInfo, api_key: _ApiKey, /) -> None:
return _delete_api_key(app, api_key, self)
def export_embeddings(self, app: _AppInfo, filename: str) -> None:
_export_embeddings(app, self, filename=filename)
def initiate_password_reset(
self,
app: _AppInfo,
smtpd: smtpdfix.AuthController,
/,
*,
should_receive_email: bool = True,
) -> Optional[_PasswordResetToken]:
return _initiate_password_reset(
app,
self.email,
smtpd,
should_receive_email=should_receive_email,
)
_SYSTEM_USER_GID = _GqlId(GlobalID(type_name="User", node_id="1"))
_DEFAULT_ADMIN = _User(
_GqlId(GlobalID("User", "2")),
_ADMIN,
_Profile(
email=DEFAULT_ADMIN_EMAIL,
password=DEFAULT_ADMIN_PASSWORD,
username=DEFAULT_ADMIN_USERNAME,
),
)
_ApiKeyKind = Literal["System", "User"]
class _ApiKey(str):
def __new__(
cls,
string: str,
gid: _GqlId,
kind: _ApiKeyKind = "User",
) -> _ApiKey:
return super().__new__(cls, string)
def __init__(
self,
string: str,
gid: _GqlId,
kind: _ApiKeyKind = "User",
) -> None:
self._gid = gid
self._kind: _ApiKeyKind = kind
@cached_property
def gid(self) -> _GqlId:
return self._gid
@cached_property
def kind(self) -> _ApiKeyKind:
return self._kind
class _AdminSecret(str): ...
class _Token(str, ABC): ...
class _PasswordResetToken(_Token):
def reset(self, app: _AppInfo, password: _Password, /) -> None:
return _reset_password(app, self, password=password)
class _AccessToken(_Token, _CanLogOut[None]):
def log_out(self, app: _AppInfo) -> None:
_log_out(app, self)
class _RefreshToken(_Token, _CanLogOut[None]):
def log_out(self, app: _AppInfo) -> None:
_log_out(app, self)
@dataclass(frozen=True)
class _LoggedInTokens(_CanLogOut[None]):
access_token: _AccessToken
refresh_token: _RefreshToken
@override
def log_out(self, app: _AppInfo) -> None:
self.access_token.log_out(app)
def refresh(self, app: _AppInfo) -> _LoggedInTokens:
resp = _httpx_client(app, self).post("auth/refresh")
resp.raise_for_status()
access_token = _AccessToken(resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME))
refresh_token = _RefreshToken(resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME))
return _LoggedInTokens(access_token, refresh_token)
@dataclass(frozen=True)
class _LoggedInUser(_User, _CanLogOut[_User]):
tokens: _LoggedInTokens
@property
def user(self) -> _User:
return _User(self.gid, self.role, self.profile)
@override
def log_out(self, app: _AppInfo) -> _User:
self.tokens.access_token.log_out(app)
return self.user
def refresh(self, app: _AppInfo) -> _LoggedInUser:
return replace(self, tokens=self.tokens.refresh(app))
def visit(self, app: _AppInfo, expected_status_code: int = 200) -> None:
response = _httpx_client(app, self).get("/graphql")
assert response.status_code == expected_status_code
_RoleOrUser = Union[UserRoleInput, _User]
_SecurityArtifact: TypeAlias = Union[
_AdminSecret,
_AccessToken,
_RefreshToken,
_LoggedInTokens,
_ApiKey,
_LoggedInUser,
_User,
]
class _UserGenerator(Protocol):
def send(self, _: tuple[_AppInfo, UserRoleInput, Optional[_Profile]]) -> _User: ...
class _UserFactory(Protocol):
def __call__(
self,
app: _AppInfo,
role: UserRoleInput = _MEMBER,
/,
*,
profile: Optional[_Profile] = None,
) -> _User: ...
class _GetUser(Protocol):
def __call__(
self,
app: _AppInfo,
role_or_user: Union[_User, UserRoleInput] = _MEMBER,
/,
*,
profile: Optional[_Profile] = None,
) -> _User: ...
class _SpanExporterFactory(Protocol):
def __call__(
self,
app: _AppInfo,
/,
*,
headers: Optional[_Headers] = None,
) -> SpanExporter: ...
class _GetSpan(Protocol):
def __call__(
self,
app: _AppInfo,
/,
project_name: Optional[str] = None,
span_name: Optional[str] = None,
attributes: Optional[dict[str, AttributeValue]] = None,
) -> ReadableSpan: ...
class _SendSpans(Protocol):
def __call__(
self,
app: _AppInfo,
api_key: Optional[_ApiKey] = None,
/,
spans: Iterable[ReadableSpan] = (),
headers: Optional[dict[str, str]] = None,
) -> SpanExportResult: ...
@dataclass(frozen=True)
class _AppInfo:
env: Mapping[str, str]
def __post_init__(self) -> None:
object.__setattr__(self, "env", MappingProxyType(dict(self.env)))
@cached_property
def base_url(self) -> str:
scheme = (
"https"
if self.env.get(
"PHOENIX_TLS_ENABLED_FOR_HTTP",
self.env.get("PHOENIX_TLS_ENABLED", "false"),
).lower()
== "true"
else "http"
)
hostname = self.env.get("PHOENIX_HOSTNAME", "127.0.0.1")
port = self.env.get("PHOENIX_PORT", "6006")
path = self.env.get("PHOENIX_ROOT_PATH", "")
return str(urljoin(f"{scheme}://{hostname}:{port}", path))
@cached_property
def grpc_url(self) -> str:
scheme = (
"https"
if self.env.get(
"PHOENIX_TLS_ENABLED_FOR_GRPC",
self.env.get("PHOENIX_TLS_ENABLED", "false"),
).lower()
== "true"
else "http"
)
hostname = self.env.get("PHOENIX_HOSTNAME", "127.0.0.1")
port = self.env.get("PHOENIX_GRPC_PORT", "4317")
return f"{scheme}://{hostname}:{port}"
@cached_property
def admin_secret(self) -> _AdminSecret:
return _AdminSecret(self.env.get("PHOENIX_ADMIN_SECRET", ""))
@cached_property
def certificate_file(self) -> Optional[str]:
return self.env.get("PHOENIX_TLS_CERT_FILE")
@cached_property
def client_certificate_file(self) -> Optional[str]:
return self.env.get("PHOENIX_TLS_CA_FILE")
@cached_property
def client_key_file(self) -> Optional[str]:
return self.env.get("PHOENIX_TLS_CA_FILE")
def _http_span_exporter(
app: _AppInfo,
/,
*,
headers: Optional[_Headers] = None,
) -> SpanExporter:
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
endpoint = urljoin(app.base_url, "v1/traces")
exporter = OTLPSpanExporter(
endpoint=endpoint,
headers=headers,
certificate_file=app.certificate_file,
client_key_file=app.client_key_file,
client_certificate_file=app.client_certificate_file,
)
return exporter
def _grpc_span_exporter(
app: _AppInfo,
*,
headers: Optional[_Headers] = None,
) -> SpanExporter:
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
endpoint = app.grpc_url
return OTLPSpanExporter(
endpoint=endpoint,
headers=headers,
credentials=_load_credentials(
certificate_file=app.certificate_file,
client_key_file=app.client_key_file,
client_certificate_file=app.client_certificate_file,
),
)
def _change_port(url: str, new_port: int) -> str:
# Parse the URL
parsed_url = urlparse(url)
# Replace the netloc part with the new port
netloc = parsed_url.netloc
if ":" in netloc:
# If there's already a port, replace it
netloc = netloc.split(":")[0] + f":{new_port}"
else:
# If there's no port, add it
netloc = netloc + f":{new_port}"
# Create a new parsed URL with the updated netloc
updated_parts = parsed_url._replace(netloc=netloc)
# Combine the parts back into a URL
return urlunparse(updated_parts)
class _RandomIdGenerator(IdGenerator):
"""
Generate random trace and span IDs without being influenced by the current seed.
"""
def generate_span_id(self) -> int:
return randbits(64)
def generate_trace_id(self) -> int:
return randbits(128)
def _get_tracer(
*,
project_name: _ProjectName,
exporter: SpanExporter,
) -> Tracer:
resource = Resource({ResourceAttributes.PROJECT_NAME: project_name})
tracer_provider = TracerProvider(resource=resource, id_generator=_RandomIdGenerator())
tracer_provider.add_span_processor(SimpleSpanProcessor(exporter))
return tracer_provider.get_tracer(__name__)
def _start_span(
*,
exporter: SpanExporter,
project_name: Optional[str] = None,
span_name: Optional[str] = None,
attributes: Optional[Mapping[str, AttributeValue]] = None,
start_time: Optional[int] = None,
) -> Span:
return _get_tracer(
project_name=project_name or token_hex(16),
exporter=exporter,
).start_span(
name=span_name or token_hex(16),
attributes=attributes,
start_time=start_time,
)
class _DefaultAdminTokens(ABC):
"""
Because the tests can be run concurrently, and we need the default admin to create database
entities (e.g. to add new users), the default admin should never log out once logged in,
because logging out invalidates all existing access tokens, resulting in a race among the
tests. The approach here is to add a middleware to block any inadvertent use of the default
admin's access tokens for logging out. This class is intended to be used as a singleton
container to ensure that all tokens are always accounted for. Furthermore, the tokens are
disambiguated by the port of the server to which they belong.
"""
_set: set[tuple[int, str]] = set()
_lock: Lock = Lock()
@classmethod
def __new__(cls) -> Self:
raise NotImplementedError("This class is intended as a singleton to be used directly.")
@classmethod
def stash(cls, port: int, headers: Headers) -> None:
tokens = _extract_tokens(headers, "set-cookie").values()
for token in tokens:
with cls._lock:
cls._set.add((port, token))
@classmethod
def intersect(cls, port: int, headers: Headers) -> bool:
tokens = _extract_tokens(headers).values()
for token in tokens:
with cls._lock:
if (port, token) in cls._set:
return True
return False
class _LogResponse(httpx.Response):
def __init__(self, info: BytesIO, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._info = info
def iter_bytes(self, *args: Any, **kwargs: Any) -> Iterator[bytes]:
for chunk in super().iter_bytes(*args, **kwargs):
self._info.write(chunk)
yield chunk
print(self._info.getvalue().decode())
def _get_token_from_cookie(cookie: str) -> str:
return cookie.split(";", 1)[0].split("=", 1)[1]
_TEST_NAME: ContextVar[str] = ContextVar("test_name", default="")
_HTTPX_OP_IDX: ContextVar[int] = ContextVar("httpx_operation_index", default=0)
class _LogTransport(httpx.BaseTransport):
def __init__(self, transport: httpx.BaseTransport) -> None:
self._transport = transport
def handle_request(self, request: httpx.Request) -> httpx.Response:
info = BytesIO()
info.write(f"{'-' * 50}\n".encode())
if test_name := _TEST_NAME.get():
op_idx = _HTTPX_OP_IDX.get()
_HTTPX_OP_IDX.set(op_idx + 1)
info.write(f"({op_idx})".encode())
info.write(f"{test_name}\n".encode())
info.write(f"{request.method} {request.url}\n".encode())
if token_ids := _decode_token_ids(request.headers):
info.write(f"{' '.join(token_ids)}\n".encode())
info.write(f"{request.headers}\n".encode())
info.write(request.read())
info.write(b"\n")
try:
response = self._transport.handle_request(request)
except BaseException:
print(info.getvalue().decode())
raise
info.write(f"{response.status_code} {response.headers}\n".encode())
if returned_token_ids := _decode_token_ids(response.headers, "set-cookie"):
info.write(f"{' '.join(returned_token_ids)}\n".encode())
return _LogResponse(
info=info,
status_code=response.status_code,
headers=response.headers,
stream=response.stream,
extensions=response.extensions,
)
def _httpx_client(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
headers: Optional[_Headers] = None,
cookies: Optional[dict[str, Any]] = None,
transport: Optional[httpx.BaseTransport] = None,
) -> httpx.Client:
if isinstance(auth, _AccessToken):
cookies = {**(cookies or {}), PHOENIX_ACCESS_TOKEN_COOKIE_NAME: auth}
elif isinstance(auth, _RefreshToken):
cookies = {**(cookies or {}), PHOENIX_REFRESH_TOKEN_COOKIE_NAME: auth}
elif isinstance(auth, _LoggedInTokens):
cookies = {
**(cookies or {}),
PHOENIX_ACCESS_TOKEN_COOKIE_NAME: auth.access_token,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME: auth.refresh_token,
}
elif isinstance(auth, _LoggedInUser):
cookies = {
**(cookies or {}),
PHOENIX_ACCESS_TOKEN_COOKIE_NAME: auth.tokens.access_token,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME: auth.tokens.refresh_token,
}
elif isinstance(auth, _User):
logged_in_user = auth.log_in(app)
return _httpx_client(app, logged_in_user.tokens, headers, cookies, transport)
elif isinstance(auth, _ApiKey):
headers = {**(headers or {}), "authorization": f"Bearer {auth}"}
elif isinstance(auth, _AdminSecret):
headers = {**(headers or {}), "authorization": f"Bearer {auth}"}
elif auth is None:
pass
else:
assert_never(auth)
ssl_context = _get_ssl_context(app.env)
# Having no timeout is useful when stepping through the debugger on the server side.
return httpx.Client(
timeout=None,
headers=headers,
cookies=cookies,
base_url=app.base_url,
transport=_LogTransport(transport or httpx.HTTPTransport(verify=ssl_context or False)),
)
def _get_ssl_context(env: Mapping[str, str]) -> Optional[ssl.SSLContext]:
if (
env.get("PHOENIX_TLS_ENABLED_FOR_HTTP", env.get("PHOENIX_TLS_ENABLED", "false")).lower()
!= "true"
):
return None
context = ssl.create_default_context()
ca_file = env.get("PHOENIX_TLS_CERT_FILE")
context.load_verify_locations(cafile=ca_file)
if env.get("PHOENIX_TLS_VERIFY_CLIENT", "false").lower() != "true":
return context
assert (cert_file := env.get("PHOENIX_TLS_CA_FILE"))
context.load_cert_chain(certfile=cert_file)
return context
_SCHEMA_PREFIX = f"_{token_hex(3)}"
@contextmanager
def _server(app: _AppInfo) -> Iterator[_AppInfo]:
if not (sql_database_url := app.env.get(ENV_PHOENIX_SQL_DATABASE_URL)):
raise ValueError(f"{ENV_PHOENIX_SQL_DATABASE_URL} is required.")
if sql_database_url.startswith("postgresql") and not str(
app.env.get(ENV_PHOENIX_SQL_DATABASE_SCHEMA, "")
).startswith(_SCHEMA_PREFIX):
raise ValueError(f"{ENV_PHOENIX_SQL_DATABASE_SCHEMA} should start with {_SCHEMA_PREFIX}")
command = f"{sys.executable} -m phoenix.server.main serve"
env = {**os.environ, **app.env} if sys.platform == "win32" else dict(app.env)
process = Popen(command.split(), stdout=PIPE, stderr=STDOUT, text=True, env=env)
log: list[str] = []
lock: Lock = Lock()
Thread(target=_capture_stdout, args=(process, log, lock), daemon=True).start()
t = 60
time_limit = time() + t
timed_out = False
url = str(urljoin(app.base_url, "healthz"))
ssl_context = _get_ssl_context(app.env)
while not timed_out and _is_alive(process):
sleep(0.1)
try:
urlopen(url, context=ssl_context)
break
except BaseException:
timed_out = time() > time_limit
try:
if timed_out:
raise TimeoutError(f"Server {url} did not start within {t} seconds.")
assert _is_alive(process)
with lock:
for line in log:
print(line, end="")
log.clear()
yield app
process.kill()
process.wait(10)
finally:
for line in log:
print(line, end="")
def _is_alive(
process: Popen,
) -> bool:
return process.is_running() and process.status() != STATUS_ZOMBIE
def _capture_stdout(
process: Popen,
log: list[str],
lock: Lock,
) -> None:
while _is_alive(process):
line = process.stdout.readline()
if line or (log and log[-1] != line):
with lock:
log.append(line)
@contextmanager
def _random_schema(
url: URL,
) -> Iterator[str]:
engine = create_engine(url.set(drivername="postgresql+psycopg"))
engine.connect().close()
engine.dispose()
schema = f"{_SCHEMA_PREFIX}{token_hex(16)}"[:63]
yield schema
time_limit = time() + 30
while time() < time_limit:
try:
with engine.connect() as conn:
conn.execute(text(f"DROP SCHEMA {schema} CASCADE;"))
conn.commit()
except OperationalError as exc:
if "too many clients" in str(exc):
sleep(1)
continue
raise
break
engine.dispose()
def _gql(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
*,
query: str,
variables: Optional[Mapping[str, Any]] = None,
operation_name: Optional[str] = None,
) -> tuple[dict[str, Any], Headers]:
json_ = dict(query=query, variables=dict(variables or {}), operationName=operation_name)
resp = _httpx_client(app, auth).post("graphql", json=json_)
return _json(resp), resp.headers
def _get_gql_spans(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
*fields: str,
) -> dict[_ProjectName, list[dict[str, Any]]]:
out = "name spans{edges{node{" + " ".join(fields) + "}}}"
query = "query{projects{edges{node{" + out + "}}}}"
resp_dict, headers = _gql(app, auth, query=query)
assert not resp_dict.get("errors")
assert not headers.get("set-cookie")
return {
project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]]
for project in resp_dict["data"]["projects"]["edges"]
}
def _list_users(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
) -> list[_User]:
all_users = []
has_next_page = True
end_cursor = None
while has_next_page:
args = ["first:1000"]
if end_cursor:
args.append(f'after:"{end_cursor}"')
args_str = f"({','.join(args)})"
query = (
"query{users"
+ args_str
+ "{edges{node{id email username role{name}}} pageInfo{hasNextPage endCursor}}}"
)
resp_dict, _ = _gql(app, auth, query=query)
users_data = resp_dict["data"]["users"]
users = [e["node"] for e in users_data["edges"]]
all_users.extend(
[
_User(
_GqlId(u["id"]),
UserRoleInput(u["role"]["name"]),
_Profile(u["email"], "", u["username"]),
)
for u in users
]
)
page_info = users_data["pageInfo"]
has_next_page = page_info["hasNextPage"]
end_cursor = page_info["endCursor"]
return all_users
def _create_user(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
*,
role: UserRoleInput,
profile: _Profile,
send_welcome_email: bool = False,
local: bool = True,
) -> _User:
email = profile.email
password = profile.password
username = profile.username
args = [f'email:"{email}"', f'password:"{password}"', f"role:{role.value}"]
if username:
args.append(f'username:"{username}"')
if not local:
args.append("authMethod:OAUTH2")
args.append(f"sendWelcomeEmail:{str(send_welcome_email).lower()}")
out = "user{id email role{name}}"
query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict, headers = _gql(app, auth, query=query)
assert (user := resp_dict["data"]["createUser"]["user"])
assert user["email"] == sanitize_email(email)
assert user["role"]["name"] == role.value
assert not headers.get("set-cookie")
return _User(_GqlId(user["id"]), role, profile)
def _delete_users(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
*,
users: Iterable[Union[_GqlId, _User]],
) -> None:
user_ids = [u.gid if isinstance(u, _User) else u for u in users]
query = "mutation($userIds:[ID!]!){deleteUsers(input:{userIds:$userIds})}"
_, headers = _gql(app, auth, query=query, variables=dict(userIds=user_ids))
assert not headers.get("set-cookie")
def _patch_user_gid(
app: _AppInfo,
gid: _GqlId,
auth: Optional[_SecurityArtifact] = None,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
new_role: Optional[UserRoleInput] = None,
) -> None:
args = [f'userId:"{gid}"']
if new_password:
args.append(f'newPassword:"{new_password}"')
if new_username:
args.append(f'newUsername:"{new_username}"')
if new_role:
args.append(f"newRole:{new_role.value}")
out = "user{id username role{name}}"
query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict, headers = _gql(app, auth, query=query)
assert (data := resp_dict["data"]["patchUser"])
assert (result := data["user"])
assert result["id"] == gid
if new_username:
assert result["username"] == new_username
if new_role:
assert result["role"]["name"] == new_role.value
assert not headers.get("set-cookie")
def _patch_user(
app: _AppInfo,
user: _User,
auth: Optional[_SecurityArtifact] = None,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
new_role: Optional[UserRoleInput] = None,
) -> _User:
_patch_user_gid(
app,
user.gid,
auth,
new_username=new_username,
new_password=new_password,
new_role=new_role,
)
if new_username:
user = replace(user, profile=replace(user.profile, username=new_username))
if new_role:
user = replace(user, role=new_role)
if new_password:
user = replace(user, profile=replace(user.profile, password=new_password))
return user
def _patch_viewer(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
current_password: Optional[_Password] = None,
/,
*,
new_username: Optional[_Username] = None,
new_password: Optional[_Password] = None,
) -> None:
args = []
if new_password:
args.append(f'newPassword:"{new_password}"')
if current_password:
args.append(f'currentPassword:"{current_password}"')
if new_username:
args.append(f'newUsername:"{new_username}"')
out = "user{username}"
query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}"
resp_dict, headers = _gql(app, auth, query=query)
assert (data := resp_dict["data"]["patchViewer"])
assert (user := data["user"])
if new_username:
assert user["username"] == new_username
if new_password:
assert headers.get("set-cookie")
else:
assert not headers.get("set-cookie")
def _create_api_key(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
kind: _ApiKeyKind = "User",
/,
*,
name: Optional[_Name] = None,
expires_at: Optional[datetime] = None,
) -> _ApiKey:
if name is None:
name = datetime.now(timezone.utc).isoformat()
exp = f' expiresAt:"{expires_at.isoformat()}"' if expires_at else ""
args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}"
field = f"create{kind}ApiKey"
query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}"
resp_dict, headers = _gql(app, auth, query=query)
assert (data := resp_dict["data"][field])
assert (key := data["apiKey"])
assert key["name"] == name
exp_t = datetime.fromisoformat(key["expiresAt"]) if key["expiresAt"] else None
assert exp_t == expires_at
assert not headers.get("set-cookie")
return _ApiKey(data["jwt"], _GqlId(key["id"]), kind)
def _delete_api_key(
app: _AppInfo,
api_key: _ApiKey,
auth: Optional[_SecurityArtifact] = None,
/,
) -> None:
kind = api_key.kind
field = f"delete{kind}ApiKey"
gid = api_key.gid
args, out = f'id:"{gid}"', "apiKeyId"
query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}"
resp_dict, headers = _gql(app, auth, query=query)
assert resp_dict["data"][field]["apiKeyId"] == gid
assert not headers.get("set-cookie")
def _will_be_asked_to_reset_password(
app: _AppInfo,
user: _User,
) -> bool:
query = "query($gid:ID!){node(id:$gid){... on User{passwordNeedsReset}}}"
variables = dict(gid=user.gid)
resp_dict, _ = user.log_in(app).gql(app, query, variables)
return cast(bool, resp_dict["data"]["node"]["passwordNeedsReset"])
def _log_in(
app: _AppInfo,
password: _Password,
/,
*,
email: _Email,
) -> _LoggedInTokens:
json_ = dict(email=email, password=password)
resp = _httpx_client(app).post("auth/login", json=json_)
resp.raise_for_status()
assert (access_token := resp.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME))
assert (refresh_token := resp.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME))
return _LoggedInTokens(_AccessToken(access_token), _RefreshToken(refresh_token))
def _log_out(
app: _AppInfo,
auth: Optional[_SecurityArtifact] = None,
/,
) -> None:
resp = _httpx_client(app, auth).get("auth/logout", follow_redirects=False)
try:
resp.raise_for_status()
except HTTPStatusError as e:
if e.response.status_code != 302:
raise
assert e.response.headers["location"] in ("/login", "/logout")
tokens = _extract_tokens(resp.headers, "set-cookie")
for k in _COOKIE_NAMES:
assert tokens[k] == '""'
def _initiate_password_reset(
app: _AppInfo,
email: _Email,
smtpd: smtpdfix.AuthController,
/,
*,
should_receive_email: bool = True,
) -> Optional[_PasswordResetToken]:
old_msg_count = len(smtpd.messages)
json_ = dict(email=email)
resp = _httpx_client(app).post("auth/password-reset-email", json=json_)
resp.raise_for_status()
new_msg_count = len(smtpd.messages) - old_msg_count
assert new_msg_count == int(should_receive_email)
if not should_receive_email:
return None
msg = smtpd.messages[-1]
assert msg["to"] == sanitize_email(email)
return _extract_password_reset_token(msg)
def _reset_password(
app: _AppInfo,
token: _PasswordResetToken,
/,
password: _Password,
) -> None:
json_ = dict(token=token, password=password)
resp = _httpx_client(app).post("auth/password-reset", json=json_)
resp.raise_for_status()
def _export_embeddings(
app: _AppInfo, auth: Optional[_SecurityArtifact] = None, /, *, filename: str
) -> None:
resp = _httpx_client(app, auth).get("/exports", params={"filename": filename})
resp.raise_for_status()
def _json(
resp: httpx.Response,
) -> dict[str, Any]:
resp.raise_for_status()
assert (resp_dict := cast(dict[str, Any], resp.json()))
if errers := resp_dict.get("errors"):
msg = errers[0]["message"]
# Raise Unauthorized for permission-related errors
if (
"not auth" in msg
or IsAdmin.message in msg
or "Viewers cannot perform this action" in msg
):
raise Unauthorized(msg)
raise RuntimeError(msg)
return resp_dict
class _Expectation(Protocol):
def __enter__(self) -> Optional[BaseException]: ...
def __exit__(self, *args: Any, **kwargs: Any) -> None: ...
_OK_OR_DENIED: TypeAlias = AbstractContextManager[Optional[Unauthorized]]
_OK = nullcontext()
_DENIED = pytest.raises(Unauthorized)
_EXPECTATION_401 = pytest.raises(HTTPStatusError, match="401 Unauthorized")
_EXPECTATION_403 = pytest.raises(HTTPStatusError, match="403 Forbidden")
_EXPECTATION_404 = pytest.raises(HTTPStatusError, match="404 Not Found")
def _extract_tokens(
headers: Headers,
key: Literal["cookie", "set-cookie"] = "cookie",
) -> dict[str, str]:
if not (cookies := headers.get(key)):
return {}
parts = re.split(r"[ ,;=]", cookies)
return {k: v for k, v in zip(parts[:-1], parts[1:]) if k in _COOKIE_NAMES}
def _decode_token_ids(
headers: Headers,
key: Literal["cookie", "set-cookie"] = "cookie",
) -> list[str]:
ans = []
for v in _extract_tokens(headers, key).values():
if v == '""':
continue
try:
token = jwt.decode(v, options={"verify_signature": False})["jti"]
except (DecodeError, KeyError):
continue
ans.append(token)
return ans
def _extract_password_reset_token(msg: Message) -> _PasswordResetToken:
assert (soup := _extract_html(msg))
assert isinstance((link := soup.find(id="reset-url")), bs4.Tag)
assert isinstance((url := link.get("href")), str)
assert url
params = parse_qs(urlparse(url).query)
assert (tokens := params["token"])
assert (token := tokens[0])
decoded = jwt.decode(token, options=dict(verify_signature=False))
assert (jti := decoded["jti"])
assert jti.startswith("PasswordResetToken")
return _PasswordResetToken(token)
def _extract_html(msg: Message) -> Optional[bs4.BeautifulSoup]:
for part in msg.walk():
if (
part.get_content_type() == "text/html"
and (payload := part.get_payload(decode=True))
and isinstance(payload, bytes)
):
content = payload.decode(part.get_content_charset() or "utf-8")
return bs4.BeautifulSoup(content, "html.parser")
return None
_COOKIE_NAMES = (
PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
)
async def _await_or_return(obj: Union[_AnyT, Awaitable[_AnyT]]) -> _AnyT:
"""Helper function to handle both synchronous and asynchronous operations uniformly.
This function enables writing code that works with both synchronous and asynchronous
operations without duplicating logic. It takes either a regular value or an awaitable
and returns the resolved value, abstracting away the sync/async distinction.
Args:
obj: Either a regular value or an awaitable (like a coroutine or Future)
Returns:
The resolved value. If obj was an awaitable, it will be awaited first.
Example:
# This works with both sync and async operations:
result = await _await_or_return(some_operation())
"""
if isinstance(obj, Awaitable):
return cast(_AnyT, await obj)
return obj
class _OIDCServer:
"""
A mock OpenID Connect (OIDC) server implementation for testing OAuth2/OIDC authentication flows.
This class provides a lightweight, in-memory OIDC server that simulates the behavior of a real
OIDC identity provider. It implements the core OIDC endpoints required for testing authentication
flows, including authorization, token issuance, and user information retrieval.
The server runs in a separate thread and can be used as a context manager to ensure proper
cleanup of resources. It generates random client credentials and signing keys for each instance,
making it suitable for isolated test scenarios.
Key features:
- Implements standard OIDC endpoints (/auth, /token, /.well-known/openid-configuration, etc.)
- Supports both standard OAuth2 authorization code flow and PKCE
- Supports group-based access control claims
- Generates JWT tokens with appropriate claims
- Provides JWKS endpoint for token verification
- Runs in a separate thread to avoid blocking the main test process
PKCE Support:
- Public clients (no client_secret): Validates code_verifier only
- Confidential clients with PKCE: Validates BOTH client_secret AND code_verifier (defense-in-depth)
Usage:
# Standard OAuth2 flow
with _OIDCServer(port=8000) as oidc_server:
# Use oidc_server.client_id and oidc_server.client_secret for OAuth2 configuration
# The server will be available at oidc_server.base_url
# PKCE flow with groups
with _OIDCServer(port=8000, use_pkce=True, groups=["admin", "users"]) as oidc_server:
# PKCE-enabled server with group claims
pass
"""
def __init__(
self,
port: int,
use_pkce: bool = False,
groups: Optional[list[str]] = None,
):
"""
Initialize a new OIDC server instance.
Args:
port: The port number on which the server will listen.
use_pkce: Enable PKCE (Proof Key for Code Exchange) support.
groups: List of groups to include in ID token claims (for group-based access control testing).
"""
self._name: str = f"oidc_server_{token_hex(8)}"
self._client_id: str = f"client_id_{token_hex(8)}"
self._client_secret: str = f"client_secret_{token_hex(8)}"
self._secret_key: str = f"secret_key_{token_hex(16)}"
self._host: str = "127.0.0.1"
self._port: int = port
self._use_pkce: bool = use_pkce
self._groups: list[str] = groups or []
self._app = FastAPI()
self._nonce: Optional[str] = None
self._user_id: Optional[str] = None
self._user_email: Optional[str] = None
self._user_name: Optional[str] = None
# PKCE state: maps auth_code -> code_challenge
self._code_challenges: dict[str, str] = {}
self._server: Optional[Generator[Thread, None, None]] = None
self._thread: Optional[Thread] = None
self._setup_routes()
def _setup_routes(self) -> None:
"""
Set up the FastAPI routes for the OIDC server.
This method configures all the necessary endpoints for OIDC functionality:
- /auth: Authorization endpoint that simulates the initial OAuth2 authorization request.
- /token: Token endpoint that exchanges authorization codes for tokens
- /.well-known/openid-configuration: Discovery document for OIDC clients
- /userinfo: User information endpoint
- /.well-known/jwks.json: JSON Web Key Set for token verification
"""
@self._app.get("/auth")
async def auth(request: Request) -> Response:
"""
Authorization endpoint that simulates the initial OAuth2 authorization request.
Validates the client_id and returns a redirect with an authorization code.
For PKCE flows, also receives and stores the code_challenge.
"""
params = dict(request.query_params)
if params.get("client_id") != self._client_id:
return JSONResponse({"error": "invalid_client"}, status_code=400)
state = params.get("state")
nonce = params.get("nonce")
redirect_uri = params.get("redirect_uri")
# Generate unique authorization code
auth_code = f"auth_code_{token_hex(16)}"
# PKCE: Store code_challenge if provided
if self._use_pkce:
code_challenge = params.get("code_challenge")
code_challenge_method = params.get("code_challenge_method")
if not code_challenge:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "code_challenge required",
},
status_code=400,
)
if code_challenge_method != "S256":
return JSONResponse(
{
"error": "invalid_request",
"error_description": "code_challenge_method must be S256",
},
status_code=400,
)
self._code_challenges[auth_code] = code_challenge
# Generate user for this session
self._nonce = nonce
self._user_id = f"user_id_{token_hex(8)}"
self._user_email = _randomize_casing(f"{string.ascii_lowercase}@{token_hex(16)}.com")
self._user_name = f"User {token_hex(8)}"
return RedirectResponse(
f"{redirect_uri}?code={auth_code}&state={state}",
status_code=302,
)
@self._app.post("/token")
async def token(request: Request) -> Response:
"""
Token endpoint that exchanges authorization codes for access and ID tokens.
Supports both standard OAuth2 and PKCE flows:
- Standard: Validates client_secret via HTTP Basic Auth
- PKCE: Validates code_verifier against stored code_challenge
- Confidential + PKCE: Validates BOTH client_secret AND code_verifier
Returns a token response with access_token, id_token, and refresh_token.
"""
from hashlib import sha256
form_data = await request.form()
code = form_data.get("code")
if not code:
return JSONResponse(
{"error": "invalid_request", "error_description": "code required"},
status_code=400,
)
# Type assertions for form data (FastAPI form_data.get returns Union[UploadFile, str])
assert isinstance(code, str)
# Step 1: Validate client authentication (if required)
client_authenticated = False
auth_header = request.headers.get("Authorization")
# Try HTTP Basic Auth (client_secret_basic)
if auth_header and auth_header.startswith("Basic "):
try:
credentials = b64decode(auth_header[6:]).decode()
client_id, client_secret = credentials.split(":", 1)
if client_id == self._client_id and client_secret == self._client_secret:
client_authenticated = True
except Exception:
pass
# Try POST body (client_secret_post)
if not client_authenticated:
body_client_id = form_data.get("client_id")
body_client_secret = form_data.get("client_secret")
if body_client_id == self._client_id and body_client_secret == self._client_secret:
client_authenticated = True
# Step 2: Validate PKCE (if required)
pkce_valid = False
code_verifier = form_data.get("code_verifier")
if self._use_pkce and code in self._code_challenges:
# Reject missing or empty code_verifier
if not code_verifier:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "code_verifier required for PKCE",
},
status_code=400,
)
assert isinstance(code_verifier, str)
# Compute challenge from verifier
challenge = (
urlsafe_b64encode(sha256(code_verifier.encode()).digest()).decode().rstrip("=")
)
stored_challenge = self._code_challenges.get(code)
if challenge == stored_challenge:
pkce_valid = True
# Clean up after successful validation
del self._code_challenges[code]
else:
return JSONResponse(
{
"error": "invalid_grant",
"error_description": "code_verifier does not match code_challenge",
},
status_code=400,
)
# Step 3: Determine authentication mode and validate
if self._use_pkce:
# PKCE flow
if not pkce_valid:
return JSONResponse(
{"error": "invalid_grant", "error_description": "PKCE validation failed"},
status_code=400,
)
# For confidential clients with PKCE, also check client_secret if provided
# (This is defense-in-depth: both PKCE and client_secret)
# If client_secret is in the request, it must be valid
if auth_header or form_data.get("client_secret"):
if not client_authenticated:
return JSONResponse(
{
"error": "invalid_client",
"error_description": "Invalid client credentials",
},
status_code=400,
)
else:
# Standard flow (non-PKCE): Validate code_verifier BEFORE client auth
# to avoid leaking information about server configuration
if code_verifier is not None and code_verifier != "":
return JSONResponse(
{
"error": "invalid_request",
"error_description": "code_verifier not allowed when PKCE is not enabled",
},
status_code=400,
)
# Now validate client authentication
if not client_authenticated:
return JSONResponse(
{
"error": "invalid_client",
"error_description": "Invalid client credentials",
},
status_code=400,
)
# Create ID token with required claims
now = int(time())
id_token_claims = {
"iss": self.base_url,
"sub": self._user_id,
"aud": self._client_id,
"iat": now,
"exp": now + 3600,
"email": self._user_email,
"name": self._user_name,
"nonce": self._nonce,
}
# NOTE: Groups are intentionally NOT included in ID token to simulate
# real-world IDPs (AWS Cognito, Azure AD) that keep ID tokens small.
# Groups must be fetched from the /userinfo endpoint instead.
id_token = jwt.encode(
payload=id_token_claims,
key=self._secret_key.encode(),
algorithm="HS256",
)
# Return token response with all required fields
return JSONResponse(
{
"access_token": f"access_token_{token_hex(8)}",
"id_token": id_token,
"token_type": "bearer",
"expires_in": 3600, # 1 hour in seconds
"refresh_token": f"refresh_token_{token_hex(8)}",
"scope": "openid profile email",
}
)
@self._app.get("/.well-known/openid-configuration")
async def openid_configuration() -> Response:
"""
OpenID Connect discovery document endpoint.
Returns the standard OIDC configuration document that clients use to
discover the endpoints and capabilities of this identity provider.
"""
config = {
"issuer": self.base_url,
"authorization_endpoint": self.auth_url,
"token_endpoint": self.token_url,
"userinfo_endpoint": f"{self.base_url}/userinfo",
"jwks_uri": f"{self.base_url}/.well-known/jwks.json",
"response_types_supported": ["code"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["HS256"],
"scopes_supported": ["openid", "profile", "email"],
"token_endpoint_auth_methods_supported": [
"client_secret_basic",
"client_secret_post",
],
"claims_supported": [
"sub",
"iss",
"aud",
"exp",
"iat",
"name",
"email",
"picture",
],
}
# Add PKCE support to discovery document
if self._use_pkce:
config["code_challenge_methods_supported"] = ["S256"]
# Public clients don't require client authentication
token_auth_methods = config["token_endpoint_auth_methods_supported"]
assert isinstance(token_auth_methods, list)
token_auth_methods.append("none")
# Add groups claim if configured
if self._groups:
claims_supported = config["claims_supported"]
assert isinstance(claims_supported, list)
claims_supported.append("groups")
return JSONResponse(config)
@self._app.get("/userinfo")
async def userinfo() -> Response:
"""
User information endpoint.
Returns a JSON response with user profile information that would typically
be retrieved from a real identity provider's user database.
Includes groups claim if configured.
"""
user_info = {
"sub": self._user_id,
"name": self._user_name,
"email": self._user_email,
"picture": "https://example.com/picture.jpg",
}
# Add groups if configured
if self._groups:
user_info["groups"] = self._groups # type: ignore[assignment]
return JSONResponse(user_info)
@self._app.get("/.well-known/jwks.json")
async def jwks() -> Response:
"""
JSON Web Key Set endpoint.
Returns the public keys that clients can use to verify the signatures
of ID tokens issued by this server. In this implementation, we're using
a symmetric key (HS256) for simplicity, but in a real OIDC provider,
this would typically use asymmetric keys (RS256).
"""
# Base64url encode the secret key
encoded_key = urlsafe_b64encode(self._secret_key.encode()).decode().rstrip("=")
return JSONResponse(
{
"keys": [
{
"kty": "oct",
"kid": "test_key_id",
"use": "sig",
"alg": "HS256",
"k": encoded_key,
}
]
}
)
def __enter__(self) -> Self:
self._server = ThreadServer(
app=self._app,
host=self._host,
port=self._port,
root_path="",
).run_in_thread()
self._thread = next(self._server)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if not self._server:
return
self._server.close()
if not self._thread:
return
self._thread.join(timeout=5)
@cached_property
def base_url(self) -> str:
return f"http://{self._host}:{self._port}"
@cached_property
def auth_url(self) -> str:
return f"{self.base_url}/auth"
@cached_property
def token_url(self) -> str:
return f"{self.base_url}/token"
@property
def user_id(self) -> Optional[str]:
"""Get the current user ID."""
return self._user_id
@property
def user_email(self) -> Optional[str]:
"""Get the current user email."""
return self._user_email
@property
def user_name(self) -> Optional[str]:
"""Get the current user name."""
return self._user_name
@property
def client_id(self) -> str:
"""Get the OAuth client ID."""
return self._client_id
@property
def client_secret(self) -> str:
"""Get the OAuth client secret."""
return self._client_secret
@property
def groups(self) -> list[str]:
"""Get the configured groups for this OIDC server."""
return self._groups
@property
def use_pkce(self) -> bool:
"""Check if PKCE is enabled for this OIDC server."""
return self._use_pkce
def __str__(self) -> str:
return self._name
T = TypeVar("T")
async def _get(
query_fn: Callable[..., Optional[T]] | Callable[..., Awaitable[Optional[T]]],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = MappingProxyType({}),
error_msg: str = "",
no_wait: bool = False,
retries: int = 20,
initial_wait_time: float = 0.1,
max_wait_time: float = 1,
) -> T:
"""If no_wait, run the query once. Otherwise, retry it if it returns None
and raise if retries are exhausted.
Args:
query_fn: Function that returns Optional[T] or Awaitable[Optional[T]]
args: Positional arguments for query_fn
kwargs: Keyword arguments for query_fn
error_msg: Error message if all retries fail
no_wait: If True, only try once without retries
retries: Maximum number of retry attempts
initial_wait_time: Initial wait time between retries in seconds
max_wait_time: Maximum wait time between retries in seconds
Returns:
Result from query_fn
Raises:
AssertionError: If query_fn returns None after all retries
"""
from asyncio import sleep
wt = 0 if no_wait else initial_wait_time
while True:
await sleep(wt)
res = query_fn(*args, **kwargs)
ans = cast(Optional[T], await res) if isinstance(res, Awaitable) else res
if ans is not None:
return ans
if no_wait or not retries:
raise AssertionError(error_msg)
retries -= 1
wt = min(wt * 1.5, max_wait_time)
_SpanId: TypeAlias = str
_TraceId: TypeAlias = str
_SessionId: TypeAlias = str
_SpanGlobalId: TypeAlias = GlobalID
_TraceGlobalId: TypeAlias = GlobalID
_SessionGlobalId: TypeAlias = GlobalID
_ProjectGlobalId: TypeAlias = GlobalID
class _ExistingProject(NamedTuple):
id: _ProjectGlobalId
name: _ProjectName
class _ExistingSession(NamedTuple):
id: _SessionGlobalId
session_id: _SessionId
class _ExistingTrace(NamedTuple):
id: _TraceGlobalId
trace_id: _TraceId
project: _ExistingProject
session: Optional[_ExistingSession]
class _ExistingSpan(NamedTuple):
id: _SpanGlobalId
span_id: _SpanId
trace: _ExistingTrace
def _insert_spans(app: _AppInfo, n: int) -> tuple[_ExistingSpan, ...]:
assert n > 0, "Number of spans to insert must be greater than 0"
memory = InMemorySpanExporter()
project_name = token_hex(16)
for _ in range(n):
_start_span(
project_name=project_name,
attributes={
"session.id": token_hex(8),
"retrieval.documents.0.document.id": token_hex(8),
"retrieval.documents.1.document.id": token_hex(8),
"retrieval.documents.2.document.id": token_hex(8),
},
exporter=memory,
).end()
assert len(spans := memory.get_finished_spans()) == n
headers = {"authorization": f"Bearer {app.admin_secret}"}
assert _grpc_span_exporter(app, headers=headers).export(spans) is SpanExportResult.SUCCESS
span_ids = set()
for span in spans:
assert (context := span.get_span_context()) # type: ignore[no-untyped-call]
span_ids.add(format_span_id(context.span_id))
assert len(span_ids) == n
return asyncio.run(
_get(
lambda: tuple(ans) if len(ans := _get_existing_spans(app, span_ids)) == n else None,
error_msg="spans not found",
)
)
def _get_existing_spans(
app: _AppInfo,
span_ids: Iterable[_SpanId],
) -> set[_ExistingSpan]:
ids = list(span_ids)
n = len(ids)
query = """
query ($filterCondition: String, $first: Int) {
projects {
edges {
node {
id
name
spans (filterCondition: $filterCondition, first: $first) {
edges {
node {
id
spanId
trace {
id
traceId
session {
id
sessionId
}
}
}
}
}
}
}
}
}
"""
res, _ = _gql(
app,
app.admin_secret,
query=query,
variables={"filterCondition": f"span_id in {ids}", "first": n},
)
return {
_ExistingSpan(
id=GlobalID.from_id(span["node"]["id"]),
span_id=span["node"]["spanId"],
trace=_ExistingTrace(
id=GlobalID.from_id(span["node"]["trace"]["id"]),
trace_id=span["node"]["trace"]["traceId"],
project=_ExistingProject(
id=GlobalID.from_id(project["node"]["id"]),
name=project["node"]["name"],
),
session=(
_ExistingSession(
id=GlobalID.from_id(span["node"]["trace"]["session"]["id"]),
session_id=span["node"]["trace"]["session"]["sessionId"],
)
if span["node"]["trace"]["session"] is not None
else None
),
),
)
for project in res["data"]["projects"]["edges"]
for span in project["node"]["spans"]["edges"]
if span["node"]["spanId"] in ids
}
async def _until_spans_exist(app: _AppInfo, span_ids: Iterable[_SpanId]) -> None:
ids = set(span_ids)
await _get(lambda: (len(_get_existing_spans(app, ids)) == len(ids)) or None)
def _randomize_casing(email: str) -> str:
return "".join(c.lower() if random() < 0.5 else c.upper() for c in email)
# GET endpoints that all roles can read with expected status codes
_COMMON_RESOURCE_ENDPOINTS = (
# Projects
(404, "GET", "v1/projects/fake-id-{}"),
(200, "GET", "v1/projects"),
# Datasets
(422, "GET", "v1/datasets/fake-id-{}"),
(200, "GET", "v1/datasets"),
(422, "GET", "v1/datasets/fake-id-{}/versions"),
(422, "GET", "v1/datasets/fake-id-{}/examples"),
(422, "GET", "v1/datasets/fake-id-{}/csv"),
(422, "GET", "v1/datasets/fake-id-{}/jsonl/openai_ft"),
(422, "GET", "v1/datasets/fake-id-{}/jsonl/openai_evals"),
# Experiments
(422, "GET", "v1/experiments/fake-id-{}"),
(422, "GET", "v1/datasets/fake-id-{}/experiments"),
(422, "GET", "v1/experiments/fake-id-{}/runs"),
(422, "GET", "v1/experiments/fake-id-{}/json"),
(422, "GET", "v1/experiments/fake-id-{}/csv"),
# Prompts
(200, "GET", "v1/prompts"),
(200, "GET", "v1/prompts/fake-id-{}/versions"),
(422, "GET", "v1/prompt_versions/fake-id-{}"),
(404, "GET", "v1/prompts/fake-id-{}/tags/test-tag"),
(404, "GET", "v1/prompts/fake-id-{}/latest"),
(422, "GET", "v1/prompt_versions/fake-id-{}/tags"),
# Annotation configs
(200, "GET", "v1/annotation_configs"),
(404, "GET", "v1/annotation_configs/fake-id-{}"),
# Evaluations
(404, "GET", "v1/evaluations"),
# Spans (project-scoped)
(404, "GET", "v1/projects/fake-id-{}/spans"),
(404, "GET", "v1/projects/fake-id-{}/spans/otlpv1"),
# Annotations (project-scoped)
(422, "GET", "v1/projects/fake-id-{}/span_annotations"),
(422, "GET", "v1/projects/fake-id-{}/trace_annotations"),
(422, "GET", "v1/projects/fake-id-{}/session_annotations"),
# Spans
(422, "GET", "v1/spans"),
)
# Admin-only endpoints (user management, project CRUD)
# Non-admins always receive 403, admins get expected_admin_status
_ADMIN_ONLY_ENDPOINTS = (
(200, "GET", "v1/users"),
(422, "POST", "v1/users"),
(422, "DELETE", "v1/users/fake-id-{}"),
(422, "PUT", "v1/projects/fake-id-{}"),
(404, "DELETE", "v1/projects/fake-id-{}"),
)
# Write operations blocked for viewers (POST/PUT/DELETE)
# Viewers always receive 403, non-viewers (admins/members) get expected_non_viewer_status
_VIEWER_BLOCKED_WRITE_OPERATIONS = (
# POST routes
(422, "POST", "v1/annotation_configs"),
(400, "POST", "v1/datasets/upload"),
(422, "POST", "v1/datasets/fake-id-{}/experiments"),
(422, "POST", "v1/document_annotations"),
(415, "POST", "v1/evaluations"),
(422, "POST", "v1/experiment_evaluations"),
(422, "POST", "v1/experiments/fake-id-{}/runs"),
(422, "POST", "v1/projects"),
(422, "POST", "v1/projects/fake-id-{}/spans"),
(422, "POST", "v1/prompts"),
(422, "POST", "v1/prompt_versions/fake-id-{}/tags"),
(422, "POST", "v1/session_annotations"),
(422, "POST", "v1/span_annotations"),
(422, "POST", "v1/spans"),
(422, "POST", "v1/trace_annotations"),
(415, "POST", "v1/traces"),
# PUT routes
(422, "PUT", "v1/annotation_configs/fake-id-{}"),
# DELETE routes
(422, "DELETE", "v1/annotation_configs/fake-id-{}"),
(422, "DELETE", "v1/datasets/fake-id-{}"),
(404, "DELETE", "v1/spans/fake-id-{}"),
(404, "DELETE", "v1/traces/fake-id-{}"),
)
def _ensure_endpoint_coverage_is_exhaustive() -> None:
"""Verify that test constants cover all actual v1 API routes.
This runs at module import time as a prerequisite check. If endpoint
coverage is incomplete, all tests that import this module will fail fast.
"""
import re
from fastapi.routing import APIRoute
from phoenix.server.api.routers.v1 import create_v1_router
# Get all actual routes from the v1 router
router = create_v1_router(authentication_enabled=False)
actual_routes = {
(method, route.path)
for route in router.routes
if isinstance(route, APIRoute)
for method in route.methods
}
# Get all routes from test constants
test_routes = {
(method, endpoint)
for _, method, endpoint in chain(
_COMMON_RESOURCE_ENDPOINTS,
_ADMIN_ONLY_ENDPOINTS,
_VIEWER_BLOCKED_WRITE_OPERATIONS,
)
}
# Normalize paths: server uses {param_name}, tests use fake-id-{}
def normalize_path(path: str) -> str:
if not path.startswith("/"):
path = "/" + path
path = re.sub(r"fake-id-\{\}", "{id}", path)
path = re.sub(r"\{[^}]*\}", "{id}", path)
path = re.sub(r"/tags/test-tag$", "/tags/{id}", path)
return path
# Map normalized paths back to original paths for error reporting
normalized_to_actual = {(m, normalize_path(p)): (m, p) for m, p in actual_routes}
normalized_to_test = {(m, normalize_path(p)): (m, p) for m, p in test_routes}
normalized_actual = set(normalized_to_actual.keys())
normalized_test = set(normalized_to_test.keys())
# Check for discrepancies
missing_in_tests = normalized_actual - normalized_test
extra_in_tests = normalized_test - normalized_actual
if missing_in_tests or extra_in_tests:
error_parts = []
if missing_in_tests:
# Show actual server paths (not normalized)
actual_paths = [normalized_to_actual[route] for route in sorted(missing_in_tests)]
routes_str = "\n".join(f" {m} {p}" for m, p in actual_paths)
error_parts.append(
f"Routes in server but NOT in test constants:\n{routes_str}\n\n"
f"Add these to _helpers.py:\n"
f" - GET routes → _COMMON_RESOURCE_ENDPOINTS\n"
f" - Admin-only routes (users, project CRUD) → _ADMIN_ONLY_ENDPOINTS\n"
f" - Write operations (POST/PUT/DELETE) → _VIEWER_BLOCKED_WRITE_OPERATIONS\n\n"
f"Format: (expected_status_code, method, endpoint_path)\n"
f'Example: (404, "GET", "v1/projects/fake-id-{{}}") or (422, "POST", "v1/datasets/upload")'
)
if extra_in_tests:
# Show actual test paths (not normalized)
test_paths = [normalized_to_test[route] for route in sorted(extra_in_tests)]
routes_str = "\n".join(f" {m} {p}" for m, p in test_paths)
error_parts.append(
f"Routes in test constants but NOT in server (removed?):\n{routes_str}\n\n"
f"Remove these from _COMMON_RESOURCE_ENDPOINTS, _ADMIN_ONLY_ENDPOINTS,\n"
f"or _VIEWER_BLOCKED_WRITE_OPERATIONS in _helpers.py"
)
raise AssertionError("Endpoint coverage is incomplete!\n\n" + "\n\n".join(error_parts))
_ensure_endpoint_coverage_is_exhaustive()