Skip to main content
Glama

@arizeai/phoenix-mcp

Official
by Arize-ai
test_dataset_mutations.py29.6 kB
import textwrap from datetime import datetime from typing import Any import pytest import pytz from sqlalchemy import insert, select from strawberry.relay import GlobalID from phoenix.config import DEFAULT_PROJECT_NAME from phoenix.db import models from phoenix.server.api.types.DatasetExample import DatasetExample from phoenix.server.types import DbSessionFactory from tests.unit.graphql import AsyncGraphQLClient async def test_create_dataset( gql_client: AsyncGraphQLClient, ) -> None: create_dataset_mutation = """ mutation ($name: String!, $description: String!, $metadata: JSON!) { createDataset( input: {name: $name, description: $description, metadata: $metadata} ) { dataset { id name description metadata } } } """ response = await gql_client.execute( query=create_dataset_mutation, variables={ "name": "original-dataset-name", "description": "original-dataset-description", "metadata": {"original-metadata-key": "original-metadata-value"}, }, ) assert not response.errors assert response.data == { "createDataset": { "dataset": { "id": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": "original-dataset-name", "description": "original-dataset-description", "metadata": {"original-metadata-key": "original-metadata-value"}, } } } class TestPatchDatasetMutation: MUTATION = """ mutation ($datasetId: ID!, $name: String, $description: String, $metadata: JSON) { patchDataset( input: {datasetId: $datasetId, name: $name, description: $description, metadata: $metadata} ) { dataset { id name description metadata } } } """ async def test_patch_all_dataset_fields( self, gql_client: AsyncGraphQLClient, dataset_with_a_single_version: None, ) -> None: response = await gql_client.execute( query=self.MUTATION, variables={ "datasetId": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": "patched-dataset-name", "description": "patched-dataset-description", "metadata": {"patched-metadata-key": "patched-metadata-value"}, }, ) assert not response.errors assert response.data == { "patchDataset": { "dataset": { "id": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": "patched-dataset-name", "description": "patched-dataset-description", "metadata": {"patched-metadata-key": "patched-metadata-value"}, } } } async def test_only_description_field_can_be_set_to_null( self, gql_client: AsyncGraphQLClient, dataset_with_a_single_version: None, ) -> None: response = await gql_client.execute( query=self.MUTATION, variables={ "datasetId": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": None, "description": None, "metadata": None, }, ) assert not response.errors assert response.data == { "patchDataset": { "dataset": { "id": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": "dataset-name", "description": None, "metadata": {"dataset-metadata-key": "dataset-metadata-value"}, } } } async def test_updating_a_single_field_leaves_remaining_fields_unchannged( self, gql_client: AsyncGraphQLClient, dataset_with_a_single_version: None, ) -> None: response = await gql_client.execute( query=self.MUTATION, variables={ "datasetId": str(GlobalID(type_name="Dataset", node_id=str(1))), "description": "patched-dataset-description", }, ) assert not response.errors assert response.data == { "patchDataset": { "dataset": { "id": str(GlobalID(type_name="Dataset", node_id=str(1))), "name": "dataset-name", "description": "patched-dataset-description", "metadata": {"dataset-metadata-key": "dataset-metadata-value"}, } } } async def test_add_span_to_dataset( gql_client: AsyncGraphQLClient, empty_dataset: None, spans: list[models.Span], span_annotation: None, ) -> None: dataset_id = GlobalID(type_name="Dataset", node_id=str(1)) mutation = """ mutation ($datasetId: ID!, $spanIds: [ID!]!) { addSpansToDataset(input: {datasetId: $datasetId, spanIds: $spanIds}) { dataset { id examples { edges { example: node { revision { input output metadata } } } } } } } """ response = await gql_client.execute( query=mutation, variables={ "datasetId": str(dataset_id), "spanIds": [str(GlobalID(type_name="Span", node_id=str(span.id))) for span in spans], }, ) assert not response.errors assert response.data == { "addSpansToDataset": { "dataset": { "id": str(dataset_id), "examples": { "edges": [ { "example": { "revision": { "input": { "messages": [ {"content": "user-message-content", "role": "user"} ] }, "metadata": { "span_kind": "LLM", "annotations": {}, }, "output": { "messages": [ { "content": "assistant-message-content", "role": "assistant", } ] }, } } }, { "example": { "revision": { "input": {"input": "retriever-span-input"}, "output": { "documents": [ { "content": "retrieved-document-content", "score": 1, } ] }, "metadata": { "span_kind": "RETRIEVER", "annotations": {}, }, } } }, { "example": { "revision": { "input": {"input": "chain-span-input-value"}, "output": {"output": "chain-span-output-value"}, "metadata": { "span_kind": "CHAIN", "annotations": { "test annotation": [ { "label": "ambiguous", "score": 0.5, "explanation": "meaningful words", "metadata": {}, "annotator_kind": "HUMAN", "user_id": None, "username": None, "email": None, } ] }, }, } } }, ] }, } } } class TestPatchDatasetExamples: MUTATION = """ mutation ($input: PatchDatasetExamplesInput!) { patchDatasetExamples(input: $input) { dataset { examples { edges { example: node { id revision { input output metadata revisionKind } } } } } } }""" async def test_happy_path( self, gql_client: AsyncGraphQLClient, dataset_with_revisions: None, ) -> None: # todo: update this test case to verify that version description and # metadata are updated once a versions resolver has been implemented # https://github.com/Arize-ai/phoenix/issues/3359 mutation_input = { "patches": [ { "exampleId": str(GlobalID(type_name=DatasetExample.__name__, node_id=str(1))), "input": {"input": "patched-example-1-input"}, }, { "exampleId": str(GlobalID(type_name=DatasetExample.__name__, node_id=str(2))), "input": {"input": "patched-example-2-input"}, "output": {"output": "patched-example-2-output"}, "metadata": {"metadata": "patched-example-2-metadata"}, }, ] } expected_examples = [ { "example": { "id": str(GlobalID(type_name=DatasetExample.__name__, node_id=str(2))), "revision": { "input": {"input": "patched-example-2-input"}, "output": {"output": "patched-example-2-output"}, "metadata": {"metadata": "patched-example-2-metadata"}, "revisionKind": "PATCH", }, } }, { "example": { "id": str(GlobalID(type_name=DatasetExample.__name__, node_id=str(1))), "revision": { "input": {"input": "patched-example-1-input"}, "output": {"output": "original-example-1-version-1-output"}, "metadata": {"metadata": "original-example-1-version-1-metadata"}, "revisionKind": "PATCH", }, } }, ] response = await gql_client.execute( query=self.MUTATION, variables={"input": mutation_input}, ) assert not response.errors assert (data := response.data) is not None actual_examples = data["patchDatasetExamples"]["dataset"]["examples"]["edges"] assert actual_examples == expected_examples @pytest.mark.parametrize( "mutation_input, expected_error_message", [ pytest.param( {"patches": []}, "Must provide examples to patch.", id="empty-example-patches", ), pytest.param( { "patches": [ { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(1)) ), "input": {"input": "value"}, }, { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(1)) ), "input": {"input": "value"}, }, ] }, "Cannot patch the same example more than once per mutation.", id="same-example-patched-twice", ), pytest.param( { "patches": [ { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(1)) ), }, { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(2)) ), "input": {"input": "value"}, }, ] }, "Received one or more empty patches that contain no fields to update.", id="found-patch-with-nothing-to-update", ), pytest.param( { "patches": [ { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(500)) ), "input": {"input": "value"}, }, ] }, "No examples found.", id="invalid-example-id", ), pytest.param( { "patches": [ { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(1)) ), "input": {"input": "value"}, }, { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(4)) ), "input": {"input": "value"}, }, ] }, "Examples must come from the same dataset.", id="examples-from-different-datasets", ), pytest.param( { "patches": [ { "exampleId": str( GlobalID(type_name=DatasetExample.__name__, node_id=str(3)) ), "input": {"input": "value"}, }, ] }, "1 example(s) could not be found.", id="deleted-example-id", ), ], ) async def test_raises_value_error_for_invalid_input( self, mutation_input: dict[str, Any], expected_error_message: str, gql_client: AsyncGraphQLClient, dataset_with_revisions: None, dataset_with_a_single_version: None, ) -> None: response = await gql_client.execute( query=self.MUTATION, variables={"input": mutation_input}, ) assert (errors := response.errors) assert len(errors) == 1 assert errors[0].message == expected_error_message async def test_delete_a_dataset( db: DbSessionFactory, gql_client: AsyncGraphQLClient, empty_dataset: None, ) -> None: dataset_id = GlobalID(type_name="Dataset", node_id=str(1)) mutation = textwrap.dedent( """ mutation ($datasetId: ID!) { deleteDataset(input: { datasetId: $datasetId }) { dataset { id } } } """ ) response = await gql_client.execute( query=mutation, variables={ "datasetId": str(dataset_id), }, ) assert not response.errors assert (data := response.data) is not None assert data["deleteDataset"]["dataset"] == {"id": str(dataset_id)}, ( "deleted dataset is returned" ) async with db() as session: dataset = ( await session.execute(select(models.Dataset).where(models.Dataset.id == 1)) ).first() assert not dataset async def test_deleting_a_nonexistent_dataset_fails(gql_client: AsyncGraphQLClient) -> None: dataset_id = GlobalID(type_name="Dataset", node_id=str(1)) mutation = textwrap.dedent( """ mutation ($datasetId: ID!) { deleteDataset(input: { datasetId: $datasetId }) { dataset { id } } } """ ) response = await gql_client.execute( query=mutation, variables={ "datasetId": str(dataset_id), }, ) assert (errors := response.errors) assert len(errors) == 1 assert f"Unknown dataset: {dataset_id}" in errors[0].message @pytest.fixture async def empty_dataset(db: DbSessionFactory) -> None: """ Inserts an empty dataset. """ async with db() as session: dataset = models.Dataset( id=1, name="empty dataset", description=None, metadata_={}, ) session.add(dataset) await session.flush() @pytest.fixture async def spans(db: DbSessionFactory) -> list[models.Span]: """ Inserts three spans from a single trace: a chain root span, a retriever child span, and an llm child span. """ spans = [] async with db() as session: project_row_id = await session.scalar( insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id) ) trace_row_id = await session.scalar( insert(models.Trace) .values( trace_id="1", project_rowid=project_row_id, start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:01:00.000+00:00"), ) .returning(models.Trace.id) ) span = await session.scalar( insert(models.Span) .values( trace_rowid=trace_row_id, span_id="1", parent_id=None, name="chain span", span_kind="CHAIN", start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), attributes={ "input": {"value": "chain-span-input-value", "mime_type": "text/plain"}, "output": {"value": "chain-span-output-value", "mime_type": "text/plain"}, }, events=[], status_code="OK", status_message="okay", cumulative_error_count=0, cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) .returning(models.Span) ) assert span is not None spans.append(span) span = await session.scalar( insert(models.Span) .values( trace_rowid=trace_row_id, span_id="2", parent_id="1", name="retriever span", span_kind="RETRIEVER", start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), attributes={ "input": { "value": "retriever-span-input", "mime_type": "text/plain", }, "retrieval": { "documents": [ {"document": {"content": "retrieved-document-content", "score": 1}}, ], }, }, events=[], status_code="OK", status_message="okay", cumulative_error_count=0, cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) .returning(models.Span) ) assert span is not None spans.append(span) span = await session.scalar( insert(models.Span) .values( trace_rowid=trace_row_id, span_id="3", parent_id="1", name="llm span", span_kind="LLM", start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), attributes={ "llm": { "input_messages": [ {"message": {"role": "user", "content": "user-message-content"}} ], "output_messages": [ { "message": { "role": "assistant", "content": "assistant-message-content", } }, ], "invocation_parameters": {"temperature": 1}, }, }, events=[], status_code="OK", status_message="okay", cumulative_error_count=0, cumulative_llm_token_count_prompt=0, cumulative_llm_token_count_completion=0, ) .returning(models.Span) ) assert span is not None spans.append(span) return spans @pytest.fixture async def span_annotation(db: DbSessionFactory) -> None: async with db() as session: span_annotation = models.SpanAnnotation( span_rowid=1, name="test annotation", annotator_kind="HUMAN", label="ambiguous", score=0.5, explanation="meaningful words", identifier="", source="APP", user_id=None, ) session.add(span_annotation) await session.flush() @pytest.fixture async def dataset_with_a_single_version( db: DbSessionFactory, ) -> None: """ A dataset with a single example and a single version. """ async with db() as session: # insert dataset dataset_id = await session.scalar( insert(models.Dataset) .returning(models.Dataset.id) .values( name="dataset-name", description="dataset-description", metadata_={"dataset-metadata-key": "dataset-metadata-value"}, ) ) # insert example example_id = await session.scalar( insert(models.DatasetExample) .values( dataset_id=dataset_id, created_at=datetime(year=2020, month=1, day=1, hour=0, minute=0, tzinfo=pytz.utc), ) .returning(models.DatasetExample.id) ) # insert version version_id = await session.scalar( insert(models.DatasetVersion) .returning(models.DatasetVersion.id) .values( dataset_id=dataset_id, description="original-description", metadata_={"metadata": "original-metadata"}, ) ) # insert revision await session.scalar( insert(models.DatasetExampleRevision) .returning(models.DatasetExampleRevision.id) .values( dataset_example_id=example_id, dataset_version_id=version_id, input={"input": "first-input"}, output={"output": "first-output"}, metadata_={"metadata": "first-metadata"}, revision_kind="CREATE", ) ) @pytest.fixture async def dataset_with_revisions(db: DbSessionFactory) -> None: """ A dataset with three examples and two versions. The first version creates all three examples, and the second version deletes the third example. """ async with db() as session: # insert dataset dataset_id = await session.scalar( insert(models.Dataset) .returning(models.Dataset.id) .values( name="original-dataset-name", description="original-dataset-description", metadata_={}, ) ) # insert examples example_id_1 = await session.scalar( insert(models.DatasetExample) .values(dataset_id=dataset_id) .returning(models.DatasetExample.id) ) example_id_2 = await session.scalar( insert(models.DatasetExample) .values(dataset_id=dataset_id) .returning(models.DatasetExample.id) ) example_id_3 = await session.scalar( insert(models.DatasetExample) .values(dataset_id=dataset_id) .returning(models.DatasetExample.id) ) # insert first version version_id_1 = await session.scalar( insert(models.DatasetVersion) .returning(models.DatasetVersion.id) .values( dataset_id=dataset_id, description="original-version-1-description", metadata_={"metadata": "original-version-1-metadata"}, created_at=datetime.fromisoformat("2024-05-28T00:00:04+00:00"), ) ) # insert revisions for first version await session.execute( insert(models.DatasetExampleRevision).values( dataset_example_id=example_id_1, dataset_version_id=version_id_1, input={"input": "original-example-1-version-1-input"}, output={"output": "original-example-1-version-1-output"}, metadata_={"metadata": "original-example-1-version-1-metadata"}, revision_kind="CREATE", ) ) await session.execute( insert(models.DatasetExampleRevision).values( dataset_example_id=example_id_2, dataset_version_id=version_id_1, input={"input": "original-example-2-version-1-input"}, output={"output": "original-example-2-version-1-output"}, metadata_={"metadata": "original-example-2-version-1-metadata"}, revision_kind="CREATE", ) ) await session.execute( insert(models.DatasetExampleRevision).values( dataset_example_id=example_id_3, dataset_version_id=version_id_1, input={"input": "original-example-3-version-1-input"}, output={"output": "original-example-3-version-1-output"}, metadata_={"metadata": "original-example-3-version-1-metadata"}, revision_kind="CREATE", ) ) # insert second version version_id_2 = await session.scalar( insert(models.DatasetVersion) .returning(models.DatasetVersion.id) .values( dataset_id=dataset_id, description="original-version-2-description", metadata_={"metadata": "original-version-2-metadata"}, created_at=datetime.fromisoformat("2024-05-28T00:00:04+00:00"), ) ) # insert revisions for second version await session.execute( insert(models.DatasetExampleRevision).values( dataset_example_id=example_id_3, dataset_version_id=version_id_2, input={"input": "original-example-3-version-1-input"}, output={"output": "original-example-3-version-1-output"}, metadata_={"metadata": "original-example-3-version-1-metadata"}, revision_kind="DELETE", ) )

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Arize-ai/phoenix'

If you have feedback or need assistance with the MCP directory API, please join our Discord server