from __future__ import annotations as _annotations
import json
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import cached_property
from typing import Any, cast
import pytest
from inline_snapshot import snapshot
from pydantic import BaseModel
from typing_extensions import TypedDict
from pydantic_ai import (
BinaryContent,
DocumentUrl,
ImageUrl,
ModelRequest,
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
VideoUrl,
)
from pydantic_ai.agent import Agent
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry
from pydantic_ai.usage import RequestUsage
from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import
from .mock_async_stream import MockAsyncStream
with try_import() as imports_successful:
from mistralai import (
AssistantMessage as MistralAssistantMessage,
ChatCompletionChoice as MistralChatCompletionChoice,
CompletionChunk as MistralCompletionChunk,
CompletionResponseStreamChoice as MistralCompletionResponseStreamChoice,
CompletionResponseStreamChoiceFinishReason as MistralCompletionResponseStreamChoiceFinishReason,
DeltaMessage as MistralDeltaMessage,
FunctionCall as MistralFunctionCall,
Mistral,
TextChunk as MistralTextChunk,
UsageInfo as MistralUsageInfo,
)
from mistralai.models import (
ChatCompletionResponse as MistralChatCompletionResponse,
CompletionEvent as MistralCompletionEvent,
SDKError,
ToolCall as MistralToolCall,
)
from mistralai.types.basemodel import Unset as MistralUnset
from pydantic_ai.models.mistral import MistralModel, MistralStreamedResponse
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from pydantic_ai.providers.mistral import MistralProvider
from pydantic_ai.providers.openai import OpenAIProvider
MockChatCompletion = MistralChatCompletionResponse | Exception
MockCompletionEvent = MistralCompletionEvent | Exception
pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='mistral or openai not installed'),
pytest.mark.anyio,
]
@dataclass
class MockSdkConfiguration:
def get_server_details(self) -> tuple[str, ...]:
return ('https://api.mistral.ai',)
@dataclass
class MockMistralAI:
completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None
stream: Sequence[MockCompletionEvent] | Sequence[Sequence[MockCompletionEvent]] | None = None
index: int = 0
@cached_property
def sdk_configuration(self) -> MockSdkConfiguration:
return MockSdkConfiguration()
@cached_property
def chat(self) -> Any:
if self.stream:
return type(
'Chat',
(),
{'stream_async': self.chat_completions_create, 'complete_async': self.chat_completions_create},
)
else:
return type('Chat', (), {'complete_async': self.chat_completions_create})
@classmethod
def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> Mistral:
return cast(Mistral, cls(completions=completions))
@classmethod
def create_stream_mock(
cls, completions_streams: Sequence[MockCompletionEvent] | Sequence[Sequence[MockCompletionEvent]]
) -> Mistral:
return cast(Mistral, cls(stream=completions_streams))
async def chat_completions_create( # pragma: lax no cover
self, *_args: Any, stream: bool = False, **_kwargs: Any
) -> MistralChatCompletionResponse | MockAsyncStream[MockCompletionEvent]:
if stream or self.stream:
assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided'
if isinstance(self.stream[0], list):
response = MockAsyncStream(iter(cast(list[MockCompletionEvent], self.stream[self.index])))
else:
response = MockAsyncStream(iter(cast(list[MockCompletionEvent], self.stream)))
else:
assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided'
if isinstance(self.completions, Sequence):
raise_if_exception(self.completions[self.index])
response = cast(MistralChatCompletionResponse, self.completions[self.index])
else:
raise_if_exception(self.completions)
response = cast(MistralChatCompletionResponse, self.completions)
self.index += 1
return response
def completion_message(
message: MistralAssistantMessage, *, usage: MistralUsageInfo | None = None, with_created: bool = True
) -> MistralChatCompletionResponse:
return MistralChatCompletionResponse(
id='123',
choices=[MistralChatCompletionChoice(finish_reason='stop', index=0, message=message)],
created=1704067200 if with_created else 0, # 2024-01-01
model='mistral-large-123',
object='chat.completion',
usage=usage or MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
)
def chunk(
delta: list[MistralDeltaMessage],
finish_reason: MistralCompletionResponseStreamChoiceFinishReason | None = None,
with_created: bool = True,
) -> MistralCompletionEvent:
return MistralCompletionEvent(
data=MistralCompletionChunk(
id='x',
choices=[
MistralCompletionResponseStreamChoice(index=index, delta=delta, finish_reason=finish_reason)
for index, delta in enumerate(delta)
],
created=1704067200 if with_created else 0, # 2024-01-01
model='gpt-4',
object='chat.completion.chunk',
usage=MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
)
)
def text_chunk(
text: str, finish_reason: MistralCompletionResponseStreamChoiceFinishReason | None = None
) -> MistralCompletionEvent:
return chunk([MistralDeltaMessage(content=text, role='assistant')], finish_reason=finish_reason)
def text_chunkk(
text: str, finish_reason: MistralCompletionResponseStreamChoiceFinishReason | None = None
) -> MistralCompletionEvent:
return chunk(
[MistralDeltaMessage(content=[MistralTextChunk(text=text)], role='assistant')], finish_reason=finish_reason
)
def func_chunk(
tool_calls: list[MistralToolCall], finish_reason: MistralCompletionResponseStreamChoiceFinishReason | None = None
) -> MistralCompletionEvent:
return chunk([MistralDeltaMessage(tool_calls=tool_calls, role='assistant')], finish_reason=finish_reason)
#####################
## Init
#####################
def test_init():
m = MistralModel('mistral-large-latest', provider=MistralProvider(api_key='foobar'))
assert m.model_name == 'mistral-large-latest'
assert m.base_url == 'https://api.mistral.ai'
#####################
## Completion
#####################
async def test_multiple_completions(allow_model_requests: None):
completions = [
# First completion: created is "now" (simulate IsNow)
completion_message(
MistralAssistantMessage(content='world'),
usage=MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
with_created=False,
),
# Second completion: created is fixed 2024-01-01 00:00:00 UTC
completion_message(MistralAssistantMessage(content='hello again')),
]
mock_client = MockMistralAI.create_mock(completions)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model)
result = await agent.run('hello')
assert result.output == 'world'
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
result = await agent.run('hello again', message_history=result.new_messages())
assert result.output == 'hello again'
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsNow(tz=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='hello again')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_three_completions(allow_model_requests: None):
completions = [
completion_message(
MistralAssistantMessage(content='world'),
usage=MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
),
completion_message(MistralAssistantMessage(content='hello again')),
completion_message(MistralAssistantMessage(content='final message')),
]
mock_client = MockMistralAI.create_mock(completions)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model)
result = await agent.run('hello')
assert result.output == 'world'
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
result = await agent.run('hello again', message_history=result.all_messages())
assert result.output == 'hello again'
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
result = await agent.run('final message', message_history=result.all_messages())
assert result.output == 'final message'
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='hello again')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='final message')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
#####################
## Completion Stream
#####################
async def test_stream_text(allow_model_requests: None):
stream = [
text_chunk('hello '),
text_chunk('world '),
text_chunk('welcome '),
text_chunkk('mistral'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model)
async with agent.run_stream('') as result:
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(
['hello ', 'hello world ', 'hello world welcome ', 'hello world welcome mistral']
)
assert result.is_complete
assert result.usage().input_tokens == 5
assert result.usage().output_tokens == 5
async def test_stream_text_finish_reason(allow_model_requests: None):
stream = [
text_chunk('hello '),
text_chunkk('world'),
text_chunk('.', finish_reason='stop'),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model)
async with agent.run_stream('') as result:
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(
['hello ', 'hello world', 'hello world.']
)
assert result.is_complete
async def test_no_delta(allow_model_requests: None):
stream = [
chunk([], with_created=False),
text_chunk('hello '),
text_chunk('world'),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model)
async with agent.run_stream('') as result:
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world'])
assert result.is_complete
assert result.usage().input_tokens == 3
assert result.usage().output_tokens == 3
#####################
## Completion Model Structured
#####################
async def test_request_native_with_arguments_dict_response(allow_model_requests: None):
class CityLocation(BaseModel):
city: str
country: str
completion = completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='123',
function=MistralFunctionCall(arguments={'city': 'paris', 'country': 'france'}, name='final_result'),
type='function',
)
],
),
usage=MistralUsageInfo(prompt_tokens=1, completion_tokens=2, total_tokens=3),
)
mock_client = MockMistralAI.create_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=CityLocation)
result = await agent.run('User prompt value')
assert result.output == CityLocation(city='paris', country='france')
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 2
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
args={'city': 'paris', 'country': 'france'},
tool_call_id='123',
)
],
usage=RequestUsage(input_tokens=1, output_tokens=2),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id='123',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
]
)
async def test_request_native_with_arguments_str_response(allow_model_requests: None):
class CityLocation(BaseModel):
city: str
country: str
completion = completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='123',
function=MistralFunctionCall(
arguments='{"city": "paris", "country": "france"}', name='final_result'
),
type='function',
)
],
)
)
mock_client = MockMistralAI.create_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=CityLocation)
result = await agent.run('User prompt value')
assert result.output == CityLocation(city='paris', country='france')
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
assert result.usage().details == {}
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
args='{"city": "paris", "country": "france"}',
tool_call_id='123',
)
],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id='123',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
]
)
async def test_request_output_type_with_arguments_str_response(allow_model_requests: None):
completion = completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='123',
function=MistralFunctionCall(arguments='{"response": 42}', name='final_result'),
type='function',
)
],
)
)
mock_client = MockMistralAI.create_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=int, system_prompt='System prompt value')
result = await agent.run('User prompt value')
assert result.output == 42
assert result.usage().input_tokens == 1
assert result.usage().output_tokens == 1
assert result.usage().details == {}
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='System prompt value', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
args='{"response": 42}',
tool_call_id='123',
)
],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id='123',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
]
)
#####################
## Completion Model Structured Stream (JSON Mode)
#####################
async def test_stream_structured_with_all_type(allow_model_requests: None):
class MyTypedDict(TypedDict, total=False):
first: str
second: int
bool_value: bool
nullable_value: int | None
array_value: list[str]
dict_value: dict[str, Any]
dict_int_value: dict[str, int]
dict_str_value: dict[int, str]
stream = [
text_chunk('{'),
text_chunk('"first": "One'),
text_chunk(
'", "second": 2',
),
text_chunk(
', "bool_value": true',
),
text_chunk(
', "nullable_value": null',
),
text_chunk(
', "array_value": ["A", "B", "C"]',
),
text_chunk(
', "dict_value": {"A": "A", "B":"B"}',
),
text_chunk(
', "dict_int_value": {"A": 1, "B":2}',
),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, output_type=MyTypedDict)
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [dict(c) async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(
[
{'first': 'One'},
{'first': 'One', 'second': 2},
{'first': 'One', 'second': 2, 'bool_value': True},
{'first': 'One', 'second': 2, 'bool_value': True, 'nullable_value': None},
{
'first': 'One',
'second': 2,
'bool_value': True,
'nullable_value': None,
'array_value': ['A', 'B', 'C'],
},
{
'first': 'One',
'second': 2,
'bool_value': True,
'nullable_value': None,
'array_value': ['A', 'B', 'C'],
'dict_value': {'A': 'A', 'B': 'B'},
},
{
'first': 'One',
'second': 2,
'bool_value': True,
'nullable_value': None,
'array_value': ['A', 'B', 'C'],
'dict_value': {'A': 'A', 'B': 'B'},
'dict_int_value': {'A': 1, 'B': 2},
},
{
'first': 'One',
'second': 2,
'bool_value': True,
'nullable_value': None,
'array_value': ['A', 'B', 'C'],
'dict_value': {'A': 'A', 'B': 'B'},
'dict_int_value': {'A': 1, 'B': 2},
},
]
)
assert result.is_complete
assert result.usage().input_tokens == 10
assert result.usage().output_tokens == 10
# double check usage matches stream count
assert result.usage().output_tokens == len(stream)
async def test_stream_result_type_primitif_dict(allow_model_requests: None):
"""This test tests the primitif result with the pydantic ai format model response"""
class MyTypedDict(TypedDict, total=False):
first: str
second: str
stream = [
text_chunk('{'),
text_chunk('"'),
text_chunk('f'),
text_chunk('i'),
text_chunk('r'),
text_chunk('s'),
text_chunk('t'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('O'),
text_chunk('n'),
text_chunk('e'),
text_chunk('"'),
text_chunk(','),
text_chunk(' '),
text_chunk('"'),
text_chunk('s'),
text_chunk('e'),
text_chunk('c'),
text_chunk('o'),
text_chunk('n'),
text_chunk('d'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('T'),
text_chunk('w'),
text_chunk('o'),
text_chunk('"'),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=MyTypedDict)
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(
[
{'first': 'O'},
{'first': 'On'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One', 'second': ''},
{'first': 'One', 'second': 'T'},
{'first': 'One', 'second': 'Tw'},
{'first': 'One', 'second': 'Two'},
{'first': 'One', 'second': 'Two'},
{'first': 'One', 'second': 'Two'},
]
)
assert result.is_complete
assert result.usage().input_tokens == 34
assert result.usage().output_tokens == 34
# double check usage matches stream count
assert result.usage().output_tokens == len(stream)
async def test_stream_result_type_primitif_int(allow_model_requests: None):
"""This test tests the primitif result with the pydantic ai format model response"""
stream = [
# {'response':
text_chunk('{'),
text_chunk('"resp'),
text_chunk('onse":'),
text_chunk('1'),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=int)
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot([1, 1])
assert result.is_complete
assert result.usage().input_tokens == 6
assert result.usage().output_tokens == 6
# double check usage matches stream count
assert result.usage().output_tokens == len(stream)
async def test_stream_result_type_primitif_array(allow_model_requests: None):
"""This test tests the primitif result with the pydantic ai format model response"""
stream = [
# {'response':
text_chunk('{'),
text_chunk('"resp'),
text_chunk('onse":'),
text_chunk('['),
text_chunk('"'),
text_chunk('f'),
text_chunk('i'),
text_chunk('r'),
text_chunk('s'),
text_chunk('t'),
text_chunk('"'),
text_chunk(','),
text_chunk('"'),
text_chunk('O'),
text_chunk('n'),
text_chunk('e'),
text_chunk('"'),
text_chunk(','),
text_chunk('"'),
text_chunk('s'),
text_chunk('e'),
text_chunk('c'),
text_chunk('o'),
text_chunk('n'),
text_chunk('d'),
text_chunk('"'),
text_chunk(','),
text_chunk('"'),
text_chunk('T'),
text_chunk('w'),
text_chunk('o'),
text_chunk('"'),
text_chunk(']'),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, output_type=list[str])
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(
[
[''],
['f'],
['fi'],
['fir'],
['firs'],
['first'],
['first'],
['first'],
['first', ''],
['first', 'O'],
['first', 'On'],
['first', 'One'],
['first', 'One'],
['first', 'One'],
['first', 'One', ''],
['first', 'One', 's'],
['first', 'One', 'se'],
['first', 'One', 'sec'],
['first', 'One', 'seco'],
['first', 'One', 'secon'],
['first', 'One', 'second'],
['first', 'One', 'second'],
['first', 'One', 'second'],
['first', 'One', 'second', ''],
['first', 'One', 'second', 'T'],
['first', 'One', 'second', 'Tw'],
['first', 'One', 'second', 'Two'],
['first', 'One', 'second', 'Two'],
['first', 'One', 'second', 'Two'],
['first', 'One', 'second', 'Two'],
]
)
assert result.is_complete
assert result.usage().input_tokens == 35
assert result.usage().output_tokens == 35
# double check usage matches stream count
assert result.usage().output_tokens == len(stream)
async def test_stream_result_type_basemodel_with_default_params(allow_model_requests: None):
class MyTypedBaseModel(BaseModel):
first: str = '' # Note: Default, set value.
second: str = '' # Note: Default, set value.
stream = [
text_chunk('{'),
text_chunk('"'),
text_chunk('f'),
text_chunk('i'),
text_chunk('r'),
text_chunk('s'),
text_chunk('t'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('O'),
text_chunk('n'),
text_chunk('e'),
text_chunk('"'),
text_chunk(','),
text_chunk(' '),
text_chunk('"'),
text_chunk('s'),
text_chunk('e'),
text_chunk('c'),
text_chunk('o'),
text_chunk('n'),
text_chunk('d'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('T'),
text_chunk('w'),
text_chunk('o'),
text_chunk('"'),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=MyTypedBaseModel)
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(
[
MyTypedBaseModel(first='O', second=''),
MyTypedBaseModel(first='On', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second='T'),
MyTypedBaseModel(first='One', second='Tw'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
]
)
assert result.is_complete
assert result.usage().input_tokens == 34
assert result.usage().output_tokens == 34
# double check usage matches stream count
assert result.usage().output_tokens == len(stream)
async def test_stream_result_type_basemodel_with_required_params(allow_model_requests: None):
class MyTypedBaseModel(BaseModel):
first: str # Note: Required params
second: str # Note: Required params
stream = [
text_chunk('{'),
text_chunk('"'),
text_chunk('f'),
text_chunk('i'),
text_chunk('r'),
text_chunk('s'),
text_chunk('t'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('O'),
text_chunk('n'),
text_chunk('e'),
text_chunk('"'),
text_chunk(','),
text_chunk(' '),
text_chunk('"'),
text_chunk('s'),
text_chunk('e'),
text_chunk('c'),
text_chunk('o'),
text_chunk('n'),
text_chunk('d'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('T'),
text_chunk('w'),
text_chunk('o'),
text_chunk('"'),
text_chunk('}'),
chunk([]),
]
mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model=model, output_type=MyTypedBaseModel)
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(
[
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second='T'),
MyTypedBaseModel(first='One', second='Tw'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
]
)
assert result.is_complete
assert result.usage().input_tokens == 34
assert result.usage().output_tokens == 34
# double check cost matches stream count
assert result.usage().output_tokens == len(stream)
#####################
## Completion Function call
#####################
async def test_request_tool_call(allow_model_requests: None):
completion = [
completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
),
usage=MistralUsageInfo(
completion_tokens=1,
prompt_tokens=2,
total_tokens=3,
),
),
completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='2',
function=MistralFunctionCall(arguments='{"loc_name": "London"}', name='get_location'),
type='function',
)
],
),
usage=MistralUsageInfo(
completion_tokens=2,
prompt_tokens=3,
total_tokens=6,
),
),
completion_message(MistralAssistantMessage(content='final response', role='assistant')),
]
mock_client = MockMistralAI.create_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, system_prompt='this is the system prompt')
@agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
else:
raise ModelRetry('Wrong location, please try again')
result = await agent.run('Hello')
assert result.output == 'final response'
assert result.usage().input_tokens == 6
assert result.usage().output_tokens == 4
assert result.usage().total_tokens == 10
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "San Fransisco"}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
RetryPromptPart(
content='Wrong location, please try again',
tool_name='get_location',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "London"}',
tool_call_id='2',
)
],
usage=RequestUsage(input_tokens=3, output_tokens=2),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_call_id='2',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='final response')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_request_tool_call_with_result_type(allow_model_requests: None):
class MyTypedDict(TypedDict, total=False):
lat: int
lng: int
completion = [
completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
),
usage=MistralUsageInfo(
completion_tokens=1,
prompt_tokens=2,
total_tokens=3,
),
),
completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='2',
function=MistralFunctionCall(arguments='{"loc_name": "London"}', name='get_location'),
type='function',
)
],
),
usage=MistralUsageInfo(
completion_tokens=2,
prompt_tokens=3,
total_tokens=6,
),
),
completion_message(
MistralAssistantMessage(
content=None,
role='assistant',
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"lat": 51, "lng": 0}', name='final_result'),
type='function',
)
],
),
usage=MistralUsageInfo(
completion_tokens=1,
prompt_tokens=2,
total_tokens=3,
),
),
]
mock_client = MockMistralAI.create_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, system_prompt='this is the system prompt', output_type=MyTypedDict)
@agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
else:
raise ModelRetry('Wrong location, please try again')
result = await agent.run('Hello')
assert result.output == {'lat': 51, 'lng': 0}
assert result.usage().input_tokens == 7
assert result.usage().output_tokens == 4
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "San Fransisco"}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
RetryPromptPart(
content='Wrong location, please try again',
tool_name='get_location',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "London"}',
tool_call_id='2',
)
],
usage=RequestUsage(input_tokens=3, output_tokens=2),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_call_id='2',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='final_result',
args='{"lat": 51, "lng": 0}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=1),
model_name='mistral-large-123',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
]
)
#####################
## Completion Function call Stream
#####################
async def test_stream_tool_call_with_return_type(allow_model_requests: None):
class MyTypedDict(TypedDict, total=False):
won: bool
completion = [
[
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='tool_calls',
),
func_chunk(
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
finish_reason='tool_calls',
),
],
[
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='tool_calls',
),
func_chunk(
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"won": true}', name='final_result'),
type=None,
)
],
finish_reason='tool_calls',
),
],
]
mock_client = MockMistralAI.create_stream_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, system_prompt='this is the system prompt', output_type=MyTypedDict)
@agent.tool_plain
async def get_location(loc_name: str) -> str:
return json.dumps({'lat': 51, 'lng': 0})
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot([{'won': True}])
assert result.is_complete
assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)
assert result.usage().input_tokens == 4
assert result.usage().output_tokens == 4
# double check usage matches stream count
assert result.usage().output_tokens == 4
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "San Fransisco"}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=2),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='x',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args='{"won": true}', tool_call_id='1')],
usage=RequestUsage(input_tokens=2, output_tokens=2),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='x',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
]
)
assert await result.get_output() == {'won': True}
async def test_stream_tool_call(allow_model_requests: None):
completion = [
[
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='tool_calls',
),
func_chunk(
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
finish_reason='tool_calls',
),
],
[
chunk(delta=[MistralDeltaMessage(role='assistant', content='', tool_calls=MistralUnset())]),
chunk(delta=[MistralDeltaMessage(role=MistralUnset(), content='final ', tool_calls=MistralUnset())]),
chunk(delta=[MistralDeltaMessage(role=MistralUnset(), content='response', tool_calls=MistralUnset())]),
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='stop',
),
],
]
mock_client = MockMistralAI.create_stream_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, system_prompt='this is the system prompt')
@agent.tool_plain
async def get_location(loc_name: str) -> str:
return json.dumps({'lat': 51, 'lng': 0})
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_output(debounce_by=None)]
assert v == snapshot(['final ', 'final response'])
assert result.is_complete
assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)
assert result.usage().input_tokens == 6
assert result.usage().output_tokens == 6
# double check usage matches stream count
assert result.usage().output_tokens == 6
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "San Fransisco"}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=2),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='x',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='final response')],
usage=RequestUsage(input_tokens=4, output_tokens=4),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='x',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_stream_tool_call_with_retry(allow_model_requests: None):
completion = [
[
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='tool_calls',
),
func_chunk(
tool_calls=[
MistralToolCall(
id='1',
function=MistralFunctionCall(arguments='{"loc_name": "San Fransisco"}', name='get_location'),
type='function',
)
],
finish_reason='tool_calls',
),
],
[
func_chunk(
tool_calls=[
MistralToolCall(
id='2',
function=MistralFunctionCall(arguments='{"loc_name": "London"}', name='get_location'),
type='function',
)
],
finish_reason='tool_calls',
),
],
[
chunk(delta=[MistralDeltaMessage(role='assistant', content='', tool_calls=MistralUnset())]),
chunk(delta=[MistralDeltaMessage(role=MistralUnset(), content='final ', tool_calls=MistralUnset())]),
chunk(delta=[MistralDeltaMessage(role=MistralUnset(), content='response', tool_calls=MistralUnset())]),
chunk(
delta=[MistralDeltaMessage(role=MistralUnset(), content='', tool_calls=MistralUnset())],
finish_reason='stop',
),
],
]
mock_client = MockMistralAI.create_stream_mock(completion)
model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(model, system_prompt='this is the system prompt')
@agent.tool_plain
async def get_location(loc_name: str) -> str:
if loc_name == 'London':
return json.dumps({'lat': 51, 'lng': 0})
else:
raise ModelRetry('Wrong location, please try again')
async with agent.run_stream('User prompt value') as result:
assert not result.is_complete
v = [c async for c in result.stream_text(debounce_by=None)]
assert v == snapshot(['final ', 'final response'])
assert result.is_complete
assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)
assert result.usage().input_tokens == 7
assert result.usage().output_tokens == 7
# double check usage matches stream count
assert result.usage().output_tokens == 7
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)),
UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "San Fransisco"}',
tool_call_id='1',
)
],
usage=RequestUsage(input_tokens=2, output_tokens=2),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='x',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
RetryPromptPart(
content='Wrong location, please try again',
tool_name='get_location',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location',
args='{"loc_name": "London"}',
tool_call_id='2',
)
],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='x',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_call_id='2',
timestamp=IsNow(tz=timezone.utc),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='final response')],
usage=RequestUsage(input_tokens=4, output_tokens=4),
model_name='gpt-4',
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='x',
finish_reason='stop',
run_id=IsStr(),
),
]
)
#####################
## Test methods
#####################
def test_generate_user_output_format_complex(mistral_api_key: str):
"""
Single test that includes properties exercising every branch
in _get_python_type (anyOf, arrays, objects with additionalProperties, etc.).
"""
schema = {
'properties': {
'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]},
'prop_no_type': {
# no 'type' key
},
'prop_simple_string': {'type': 'string'},
'prop_array_booleans': {'type': 'array', 'items': {'type': 'boolean'}},
'prop_object_simple': {'type': 'object', 'additionalProperties': {'type': 'boolean'}},
'prop_object_array': {
'type': 'object',
'additionalProperties': {'type': 'array', 'items': {'type': 'integer'}},
},
'prop_object_object': {'type': 'object', 'additionalProperties': {'type': 'object'}},
'prop_object_unknown': {'type': 'object', 'additionalProperties': {'type': 'someUnknownType'}},
'prop_unrecognized_type': {'type': 'customSomething'},
}
}
m = MistralModel('', json_mode_schema_prompt='{schema}', provider=MistralProvider(api_key=mistral_api_key))
result = m._generate_user_output_format([schema]) # pyright: ignore[reportPrivateUsage]
assert result.content == (
"{'prop_anyOf': 'Optional[str]', "
"'prop_no_type': 'Any', "
"'prop_simple_string': 'str', "
"'prop_array_booleans': 'list[bool]', "
"'prop_object_simple': 'dict[str, bool]', "
"'prop_object_array': 'dict[str, list[int]]', "
"'prop_object_object': 'dict[str, dict[str, Any]]', "
"'prop_object_unknown': 'dict[str, Any]', "
"'prop_unrecognized_type': 'Any'}"
)
def test_generate_user_output_format_multiple(mistral_api_key: str):
schema = {'properties': {'prop_anyOf': {'anyOf': [{'type': 'string'}, {'type': 'integer'}]}}}
m = MistralModel('', json_mode_schema_prompt='{schema}', provider=MistralProvider(api_key=mistral_api_key))
result = m._generate_user_output_format([schema, schema]) # pyright: ignore[reportPrivateUsage]
assert result.content == "[{'prop_anyOf': 'Optional[str]'}, {'prop_anyOf': 'Optional[str]'}]"
@pytest.mark.parametrize(
'desc, schema, data, expected',
[
(
'Missing required parameter',
{
'required': ['name', 'age'],
'properties': {
'name': {'type': 'string'},
'age': {'type': 'integer'},
},
},
{'name': 'Alice'}, # Missing "age"
False,
),
(
'Type mismatch (expected string, got int)',
{'required': ['name'], 'properties': {'name': {'type': 'string'}}},
{'name': 123}, # Should be a string, got int
False,
),
(
'Array parameter check (param not a list)',
{'required': ['tags'], 'properties': {'tags': {'type': 'array', 'items': {'type': 'string'}}}},
{'tags': 'not a list'}, # Not a list
False,
),
(
'Array item type mismatch',
{'required': ['tags'], 'properties': {'tags': {'type': 'array', 'items': {'type': 'string'}}}},
{'tags': ['ok', 123, 'still ok']}, # One item is int, not str
False,
),
(
'Nested object fails',
{
'required': ['user'],
'properties': {
'user': {
'type': 'object',
'required': ['id', 'profile'],
'properties': {
'id': {'type': 'integer'},
'profile': {
'type': 'object',
'required': ['address'],
'properties': {'address': {'type': 'string'}},
},
},
}
},
},
{'user': {'id': 101, 'profile': {}}}, # Missing "address" in the nested profile
False,
),
(
'All requirements met (success)',
{
'required': ['name', 'age', 'tags', 'user'],
'properties': {
'name': {'type': 'string'},
'age': {'type': 'integer'},
'tags': {'type': 'array', 'items': {'type': 'string'}},
'user': {
'type': 'object',
'required': ['id', 'profile'],
'properties': {
'id': {'type': 'integer'},
'profile': {
'type': 'object',
'required': ['address'],
'properties': {'address': {'type': 'string'}},
},
},
},
},
},
{
'name': 'Alice',
'age': 30,
'tags': ['tag1', 'tag2'],
'user': {'id': 101, 'profile': {'address': '123 Street'}},
},
True,
),
],
)
def test_validate_required_json_schema(desc: str, schema: dict[str, Any], data: dict[str, Any], expected: bool) -> None:
result = MistralStreamedResponse._validate_required_json_schema(data, schema) # pyright: ignore[reportPrivateUsage]
assert result == expected, f'{desc} — expected {expected}, got {result}'
@pytest.mark.vcr()
async def test_image_as_binary_content_tool_response(
allow_model_requests: None, mistral_api_key: str, image_content: BinaryContent
):
m = MistralModel('pixtral-12b-latest', provider=MistralProvider(api_key=mistral_api_key))
agent = Agent(m)
@agent.tool_plain
async def get_image() -> BinaryContent:
return image_content
result = await agent.run(['What fruit is in the image you can get from the get_image tool? Call the tool.'])
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content=['What fruit is in the image you can get from the get_image tool? Call the tool.'],
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[ToolCallPart(tool_name='get_image', args='{}', tool_call_id='GJYBCIkcS')],
usage=RequestUsage(input_tokens=65, output_tokens=16),
model_name='pixtral-12b-latest',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'tool_calls'},
provider_response_id='412174432ea945889703eac58b44ae35',
finish_reason='tool_call',
run_id=IsStr(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_image',
content='See file 1c8566',
tool_call_id='GJYBCIkcS',
timestamp=IsDatetime(),
),
UserPromptPart(
content=[
'This is file 1c8566:',
image_content,
],
timestamp=IsDatetime(),
),
],
run_id=IsStr(),
),
ModelResponse(
parts=[
TextPart(
content='The image you\'re referring to, labeled as "file 1c8566," shows a kiwi fruit that has been cut in half. The kiwi is known for its bright green flesh with tiny black seeds and a central white core. It is a popular fruit known for its sweet taste and nutritional benefits.'
)
],
usage=RequestUsage(input_tokens=2931, output_tokens=66),
model_name='pixtral-12b-latest',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='049b5c7704554d3396e727a95cb6d947',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_image_url_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
result = await agent.run(
[
'hello',
ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'),
]
)
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content=[
'hello',
ImageUrl(
url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg',
identifier='bd38f5',
),
],
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_image_as_binary_content_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
base64_content = (
b'/9j/4AAQSkZJRgABAQEAYABgAAD/4QBYRXhpZgAATU0AKgAAAAgAA1IBAAEAAAABAAAAPgIBAAEAAAABAAAARgMBAAEAAAABAAAA'
b'WgAAAAAAAAAE'
)
result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='image/jpeg')])
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content=[
'hello',
BinaryContent(data=base64_content, media_type='image/jpeg', identifier='cb93e3'),
],
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_pdf_url_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
result = await agent.run(
[
'hello',
DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf'),
]
)
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content=[
'hello',
DocumentUrl(
url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf',
identifier='c6720d',
),
],
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_pdf_as_binary_content_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
base64_content = b'%PDF-1.\rtrailer<</Root<</Pages<</Kids[<</MediaBox[0 0 3 3]>>>>>>>>>'
result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='application/pdf')])
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content=[
'hello',
BinaryContent(data=base64_content, media_type='application/pdf', identifier='b9d976'),
],
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
async def test_txt_url_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
with pytest.raises(RuntimeError, match='DocumentUrl other than PDF is not supported in Mistral.'):
await agent.run(
[
'hello',
DocumentUrl(url='https://examplefiles.org/files/documents/plaintext-example-file-download.txt'),
]
)
async def test_audio_as_binary_content_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
base64_content = b'//uQZ'
with pytest.raises(RuntimeError, match='BinaryContent other than image or PDF is not supported in Mistral.'):
await agent.run(['hello', BinaryContent(data=base64_content, media_type='audio/wav')])
async def test_video_url_input(allow_model_requests: None):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
with pytest.raises(RuntimeError, match='VideoUrl is not supported in Mistral.'):
await agent.run(['hello', VideoUrl(url='https://www.google.com')])
def test_model_status_error(allow_model_requests: None) -> None:
mock_client = MockMistralAI.create_mock(
SDKError(
'test error',
status_code=500,
body='test error',
)
)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
with pytest.raises(ModelHTTPError) as exc_info:
agent.run_sync('hello')
assert str(exc_info.value) == snapshot('status_code: 500, model_name: mistral-large-latest, body: test error')
def test_model_non_http_error(allow_model_requests: None) -> None:
mock_client = MockMistralAI.create_mock(
SDKError(
'Connection error',
status_code=300,
body='redirect',
)
)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m)
with pytest.raises(ModelAPIError) as exc_info:
agent.run_sync('hello')
assert exc_info.value.model_name == 'mistral-large-latest'
async def test_mistral_model_instructions(allow_model_requests: None, mistral_api_key: str):
c = completion_message(MistralAssistantMessage(content='world', role='assistant'))
mock_client = MockMistralAI.create_mock(c)
m = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client))
agent = Agent(m, instructions='You are a helpful assistant.')
result = await agent.run('hello')
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='hello', timestamp=IsDatetime())],
instructions='You are a helpful assistant.',
run_id=IsStr(),
),
ModelResponse(
parts=[TextPart(content='world')],
usage=RequestUsage(input_tokens=1, output_tokens=1),
model_name='mistral-large-123',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='123',
finish_reason='stop',
run_id=IsStr(),
),
]
)
@pytest.mark.vcr()
async def test_mistral_model_thinking_part(allow_model_requests: None, openai_api_key: str, mistral_api_key: str):
openai_model = OpenAIResponsesModel('o3-mini', provider=OpenAIProvider(api_key=openai_api_key))
settings = OpenAIResponsesModelSettings(openai_reasoning_effort='high', openai_reasoning_summary='detailed')
agent = Agent(openai_model, model_settings=settings)
result = await agent.run('How do I cross the street?')
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())],
run_id=IsStr(),
),
ModelResponse(
parts=[
ThinkingPart(
content=IsStr(),
id='rs_68bb645d50f48196a0c49fd603b87f4503498c8aa840cf12',
signature=IsStr(),
provider_name='openai',
),
ThinkingPart(content=IsStr(), id='rs_68bb645d50f48196a0c49fd603b87f4503498c8aa840cf12'),
ThinkingPart(content=IsStr(), id='rs_68bb645d50f48196a0c49fd603b87f4503498c8aa840cf12'),
TextPart(content=IsStr(), id='msg_68bb64663d1c8196b9c7e78e7018cc4103498c8aa840cf12'),
],
usage=RequestUsage(input_tokens=13, output_tokens=1616, details={'reasoning_tokens': 1344}),
model_name='o3-mini-2025-01-31',
timestamp=IsDatetime(),
provider_name='openai',
provider_url='https://api.openai.com/v1/',
provider_details={'finish_reason': 'completed'},
provider_response_id='resp_68bb6452990081968f5aff503a55e3b903498c8aa840cf12',
finish_reason='stop',
run_id=IsStr(),
),
]
)
mistral_model = MistralModel('magistral-medium-latest', provider=MistralProvider(api_key=mistral_api_key))
result = await agent.run(
'Considering the way to cross the street, analogously, how do I cross the river?',
model=mistral_model,
message_history=result.all_messages(),
)
assert result.new_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='Considering the way to cross the street, analogously, how do I cross the river?',
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ThinkingPart(content=IsStr()),
TextPart(content=IsStr()),
],
usage=RequestUsage(input_tokens=664, output_tokens=747),
model_name='magistral-medium-latest',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='9abe8b736bff46af8e979b52334a57cd',
finish_reason='stop',
run_id=IsStr(),
),
]
)
@pytest.mark.vcr()
async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mistral_api_key: str):
model = MistralModel('magistral-medium-latest', provider=MistralProvider(api_key=mistral_api_key))
agent = Agent(model)
async with agent.iter(user_prompt='How do I cross the street?') as agent_run:
async for node in agent_run:
if Agent.is_model_request_node(node) or Agent.is_call_tools_node(node):
async with node.stream(agent_run.ctx) as request_stream:
async for _ in request_stream:
pass
assert agent_run.result is not None
assert agent_run.result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='How do I cross the street?',
timestamp=IsDatetime(),
)
],
run_id=IsStr(),
),
ModelResponse(
parts=[
ThinkingPart(
content='Okay, the user is asking how to cross the street. I know that crossing the street safely involves a few key steps: first, look both ways to check for oncoming traffic; second, use a crosswalk if one is available; third, obey any traffic signals or signs that may be present; and finally, proceed with caution until you have safely reached the other side. Let me compile this information into a clear and concise response.'
),
TextPart(
content="""\
To cross the street safely, follow these steps:
1. Look both ways to check for oncoming traffic.
2. Use a crosswalk if one is available.
3. Obey any traffic signals or signs that may be present.
4. Proceed with caution until you have safely reached the other side.
```markdown
To cross the street safely, follow these steps:
1. Look both ways to check for oncoming traffic.
2. Use a crosswalk if one is available.
3. Obey any traffic signals or signs that may be present.
4. Proceed with caution until you have safely reached the other side.
```
By following these steps, you can ensure a safe crossing.\
"""
),
],
usage=RequestUsage(input_tokens=10, output_tokens=232),
model_name='magistral-medium-latest',
timestamp=IsDatetime(),
provider_name='mistral',
provider_url='https://api.mistral.ai',
provider_details={'finish_reason': 'stop'},
provider_response_id='9f9d90210f194076abeee223863eaaf0',
finish_reason='stop',
run_id=IsStr(),
),
]
)