import gzip
import inspect
import io
import json
from io import BytesIO, StringIO
from typing import Any
import httpx
import pandas as pd
import pyarrow as pa
import pytest
from httpx import HTTPStatusError
from pandas.testing import assert_frame_equal
from sqlalchemy import select
from strawberry.relay import GlobalID
from phoenix.db import models
from phoenix.server.api.types.Dataset import Dataset
from phoenix.server.api.types.DatasetVersion import DatasetVersion
from phoenix.server.types import DbSessionFactory
async def test_get_simple_dataset(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(0))
response = await httpx_client.get(f"/v1/datasets/{global_id}")
assert response.status_code == 200
dataset_json = response.json()["data"]
assert "created_at" in dataset_json
assert "updated_at" in dataset_json
fixture_values = {
"id": str(global_id),
"name": "simple dataset",
"description": None,
"metadata": {"info": "a test dataset"},
"example_count": 1,
}
assert all(item in dataset_json.items() for item in fixture_values.items())
async def test_get_empty_dataset(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(1))
response = await httpx_client.get(f"/v1/datasets/{global_id}")
assert response.status_code == 200
dataset_json = response.json()["data"]
assert "created_at" in dataset_json
assert "updated_at" in dataset_json
fixture_values = {
"id": str(global_id),
"name": "empty dataset",
"description": "emptied after two revisions",
"metadata": {},
"example_count": 0,
}
assert all(item in dataset_json.items() for item in fixture_values.items())
async def test_get_dataset_with_revisions(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(f"/v1/datasets/{global_id}")
assert response.status_code == 200
dataset_json = response.json()["data"]
assert "created_at" in dataset_json
assert "updated_at" in dataset_json
fixture_values = {
"id": str(global_id),
"name": "revised dataset",
"description": "this dataset grows over time",
"metadata": {},
"example_count": 3,
}
assert all(item in dataset_json.items() for item in fixture_values.items())
async def test_list_datasets(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
empty_dataset: Any,
dataset_with_revisions: Any,
) -> None:
response = await httpx_client.get("/v1/datasets")
assert response.status_code == 200
datasets_json = response.json()
assert datasets_json["next_cursor"] is None, "no next cursor when all datasets are returned"
datasets = datasets_json["data"]
assert len(datasets) == 3
# datasets are returned in reverse order of insertion
assert "created_at" in datasets[0]
assert "updated_at" in datasets[0]
fixture_values: dict[str, Any] = {
"id": str(GlobalID("Dataset", str(2))),
"name": "revised dataset",
"description": "this dataset grows over time",
"metadata": {},
}
assert all(item in datasets[0].items() for item in fixture_values.items())
assert "created_at" in datasets[1]
assert "updated_at" in datasets[1]
fixture_values = {
"id": str(GlobalID("Dataset", str(1))),
"name": "empty dataset",
"description": "emptied after two revisions",
"metadata": {},
}
assert all(item in datasets[1].items() for item in fixture_values.items())
assert "created_at" in datasets[2]
assert "updated_at" in datasets[2]
fixture_values = {
"id": str(GlobalID("Dataset", str(0))),
"name": "simple dataset",
"description": None,
"metadata": {"info": "a test dataset"},
}
assert all(item in datasets[2].items() for item in fixture_values.items())
async def test_list_fewer_datasets(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
empty_dataset: Any,
) -> None:
response = await httpx_client.get("/v1/datasets")
assert response.status_code == 200
datasets_json = response.json()
assert datasets_json["next_cursor"] is None, "no next cursor when all datasets are returned"
datasets = datasets_json["data"]
assert len(datasets) == 2
# datasets are returned in reverse order of insertion
assert "created_at" in datasets[0]
assert "updated_at" in datasets[0]
fixture_values: dict[str, Any] = {
"id": str(GlobalID("Dataset", str(1))),
"name": "empty dataset",
"description": "emptied after two revisions",
"metadata": {},
}
assert all(item in datasets[0].items() for item in fixture_values.items())
assert "created_at" in datasets[1]
assert "updated_at" in datasets[1]
fixture_values = {
"id": str(GlobalID("Dataset", str(0))),
"name": "simple dataset",
"description": None,
"metadata": {"info": "a test dataset"},
}
assert all(item in datasets[1].items() for item in fixture_values.items())
async def test_list_datasets_with_cursor(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
empty_dataset: Any,
dataset_with_revisions: Any,
) -> None:
response = await httpx_client.get("/v1/datasets", params={"limit": 2})
assert response.status_code == 200
datasets_json = response.json()
next_cursor = datasets_json["next_cursor"]
assert next_cursor, "next_cursor supplied when datasets remain"
datasets = datasets_json["data"]
assert len(datasets) == 2, "only return two datasets when limit is set to 2"
# datasets are returned in reverse order of insertion
assert "created_at" in datasets[0]
assert "updated_at" in datasets[0]
fixture_values: dict[str, Any] = {
"id": str(GlobalID("Dataset", str(2))),
"name": "revised dataset",
"description": "this dataset grows over time",
"metadata": {},
}
assert all(item in datasets[0].items() for item in fixture_values.items())
assert "created_at" in datasets[1]
assert "updated_at" in datasets[1]
fixture_values = {
"id": str(GlobalID("Dataset", str(1))),
"name": "empty dataset",
"description": "emptied after two revisions",
"metadata": {},
}
assert all(item in datasets[1].items() for item in fixture_values.items())
second_page = await httpx_client.get("/v1/datasets", params={"limit": 2, "cursor": next_cursor})
assert second_page.status_code == 200
second_page_json = second_page.json()
assert second_page_json["next_cursor"] is None, "no next cursor after all datasets are returned"
second_page_datasets = second_page_json["data"]
assert len(second_page_datasets) == 1, "only return one dataset on the second page"
assert "created_at" in second_page_datasets[0]
assert "updated_at" in second_page_datasets[0]
fixture_values = {
"id": str(GlobalID("Dataset", str(0))),
"name": "simple dataset",
"description": None,
"metadata": {"info": "a test dataset"},
}
assert all(item in second_page_datasets[0].items() for item in fixture_values.items())
async def test_get_dataset_versions(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(f"/v1/datasets/{dataset_global_id}/versions?limit=2")
assert response.status_code == 200
assert response.headers.get("content-type") == "application/json"
assert response.json() == {
"next_cursor": f"{GlobalID('DatasetVersion', str(7))}",
"data": [
{
"version_id": str(GlobalID("DatasetVersion", str(9))),
"description": "datum gets deleted",
"metadata": {},
"created_at": "2024-05-28T00:00:09+00:00",
},
{
"version_id": str(GlobalID("DatasetVersion", str(8))),
"description": "datum gets created",
"metadata": {},
"created_at": "2024-05-28T00:00:08+00:00",
},
],
}
async def test_get_dataset_versions_with_cursor(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/versions?limit=2"
f"&cursor={GlobalID('DatasetVersion', str(4))}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "application/json"
assert response.json() == {
"next_cursor": None,
"data": [
{
"version_id": str(GlobalID("DatasetVersion", str(4))),
"created_at": "2024-05-28T00:00:04+00:00",
"description": "data gets added",
"metadata": {"info": "gotta get some test data somewhere"},
},
],
}
async def test_get_dataset_versions_with_nonexistent_cursor(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/versions?limit=1"
f"&cursor={GlobalID('DatasetVersion', str(1))}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "application/json"
assert response.json() == {"next_cursor": None, "data": []}
async def test_get_dataset_download_empty_dataset(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(1))
response = await httpx_client.get(f"/v1/datasets/{dataset_global_id}/csv")
assert response.status_code == 200
assert response.headers.get("content-type") == "text/csv"
assert response.headers.get("content-encoding") == "gzip"
assert (
response.headers.get("content-disposition")
== "attachment; filename*=UTF-8''empty%20dataset.csv"
)
with pytest.raises(Exception):
pd.read_csv(StringIO(response.content.decode()))
async def test_get_dataset_download_nonexistent_version(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(1))
dataset_version_global_id = GlobalID("DatasetVersion", str(4)) # Version for Dataset id=2
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/csv?version_id={dataset_version_global_id}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "text/csv"
assert response.headers.get("content-encoding") == "gzip"
assert (
response.headers.get("content-disposition")
== "attachment; filename*=UTF-8''empty%20dataset.csv"
)
with pytest.raises(Exception):
pd.read_csv(StringIO(response.content.decode()))
async def test_get_dataset_download_latest_version(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(f"/v1/datasets/{dataset_global_id}/csv")
assert response.status_code == 200
assert response.headers.get("content-type") == "text/csv"
assert response.headers.get("content-encoding") == "gzip"
assert (
response.headers.get("content-disposition")
== "attachment; filename*=UTF-8''revised%20dataset.csv"
)
actual = pd.read_csv(StringIO(response.content.decode())).sort_index(axis=1)
expected = pd.read_csv(
StringIO(
"example_id,input_in,metadata_info,output_out\n"
"RGF0YXNldEV4YW1wbGU6Mw==,foo,first revision,bar\n"
"RGF0YXNldEV4YW1wbGU6NA==,updated foofoo,updating revision,updated barbar\n"
"RGF0YXNldEV4YW1wbGU6NQ==,look at me,a new example,i have all the answers\n"
)
).sort_index(axis=1)
assert_frame_equal(actual, expected)
async def test_get_dataset_download_specific_version(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
dataset_global_id = GlobalID("Dataset", str(2))
dataset_version_global_id = GlobalID("DatasetVersion", str(8))
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/csv?version_id={dataset_version_global_id}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "text/csv"
assert response.headers.get("content-encoding") == "gzip"
assert (
response.headers.get("content-disposition")
== "attachment; filename*=UTF-8''revised%20dataset.csv"
)
actual = pd.read_csv(StringIO(response.content.decode())).sort_index(axis=1)
expected = pd.read_csv(
StringIO(
"example_id,input_in,metadata_info,output_out\n"
"RGF0YXNldEV4YW1wbGU6Mw==,foo,first revision,bar\n"
"RGF0YXNldEV4YW1wbGU6NA==,updated foofoo,updating revision,updated barbar\n"
"RGF0YXNldEV4YW1wbGU6NQ==,look at me,a new example,i have all the answers\n"
"RGF0YXNldEV4YW1wbGU6Nw==,look at me,a newer example,i have all the answers\n"
)
).sort_index(axis=1)
assert_frame_equal(actual, expected)
async def test_get_dataset_jsonl_openai_ft(
httpx_client: httpx.AsyncClient,
dataset_with_messages: tuple[int, int],
) -> None:
dataset_id, dataset_version_id = dataset_with_messages
dataset_global_id = GlobalID(Dataset.__name__, str(dataset_id))
dataset_version_global_id = GlobalID(DatasetVersion.__name__, str(dataset_version_id))
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/jsonl/openai_ft?version_id={dataset_version_global_id}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "text/plain; charset=utf-8"
assert response.headers.get("content-encoding") == "gzip"
assert response.headers.get("content-disposition") == "attachment; filename*=UTF-8''xyz.jsonl"
json_lines = io.StringIO(response.text).readlines()
assert len(json_lines) == 2
assert json.loads(json_lines[0]) == {
"messages": [
{"role": "system", "content": "x"},
{"role": "user", "content": "y"},
{"role": "assistant", "content": "z"},
]
}
assert json.loads(json_lines[1]) == {
"messages": [
{"role": "system", "content": "xx"},
{"role": "user", "content": "yy"},
{"role": "assistant", "content": "zz"},
]
}
async def test_get_dataset_jsonl_openai_evals(
httpx_client: httpx.AsyncClient, dataset_with_messages: tuple[int, int]
) -> None:
dataset_id, dataset_version_id = dataset_with_messages
dataset_global_id = GlobalID(Dataset.__name__, str(dataset_id))
dataset_version_global_id = GlobalID(DatasetVersion.__name__, str(dataset_version_id))
response = await httpx_client.get(
f"/v1/datasets/{dataset_global_id}/jsonl/openai_evals?version_id={dataset_version_global_id}"
)
assert response.status_code == 200
assert response.headers.get("content-type") == "text/plain; charset=utf-8"
assert response.headers.get("content-encoding") == "gzip"
assert response.headers.get("content-disposition") == "attachment; filename*=UTF-8''xyz.jsonl"
json_lines = io.StringIO(response.text).readlines()
assert len(json_lines) == 2
assert json.loads(json_lines[0]) == {
"messages": [{"role": "system", "content": "x"}, {"role": "user", "content": "y"}],
"ideal": "z",
}
assert json.loads(json_lines[1]) == {
"messages": [{"role": "system", "content": "xx"}, {"role": "user", "content": "yy"}],
"ideal": "zz",
}
async def test_post_dataset_upload_json_create_then_append(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
name = inspect.stack()[0][3]
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name,
"inputs": [{"a": 1, "b": 2, "c": 3}],
"outputs": [{"b": "2", "c": "3", "d": "4"}],
"metadata": [{"c": 3, "d": 4, "e": 5}],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
del response, data
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "append",
"name": name,
"inputs": [{"a": 11, "b": 22, "c": 33}],
"outputs": [{"b": "22", "c": "33", "d": "44"}],
"metadata": [],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert dataset_id == data.get("dataset_id")
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
async with db() as session:
revisions = list(
await session.scalars(
select(models.DatasetExampleRevision)
.join(models.DatasetExample)
.join_from(models.DatasetExample, models.Dataset)
.where(models.Dataset.name == name)
.order_by(models.DatasetExample.id)
)
)
assert len(revisions) == 2
assert revisions[0].input == {"a": 1, "b": 2, "c": 3}
assert revisions[0].output == {"b": "2", "c": "3", "d": "4"}
assert revisions[0].metadata_ == {"c": 3, "d": 4, "e": 5}
assert revisions[1].input == {"a": 11, "b": 22, "c": 33}
assert revisions[1].output == {"b": "22", "c": "33", "d": "44"}
assert revisions[1].metadata_ == {}
# Verify the DatasetVersion from the response
db_dataset_version = await session.get(models.DatasetVersion, int(version_global_id.node_id))
assert db_dataset_version is not None
assert db_dataset_version.dataset_id == int(GlobalID.from_id(dataset_id).node_id)
async def test_post_dataset_upload_csv_create_then_append(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
name = inspect.stack()[0][3]
file = gzip.compress(b"a,b,c,d,e,f\n1,2,3,4,5,6\n")
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "text/csv", {"Content-Encoding": "gzip"})},
data={
"action": "create",
"name": name,
"input_keys[]": ["a", "b", "c"],
"output_keys[]": ["b", "c", "d"],
"metadata_keys[]": ["c", "d", "e"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
del response, file, data
file = gzip.compress(b"a,b,c,d,e,f\n11,22,33,44,55,66\n")
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "text/csv", {"Content-Encoding": "gzip"})},
data={
"action": "append",
"name": name,
"input_keys[]": ["a", "b", "c"],
"output_keys[]": ["b", "c", "d"],
"metadata_keys[]": ["c", "d", "e"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert dataset_id == data.get("dataset_id")
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
async with db() as session:
revisions = list(
await session.scalars(
select(models.DatasetExampleRevision)
.join(models.DatasetExample)
.join_from(models.DatasetExample, models.Dataset)
.where(models.Dataset.name == name)
.order_by(models.DatasetExample.id)
)
)
assert len(revisions) == 2
assert revisions[0].input == {"a": "1", "b": "2", "c": "3"}
assert revisions[0].output == {"b": "2", "c": "3", "d": "4"}
assert revisions[0].metadata_ == {"c": "3", "d": "4", "e": "5"}
assert revisions[1].input == {"a": "11", "b": "22", "c": "33"}
assert revisions[1].output == {"b": "22", "c": "33", "d": "44"}
assert revisions[1].metadata_ == {"c": "33", "d": "44", "e": "55"}
# Verify the DatasetVersion from the response
db_dataset_version = await session.get(models.DatasetVersion, int(version_global_id.node_id))
assert db_dataset_version is not None
assert db_dataset_version.dataset_id == int(GlobalID.from_id(dataset_id).node_id)
async def test_post_dataset_upload_pyarrow_create_then_append(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
name = inspect.stack()[0][3]
df = pd.read_csv(StringIO("a,b,c,d,e,f\n1,2,3,4,5,6\n"))
table = pa.Table.from_pandas(df)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
file = BytesIO(sink.getvalue().to_pybytes())
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "application/x-pandas-pyarrow", {})},
data={
"action": "create",
"name": name,
"input_keys[]": ["a", "b", "c"],
"output_keys[]": ["b", "c", "d"],
"metadata_keys[]": ["c", "d", "e"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
del response, file, data, df, table, sink
df = pd.read_csv(StringIO("a,b,c,d,e,f\n11,22,33,44,55,66\n"))
table = pa.Table.from_pandas(df)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
file = BytesIO(sink.getvalue().to_pybytes())
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "application/x-pandas-pyarrow", {})},
data={
"action": "append",
"name": name,
"input_keys[]": ["a", "b", "c"],
"output_keys[]": ["b", "c", "d"],
"metadata_keys[]": ["c", "d", "e"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert dataset_id == data.get("dataset_id")
assert "version_id" in data
version_id_str = data["version_id"]
version_global_id = GlobalID.from_id(version_id_str)
assert version_global_id.type_name == "DatasetVersion"
async with db() as session:
revisions = list(
await session.scalars(
select(models.DatasetExampleRevision)
.join(models.DatasetExample)
.join_from(models.DatasetExample, models.Dataset)
.where(models.Dataset.name == name)
.order_by(models.DatasetExample.id)
)
)
assert len(revisions) == 2
assert revisions[0].input == {"a": 1, "b": 2, "c": 3}
assert revisions[0].output == {"b": 2, "c": 3, "d": 4}
assert revisions[0].metadata_ == {"c": 3, "d": 4, "e": 5}
assert revisions[1].input == {"a": 11, "b": 22, "c": 33}
assert revisions[1].output == {"b": 22, "c": 33, "d": 44}
assert revisions[1].metadata_ == {"c": 33, "d": 44, "e": 55}
# Verify the DatasetVersion from the response
db_dataset_version = await session.get(models.DatasetVersion, int(version_global_id.node_id))
assert db_dataset_version is not None
assert db_dataset_version.dataset_id == int(GlobalID.from_id(dataset_id).node_id)
async def test_delete_dataset(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
) -> None:
url = f"v1/datasets/{GlobalID(Dataset.__name__, str(1))}"
assert len((await httpx_client.get("v1/datasets")).json()["data"]) > 0
(await httpx_client.delete(url)).raise_for_status()
assert len((await httpx_client.get("v1/datasets")).json()["data"]) == 0
with pytest.raises(HTTPStatusError):
(await httpx_client.delete(url)).raise_for_status()
async def test_get_dataset_examples_404s_with_nonexistent_dataset_id(
httpx_client: httpx.AsyncClient,
) -> None:
global_id = GlobalID("Dataset", str(0))
response = await httpx_client.get(f"/v1/datasets/{global_id}/examples")
assert response.status_code == 404
assert response.content.decode() == f"No dataset with id {global_id} can be found."
async def test_get_dataset_examples_404s_with_invalid_global_id(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("InvalidDataset", str(0))
response = await httpx_client.get(f"/v1/datasets/{global_id}/examples")
assert response.status_code == 404
assert "refers to a InvalidDataset" in response.content.decode()
async def test_get_dataset_examples_404s_with_nonexistent_version_id(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(0))
version_id = GlobalID("DatasetVersion", str(99))
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)}
)
assert response.status_code == 404
assert response.content.decode() == f"No dataset version with id {version_id} can be found."
async def test_get_dataset_examples_404s_with_invalid_version_global_id(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(0))
version_id = GlobalID("InvalidDatasetVersion", str(0))
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(version_id)}
)
assert response.status_code == 404
assert "refers to a InvalidDatasetVersion" in response.content.decode()
async def test_get_simple_dataset_examples(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(0))
response = await httpx_client.get(f"/v1/datasets/{global_id}/examples")
assert response.status_code == 200
result = response.json()
data = result["data"]
assert data["dataset_id"] == str(GlobalID("Dataset", str(0)))
assert data["version_id"] == str(GlobalID("DatasetVersion", str(0)))
examples = data["examples"]
assert len(examples) == 1
expected_examples = [
{
"id": str(GlobalID("DatasetExample", str(0))),
"input": {"in": "foo"},
"output": {"out": "bar"},
"metadata": {"info": "the first reivision"},
}
]
for example, expected in zip(examples, expected_examples):
assert "updated_at" in example
example_subset = {k: v for k, v in example.items() if k in expected}
assert example_subset == expected
async def test_list_simple_dataset_examples_at_each_version(
httpx_client: httpx.AsyncClient,
simple_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(0))
v0 = GlobalID("DatasetVersion", str(0))
# one example is created in version 0
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v0)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 1
async def test_list_empty_dataset_examples(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(1))
response = await httpx_client.get(f"/v1/datasets/{global_id}/examples")
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 0
async def test_list_empty_dataset_examples_at_each_version(
httpx_client: httpx.AsyncClient,
empty_dataset: Any,
) -> None:
global_id = GlobalID("Dataset", str(1))
v1 = GlobalID("DatasetVersion", str(1))
v2 = GlobalID("DatasetVersion", str(2))
v3 = GlobalID("DatasetVersion", str(3))
# two examples are created in version 1
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v1)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 2
# two examples are patched in version 2
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v2)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 2
# two examples are deleted in version 3
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v3)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 0
async def test_list_dataset_with_revisions_examples(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
global_id = GlobalID("Dataset", str(2))
response = await httpx_client.get(f"/v1/datasets/{global_id}/examples")
assert response.status_code == 200
result = response.json()
data = result["data"]
assert data["dataset_id"] == str(GlobalID("Dataset", str(2)))
assert data["version_id"] == str(GlobalID("DatasetVersion", str(9)))
examples = data["examples"]
assert len(examples) == 3
expected_values = [
{
"id": str(GlobalID("DatasetExample", str(3))),
"input": {"in": "foo"},
"output": {"out": "bar"},
"metadata": {"info": "first revision"},
},
{
"id": str(GlobalID("DatasetExample", str(4))),
"input": {"in": "updated foofoo"},
"output": {"out": "updated barbar"},
"metadata": {"info": "updating revision"},
},
{
"id": str(GlobalID("DatasetExample", str(5))),
"input": {"in": "look at me"},
"output": {"out": "i have all the answers"},
"metadata": {"info": "a new example"},
},
]
for example, expected in zip(examples, expected_values):
assert "updated_at" in example
example_subset = {k: v for k, v in example.items() if k in expected}
assert example_subset == expected
async def test_list_dataset_with_revisions_examples_at_each_version(
httpx_client: httpx.AsyncClient,
dataset_with_revisions: Any,
) -> None:
global_id = GlobalID("Dataset", str(2))
v4 = GlobalID("DatasetVersion", str(4))
v5 = GlobalID("DatasetVersion", str(5))
v6 = GlobalID("DatasetVersion", str(6))
v7 = GlobalID("DatasetVersion", str(7))
v8 = GlobalID("DatasetVersion", str(8))
v9 = GlobalID("DatasetVersion", str(9))
# two examples are created in version 4
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v4)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 2
# two examples are patched in version 5
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v5)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 3
# one example is added in version 6
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v6)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 4
# one example is deleted in version 7
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v7)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 3
# one example is added in version 8
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v8)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 4
# one example is deleted in version 9
response = await httpx_client.get(
f"/v1/datasets/{global_id}/examples", params={"version_id": str(v9)}
)
assert response.status_code == 200
result = response.json()
data = result["data"]
assert len(data["examples"]) == 3
async def test_post_dataset_upload_json_with_splits(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
"""Test JSON upload with various split formats: string, list, null, and mixed."""
name = inspect.stack()[0][3]
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name,
"inputs": [{"q": "Q1"}, {"q": "Q2"}, {"q": "Q3"}, {"q": "Q4"}, {"q": "Q5"}],
"outputs": [{"a": "A1"}, {"a": "A2"}, {"a": "A3"}, {"a": "A4"}, {"a": "A5"}],
"splits": [
"train", # Single string
["test", "hard"], # List of strings
None, # No splits
["validate", None, "medium", ""], # List with nulls/empty (should filter)
" ", # Whitespace-only (should filter)
],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
async with db() as session:
# Verify correct splits were created (empty/whitespace/nulls filtered)
splits = list(
await session.scalars(select(models.DatasetSplit).order_by(models.DatasetSplit.name))
)
assert set(s.name for s in splits) == {"train", "test", "hard", "validate", "medium"}
dataset_db_id = int(GlobalID.from_id(dataset_id).node_id)
examples = list(
await session.scalars(
select(models.DatasetExample)
.where(models.DatasetExample.dataset_id == dataset_db_id)
.order_by(models.DatasetExample.id)
)
)
assert len(examples) == 5
# Helper to get split names for an example
async def get_example_splits(example_id: int) -> set[str]:
result = await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == example_id)
)
return {s.name for s in result}
assert await get_example_splits(examples[0].id) == {"train"}
assert await get_example_splits(examples[1].id) == {"test", "hard"}
assert await get_example_splits(examples[2].id) == set() # None -> no splits
assert await get_example_splits(examples[3].id) == {"validate", "medium"} # nulls filtered
assert await get_example_splits(examples[4].id) == set() # Whitespace-only -> no splits
async def test_post_dataset_upload_csv_with_splits(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
"""Test CSV upload with split_keys, including whitespace stripping."""
name = inspect.stack()[0][3]
file = gzip.compress(
b"question,answer,data_split,category\n"
b"Q1,A1, train ,general\n" # Whitespace should be stripped
b"Q2,A2,test,technical\n"
)
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "text/csv", {"Content-Encoding": "gzip"})},
data={
"action": "create",
"name": name,
"input_keys[]": ["question"],
"output_keys[]": ["answer"],
"split_keys[]": ["data_split", "category"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
async with db() as session:
splits = list(
await session.scalars(select(models.DatasetSplit).order_by(models.DatasetSplit.name))
)
# Verify whitespace was stripped and splits have default color
assert set(s.name for s in splits) == {"train", "test", "general", "technical"}
assert " train " not in [s.name for s in splits]
assert all(s.color == "#808080" for s in splits)
# Verify example assignments
dataset_db_id = int(GlobalID.from_id(dataset_id).node_id)
examples = list(
await session.scalars(
select(models.DatasetExample)
.where(models.DatasetExample.dataset_id == dataset_db_id)
.order_by(models.DatasetExample.id)
)
)
assert len(examples) == 2
async def get_example_splits(example_id: int) -> set[str]:
result = await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == example_id)
)
return {s.name for s in result}
assert await get_example_splits(examples[0].id) == {"train", "general"}
assert await get_example_splits(examples[1].id) == {"test", "technical"}
async def test_post_dataset_upload_pyarrow_with_splits(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
"""Test PyArrow upload with split_keys."""
name = inspect.stack()[0][3]
df = pd.read_csv(StringIO("question,answer,data_split\nQ1,A1,train\nQ2,A2,test\n"))
table = pa.Table.from_pandas(df)
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
file = BytesIO(sink.getvalue().to_pybytes())
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
files={"file": (" ", file, "application/x-pandas-pyarrow", {})},
data={
"action": "create",
"name": name,
"input_keys[]": ["question"],
"output_keys[]": ["answer"],
"split_keys[]": ["data_split"],
},
)
assert response.status_code == 200
assert (data := response.json().get("data"))
assert (dataset_id := data.get("dataset_id"))
async with db() as session:
splits = list(
await session.scalars(select(models.DatasetSplit).order_by(models.DatasetSplit.name))
)
assert set(s.name for s in splits) == {"train", "test"}
# Verify example assignments
dataset_db_id = int(GlobalID.from_id(dataset_id).node_id)
examples = list(
await session.scalars(
select(models.DatasetExample)
.where(models.DatasetExample.dataset_id == dataset_db_id)
.order_by(models.DatasetExample.id)
)
)
assert len(examples) == 2
async def get_example_splits(example_id: int) -> set[str]:
result = await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == example_id)
)
return {s.name for s in result}
assert await get_example_splits(examples[0].id) == {"train"}
assert await get_example_splits(examples[1].id) == {"test"}
async def test_post_dataset_upload_reuses_existing_splits(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
"""Test that uploading datasets reuses existing splits instead of creating duplicates."""
name1 = "dataset_with_split_1"
name2 = "dataset_with_split_2"
# Create first dataset with split
response1 = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name1,
"inputs": [{"question": "Q1"}],
"outputs": [{"answer": "A1"}],
"splits": ["train"],
},
)
assert response1.status_code == 200
dataset1_id = response1.json()["data"]["dataset_id"]
# Get split count
async with db() as session:
splits_before = list(
await session.scalars(
select(models.DatasetSplit).where(models.DatasetSplit.name == "train")
)
)
assert len(splits_before) == 1
train_split_id = splits_before[0].id
# Create second dataset with same split name
response2 = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name2,
"inputs": [{"question": "Q2"}],
"outputs": [{"answer": "A2"}],
"splits": ["train"],
},
)
assert response2.status_code == 200
dataset2_id = response2.json()["data"]["dataset_id"]
# Verify split was reused, not duplicated
async with db() as session:
splits_after = list(
await session.scalars(
select(models.DatasetSplit).where(models.DatasetSplit.name == "train")
)
)
assert len(splits_after) == 1
assert splits_after[0].id == train_split_id # Same split ID
# Verify both datasets' examples are assigned to the split
dataset1_db_id = int(GlobalID.from_id(dataset1_id).node_id)
dataset2_db_id = int(GlobalID.from_id(dataset2_id).node_id)
# Get examples from both datasets
dataset1_examples = list(
await session.scalars(
select(models.DatasetExample).where(
models.DatasetExample.dataset_id == dataset1_db_id
)
)
)
dataset2_examples = list(
await session.scalars(
select(models.DatasetExample).where(
models.DatasetExample.dataset_id == dataset2_db_id
)
)
)
assert len(dataset1_examples) == 1
assert len(dataset2_examples) == 1
# Verify first dataset's example is assigned to train split
dataset1_splits = list(
await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(
models.DatasetSplitDatasetExample.dataset_example_id == dataset1_examples[0].id
)
)
)
assert len(dataset1_splits) == 1
assert dataset1_splits[0].name == "train"
# Verify second dataset's example is also assigned to the same train split
dataset2_splits = list(
await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(
models.DatasetSplitDatasetExample.dataset_example_id == dataset2_examples[0].id
)
)
)
assert len(dataset2_splits) == 1
assert dataset2_splits[0].name == "train"
assert dataset2_splits[0].id == train_split_id # Same split instance
async def test_post_dataset_upload_rejects_invalid_split_formats(
httpx_client: httpx.AsyncClient,
) -> None:
"""Test that JSON upload rejects invalid split formats (dict, integer, boolean)."""
name = inspect.stack()[0][3]
# Test with dict split value (no longer supported)
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name,
"inputs": [{"question": "Q1"}],
"outputs": [{"answer": "A1"}],
"splits": [{"data_split": "train"}], # Dict format no longer supported
},
)
assert response.status_code == 422
assert "must be a string, list of strings, or None" in response.text
# Test with integer split value
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": f"{name}_int",
"inputs": [{"question": "Q1"}],
"outputs": [{"answer": "A1"}],
"splits": [123], # Integer not allowed
},
)
assert response.status_code == 422
assert "must be a string, list of strings, or None" in response.text
# Test with boolean split value
response = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": f"{name}_bool",
"inputs": [{"question": "Q1"}],
"outputs": [{"answer": "A1"}],
"splits": [True], # Boolean not allowed
},
)
assert response.status_code == 422
assert "must be a string, list of strings, or None" in response.text
async def test_post_dataset_upload_append_with_splits(
httpx_client: httpx.AsyncClient,
db: DbSessionFactory,
) -> None:
"""Test appending to a dataset with splits - both reusing existing and adding new splits."""
name = inspect.stack()[0][3]
# Create initial dataset with "train" split
response1 = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "create",
"name": name,
"inputs": [{"q": "Q1"}],
"outputs": [{"a": "A1"}],
"splits": ["train"],
},
)
assert response1.status_code == 200
dataset_id = response1.json()["data"]["dataset_id"]
# Get initial split info
async with db() as session:
train_split = await session.scalar(
select(models.DatasetSplit).where(models.DatasetSplit.name == "train")
)
assert train_split is not None
train_split_id = train_split.id
# Append to the same dataset with existing "train" split and new "test" split
response2 = await httpx_client.post(
url="v1/datasets/upload?sync=true",
json={
"action": "append",
"name": name,
"inputs": [{"q": "Q2"}, {"q": "Q3"}],
"outputs": [{"a": "A2"}, {"a": "A3"}],
"splits": ["train", "test"], # Q2 -> train (existing), Q3 -> test (new)
},
)
assert response2.status_code == 200
assert response2.json()["data"]["dataset_id"] == dataset_id # Same dataset
async with db() as session:
dataset_db_id = int(GlobalID.from_id(dataset_id).node_id)
# Verify we have 3 examples total
examples = list(
await session.scalars(
select(models.DatasetExample)
.where(models.DatasetExample.dataset_id == dataset_db_id)
.order_by(models.DatasetExample.id)
)
)
assert len(examples) == 3
# Verify train split was reused (same ID)
train_split_after = await session.scalar(
select(models.DatasetSplit).where(models.DatasetSplit.name == "train")
)
assert train_split_after is not None
assert train_split_after.id == train_split_id
# Verify test split was created
test_split = await session.scalar(
select(models.DatasetSplit).where(models.DatasetSplit.name == "test")
)
assert test_split is not None
# Verify first example (from create) is assigned to train
ex1_splits = list(
await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == examples[0].id)
)
)
assert len(ex1_splits) == 1
assert ex1_splits[0].name == "train"
# Verify second example (from append) is assigned to train
ex2_splits = list(
await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == examples[1].id)
)
)
assert len(ex2_splits) == 1
assert ex2_splits[0].name == "train"
assert ex2_splits[0].id == train_split_id # Same split instance as before
# Verify third example (from append) is assigned to test
ex3_splits = list(
await session.scalars(
select(models.DatasetSplit)
.join(models.DatasetSplitDatasetExample)
.where(models.DatasetSplitDatasetExample.dataset_example_id == examples[2].id)
)
)
assert len(ex3_splits) == 1
assert ex3_splits[0].name == "test"