Skip to main content
Glama

@arizeai/phoenix-mcp

Official
by Arize-ai
test_subscriptions.py70.8 kB
import json import re from datetime import datetime from typing import Any, Optional from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, ) from opentelemetry.semconv.attributes.url_attributes import URL_FULL, URL_PATH from sqlalchemy import select from strawberry.relay.types import GlobalID from vcr.request import Request as VCRRequest from phoenix.server.api.types.ChatCompletionSubscriptionPayload import ( ChatCompletionSubscriptionError, ChatCompletionSubscriptionExperiment, ChatCompletionSubscriptionResult, TextChunk, ToolCallChunk, ) from phoenix.server.api.types.Dataset import Dataset from phoenix.server.api.types.DatasetExample import DatasetExample from phoenix.server.api.types.DatasetVersion import DatasetVersion from phoenix.server.api.types.Experiment import Experiment from phoenix.server.api.types.node import from_global_id from phoenix.server.experiments.utils import is_experiment_project_name from phoenix.server.types import DbSessionFactory from phoenix.trace.attributes import flatten, get_attribute_value from tests.unit._helpers import verify_experiment_examples_junction_table from tests.unit.graphql import AsyncGraphQLClient from tests.unit.vcr import CustomVCR class TestChatCompletionSubscription: QUERY = """ subscription ChatCompletionSubscription($input: ChatCompletionInput!) { chatCompletion(input: $input) { __typename ... on TextChunk { content } ... on ToolCallChunk { id function { name arguments } } ... on ChatCompletionSubscriptionResult { span { ...SpanFragment } } ... on ChatCompletionSubscriptionError { message } } } query SpanQuery($spanId: ID!) { span: node(id: $spanId) { ... on Span { ...SpanFragment } } } fragment SpanFragment on Span { id name statusCode statusMessage startTime endTime latencyMs parentId spanKind context { spanId traceId } attributes metadata numDocuments tokenCountTotal tokenCountPrompt tokenCountCompletion input { mimeType value } output { mimeType value } events { name message timestamp } cumulativeTokenCountTotal cumulativeTokenCountPrompt cumulativeTokenCountCompletion propagatedStatusCode } """ async def test_openai_text_response_emits_expected_payloads_and_records_expected_span( self, gql_client: AsyncGraphQLClient, openai_api_key: str, custom_vcr: CustomVCR, ) -> None: variables = { "input": { "messages": [ { "role": "USER", "content": "Who won the World Cup in 2018? Answer in one word", } ], "model": {"name": "gpt-4", "providerKey": "OPENAI"}, "invocationParameters": [ {"invocationName": "temperature", "valueFloat": 0.1}, ], "repetitions": 1, }, } async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionSubscription", ) as subscription: with custom_vcr.use_cassette(): payloads = [payload async for payload in subscription.stream()] # check subscription payloads assert payloads assert (last_payload := payloads.pop())["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionResult.__name__ assert all( payload["chatCompletion"]["__typename"] == TextChunk.__name__ for payload in payloads ) response_text = "".join(payload["chatCompletion"]["content"] for payload in payloads) assert "france" in response_text.lower() subscription_span = last_payload["chatCompletion"]["span"] span_id = subscription_span["id"] # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "text" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"temperature": 0.1}) assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 0 assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING) == 0 assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "Who won the World Cup in 2018? Answer in one word", } } ] assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ { "message": { "role": "assistant", "content": response_text, } } ] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert not attributes async def test_openai_emits_expected_payloads_and_records_expected_span_on_error( self, gql_client: AsyncGraphQLClient, openai_api_key: str, custom_vcr: CustomVCR, ) -> None: variables = { "input": { "messages": [ { "role": "USER", "content": "Who won the World Cup in 2018? Answer in one word", } ], "model": {"name": "gpt-4", "providerKey": "OPENAI"}, "invocationParameters": [ {"invocationName": "temperature", "valueFloat": 0.1}, ], "repetitions": 1, }, } async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionSubscription", ) as subscription: with custom_vcr.use_cassette(): payloads = [payload async for payload in subscription.stream()] # check subscription payloads assert len(payloads) == 2 assert (error_payload := payloads[0])["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionError.__name__ assert "401" in (status_message := error_payload["chatCompletion"]["message"]) assert "api key" in status_message.lower() assert (last_payload := payloads.pop())["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionResult.__name__ subscription_span = last_payload["chatCompletion"]["span"] span_id = subscription_span["id"] # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "ERROR" assert span.pop("statusMessage") == status_message assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt == 0 assert token_count_completion == 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert span.pop("output") is None assert (events := span.pop("events")) assert len(events) == 1 assert (event := events[0]) assert event.pop("name") == "exception" assert event.pop("message") == status_message assert datetime.fromisoformat(event.pop("timestamp")) assert not event assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "ERROR" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"temperature": 0.1}) assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "Who won the World Cup in 2018? Answer in one word", } } ] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert not attributes async def test_openai_tool_call_response_emits_expected_payloads_and_records_expected_span( self, gql_client: AsyncGraphQLClient, openai_api_key: str, custom_vcr: CustomVCR, ) -> None: get_current_weather_tool_schema = { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city name, e.g. San Francisco", }, }, "required": ["location"], }, }, } variables = { "input": { "messages": [ { "role": "USER", "content": "How's the weather in San Francisco?", } ], "model": {"name": "gpt-4", "providerKey": "OPENAI"}, "tools": [get_current_weather_tool_schema], "invocationParameters": [ {"invocationName": "tool_choice", "valueJson": "auto"}, ], "repetitions": 1, }, } async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionSubscription", ) as subscription: with custom_vcr.use_cassette(): payloads = [payload async for payload in subscription.stream()] # check subscription payloads assert payloads assert (last_payload := payloads.pop())["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionResult.__name__ assert all( payload["chatCompletion"]["__typename"] == ToolCallChunk.__name__ for payload in payloads ) json.loads( "".join(payload["chatCompletion"]["function"]["arguments"] for payload in payloads) ) == {"location": "San Francisco"} subscription_span = last_payload["chatCompletion"]["span"] span_id = subscription_span["id"] # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "json" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"tool_choice": "auto"}) assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 0 assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING) == 0 assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == JSON assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "How's the weather in San Francisco?", } } ] assert (output_messages := attributes.pop(LLM_OUTPUT_MESSAGES)) assert len(output_messages) == 1 assert (output_message := output_messages[0]["message"])["role"] == "assistant" assert "content" not in output_message assert (tool_calls := output_message["tool_calls"]) assert len(tool_calls) == 1 assert (tool_call := tool_calls[0]["tool_call"]) assert (function := tool_call["function"]) assert function["name"] == "get_current_weather" assert json.loads(function["arguments"]) == {"location": "San Francisco"} assert (llm_tools := attributes.pop(LLM_TOOLS)) assert llm_tools == [{"tool": {"json_schema": json.dumps(get_current_weather_tool_schema)}}] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert not attributes async def test_openai_tool_call_messages_emits_expected_payloads_and_records_expected_span( self, gql_client: AsyncGraphQLClient, openai_api_key: str, custom_vcr: CustomVCR, ) -> None: tool_call_id = "call_zz1hkqH3IakqnHfVhrrUemlQ" tool_calls = [ { "id": tool_call_id, "function": { "arguments": json.dumps({"city": "San Francisco"}, indent=4), "name": "get_weather", }, "type": "function", } ] variables = { "input": { "messages": [ { "role": "USER", "content": "How's the weather in San Francisco?", }, { "role": "AI", "toolCalls": tool_calls, }, { "content": "sunny", "role": "TOOL", "toolCallId": tool_call_id, }, ], "model": {"name": "gpt-4", "providerKey": "OPENAI"}, "repetitions": 1, } } async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionSubscription", ) as subscription: with custom_vcr.use_cassette(): payloads = [payload async for payload in subscription.stream()] # check subscription payloads assert payloads assert (last_payload := payloads.pop())["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionResult.__name__ assert all( payload["chatCompletion"]["__typename"] == TextChunk.__name__ for payload in payloads ) response_text = "".join(payload["chatCompletion"]["content"] for payload in payloads) assert "sunny" in response_text.lower() subscription_span = last_payload["chatCompletion"]["span"] span_id = subscription_span["id"] # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "text" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 0 assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING) == 0 assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT assert (llm_input_messages := attributes.pop(LLM_INPUT_MESSAGES)) assert len(llm_input_messages) == 3 llm_input_message = llm_input_messages[0]["message"] assert llm_input_message == { "content": "How's the weather in San Francisco?", "role": "user", } llm_input_message = llm_input_messages[1]["message"] assert llm_input_message["content"] == "" assert llm_input_message["role"] == "ai" assert llm_input_message["tool_calls"] == [ { "tool_call": { "id": tool_call_id, "function": { "name": "get_weather", "arguments": '"{\\n \\"city\\": \\"San Francisco\\"\\n}"', }, } } ] llm_input_message = llm_input_messages[2]["message"] assert llm_input_message == { "content": "sunny", "role": "tool", "tool_call_id": tool_call_id, } assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ { "message": { "role": "assistant", "content": response_text, } } ] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert not attributes async def test_anthropic_text_response_emits_expected_payloads_and_records_expected_span( self, gql_client: AsyncGraphQLClient, anthropic_api_key: str, custom_vcr: CustomVCR, ) -> None: variables = { "input": { "messages": [ { "role": "USER", "content": "Who won the World Cup in 2018? Answer in one word", } ], "model": {"name": "claude-3-5-sonnet-20240620", "providerKey": "ANTHROPIC"}, "invocationParameters": [ {"invocationName": "temperature", "valueFloat": 0.1}, {"invocationName": "max_tokens", "valueInt": 1024}, ], "repetitions": 1, }, } async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionSubscription", ) as subscription: with custom_vcr.use_cassette(): payloads = [payload async for payload in subscription.stream()] # check subscription payloads assert payloads assert (last_payload := payloads.pop())["chatCompletion"][ "__typename" ] == ChatCompletionSubscriptionResult.__name__ assert all( payload["chatCompletion"]["__typename"] == TextChunk.__name__ for payload in payloads ) response_text = "".join(payload["chatCompletion"]["content"] for payload in payloads) assert "france" in response_text.lower() subscription_span = last_payload["chatCompletion"]["span"] span_id = subscription_span["id"] # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "text" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "claude-3-5-sonnet-20240620" assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps( {"temperature": 0.1, "max_tokens": 1024} ) assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "Who won the World Cup in 2018? Answer in one word", } } ] assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ { "message": { "role": "assistant", "content": response_text, } } ] assert attributes.pop(LLM_PROVIDER) == "anthropic" assert attributes.pop(LLM_SYSTEM) == "anthropic" assert attributes.pop(URL_FULL) == "https://api.anthropic.com/v1/messages" assert attributes.pop(URL_PATH) == "v1/messages" assert not attributes class TestChatCompletionOverDatasetSubscription: QUERY = """ subscription ChatCompletionOverDatasetSubscription($input: ChatCompletionOverDatasetInput!) { chatCompletionOverDataset(input: $input) { __typename datasetExampleId ... on TextChunk { content } ... on ChatCompletionSubscriptionResult { span { ...SpanFragment } experimentRun { ...ExperimentRunFragment } } ... on ChatCompletionSubscriptionError { message } ... on ChatCompletionSubscriptionExperiment { experiment { id } } } } query SpanQuery($spanId: ID!) { span: node(id: $spanId) { ... on Span { ...SpanFragment } } } query ExperimentQuery($experimentId: ID!) { experiment: node(id: $experimentId) { ... on Experiment { id name metadata projectName createdAt updatedAt description runs { edges { run: node { ...ExperimentRunFragment } } } } } } fragment ExperimentRunFragment on ExperimentRun { id experimentId startTime endTime output error traceId trace { id traceId project { name } } } fragment SpanFragment on Span { id name statusCode statusMessage startTime endTime latencyMs parentId spanKind context { spanId traceId } attributes metadata numDocuments tokenCountTotal tokenCountPrompt tokenCountCompletion input { mimeType value } output { mimeType value } events { name message timestamp } cumulativeTokenCountTotal cumulativeTokenCountPrompt cumulativeTokenCountCompletion propagatedStatusCode } """ async def test_emits_expected_payloads_and_records_expected_spans_and_experiment( self, gql_client: AsyncGraphQLClient, openai_api_key: str, playground_dataset_with_patch_revision: None, custom_vcr: CustomVCR, db: DbSessionFactory, ) -> None: dataset_id = str(GlobalID(type_name=Dataset.__name__, node_id=str(1))) version_id = str(GlobalID(type_name=DatasetVersion.__name__, node_id=str(1))) variables = { "input": { "model": {"providerKey": "OPENAI", "name": "gpt-4"}, "datasetId": dataset_id, "datasetVersionId": version_id, "messages": [ { "role": "USER", "content": "What country is {city} in? Answer in one word, no punctuation.", } ], "templateFormat": "F_STRING", "repetitions": 1, } } payloads: dict[Optional[str], list[Any]] = {} async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionOverDatasetSubscription", ) as subscription: custom_vcr.register_matcher( _request_bodies_contain_same_city.__name__, _request_bodies_contain_same_city ) # a custom request matcher is needed since the requests are concurrent with custom_vcr.use_cassette(match_on=[_request_bodies_contain_same_city.__name__]): async for payload in subscription.stream(): if ( dataset_example_id := payload["chatCompletionOverDataset"][ "datasetExampleId" ] ) not in payloads: payloads[dataset_example_id] = [] payloads[dataset_example_id].append(payload) # check subscription payloads assert len(payloads) == 4 example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(index))) for index in range(1, 4) ] assert set(payloads.keys()) == set(example_ids) | {None} # gather spans and experiment runs subscription_runs = {} subscription_spans = {} for example_id in example_ids: assert (result_payload := payloads[example_id].pop()["chatCompletionOverDataset"]) assert result_payload.pop("__typename") == ChatCompletionSubscriptionResult.__name__ assert result_payload.pop("datasetExampleId") == example_id subscription_runs[example_id] = result_payload.pop("experimentRun") subscription_spans[example_id] = result_payload.pop("span") assert not result_payload # check example 1 response text example_id = example_ids[0] assert all( payload["chatCompletionOverDataset"]["__typename"] == TextChunk.__name__ for payload in payloads[example_id] ) response_text = "".join( payload["chatCompletionOverDataset"]["content"] for payload in payloads[example_id] ) assert response_text == "France" # check example 2 response text example_id = example_ids[1] assert all( payload["chatCompletionOverDataset"]["__typename"] == TextChunk.__name__ for payload in payloads[example_id] ) response_text = "".join( payload["chatCompletionOverDataset"]["content"] for payload in payloads[example_id] ) assert response_text == "Japan" # check example 3 error message example_id = example_ids[2] assert (error_payload := payloads[example_id].pop()["chatCompletionOverDataset"])[ "__typename" ] == ChatCompletionSubscriptionError.__name__ assert error_payload["message"] == "Missing template variable(s): city" # check experiment payload assert len(payloads[None]) == 1 assert (experiment_payload := payloads[None].pop()["chatCompletionOverDataset"])[ "__typename" ] == ChatCompletionSubscriptionExperiment.__name__ experiment = experiment_payload["experiment"] assert (experiment_id := experiment.pop("id")) async with db() as session: await verify_experiment_examples_junction_table(session, experiment_id) # query for the span via the node interface to ensure that the span # recorded in the db contains identical information as the span emitted # by the subscription # check example 1 span example_id = example_ids[0] span_id = subscription_spans[example_id]["id"] response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] subscription_span = subscription_spans[example_id] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check example 1 span attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "text" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 0 assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING) == 0 assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "What country is Paris in? Answer in one word, no punctuation.", } } ] assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ {"message": {"role": "assistant", "content": "France"}} ] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert attributes.pop(PROMPT_TEMPLATE_VARIABLES) == json.dumps({"city": "Paris"}) assert not attributes # check example 2 span example_id = example_ids[1] span_id = subscription_spans[example_id]["id"] response = await gql_client.execute( query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) assert (data := response.data) is not None span = data["span"] subscription_span = subscription_spans[example_id] assert json.loads(attributes := span.pop("attributes")) == json.loads( subscription_span.pop("attributes") ) attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check example 2 span attributes assert span.pop("id") == span_id assert span.pop("name") == "ChatCompletion" assert span.pop("statusCode") == "OK" assert not span.pop("statusMessage") assert span.pop("startTime") assert span.pop("endTime") assert isinstance(span.pop("latencyMs"), float) assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") assert context.pop("traceId") assert not context assert span.pop("metadata") is None assert span.pop("numDocuments") == 0 assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) assert token_count_prompt > 0 assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) assert not input assert "api_key" not in input_value assert "apiKey" not in input_value assert (output := span.pop("output")).pop("mimeType") == "text" assert output.pop("value") assert not output assert not span.pop("events") assert isinstance( cumulative_token_count_total := span.pop("cumulativeTokenCountTotal"), int ) assert isinstance( cumulative_token_count_prompt := span.pop("cumulativeTokenCountPrompt"), int ) assert isinstance( cumulative_token_count_completion := span.pop("cumulativeTokenCountCompletion"), int ) assert cumulative_token_count_total == token_count_total assert cumulative_token_count_prompt == token_count_prompt assert cumulative_token_count_completion == token_count_completion assert span.pop("propagatedStatusCode") == "OK" assert not span assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "gpt-4" assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == token_count_total assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ) == 0 assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING) == 0 assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE) assert attributes.pop(OUTPUT_MIME_TYPE) == TEXT assert attributes.pop(LLM_INPUT_MESSAGES) == [ { "message": { "role": "user", "content": "What country is Tokyo in? Answer in one word, no punctuation.", } } ] assert attributes.pop(LLM_OUTPUT_MESSAGES) == [ {"message": {"role": "assistant", "content": "Japan"}} ] assert attributes.pop(LLM_PROVIDER) == "openai" assert attributes.pop(LLM_SYSTEM) == "openai" assert attributes.pop(URL_FULL) == "https://api.openai.com/v1/chat/completions" assert attributes.pop(URL_PATH) == "chat/completions" assert attributes.pop(PROMPT_TEMPLATE_VARIABLES) == json.dumps({"city": "Tokyo"}) assert not attributes # check that example 3 has no span example_id = example_ids[2] assert subscription_spans[example_id] is None # check experiment response = await gql_client.execute( query=self.QUERY, variables={"experimentId": experiment_id}, operation_name="ExperimentQuery", ) assert (data := response.data) is not None experiment = data["experiment"] assert experiment.pop("id") == experiment_id type_name, _ = from_global_id(GlobalID.from_id(experiment_id)) assert type_name == Experiment.__name__ assert experiment.pop("name") == "playground-experiment" project_name = experiment.pop("projectName") assert is_experiment_project_name(project_name) assert experiment.pop("metadata") == {} assert isinstance(created_at := experiment.pop("createdAt"), str) assert isinstance(updated_at := experiment.pop("updatedAt"), str) experiment.pop("description") assert created_at == updated_at runs = {run["run"]["id"]: run["run"] for run in experiment.pop("runs")["edges"]} assert len(runs) == 3 # check example 1 run example_id = example_ids[0] subscription_run = subscription_runs[example_id] run_id = subscription_run["id"] run = runs.pop(run_id) assert run == subscription_run assert run.pop("id") == run_id assert isinstance(experiment_id := run.pop("experimentId"), str) type_name, _ = from_global_id(GlobalID.from_id(experiment_id)) assert type_name == Experiment.__name__ assert datetime.fromisoformat(run.pop("startTime")) <= datetime.fromisoformat( run.pop("endTime") ) assert run.pop("error") is None assert isinstance(run_output := run.pop("output"), dict) assert set(run_output.keys()) == {"messages"} assert (trace_id := run.pop("traceId")) is not None trace = run.pop("trace") assert trace.pop("id") assert trace.pop("traceId") == trace_id project = trace.pop("project") assert project["name"] == project_name assert not trace assert not run # check example 2 run example_id = example_ids[1] subscription_run = subscription_runs[example_id] run_id = subscription_run["id"] run = runs.pop(run_id) assert run == subscription_run assert run.pop("id") == run_id assert isinstance(experiment_id := run.pop("experimentId"), str) type_name, _ = from_global_id(GlobalID.from_id(experiment_id)) assert type_name == Experiment.__name__ assert datetime.fromisoformat(run.pop("startTime")) <= datetime.fromisoformat( run.pop("endTime") ) assert run.pop("error") is None assert isinstance(run_output := run.pop("output"), dict) assert set(run_output.keys()) == {"messages"} assert (trace_id := run.pop("traceId")) is not None trace = run.pop("trace") assert trace.pop("id") assert trace.pop("traceId") == trace_id project = trace.pop("project") assert project["name"] == project_name assert not trace assert not run # check example 3 run example_id = example_ids[2] subscription_run = subscription_runs[example_id] run_id = subscription_run["id"] run = runs.pop(run_id) assert run == subscription_run assert run.pop("id") == run_id assert isinstance(experiment_id := run.pop("experimentId"), str) type_name, _ = from_global_id(GlobalID.from_id(experiment_id)) assert type_name == Experiment.__name__ assert datetime.fromisoformat(run.pop("startTime")) <= datetime.fromisoformat( run.pop("endTime") ) assert run.pop("error") == "Missing template variable(s): city" assert run.pop("output") is None assert run.pop("traceId") is None assert run.pop("trace") is None assert not run assert not runs assert not experiment async def test_all_spans_yielded_when_number_of_examples_exceeds_batch_size( self, gql_client: AsyncGraphQLClient, openai_api_key: str, cities_and_countries: list[tuple[str, str]], playground_city_and_country_dataset: None, custom_vcr: CustomVCR, db: DbSessionFactory, ) -> None: dataset_id = str(GlobalID(type_name=Dataset.__name__, node_id=str(1))) version_id = str(GlobalID(type_name=DatasetVersion.__name__, node_id=str(1))) variables = { "input": { "model": {"providerKey": "OPENAI", "name": "gpt-4"}, "datasetId": dataset_id, "datasetVersionId": version_id, "messages": [ { "role": "USER", "content": ( "What country is {city} in? " "Answer with the country name only without punctuation." ), } ], "templateFormat": "F_STRING", "repetitions": 1, } } payloads: dict[Optional[str], list[Any]] = {} async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionOverDatasetSubscription", ) as subscription: custom_vcr.register_matcher( _request_bodies_contain_same_city.__name__, _request_bodies_contain_same_city ) # a custom request matcher is needed since the requests are concurrent with custom_vcr.use_cassette(match_on=[_request_bodies_contain_same_city.__name__]): async for payload in subscription.stream(): if ( dataset_example_id := payload["chatCompletionOverDataset"][ "datasetExampleId" ] ) not in payloads: payloads[dataset_example_id] = [] payloads[dataset_example_id].append(payload) # check subscription payloads cities_to_countries = dict(cities_and_countries) num_examples = len(cities_to_countries) example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(index))) for index in range(1, num_examples + 1) ] assert set(payloads.keys()) == set(example_ids) | {None} # check span payloads for example_id in example_ids: assert (span_payload := payloads[example_id].pop()["chatCompletionOverDataset"])[ "__typename" ] == ChatCompletionSubscriptionResult.__name__ assert all( payload["chatCompletionOverDataset"]["__typename"] == TextChunk.__name__ for payload in payloads[example_id] ) assert (span := span_payload["span"]) assert isinstance(span["attributes"], str) attributes = json.loads(span["attributes"]) assert isinstance( input_messages := get_attribute_value(attributes, LLM_INPUT_MESSAGES), list, ) assert len(input_messages) == 1 assert isinstance(input_message_content := input_messages[0]["message"]["content"], str) assert (city := _extract_city(input_message_content)) in cities_to_countries assert isinstance( output_messages := get_attribute_value(attributes, LLM_OUTPUT_MESSAGES), list, ) assert len(output_messages) == 1 assert isinstance( output_message_content := output_messages[0]["message"]["content"], str ) assert output_message_content == cities_to_countries[city] response_text = "".join( payload["chatCompletionOverDataset"]["content"] for payload in payloads[example_id] ) assert response_text == output_message_content # check experiment payload assert len(payloads[None]) == 1 assert (experiment := payloads[None].pop()["chatCompletionOverDataset"]["experiment"]) experiment_id = experiment["id"] assert isinstance(experiment_id, str) async with db() as session: await verify_experiment_examples_junction_table(session, experiment_id) async def test_experiment_with_single_split_filters_examples( self, gql_client: AsyncGraphQLClient, openai_api_key: str, playground_dataset_with_splits: None, custom_vcr: CustomVCR, db: DbSessionFactory, ) -> None: """Test that providing a single split ID filters examples correctly.""" from phoenix.db import models dataset_id = str(GlobalID(type_name=Dataset.__name__, node_id=str(1))) version_id = str(GlobalID(type_name=DatasetVersion.__name__, node_id=str(1))) train_split_id = str(GlobalID(type_name="DatasetSplit", node_id=str(1))) variables = { "input": { "model": {"providerKey": "OPENAI", "name": "gpt-4"}, "datasetId": dataset_id, "datasetVersionId": version_id, "messages": [ { "role": "USER", "content": "What country is {city} in? Answer in one word, no punctuation.", } ], "templateFormat": "F_STRING", "repetitions": 1, "splitIds": [train_split_id], # Only train split } } payloads: dict[Optional[str], list[Any]] = {} async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionOverDatasetSubscription", ) as subscription: custom_vcr.register_matcher( _request_bodies_contain_same_city.__name__, _request_bodies_contain_same_city ) with custom_vcr.use_cassette(match_on=[_request_bodies_contain_same_city.__name__]): async for payload in subscription.stream(): if ( dataset_example_id := payload["chatCompletionOverDataset"][ "datasetExampleId" ] ) not in payloads: payloads[dataset_example_id] = [] payloads[dataset_example_id].append(payload) # Should only have examples 1, 2, 3 (train split) + experiment payload # Examples 4 and 5 (test split) should NOT be present train_example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(i))) for i in range(1, 4) ] test_example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(i))) for i in range(4, 6) ] assert set(payloads.keys()) == set(train_example_ids) | {None} for test_id in test_example_ids: assert test_id not in payloads, f"Test example {test_id} should not be in results" # Verify experiment payload exists assert len(payloads[None]) == 1 assert (experiment_payload := payloads[None][0]["chatCompletionOverDataset"])[ "__typename" ] == ChatCompletionSubscriptionExperiment.__name__ experiment_id = experiment_payload["experiment"]["id"] # Verify experiment has the correct split association in DB async with db() as session: from phoenix.server.api.types.node import from_global_id _, exp_id = from_global_id(GlobalID.from_id(experiment_id)) result = await session.execute( select(models.ExperimentDatasetSplit).where( models.ExperimentDatasetSplit.experiment_id == exp_id ) ) split_links = result.scalars().all() assert len(split_links) == 1 assert split_links[0].dataset_split_id == 1 # train split async def test_experiment_with_multiple_splits( self, gql_client: AsyncGraphQLClient, openai_api_key: str, playground_dataset_with_splits: None, custom_vcr: CustomVCR, db: DbSessionFactory, ) -> None: """Test that providing multiple split IDs includes examples from all specified splits.""" from phoenix.db import models dataset_id = str(GlobalID(type_name=Dataset.__name__, node_id=str(1))) version_id = str(GlobalID(type_name=DatasetVersion.__name__, node_id=str(1))) train_split_id = str(GlobalID(type_name="DatasetSplit", node_id=str(1))) test_split_id = str(GlobalID(type_name="DatasetSplit", node_id=str(2))) variables = { "input": { "model": {"providerKey": "OPENAI", "name": "gpt-4"}, "datasetId": dataset_id, "datasetVersionId": version_id, "messages": [ { "role": "USER", "content": "What country is {city} in? Answer in one word, no punctuation.", } ], "templateFormat": "F_STRING", "repetitions": 1, "splitIds": [train_split_id, test_split_id], # Both splits } } payloads: dict[Optional[str], list[Any]] = {} async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionOverDatasetSubscription", ) as subscription: custom_vcr.register_matcher( _request_bodies_contain_same_city.__name__, _request_bodies_contain_same_city ) with custom_vcr.use_cassette(match_on=[_request_bodies_contain_same_city.__name__]): async for payload in subscription.stream(): if ( dataset_example_id := payload["chatCompletionOverDataset"][ "datasetExampleId" ] ) not in payloads: payloads[dataset_example_id] = [] payloads[dataset_example_id].append(payload) # Should have all examples 1-5 + experiment payload all_example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(i))) for i in range(1, 6) ] assert set(payloads.keys()) == set(all_example_ids) | {None} # Verify experiment has both split associations in DB assert len(payloads[None]) == 1 experiment_id = payloads[None][0]["chatCompletionOverDataset"]["experiment"]["id"] async with db() as session: from phoenix.server.api.types.node import from_global_id _, exp_id = from_global_id(GlobalID.from_id(experiment_id)) result = await session.execute( select(models.ExperimentDatasetSplit) .where(models.ExperimentDatasetSplit.experiment_id == exp_id) .order_by(models.ExperimentDatasetSplit.dataset_split_id) ) split_links = result.scalars().all() assert len(split_links) == 2 assert split_links[0].dataset_split_id == 1 # train split assert split_links[1].dataset_split_id == 2 # test split async def test_experiment_without_splits_includes_all_examples( self, gql_client: AsyncGraphQLClient, openai_api_key: str, playground_dataset_with_splits: None, custom_vcr: CustomVCR, db: DbSessionFactory, ) -> None: """Test backward compatibility: when no splits are specified, all examples are included.""" from phoenix.db import models dataset_id = str(GlobalID(type_name=Dataset.__name__, node_id=str(1))) version_id = str(GlobalID(type_name=DatasetVersion.__name__, node_id=str(1))) variables = { "input": { "model": {"providerKey": "OPENAI", "name": "gpt-4"}, "datasetId": dataset_id, "datasetVersionId": version_id, "messages": [ { "role": "USER", "content": "What country is {city} in? Answer in one word, no punctuation.", } ], "templateFormat": "F_STRING", "repetitions": 1, # No splitIds provided } } payloads: dict[Optional[str], list[Any]] = {} async with gql_client.subscription( query=self.QUERY, variables=variables, operation_name="ChatCompletionOverDatasetSubscription", ) as subscription: custom_vcr.register_matcher( _request_bodies_contain_same_city.__name__, _request_bodies_contain_same_city ) with custom_vcr.use_cassette(match_on=[_request_bodies_contain_same_city.__name__]): async for payload in subscription.stream(): if ( dataset_example_id := payload["chatCompletionOverDataset"][ "datasetExampleId" ] ) not in payloads: payloads[dataset_example_id] = [] payloads[dataset_example_id].append(payload) # Should have all examples 1-5 + experiment payload all_example_ids = [ str(GlobalID(type_name=DatasetExample.__name__, node_id=str(i))) for i in range(1, 6) ] assert set(payloads.keys()) == set(all_example_ids) | {None} # Verify experiment has NO split associations in DB assert len(payloads[None]) == 1 experiment_id = payloads[None][0]["chatCompletionOverDataset"]["experiment"]["id"] async with db() as session: from phoenix.server.api.types.node import from_global_id _, exp_id = from_global_id(GlobalID.from_id(experiment_id)) result = await session.execute( select(models.ExperimentDatasetSplit).where( models.ExperimentDatasetSplit.experiment_id == exp_id ) ) split_links = result.scalars().all() assert len(split_links) == 0 # No splits associated def _request_bodies_contain_same_city(request1: VCRRequest, request2: VCRRequest) -> None: assert _extract_city(request1.body.decode()) == _extract_city(request2.body.decode()) def _extract_city(body: str) -> str: if match := re.search(r"What country is (\w+) in\?", body): return match.group(1) raise ValueError(f"Could not extract city from body: {body}") LLM = OpenInferenceSpanKindValues.LLM.value JSON = OpenInferenceMimeTypeValues.JSON.value TEXT = OpenInferenceMimeTypeValues.TEXT.value OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME LLM_SYSTEM = SpanAttributes.LLM_SYSTEM LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = ( SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING ) LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES LLM_PROVIDER = SpanAttributes.LLM_PROVIDER LLM_TOOLS = SpanAttributes.LLM_TOOLS INPUT_VALUE = SpanAttributes.INPUT_VALUE INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Arize-ai/phoenix'

If you have feedback or need assistance with the MCP directory API, please join our Discord server