Skip to main content
Glama

mcp-run-python

Official
by pydantic
test_agent.py207 kB
import asyncio import json import re import sys from collections import defaultdict from collections.abc import AsyncIterable, Callable from dataclasses import dataclass, replace from datetime import timezone from typing import Any, Literal, Union import httpx import pytest from dirty_equals import IsJson from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter, field_validator from pydantic_core import to_json from typing_extensions import Self from pydantic_ai import ( AbstractToolset, Agent, AgentStreamEvent, AudioUrl, BinaryContent, BinaryImage, CombinedToolset, DocumentUrl, FunctionToolset, ImageUrl, IncompleteToolCall, ModelMessage, ModelMessagesTypeAdapter, ModelProfile, ModelRequest, ModelResponse, ModelResponsePart, ModelRetry, PrefixedToolset, RetryPromptPart, RunContext, SystemPromptPart, TextPart, ToolCallPart, ToolReturn, ToolReturnPart, UnexpectedModelBehavior, UserError, UserPromptPart, VideoUrl, capture_run_messages, ) from pydantic_ai._output import ( NativeOutput, NativeOutputSchema, OutputSpec, PromptedOutput, TextOutput, ToolOutputSchema, ) from pydantic_ai.agent import AgentRunResult, WrapperAgent from pydantic_ai.builtin_tools import WebSearchTool from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import StructuredDict, ToolOutput from pydantic_ai.result import RunUsage from pydantic_ai.run import AgentRunResultEvent from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime, IsNow, IsStr, TestEnv pytestmark = pytest.mark.anyio def test_result_tuple(): def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None args_json = '{"response": ["foo", "bar"]}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_tuple), output_type=tuple[str, str]) result = agent.run_sync('Hello') assert result.output == ('foo', 'bar') assert result.response == snapshot( ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_tuple:', timestamp=IsDatetime(), ) ) class Person(BaseModel): name: str def test_result_list_of_models_with_stringified_response(): def return_list(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None # Simulate providers that return the nested payload as a JSON string under "response" args_json = json.dumps( { 'response': json.dumps( [ {'name': 'John Doe'}, {'name': 'Jane Smith'}, ] ) } ) return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_list), output_type=list[Person]) result = agent.run_sync('Hello') assert result.output == snapshot( [ Person(name='John Doe'), Person(name='Jane Smith'), ] ) class Foo(BaseModel): a: int b: str def test_result_pydantic_model(): def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None args_json = '{"a": 1, "b": "foo"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_model), output_type=Foo) result = agent.run_sync('Hello') assert isinstance(result.output, Foo) assert result.output.model_dump() == {'a': 1, 'b': 'foo'} def test_result_pydantic_model_retry(): def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: args_json = '{"a": "wrong", "b": "foo"}' else: args_json = '{"a": 42, "b": "foo"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_model), output_type=Foo) assert agent.name is None result = agent.run_sync('Hello') assert agent.name == 'agent' assert isinstance(result.output, Foo) assert result.output.model_dump() == {'a': 42, 'b': 'foo'} assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": "wrong", "b": "foo"}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ RetryPromptPart( tool_name='final_result', content=[ { 'type': 'int_parsing', 'loc': ('a',), 'msg': 'Input should be a valid integer, unable to parse string as an integer', 'input': 'wrong', } ], tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=87, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ] ) assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') def test_result_pydantic_model_validation_error(): def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: args_json = '{"a": 1, "b": "foo"}' else: args_json = '{"a": 1, "b": "bar"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) class Bar(BaseModel): a: int b: str @field_validator('b') def check_b(cls, v: str) -> str: if v == 'foo': raise ValueError('must not be foo') return v agent = Agent(FunctionModel(return_model), output_type=Bar) result = agent.run_sync('Hello') assert isinstance(result.output, Bar) assert result.output.model_dump() == snapshot({'a': 1, 'b': 'bar'}) messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()] assert messages_part_kinds == snapshot( [ ('request', ['user-prompt']), ('response', ['tool-call']), ('request', ['retry-prompt']), ('response', ['tool-call']), ('request', ['tool-return']), ] ) user_retry = result.all_messages()[2] assert isinstance(user_retry, ModelRequest) retry_prompt = user_retry.parts[0] assert isinstance(retry_prompt, RetryPromptPart) assert retry_prompt.model_response() == snapshot("""\ 1 validation errors: [ { "type": "value_error", "loc": [ "b" ], "msg": "Value error, must not be foo", "input": "foo" } ] Fix the errors and try again.""") def test_output_validator(): def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: args_json = '{"a": 41, "b": "foo"}' else: args_json = '{"a": 42, "b": "foo"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_model), output_type=Foo) @agent.output_validator def validate_output(ctx: RunContext[None], o: Foo) -> Foo: assert ctx.tool_name == 'final_result' if o.a == 42: return o else: raise ModelRetry('"a" should be 42') result = agent.run_sync('Hello') assert isinstance(result.output, Foo) assert result.output.model_dump() == {'a': 42, 'b': 'foo'} assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 41, "b": "foo"}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=7), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ RetryPromptPart( content='"a" should be 42', tool_name='final_result', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=63, output_tokens=14), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ] ) def test_plain_response_then_tuple(): call_index = 0 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: nonlocal call_index assert info.output_tools is not None call_index += 1 if call_index == 1: return ModelResponse(parts=[TextPart('hello')]) else: args_json = '{"response": ["foo", "bar"]}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(return_tuple), output_type=ToolOutput(tuple[str, str])) result = agent.run_sync('Hello') assert result.output == ('foo', 'bar') assert call_index == 2 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='hello')], usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ RetryPromptPart( content='Please include your response in a tool call.', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ) ] ), ModelResponse( parts=[ ToolCallPart(tool_name='final_result', args='{"response": ["foo", "bar"]}', tool_call_id=IsStr()) ], usage=RequestUsage(input_tokens=68, output_tokens=8), model_name='function:return_tuple:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ] ) assert result._output_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] assert result.all_messages(output_tool_return_content='foobar')[-1] == snapshot( ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='foobar', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ) ) assert result.all_messages()[-1] == snapshot( ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ) ) def test_output_tool_return_content_str_return(): agent = Agent('test') result = agent.run_sync('Hello') assert result.output == 'success (no tool calls)' assert result.response == snapshot( ModelResponse( parts=[TextPart(content='success (no tool calls)')], usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsDatetime(), ) ) msg = re.escape('Cannot set output tool return content when the return type is `str`.') with pytest.raises(ValueError, match=msg): result.all_messages(output_tool_return_content='foobar') def test_output_tool_return_content_no_tool(): agent = Agent('test', output_type=int) result = agent.run_sync('Hello') assert result.output == 0 result._output_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): result.all_messages(output_tool_return_content='foobar') def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) assert isinstance(agent._output_schema, ToolOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) assert m.last_model_request_parameters.allow_text_output is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 assert m.last_model_request_parameters.output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'properties': { 'response': { 'maxItems': 2, 'minItems': 2, 'prefixItems': [{'type': 'string'}, {'type': 'string'}], 'type': 'array', } }, 'required': ['response'], 'type': 'object', }, outer_typed_dict_key='response', kind='output', ) ] ) def upcase(text: str) -> str: return text.upper() @pytest.mark.parametrize( 'input_union_callable', [ lambda: Union[str, Foo], # noqa: UP007 lambda: Union[Foo, str], # noqa: UP007 lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str], lambda: [TextOutput(upcase), ToolOutput(Foo)], ], ids=[ 'Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]', '[TextOutput(upcase), ToolOutput(Foo)]', ], ) def test_response_union_allow_str(input_union_callable: Callable[[], Any]): try: union = input_union_callable() except TypeError: # pragma: lax no cover pytest.skip('Python version does not support `|` syntax for unions') m = TestModel() agent: Agent[None, str | Foo] = Agent(m, output_type=union) got_tool_call_name = 'unset' @agent.output_validator def validate_output(ctx: RunContext[None], o: Any) -> Any: nonlocal got_tool_call_name got_tool_call_name = ctx.tool_name return o assert agent._output_schema.allows_text # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert isinstance(result.output, str) assert result.output.lower() == snapshot('success (no tool calls)') assert got_tool_call_name == snapshot(None) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) assert m.last_model_request_parameters.allow_text_output is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 assert m.last_model_request_parameters.output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'properties': { 'a': {'type': 'integer'}, 'b': {'type': 'string'}, }, 'required': ['a', 'b'], 'title': 'Foo', 'type': 'object', }, kind='output', ) ] ) # pyright: reportUnknownMemberType=false, reportUnknownVariableType=false @pytest.mark.parametrize( 'union_code', [ pytest.param('OutputType = Union[Foo, Bar]'), pytest.param('OutputType = [Foo, Bar]'), pytest.param('OutputType = [ToolOutput(Foo), ToolOutput(Bar)]'), pytest.param('OutputType = Foo | Bar'), pytest.param('OutputType: TypeAlias = Foo | Bar'), pytest.param( 'type OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+') ), ], ) def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str): module_code = f''' from pydantic import BaseModel from typing import Union from typing_extensions import TypeAlias from pydantic_ai import ToolOutput class Foo(BaseModel): a: int b: str class Bar(BaseModel): """This is a bar model.""" b: str {union_code} ''' mod = create_module(module_code) m = TestModel() agent = Agent(m, output_type=mod.OutputType) got_tool_call_name = 'unset' @agent.output_validator def validate_output(ctx: RunContext[None], o: Any) -> Any: nonlocal got_tool_call_name got_tool_call_name = ctx.tool_name return o result = agent.run_sync('Hello') assert result.output == mod.Foo(a=0, b='a') assert got_tool_call_name == snapshot('final_result_Foo') assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) assert m.last_model_request_parameters.allow_text_output is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 2 assert m.last_model_request_parameters.output_tools == snapshot( [ ToolDefinition( name='final_result_Foo', description='Foo: The final response which ends this conversation', parameters_json_schema={ 'properties': { 'a': {'type': 'integer'}, 'b': {'type': 'string'}, }, 'required': ['a', 'b'], 'title': 'Foo', 'type': 'object', }, kind='output', ), ToolDefinition( name='final_result_Bar', description='This is a bar model.', parameters_json_schema={ 'properties': {'b': {'type': 'string'}}, 'required': ['b'], 'title': 'Bar', 'type': 'object', }, kind='output', ), ] ) result = agent.run_sync('Hello', model=TestModel(seed=1)) assert result.output == mod.Bar(b='b') assert got_tool_call_name == snapshot('final_result_Bar') def test_output_type_with_two_descriptions(): class MyOutput(BaseModel): """Description from docstring""" valid: bool m = TestModel() agent = Agent(m, output_type=ToolOutput(MyOutput, description='Description from ToolOutput')) result = agent.run_sync('Hello') assert result.output == snapshot(MyOutput(valid=False)) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.output_tools == snapshot( [ ToolDefinition( name='final_result', description='Description from ToolOutput. Description from docstring', parameters_json_schema={ 'properties': {'valid': {'type': 'boolean'}}, 'required': ['valid'], 'title': 'MyOutput', 'type': 'object', }, kind='output', ) ] ) def test_output_type_tool_output_union(): class Foo(BaseModel): a: int b: str class Bar(BaseModel): c: bool m = TestModel() marker: ToolOutput[Foo | Bar] = ToolOutput(Foo | Bar, strict=False) # type: ignore agent = Agent(m, output_type=marker) result = agent.run_sync('Hello') assert result.output == snapshot(Foo(a=0, b='a')) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ '$defs': { 'Bar': { 'properties': {'c': {'type': 'boolean'}}, 'required': ['c'], 'title': 'Bar', 'type': 'object', }, 'Foo': { 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}}, 'required': ['a', 'b'], 'title': 'Foo', 'type': 'object', }, }, 'properties': {'response': {'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}]}}, 'required': ['response'], 'type': 'object', }, outer_typed_dict_key='response', strict=False, kind='output', ) ] ) def test_output_type_function(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=get_weather) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_function_with_run_context(): class Weather(BaseModel): temperature: float description: str def get_weather(ctx: RunContext[None], city: str) -> Weather: assert ctx is not None return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=get_weather) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_bound_instance_method(): class Weather(BaseModel): temperature: float description: str def get_weather(self, city: str) -> Self: return self weather = Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_bound_instance_method_with_run_context(): class Weather(BaseModel): temperature: float description: str def get_weather(self, ctx: RunContext[None], city: str) -> Self: assert ctx is not None return self weather = Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_function_with_retry(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: if city != 'Mexico City': raise ModelRetry('City not found, I only know Mexico City') return Weather(temperature=28.7, description='sunny') def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: args_json = '{"city": "New York City"}' else: args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=get_weather) result = agent.run_sync('New York City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='New York City', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args='{"city": "New York City"}', tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=53, output_tokens=7), model_name='function:call_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ RetryPromptPart( content='City not found, I only know Mexico City', tool_name='final_result', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=68, output_tokens=13), model_name='function:call_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ] ) def test_output_type_text_output_function_with_retry(): class Weather(BaseModel): temperature: float description: str def get_weather(ctx: RunContext[None], city: str) -> Weather: assert ctx is not None if city != 'Mexico City': raise ModelRetry('City not found, I only know Mexico City') return Weather(temperature=28.7, description='sunny') def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: city = 'New York City' else: city = 'Mexico City' return ModelResponse(parts=[TextPart(content=city)]) agent = Agent(FunctionModel(call_tool), output_type=TextOutput(get_weather)) result = agent.run_sync('New York City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='New York City', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='New York City')], usage=RequestUsage(input_tokens=53, output_tokens=3), model_name='function:call_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ RetryPromptPart( content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Mexico City')], usage=RequestUsage(input_tokens=70, output_tokens=5), model_name='function:call_tool:', timestamp=IsDatetime(), ), ] ) @pytest.mark.parametrize( 'output_type', [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputSpec[str]): with pytest.raises(UserError, match='Only one `str` or `TextOutput` is allowed.'): Agent('test', output_type=output_type) def test_output_type_text_output_invalid(): def int_func(x: int) -> str: return str(int) # pragma: no cover with pytest.raises(UserError, match='TextOutput must take a function taking a single `str` argument'): output_type: TextOutput[str] = TextOutput(int_func) # type: ignore Agent('test', output_type=output_type) def test_output_type_async_function(): class Weather(BaseModel): temperature: float description: str async def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=get_weather) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_function_with_custom_tool_name(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=ToolOutput(get_weather, name='get_weather')) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='get_weather', description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ) ] ) def test_output_type_function_or_model(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=[get_weather, Weather]) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='final_result_get_weather', description='get_weather: The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ), ToolDefinition( name='final_result_Weather', description='Weather: The final response which ends this conversation', parameters_json_schema={ 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, 'required': ['temperature', 'description'], 'title': 'Weather', 'type': 'object', }, kind='output', ), ] ) def test_output_type_text_output_function(): def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='world')]) agent = Agent(FunctionModel(say_world), output_type=TextOutput(upcase)) result = agent.run_sync('hello') assert result.output == snapshot('WORLD') assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='world')], usage=RequestUsage(input_tokens=51, output_tokens=1), model_name='function:say_world:', timestamp=IsDatetime(), ), ] ) def test_output_type_handoff_to_agent(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent(FunctionModel(call_tool), output_type=get_weather) handoff_result = None async def handoff(city: str) -> Weather: result = await agent.run(f'Get me the weather in {city}') nonlocal handoff_result handoff_result = result return result.output def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) supervisor_agent = Agent(FunctionModel(call_handoff_tool), output_type=handoff) result = supervisor_agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Mexico City', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=52, output_tokens=6), model_name='function:call_handoff_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ] ) assert handoff_result is not None assert handoff_result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Get me the weather in Mexico City', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args='{"city": "Mexico City"}', tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ] ) def test_output_type_multiple_custom_tools(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: return Weather(temperature=28.7, description='sunny') output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent( FunctionModel(call_tool), output_type=[ ToolOutput(get_weather, name='get_weather'), ToolOutput(Weather, name='return_weather'), ], ) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( [ ToolDefinition( name='get_weather', description='get_weather: The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, kind='output', ), ToolDefinition( name='return_weather', description='Weather: The final response which ends this conversation', parameters_json_schema={ 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, 'required': ['temperature', 'description'], 'title': 'Weather', 'type': 'object', }, kind='output', ), ] ) def test_output_type_structured_dict(): PersonDict = StructuredDict( { 'type': 'object', 'properties': { 'name': {'type': 'string'}, 'age': {'type': 'integer'}, }, 'required': ['name', 'age'], }, name='Person', description='A person', ) AnimalDict = StructuredDict( { 'type': 'object', 'properties': { 'name': {'type': 'string'}, 'species': {'type': 'string'}, }, 'required': ['name', 'species'], }, name='Animal', description='An animal', ) output_tools = None def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None nonlocal output_tools output_tools = info.output_tools args_json = '{"name": "John Doe", "age": 30}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) agent = Agent( FunctionModel(call_tool), output_type=[PersonDict, AnimalDict], ) result = agent.run_sync('Generate a person') assert result.output == snapshot({'name': 'John Doe', 'age': 30}) assert output_tools == snapshot( [ ToolDefinition( name='final_result_Person', parameters_json_schema={ 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, 'required': ['name', 'age'], 'title': 'Person', 'type': 'object', }, description='A person', kind='output', ), ToolDefinition( name='final_result_Animal', parameters_json_schema={ 'properties': {'name': {'type': 'string'}, 'species': {'type': 'string'}}, 'required': ['name', 'species'], 'title': 'Animal', 'type': 'object', }, description='An animal', kind='output', ), ] ) def test_output_type_structured_dict_nested(): """Test StructuredDict with nested JSON schemas using $ref - Issue #2466.""" # Schema with nested $ref that pydantic's generator can't resolve CarDict = StructuredDict( { '$defs': { 'Tire': { 'type': 'object', 'properties': {'brand': {'type': 'string'}, 'size': {'type': 'integer'}}, 'required': ['brand', 'size'], } }, 'type': 'object', 'properties': { 'make': {'type': 'string'}, 'model': {'type': 'string'}, 'tires': {'type': 'array', 'items': {'$ref': '#/$defs/Tire'}}, }, 'required': ['make', 'model', 'tires'], }, name='Car', description='A car with tires', ) def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None # Verify the output tool schema has been properly transformed # The $refs should be inlined by InlineDefsJsonSchemaTransformer output_tool = info.output_tools[0] schema = output_tool.parameters_json_schema assert schema is not None assert schema == snapshot( { 'properties': { 'make': {'type': 'string'}, 'model': {'type': 'string'}, 'tires': { 'items': { 'properties': {'brand': {'type': 'string'}, 'size': {'type': 'integer'}}, 'required': ['brand', 'size'], 'type': 'object', }, 'type': 'array', }, }, 'required': ['make', 'model', 'tires'], 'title': 'Car', 'type': 'object', } ) return ModelResponse( parts=[ ToolCallPart( output_tool.name, {'make': 'Toyota', 'model': 'Camry', 'tires': [{'brand': 'Michelin', 'size': 17}]} ) ] ) agent = Agent(FunctionModel(call_tool), output_type=CarDict) result = agent.run_sync('Generate a car') assert result.output == snapshot({'make': 'Toyota', 'model': 'Camry', 'tires': [{'brand': 'Michelin', 'size': 17}]}) def test_structured_dict_recursive_refs(): class Node(BaseModel): nodes: list['Node'] | dict[str, 'Node'] schema = Node.model_json_schema() assert schema == snapshot( { '$defs': { 'Node': { 'properties': { 'nodes': { 'anyOf': [ {'items': {'$ref': '#/$defs/Node'}, 'type': 'array'}, {'additionalProperties': {'$ref': '#/$defs/Node'}, 'type': 'object'}, ], 'title': 'Nodes', } }, 'required': ['nodes'], 'title': 'Node', 'type': 'object', } }, '$ref': '#/$defs/Node', } ) with pytest.raises( UserError, match=re.escape( '`StructuredDict` does not currently support recursive `$ref`s and `$defs`. See https://github.com/pydantic/pydantic/issues/12145 for more information.' ), ): StructuredDict(schema) def test_default_structured_output_mode(): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) native_model = FunctionModel( hello, profile=ModelProfile(supports_json_schema_output=True, default_structured_output_mode='native'), ) prompted_model = FunctionModel( hello, profile=ModelProfile(default_structured_output_mode='prompted'), ) class Foo(BaseModel): bar: str tool_agent = Agent(tool_model, output_type=Foo) assert tool_agent._output_schema.mode == 'tool' # type: ignore[reportPrivateUsage] native_agent = Agent(native_model, output_type=Foo) assert native_agent._output_schema.mode == 'native' # type: ignore[reportPrivateUsage] prompted_agent = Agent(prompted_model, output_type=Foo) assert prompted_agent._output_schema.mode == 'prompted' # type: ignore[reportPrivateUsage] def test_prompted_output(): def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() return ModelResponse(parts=[TextPart(content=text)]) m = FunctionModel(return_city_location) class CityLocation(BaseModel): """Description from docstring.""" city: str country: str agent = Agent( m, output_type=PromptedOutput(CityLocation, name='City & Country', description='Description from PromptedOutput'), ) result = agent.run_sync('What is the capital of Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='What is the capital of Mexico?', timestamp=IsDatetime(), ) ], instructions="""\ Always respond with a JSON object that's compatible with this schema: {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedOutput. Description from docstring."} Don't include any text or Markdown fencing before or after.\ """, ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=RequestUsage(input_tokens=56, output_tokens=7), model_name='function:return_city_location:', timestamp=IsDatetime(), ), ] ) def test_prompted_output_with_template(): def return_foo(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = Foo(bar='baz').model_dump_json() return ModelResponse(parts=[TextPart(content=text)]) m = FunctionModel(return_foo) class Foo(BaseModel): bar: str agent = Agent(m, output_type=PromptedOutput(Foo, template='Gimme some JSON:')) result = agent.run_sync('What is the capital of Mexico?') assert result.output == snapshot(Foo(bar='baz')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='What is the capital of Mexico?', timestamp=IsDatetime(), ) ], instructions="""\ Gimme some JSON: {"properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Foo", "type": "object"}\ """, ), ModelResponse( parts=[TextPart(content='{"bar":"baz"}')], usage=RequestUsage(input_tokens=56, output_tokens=4), model_name='function:return_foo:', timestamp=IsDatetime(), ), ] ) def test_prompted_output_with_defs(): class Foo(BaseModel): """Foo description""" foo: str class Bar(BaseModel): """Bar description""" bar: str class Baz(BaseModel): """Baz description""" baz: str class FooBar(BaseModel): """FooBar description""" foo: Foo bar: Bar class FooBaz(BaseModel): """FooBaz description""" foo: Foo baz: Baz def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = '{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' return ModelResponse(parts=[TextPart(content=text)]) m = FunctionModel(return_foo_bar) agent = Agent( m, output_type=PromptedOutput( [FooBar, FooBaz], name='FooBar or FooBaz', description='FooBar or FooBaz description' ), ) result = agent.run_sync('What is foo?') assert result.output == snapshot(FooBar(foo=Foo(foo='foo'), bar=Bar(bar='bar'))) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='What is foo?', timestamp=IsDatetime(), ) ], instructions="""\ Always respond with a JSON object that's compatible with this schema: {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "FooBar"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "bar": {"$ref": "#/$defs/Bar"}}, "required": ["foo", "bar"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBar", "description": "FooBar description"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "FooBaz"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "baz": {"$ref": "#/$defs/Baz"}}, "required": ["foo", "baz"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBaz", "description": "FooBaz description"}]}}, "required": ["result"], "additionalProperties": false, "$defs": {"Bar": {"description": "Bar description", "properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Bar", "type": "object"}, "Foo": {"description": "Foo description", "properties": {"foo": {"type": "string"}}, "required": ["foo"], "title": "Foo", "type": "object"}, "Baz": {"description": "Baz description", "properties": {"baz": {"type": "string"}}, "required": ["baz"], "title": "Baz", "type": "object"}}, "title": "FooBar or FooBaz", "description": "FooBaz description"} Don't include any text or Markdown fencing before or after.\ """, ), ModelResponse( parts=[ TextPart( content='{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' ) ], usage=RequestUsage(input_tokens=53, output_tokens=17), model_name='function:return_foo_bar:', timestamp=IsDatetime(), ), ] ) def test_native_output(): def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: if len(messages) == 1: text = '{"city": "Mexico City"}' else: text = '{"city": "Mexico City", "country": "Mexico"}' return ModelResponse(parts=[TextPart(content=text)]) m = FunctionModel(return_city_location) class CityLocation(BaseModel): city: str country: str agent = Agent( m, output_type=NativeOutput(CityLocation), ) result = agent.run_sync('What is the capital of Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='What is the capital of Mexico?', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], usage=RequestUsage(input_tokens=56, output_tokens=5), model_name='function:return_city_location:', timestamp=IsDatetime(), ), ModelRequest( parts=[ RetryPromptPart( content=[ { 'type': 'missing', 'loc': ('country',), 'msg': 'Field required', 'input': {'city': 'Mexico City'}, } ], tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], usage=RequestUsage(input_tokens=85, output_tokens=12), model_name='function:return_city_location:', timestamp=IsDatetime(), ), ] ) def test_native_output_strict_mode(): class CityLocation(BaseModel): city: str country: str agent = Agent(output_type=NativeOutput(CityLocation, strict=True)) output_schema = agent._output_schema # pyright: ignore[reportPrivateUsage] assert isinstance(output_schema, NativeOutputSchema) assert output_schema.object_def.strict def test_prompted_output_function_with_retry(): class Weather(BaseModel): temperature: float description: str def get_weather(city: str) -> Weather: if city != 'Mexico City': raise ModelRetry('City not found, I only know Mexico City') return Weather(temperature=28.7, description='sunny') def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None if len(messages) == 1: args_json = '{"city": "New York City"}' else: args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[TextPart(content=args_json)]) agent = Agent(FunctionModel(call_tool), output_type=PromptedOutput(get_weather)) result = agent.run_sync('New York City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='New York City', timestamp=IsDatetime(), ) ], instructions="""\ Always respond with a JSON object that's compatible with this schema: {"additionalProperties": false, "properties": {"city": {"type": "string"}}, "required": ["city"], "type": "object", "title": "get_weather"} Don't include any text or Markdown fencing before or after.\ """, ), ModelResponse( parts=[TextPart(content='{"city": "New York City"}')], usage=RequestUsage(input_tokens=53, output_tokens=6), model_name='function:call_tool:', timestamp=IsDatetime(), ), ModelRequest( parts=[ RetryPromptPart( content='City not found, I only know Mexico City', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], usage=RequestUsage(input_tokens=70, output_tokens=11), model_name='function:call_tool:', timestamp=IsDatetime(), ), ] ) def test_run_with_history_new(): m = TestModel() agent = Agent(m, system_prompt='Foobar') @agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' result1 = agent.run_sync('Hello') assert result1.new_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ] ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) # if we pass new_messages, system prompt is inserted before the message_history messages result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) assert result2.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ] ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], usage=RequestUsage(input_tokens=55, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result2.new_messages() == result2.all_messages()[-2:] assert result2.output == snapshot('{"ret_a":"a-apple"}') assert result2._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ ('request', ['system-prompt', 'user-prompt']), ('response', ['tool-call']), ('request', ['tool-return']), ('response', ['text']), ('request', ['user-prompt']), ('response', ['text']), ] ) assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') # if we pass all_messages, system prompt is NOT inserted before the message_history messages, # so only one system prompt result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) # same as result2 except for datetimes assert result3.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ] ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='{"ret_a":"a-apple"}')], usage=RequestUsage(input_tokens=55, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result3.new_messages() == result3.all_messages()[-2:] assert result3.output == snapshot('{"ret_a":"a-apple"}') assert result3._output_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] assert result3.usage() == snapshot(RunUsage(requests=1, input_tokens=55, output_tokens=13)) assert result3.timestamp() == IsNow(tz=timezone.utc) def test_run_with_history_new_structured(): m = TestModel() class Response(BaseModel): a: int agent = Agent(m, system_prompt='Foobar', output_type=Response) @agent.tool_plain async def ret_a(x: str) -> str: return f'{x}-apple' result1 = agent.run_sync('Hello') assert result1.new_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ] ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args={'a': 0}, tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ] ) result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) assert result2.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], ), ModelResponse( parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=52, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='ret_a', content='a-apple', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ], ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=53, output_tokens=9), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ], ), # second call, notice no repeated system prompt ModelRequest( parts=[ UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), ], ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'a': 0}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=59, output_tokens=13), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ] ), ] ) assert result2.output == snapshot(Response(a=0)) assert result2.new_messages() == result2.all_messages()[-3:] assert result2._output_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] assert result2.usage() == snapshot(RunUsage(requests=1, input_tokens=59, output_tokens=13)) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( [ ('request', ['system-prompt', 'user-prompt']), ('response', ['tool-call']), ('request', ['tool-return']), ('response', ['tool-call']), ('request', ['tool-return']), ('request', ['user-prompt']), ('response', ['tool-call']), ('request', ['tool-return']), ] ) assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",') def test_run_with_history_ending_on_model_request_and_no_user_prompt(): m = TestModel() agent = Agent(m) @agent.system_prompt(dynamic=True) async def system_prompt(ctx: RunContext) -> str: return f'System prompt: user prompt length = {len(ctx.prompt or [])}' messages: list[ModelMessage] = [ ModelRequest( parts=[ SystemPromptPart(content='System prompt', dynamic_ref=system_prompt.__qualname__), UserPromptPart(content=['Hello', ImageUrl('https://example.com/image.jpg')]), UserPromptPart(content='How goes it?'), ], instructions='Original instructions', ), ] @agent.instructions async def instructions(ctx: RunContext) -> str: assert ctx.prompt == ['Hello', ImageUrl('https://example.com/image.jpg'), 'How goes it?'] return 'New instructions' result = agent.run_sync(message_history=messages) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart( content='System prompt: user prompt length = 3', timestamp=IsDatetime(), dynamic_ref=IsStr(), ), UserPromptPart( content=['Hello', ImageUrl(url='https://example.com/image.jpg', identifier='39cfc4')], timestamp=IsDatetime(), ), UserPromptPart( content='How goes it?', timestamp=IsDatetime(), ), ], instructions='New instructions', ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], usage=RequestUsage(input_tokens=61, output_tokens=4), model_name='test', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-1:] def test_run_with_history_ending_on_model_response_with_tool_calls_and_no_user_prompt(): """Test that an agent run with message_history ending on ModelResponse starts with CallToolsNode.""" def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='Final response')]) agent = Agent(FunctionModel(simple_response)) @agent.tool_plain def test_tool() -> str: return 'Test response' message_history = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse(parts=[ToolCallPart(tool_name='test_tool', args='{}', tool_call_id='call_123')]), ] result = agent.run_sync(message_history=message_history) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='test_tool', args='{}', tool_call_id='call_123')], timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='test_tool', content='Test response', tool_call_id='call_123', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Final response')], usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='function:simple_response:', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] def test_run_with_history_ending_on_model_response_with_tool_calls_and_user_prompt(): """Test that an agent run raises error when message_history ends on ModelResponse with tool calls and there's a new prompt.""" def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover agent = Agent(FunctionModel(simple_response)) message_history = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse(parts=[ToolCallPart(tool_name='test_tool', args='{}', tool_call_id='call_123')]), ] with pytest.raises( UserError, match='Cannot provide a new user prompt when the message history contains unprocessed tool calls.', ): agent.run_sync(user_prompt='New question', message_history=message_history) def test_run_with_history_ending_on_model_response_without_tool_calls_or_user_prompt(): """Test that an agent run raises error when message_history ends on ModelResponse without tool calls or a new prompt.""" def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover agent = Agent(FunctionModel(simple_response)) message_history = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse(parts=[TextPart('world')]), ] result = agent.run_sync(message_history=message_history) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='world')], timestamp=IsDatetime(), ), ] ) assert result.new_messages() == [] def test_empty_response(): def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[]) else: return ModelResponse(parts=[TextPart('ok here is text')]) agent = Agent(FunctionModel(llm)) result = agent.run_sync('Hello') assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[], usage=RequestUsage(input_tokens=51), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest(parts=[]), ModelResponse( parts=[TextPart(content='ok here is text')], usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='function:llm:', timestamp=IsDatetime(), ), ] ) def test_empty_response_without_recovery(): def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[]) agent = Agent(FunctionModel(llm), output_type=tuple[str, int]) with capture_run_messages() as messages: with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[], usage=RequestUsage(input_tokens=51), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest(parts=[]), ModelResponse( parts=[], usage=RequestUsage(input_tokens=51), model_name='function:llm:', timestamp=IsDatetime(), ), ] ) def test_unknown_tool(): def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) agent = Agent(FunctionModel(empty)) with capture_run_messages() as messages: with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ RetryPromptPart( tool_name='foobar', content="Unknown tool name: 'foobar'. No tools available.", tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=65, output_tokens=4), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_unknown_tool_fix(): def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: if len(m) > 1: return ModelResponse(parts=[TextPart('success')]) else: return ModelResponse(parts=[ToolCallPart('foobar', '{}')]) agent = Agent(FunctionModel(empty)) result = agent.run_sync('Hello') assert result.output == 'success' assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args='{}', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=2), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ RetryPromptPart( tool_name='foobar', content="Unknown tool name: 'foobar'. No tools available.", tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[TextPart(content='success')], usage=RequestUsage(input_tokens=65, output_tokens=3), model_name='function:empty:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_tool_exceeds_token_limit_error(): def return_incomplete_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar",')]) resp.finish_reason = 'length' return resp agent = Agent(FunctionModel(return_incomplete_tool), output_type=str) with pytest.raises( IncompleteToolCall, match=r'Model token limit \(10\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.', ): agent.run_sync('Hello', model_settings=ModelSettings(max_tokens=10)) with pytest.raises( IncompleteToolCall, match=r'Model token limit \(provider default\) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.', ): agent.run_sync('Hello') def test_tool_exceeds_token_limit_but_complete_args(): def return_complete_tool_but_hit_limit(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: resp = ModelResponse(parts=[ToolCallPart('dummy_tool', args='{"foo": "bar"}')]) resp.finish_reason = 'length' return resp return ModelResponse(parts=[TextPart('done')]) agent = Agent(FunctionModel(return_complete_tool_but_hit_limit), output_type=str) @agent.tool_plain def dummy_tool(foo: str) -> str: return 'tool-ok' result = agent.run_sync('Hello') assert result.output == 'done' def test_model_requests_blocked(env: TestEnv): try: env.set('GEMINI_API_KEY', 'foobar') agent = Agent('google-gla:gemini-1.5-flash', output_type=tuple[str, str], defer_model_check=True) with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): agent.run_sync('Hello') except ImportError: # pragma: lax no cover pytest.skip('google-genai not installed') def test_override_model(env: TestEnv): env.set('GEMINI_API_KEY', 'foobar') agent = Agent('google-gla:gemini-1.5-flash', output_type=tuple[int, str], defer_model_check=True) with agent.override(model='test'): result = agent.run_sync('Hello') assert result.output == snapshot((0, 'a')) def test_set_model(env: TestEnv): env.set('GEMINI_API_KEY', 'foobar') agent = Agent(output_type=tuple[int, str]) agent.model = 'test' result = agent.run_sync('Hello') assert result.output == snapshot((0, 'a')) def test_override_model_no_model(): agent = Agent() with pytest.raises(UserError, match=r'`model` must either be set.+Even when `override\(model=...\)` is customiz'): with agent.override(model='test'): agent.run_sync('Hello') def test_run_sync_multiple(): agent = Agent('test') @agent.tool_plain async def make_request() -> str: async with httpx.AsyncClient() as client: # use this as I suspect it's about the fastest globally available endpoint try: response = await client.get('https://cloudflare.com/cdn-cgi/trace') except httpx.ConnectError: # pragma: no cover pytest.skip('offline') else: return str(response.status_code) for _ in range(2): result = agent.run_sync('Hello') assert result.output == '{"make_request":"200"}' async def test_agent_name(): my_agent = Agent('test') assert my_agent.name is None await my_agent.run('Hello', infer_name=False) assert my_agent.name is None await my_agent.run('Hello') assert my_agent.name == 'my_agent' async def test_agent_name_already_set(): my_agent = Agent('test', name='fig_tree') assert my_agent.name == 'fig_tree' await my_agent.run('Hello') assert my_agent.name == 'fig_tree' async def test_agent_name_changes(): my_agent = Agent('test') await my_agent.run('Hello') assert my_agent.name == 'my_agent' new_agent = my_agent del my_agent await new_agent.run('Hello') assert new_agent.name == 'my_agent' def test_agent_name_override(): agent = Agent('test', name='custom_name') with agent.override(name='overridden_name'): agent.run_sync('Hello') assert agent.name == 'overridden_name' def test_name_from_global(create_module: Callable[[str], Any]): module_code = """ from pydantic_ai import Agent my_agent = Agent('test') def foo(): result = my_agent.run_sync('Hello') return result.output """ mod = create_module(module_code) assert mod.my_agent.name is None assert mod.foo() == snapshot('success (no tool calls)') assert mod.my_agent.name == 'my_agent' class TestMultipleToolCalls: """Tests for scenarios where multiple tool calls are made in a single response.""" class OutputType(BaseModel): """Result type used by all tests.""" value: str def test_early_strategy_stops_after_first_final_result(self): """Test that 'early' strategy stops processing regular tools after first final result.""" tool_called = [] def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None return ModelResponse( parts=[ ToolCallPart('final_result', {'value': 'final'}), ToolCallPart('regular_tool', {'x': 1}), ToolCallPart('another_tool', {'y': 2}), ToolCallPart('deferred_tool', {'x': 3}), ], ) agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early') @agent.tool_plain def regular_tool(x: int) -> int: # pragma: no cover """A regular tool that should not be called.""" tool_called.append('regular_tool') return x @agent.tool_plain def another_tool(y: int) -> int: # pragma: no cover """Another tool that should not be called.""" tool_called.append('another_tool') return y async def defer(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None: return replace(tool_def, kind='external') @agent.tool_plain(prepare=defer) def deferred_tool(x: int) -> int: # pragma: no cover return x + 1 result = agent.run_sync('test early strategy') messages = result.all_messages() # Verify no tools were called after final result assert tool_called == [] # Verify we got tool returns for all calls assert messages[-1].parts == snapshot( [ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ] ) def test_early_strategy_uses_first_final_result(self): """Test that 'early' strategy uses the first final result and ignores subsequent ones.""" def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None return ModelResponse( parts=[ ToolCallPart('final_result', {'value': 'first'}), ToolCallPart('final_result', {'value': 'second'}), ], ) agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early') result = agent.run_sync('test multiple final results') # Verify the result came from the first final tool assert result.output.value == 'first' # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ] ) def test_exhaustive_strategy_executes_all_tools(self): """Test that 'exhaustive' strategy executes all tools while using first final result.""" tool_called: list[str] = [] def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None return ModelResponse( parts=[ ToolCallPart('regular_tool', {'x': 42}), ToolCallPart('final_result', {'value': 'first'}), ToolCallPart('another_tool', {'y': 2}), ToolCallPart('final_result', {'value': 'second'}), ToolCallPart('unknown_tool', {'value': '???'}), ToolCallPart('deferred_tool', {'x': 4}), ], ) agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='exhaustive') @agent.tool_plain def regular_tool(x: int) -> int: """A regular tool that should be called.""" tool_called.append('regular_tool') return x @agent.tool_plain def another_tool(y: int) -> int: """Another tool that should be called.""" tool_called.append('another_tool') return y async def defer(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None: return replace(tool_def, kind='external') @agent.tool_plain(prepare=defer) def deferred_tool(x: int) -> int: # pragma: no cover return x + 1 result = agent.run_sync('test exhaustive strategy') # Verify the result came from the first final tool assert result.output.value == 'first' # Verify all regular tools were called assert sorted(tool_called) == sorted(['regular_tool', 'another_tool']) # Verify we got tool returns in the correct order assert result.all_messages() == snapshot( [ ModelRequest( parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))] ), ModelResponse( parts=[ ToolCallPart(tool_name='regular_tool', args={'x': 42}, tool_call_id=IsStr()), ToolCallPart(tool_name='final_result', args={'value': 'first'}, tool_call_id=IsStr()), ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), ToolCallPart(tool_name='final_result', args={'value': 'second'}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ToolCallPart( tool_name='deferred_tool', args={'x': 4}, tool_call_id=IsStr(), ), ], usage=RequestUsage(input_tokens=53, output_tokens=27), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='final_result', content='Output tool not used - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='regular_tool', content=42, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ] ), ] ) def test_early_strategy_with_final_result_in_middle(self): """Test that 'early' strategy stops at first final result, regardless of position.""" tool_called = [] def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None return ModelResponse( parts=[ ToolCallPart('regular_tool', {'x': 1}), ToolCallPart('final_result', {'value': 'final'}), ToolCallPart('another_tool', {'y': 2}), ToolCallPart('unknown_tool', {'value': '???'}), ToolCallPart('deferred_tool', {'x': 5}), ], ) agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early') @agent.tool_plain def regular_tool(x: int) -> int: # pragma: no cover """A regular tool that should not be called.""" tool_called.append('regular_tool') return x @agent.tool_plain def another_tool(y: int) -> int: # pragma: no cover """A tool that should not be called.""" tool_called.append('another_tool') return y async def defer(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None: return replace(tool_def, kind='external') @agent.tool_plain(prepare=defer) def deferred_tool(x: int) -> int: # pragma: no cover return x + 1 result = agent.run_sync('test early strategy with final result in middle') # Verify no tools were called assert tool_called == [] # Verify we got appropriate tool returns assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[ ToolCallPart(tool_name='regular_tool', args={'x': 1}, tool_call_id=IsStr()), ToolCallPart(tool_name='final_result', args={'value': 'final'}, tool_call_id=IsStr()), ToolCallPart(tool_name='another_tool', args={'y': 2}, tool_call_id=IsStr()), ToolCallPart(tool_name='unknown_tool', args={'value': '???'}, tool_call_id=IsStr()), ToolCallPart( tool_name='deferred_tool', args={'x': 5}, tool_call_id=IsStr(), ), ], usage=RequestUsage(input_tokens=58, output_tokens=22), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='regular_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), ), ToolReturnPart( tool_name='another_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), RetryPromptPart( content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool', 'deferred_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( tool_name='deferred_tool', content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ] ), ] ) def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(self): """Test that 'early' strategy does not apply to tool calls without final tool.""" tool_called = [] agent = Agent(TestModel(), output_type=self.OutputType, end_strategy='early') @agent.tool_plain def regular_tool(x: int) -> int: """A regular tool that should be called.""" tool_called.append('regular_tool') return x result = agent.run_sync('test early strategy with regular tool calls') assert tool_called == ['regular_tool'] tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)] assert tool_returns == snapshot([]) def test_multiple_final_result_are_validated_correctly(self): """Tests that if multiple final results are returned, but one fails validation, the other is used.""" def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None return ModelResponse( parts=[ ToolCallPart('final_result', {'bad_value': 'first'}, tool_call_id='first'), ToolCallPart('final_result', {'value': 'second'}, tool_call_id='second'), ], ) agent = Agent(FunctionModel(return_model), output_type=self.OutputType, end_strategy='early') result = agent.run_sync('test multiple final results') # Verify the result came from the second final tool assert result.output.value == 'second' # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ RetryPromptPart( content=[ {'type': 'missing', 'loc': ('value',), 'msg': 'Field required', 'input': {'bad_value': 'first'}} ], tool_name='final_result', tool_call_id='first', timestamp=IsDatetime(), ), ToolReturnPart( tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc), tool_call_id='second', ), ] ) async def test_model_settings_override() -> None: def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())]) my_agent = Agent(FunctionModel(return_settings)) assert (await my_agent.run('Hello')).output == IsJson(None) assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).output == IsJson({'temperature': 0.5}) my_agent = Agent(FunctionModel(return_settings), model_settings={'temperature': 0.1}) assert (await my_agent.run('Hello')).output == IsJson({'temperature': 0.1}) assert (await my_agent.run('Hello', model_settings={'temperature': 0.5})).output == IsJson({'temperature': 0.5}) async def test_empty_text_part(): def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None args_json = '{"response": ["foo", "bar"]}' return ModelResponse( parts=[ TextPart(''), ToolCallPart(info.output_tools[0].name, args_json), ], ) agent = Agent(FunctionModel(return_empty_text), output_type=tuple[str, str]) result = await agent.run('Hello') assert result.output == ('foo', 'bar') def test_heterogeneous_responses_non_streaming() -> None: """Indicates that tool calls are prioritized over text in heterogeneous responses.""" def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert info.output_tools is not None parts: list[ModelResponsePart] = [] if len(messages) == 1: parts = [TextPart(content='foo'), ToolCallPart('get_location', {'loc_name': 'London'})] else: parts = [TextPart(content='final response')] return ModelResponse(parts=parts) agent = Agent(FunctionModel(return_model)) @agent.tool_plain async def get_location(loc_name: str) -> str: if loc_name == 'London': return json.dumps({'lat': 51, 'lng': 0}) else: raise ModelRetry('Wrong location, please try again') # pragma: no cover result = agent.run_sync('Hello') assert result.output == 'final response' assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ TextPart(content='foo'), ToolCallPart(tool_name='get_location', args={'loc_name': 'London'}, tool_call_id=IsStr()), ], usage=RequestUsage(input_tokens=51, output_tokens=6), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_location', content='{"lat": 51, "lng": 0}', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[TextPart(content='final response')], usage=RequestUsage(input_tokens=56, output_tokens=8), model_name='function:return_model:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_nested_capture_run_messages() -> None: agent = Agent('test') with capture_run_messages() as messages1: assert messages1 == [] with capture_run_messages() as messages2: assert messages2 == [] assert messages1 is messages2 result = agent.run_sync('Hello') assert result.output == 'success (no tool calls)' assert messages1 == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert messages1 == messages2 def test_double_capture_run_messages() -> None: agent = Agent('test') with capture_run_messages() as messages: assert messages == [] result = agent.run_sync('Hello') assert result.output == 'success (no tool calls)' result2 = agent.run_sync('Hello 2') assert result2.output == 'success (no tool calls)' assert messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='success (no tool calls)')], usage=RequestUsage(input_tokens=51, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_capture_run_messages_with_user_exception_does_not_contain_internal_errors() -> None: """Test that user exceptions within capture_run_messages context have clean stack traces.""" agent = Agent('test') try: with capture_run_messages(): agent.run_sync('Hello') raise ZeroDivisionError('division by zero') except Exception as e: assert e.__context__ is None def test_dynamic_false_no_reevaluate(): """When dynamic is false (default), the system prompt is not reevaluated i.e: SystemPromptPart( content="A", <--- Remains the same when `message_history` is passed. part_kind='system-prompt') """ agent = Agent('test', system_prompt='Foobar') dynamic_value = 'A' @agent.system_prompt async def func() -> str: return dynamic_value res = agent.run_sync('Hello') assert res.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), SystemPromptPart( content=dynamic_value, part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc) ), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ] ) dynamic_value = 'B' res_two = agent.run_sync('World', message_history=res.all_messages()) assert res_two.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), SystemPromptPart( content='A', # Remains the same part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc), ), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ModelRequest( parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=54, output_tokens=8), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ] ) assert res_two.new_messages() == res_two.all_messages()[-2:] def test_dynamic_true_reevaluate_system_prompt(): """When dynamic is true, the system prompt is reevaluated i.e: SystemPromptPart( content="B", <--- Updated value part_kind='system-prompt') """ agent = Agent('test', system_prompt='Foobar') dynamic_value = 'A' @agent.system_prompt(dynamic=True) async def func(): return dynamic_value res = agent.run_sync('Hello') assert res.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), SystemPromptPart( content=dynamic_value, part_kind='system-prompt', dynamic_ref=func.__qualname__, timestamp=IsNow(tz=timezone.utc), ), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ] ) dynamic_value = 'B' res_two = agent.run_sync('World', message_history=res.all_messages()) assert res_two.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart(content='Foobar', part_kind='system-prompt', timestamp=IsNow(tz=timezone.utc)), SystemPromptPart( content='B', part_kind='system-prompt', dynamic_ref=func.__qualname__, timestamp=IsNow(tz=timezone.utc), ), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), ], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=53, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ModelRequest( parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], kind='request', ), ModelResponse( parts=[TextPart(content='success (no tool calls)', part_kind='text')], usage=RequestUsage(input_tokens=54, output_tokens=8), model_name='test', timestamp=IsNow(tz=timezone.utc), kind='response', ), ] ) assert res_two.new_messages() == res_two.all_messages()[-2:] def test_dynamic_system_prompt_no_changes(): """Test coverage for _reevaluate_dynamic_prompts branch where no parts are changed and the messages loop continues after replacement of parts. """ agent = Agent('test') @agent.system_prompt(dynamic=True) async def dynamic_func() -> str: return 'Dynamic' result1 = agent.run_sync('Hello') # Create ModelRequest with non-dynamic SystemPromptPart (no dynamic_ref) manual_request = ModelRequest(parts=[SystemPromptPart(content='Static'), UserPromptPart(content='Manual')]) # Mix dynamic and non-dynamic messages to trigger branch coverage result2 = agent.run_sync('Second call', message_history=result1.all_messages() + [manual_request]) assert result2.output == 'success (no tool calls)' def test_capture_run_messages_tool_agent() -> None: agent_outer = Agent('test') agent_inner = Agent(TestModel(custom_output_text='inner agent result')) @agent_outer.tool_plain async def foobar(x: str) -> str: result_ = await agent_inner.run(x) return result_.output with capture_run_messages() as messages: result = agent_outer.run_sync('foobar') assert result.output == snapshot('{"foobar":"inner agent result"}') assert messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ToolCallPart(tool_name='foobar', args={'x': 'a'}, tool_call_id=IsStr())], usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='foobar', content='inner agent result', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) ] ), ModelResponse( parts=[TextPart(content='{"foobar":"inner agent result"}')], usage=RequestUsage(input_tokens=54, output_tokens=11), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) class Bar(BaseModel): c: int d: str def test_custom_output_type_sync() -> None: agent = Agent('test', output_type=Foo) assert agent.run_sync('Hello').output == snapshot(Foo(a=0, b='a')) assert agent.run_sync('Hello', output_type=Bar).output == snapshot(Bar(c=0, d='a')) assert agent.run_sync('Hello', output_type=str).output == snapshot('success (no tool calls)') assert agent.run_sync('Hello', output_type=int).output == snapshot(0) async def test_custom_output_type_async() -> None: agent = Agent('test') result = await agent.run('Hello') assert result.output == snapshot('success (no tool calls)') result = await agent.run('Hello', output_type=Foo) assert result.output == snapshot(Foo(a=0, b='a')) result = await agent.run('Hello', output_type=int) assert result.output == snapshot(0) def test_custom_output_type_invalid() -> None: agent = Agent('test') @agent.output_validator def validate_output(ctx: RunContext[None], o: Any) -> Any: # pragma: no cover return o with pytest.raises(UserError, match='Cannot set a custom run `output_type` when the agent has output validators'): agent.run_sync('Hello', output_type=int) def test_binary_content_serializable(): agent = Agent('test') content = BinaryContent(data=b'Hello', media_type='text/plain') result = agent.run_sync(['Hello', content]) serialized = result.all_messages_json() assert json.loads(serialized) == snapshot( [ { 'parts': [ { 'content': [ 'Hello', { 'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, 'kind': 'binary', 'identifier': 'f7ff9e', }, ], 'timestamp': IsStr(), 'part_kind': 'user-prompt', } ], 'instructions': None, 'kind': 'request', }, { 'parts': [{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text'}], 'usage': { 'input_tokens': 56, 'cache_write_tokens': 0, 'cache_read_tokens': 0, 'output_tokens': 4, 'input_audio_tokens': 0, 'cache_audio_read_tokens': 0, 'output_audio_tokens': 0, 'details': {}, }, 'model_name': 'test', 'provider_name': None, 'provider_details': None, 'provider_response_id': None, 'timestamp': IsStr(), 'kind': 'response', 'finish_reason': None, }, ] ) # We also need to be able to round trip the serialized messages. messages = ModelMessagesTypeAdapter.validate_json(serialized) assert messages == result.all_messages() def test_image_url_serializable_missing_media_type(): agent = Agent('test') content = ImageUrl('https://example.com/chart.jpeg') result = agent.run_sync(['Hello', content]) serialized = result.all_messages_json() assert json.loads(serialized) == snapshot( [ { 'parts': [ { 'content': [ 'Hello', { 'url': 'https://example.com/chart.jpeg', 'force_download': False, 'vendor_metadata': None, 'kind': 'image-url', 'media_type': 'image/jpeg', 'identifier': 'a72e39', }, ], 'timestamp': IsStr(), 'part_kind': 'user-prompt', } ], 'instructions': None, 'kind': 'request', }, { 'parts': [{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text'}], 'usage': { 'input_tokens': 51, 'cache_write_tokens': 0, 'cache_read_tokens': 0, 'output_tokens': 4, 'input_audio_tokens': 0, 'cache_audio_read_tokens': 0, 'output_audio_tokens': 0, 'details': {}, }, 'model_name': 'test', 'timestamp': IsStr(), 'provider_name': None, 'provider_details': None, 'provider_response_id': None, 'kind': 'response', 'finish_reason': None, }, ] ) # We also need to be able to round trip the serialized messages. messages = ModelMessagesTypeAdapter.validate_json(serialized) part = messages[0].parts[0] assert isinstance(part, UserPromptPart) content = part.content[1] assert isinstance(content, ImageUrl) assert content.media_type == 'image/jpeg' assert messages == result.all_messages() def test_image_url_serializable(): agent = Agent('test') content = ImageUrl('https://example.com/chart', media_type='image/jpeg') result = agent.run_sync(['Hello', content]) serialized = result.all_messages_json() assert json.loads(serialized) == snapshot( [ { 'parts': [ { 'content': [ 'Hello', { 'url': 'https://example.com/chart', 'force_download': False, 'vendor_metadata': None, 'kind': 'image-url', 'media_type': 'image/jpeg', 'identifier': 'bdd86d', }, ], 'timestamp': IsStr(), 'part_kind': 'user-prompt', } ], 'instructions': None, 'kind': 'request', }, { 'parts': [{'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text'}], 'usage': { 'input_tokens': 51, 'cache_write_tokens': 0, 'cache_read_tokens': 0, 'output_tokens': 4, 'input_audio_tokens': 0, 'cache_audio_read_tokens': 0, 'output_audio_tokens': 0, 'details': {}, }, 'model_name': 'test', 'timestamp': IsStr(), 'provider_name': None, 'provider_details': None, 'provider_response_id': None, 'kind': 'response', 'finish_reason': None, }, ] ) # We also need to be able to round trip the serialized messages. messages = ModelMessagesTypeAdapter.validate_json(serialized) part = messages[0].parts[0] assert isinstance(part, UserPromptPart) content = part.content[1] assert isinstance(content, ImageUrl) assert content.media_type == 'image/jpeg' assert messages == result.all_messages() def test_tool_return_part_binary_content_serialization(): """Test that ToolReturnPart can properly serialize BinaryContent.""" png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' binary_content = BinaryContent(png_data, media_type='image/png') tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123') assert tool_return.model_response_object() == snapshot( { 'return_value': { 'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=', 'media_type': 'image/png', 'vendor_metadata': None, '_identifier': None, 'kind': 'binary', } } ) def test_tool_returning_binary_content_directly(): """Test that a tool returning BinaryContent directly works correctly.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('get_image', {})]) else: return ModelResponse(parts=[TextPart('Image received')]) agent = Agent(FunctionModel(llm)) @agent.tool_plain def get_image() -> BinaryContent: """Return a simple image.""" png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' return BinaryContent(png_data, media_type='image/png') # This should work without the serialization error result = agent.run_sync('Get an image') assert result.output == 'Image received' def test_tool_returning_binary_content_with_identifier(): """Test that a tool returning BinaryContent directly works correctly.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('get_image', {})]) else: return ModelResponse(parts=[TextPart('Image received')]) agent = Agent(FunctionModel(llm)) @agent.tool_plain def get_image() -> BinaryContent: """Return a simple image.""" png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' return BinaryContent(png_data, media_type='image/png', identifier='image_id_1') # This should work without the serialization error result = agent.run_sync('Get an image') assert result.all_messages()[2] == snapshot( ModelRequest( parts=[ ToolReturnPart( tool_name='get_image', content='See file image_id_1', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), UserPromptPart( content=[ 'This is file image_id_1:', BinaryContent( data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82', media_type='image/png', identifier='image_id_1', ), ], timestamp=IsNow(tz=timezone.utc), ), ] ) ) def test_tool_returning_file_url_with_identifier(): """Test that a tool returning FileUrl subclasses with identifiers works correctly.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('get_files', {})]) else: return ModelResponse(parts=[TextPart('Files received')]) agent = Agent(FunctionModel(llm)) @agent.tool_plain def get_files(): """Return various file URLs with custom identifiers.""" return [ ImageUrl(url='https://example.com/image.jpg', identifier='img_001'), VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'), AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'), DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'), ] result = agent.run_sync('Get some files') assert result.all_messages()[2] == snapshot( ModelRequest( parts=[ ToolReturnPart( tool_name='get_files', content=['See file img_001', 'See file vid_002', 'See file aud_003', 'See file doc_004'], tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), UserPromptPart( content=[ 'This is file img_001:', ImageUrl(url='https://example.com/image.jpg', identifier='img_001'), 'This is file vid_002:', VideoUrl(url='https://example.com/video.mp4', identifier='vid_002'), 'This is file aud_003:', AudioUrl(url='https://example.com/audio.mp3', identifier='aud_003'), 'This is file doc_004:', DocumentUrl(url='https://example.com/document.pdf', identifier='doc_004'), ], timestamp=IsNow(tz=timezone.utc), ), ] ) ) def test_instructions_raise_error_when_system_prompt_is_set(): agent = Agent('test', instructions='An instructions!') @agent.system_prompt def system_prompt() -> str: return 'A system prompt!' result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[ SystemPromptPart(content='A system prompt!', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], instructions='An instructions!', ) ) def test_instructions_raise_error_when_instructions_is_set(): agent = Agent('test', system_prompt='A system prompt!') @agent.instructions def instructions() -> str: return 'An instructions!' @agent.instructions def empty_instructions() -> str: return '' result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[ SystemPromptPart(content='A system prompt!', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], instructions='An instructions!', ) ) def test_instructions_both_instructions_and_system_prompt_are_set(): agent = Agent('test', instructions='An instructions!', system_prompt='A system prompt!') result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[ SystemPromptPart(content='A system prompt!', timestamp=IsNow(tz=timezone.utc)), UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), ], instructions='An instructions!', ) ) def test_instructions_decorator_without_parenthesis(): agent = Agent('test') @agent.instructions def instructions() -> str: return 'You are a helpful assistant.' result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], instructions='You are a helpful assistant.', ) ) def test_instructions_decorator_with_parenthesis(): agent = Agent('test') @agent.instructions() def instructions_2() -> str: return 'You are a helpful assistant.' result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], instructions='You are a helpful assistant.', ) ) def test_instructions_with_message_history(): agent = Agent('test', instructions='You are a helpful assistant.') result = agent.run_sync( 'Hello', message_history=[ModelRequest(parts=[SystemPromptPart(content='You are a helpful assistant')])], ) assert result.all_messages() == snapshot( [ ModelRequest( parts=[SystemPromptPart(content='You are a helpful assistant', timestamp=IsNow(tz=timezone.utc))] ), ModelRequest( parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], instructions='You are a helpful assistant.', ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], usage=RequestUsage(input_tokens=56, output_tokens=4), model_name='test', timestamp=IsNow(tz=timezone.utc), ), ] ) assert result.new_messages() == result.all_messages()[-2:] def test_instructions_parameter_with_sequence(): def instructions() -> str: return 'You are a potato.' def empty_instructions() -> str: return '' agent = Agent('test', instructions=('You are a helpful assistant.', empty_instructions, instructions)) result = agent.run_sync('Hello') assert result.all_messages()[0] == snapshot( ModelRequest( parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], instructions="""\ You are a helpful assistant. You are a potato.\ """, ) ) def test_empty_final_response(): def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[TextPart('foo'), ToolCallPart('my_tool', {'x': 1})]) elif len(messages) == 3: return ModelResponse(parts=[TextPart('bar'), ToolCallPart('my_tool', {'x': 2})]) else: return ModelResponse(parts=[]) agent = Agent(FunctionModel(llm)) @agent.tool_plain def my_tool(x: int) -> int: return x * 2 result = agent.run_sync('Hello') assert result.output == 'bar' assert result.new_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ TextPart(content='foo'), ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr()), ], usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='my_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[ TextPart(content='bar'), ToolCallPart(tool_name='my_tool', args={'x': 2}, tool_call_id=IsStr()), ], usage=RequestUsage(input_tokens=52, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='my_tool', content=4, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ) ] ), ModelResponse( parts=[], usage=RequestUsage(input_tokens=53, output_tokens=10), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_agent_run_result_serialization() -> None: agent = Agent('test', output_type=Foo) result = agent.run_sync('Hello') # Check that dump_json doesn't raise an error adapter = TypeAdapter(AgentRunResult[Foo]) serialized_data = adapter.dump_json(result) # Check that we can load the data back deserialized_result = adapter.validate_json(serialized_data) assert deserialized_result == result def test_agent_repr() -> None: agent = Agent() assert repr(agent) == snapshot( "Agent(model=None, name=None, end_strategy='early', model_settings=None, output_type=<class 'str'>, instrument=None)" ) def test_tool_call_with_validation_value_error_serializable(): def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('foo_tool', {'bar': 0})]) elif len(messages) == 3: return ModelResponse(parts=[ToolCallPart('foo_tool', {'bar': 1})]) else: return ModelResponse(parts=[TextPart('Tool returned 1')]) agent = Agent(FunctionModel(llm)) class Foo(BaseModel): bar: int @field_validator('bar') def validate_bar(cls, v: int) -> int: if v == 0: raise ValueError('bar cannot be 0') return v @agent.tool_plain def foo_tool(foo: Foo) -> int: return foo.bar result = agent.run_sync('Hello') assert json.loads(result.all_messages_json())[2] == snapshot( { 'parts': [ { 'content': [ {'type': 'value_error', 'loc': ['bar'], 'msg': 'Value error, bar cannot be 0', 'input': 0} ], 'tool_name': 'foo_tool', 'tool_call_id': IsStr(), 'timestamp': IsStr(), 'part_kind': 'retry-prompt', } ], 'instructions': None, 'kind': 'request', } ) def test_unsupported_output_mode(): class Foo(BaseModel): bar: str def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart('hello')]) # pragma: no cover model = FunctionModel(hello, profile=ModelProfile(supports_tools=False, supports_json_schema_output=False)) agent = Agent(model, output_type=NativeOutput(Foo)) with pytest.raises(UserError, match='Native structured output is not supported by this model.'): agent.run_sync('Hello') agent = Agent(model, output_type=ToolOutput(Foo)) with pytest.raises(UserError, match='Tool output is not supported by this model.'): agent.run_sync('Hello') agent = Agent(model, output_type=BinaryImage) with pytest.raises(UserError, match='Image output is not supported by this model.'): agent.run_sync('Hello') def test_multimodal_tool_response(): """Test ToolReturn with custom content and tool return.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})]) else: return ModelResponse( parts=[ TextPart('Analysis completed'), ] ) agent = Agent(FunctionModel(llm)) @agent.tool_plain def analyze_data() -> ToolReturn: return ToolReturn( return_value='Data analysis completed successfully', content=[ 'Here are the analysis results:', ImageUrl('https://example.com/chart.jpg'), 'The chart shows positive trends.', ], metadata={'foo': 'bar'}, ) result = agent.run_sync('Please analyze the data') # Verify final output assert result.output == 'Analysis completed' # Verify message history contains the expected parts # Verify the complete message structure using snapshot assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Please analyze the data', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ TextPart(content='Starting analysis'), ToolCallPart( tool_name='analyze_data', args={}, tool_call_id=IsStr(), ), ], usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='analyze_data', content='Data analysis completed successfully', tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), ), UserPromptPart( content=[ 'Here are the analysis results:', ImageUrl(url='https://example.com/chart.jpg', identifier='672a5c'), 'The chart shows positive trends.', ], timestamp=IsNow(tz=timezone.utc), ), ] ), ModelResponse( parts=[TextPart(content='Analysis completed')], usage=RequestUsage(input_tokens=70, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_plain_tool_response(): """Test ToolReturn with custom content and tool return.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})]) else: return ModelResponse( parts=[ TextPart('Analysis completed'), ] ) agent = Agent(FunctionModel(llm)) @agent.tool_plain def analyze_data() -> ToolReturn: return ToolReturn( return_value='Data analysis completed successfully', metadata={'foo': 'bar'}, ) result = agent.run_sync('Please analyze the data') # Verify final output assert result.output == 'Analysis completed' # Verify message history contains the expected parts # Verify the complete message structure using snapshot assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Please analyze the data', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ TextPart(content='Starting analysis'), ToolCallPart( tool_name='analyze_data', args={}, tool_call_id=IsStr(), ), ], usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ ToolReturnPart( tool_name='analyze_data', content='Data analysis completed successfully', tool_call_id=IsStr(), metadata={'foo': 'bar'}, timestamp=IsNow(tz=timezone.utc), ), ] ), ModelResponse( parts=[TextPart(content='Analysis completed')], usage=RequestUsage(input_tokens=58, output_tokens=6), model_name='function:llm:', timestamp=IsNow(tz=timezone.utc), ), ] ) def test_many_multimodal_tool_response(): """Test ToolReturn with custom content and tool return.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})]) else: return ModelResponse( # pragma: no cover parts=[ TextPart('Analysis completed'), ] ) agent = Agent(FunctionModel(llm)) @agent.tool_plain def analyze_data() -> list[Any]: return [ ToolReturn( return_value='Data analysis completed successfully', content=[ 'Here are the analysis results:', ImageUrl('https://example.com/chart.jpg'), 'The chart shows positive trends.', ], metadata={'foo': 'bar'}, ), 'Something else', ] with pytest.raises( UserError, match="The return value of tool 'analyze_data' contains invalid nested `ToolReturn` objects. `ToolReturn` should be used directly.", ): agent.run_sync('Please analyze the data') def test_multimodal_tool_response_nested(): """Test ToolReturn with custom content and tool return.""" def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[TextPart('Starting analysis'), ToolCallPart('analyze_data', {})]) else: return ModelResponse( # pragma: no cover parts=[ TextPart('Analysis completed'), ] ) agent = Agent(FunctionModel(llm)) @agent.tool_plain def analyze_data() -> ToolReturn: return ToolReturn( return_value=ImageUrl('https://example.com/chart.jpg'), content=[ 'Here are the analysis results:', ImageUrl('https://example.com/chart.jpg'), 'The chart shows positive trends.', ], metadata={'foo': 'bar'}, ) with pytest.raises( UserError, match="The `return_value` of tool 'analyze_data' contains invalid nested `MultiModalContent` objects. Please use `content` instead.", ): agent.run_sync('Please analyze the data') def test_deprecated_kwargs_validation_agent_init(): """Test that invalid kwargs raise UserError in Agent constructor.""" with pytest.raises(UserError, match='Unknown keyword arguments: `usage_limits`'): Agent('test', usage_limits='invalid') # type: ignore[call-arg] with pytest.raises(UserError, match='Unknown keyword arguments: `invalid_kwarg`'): Agent('test', invalid_kwarg='value') # type: ignore[call-arg] with pytest.raises(UserError, match='Unknown keyword arguments: `foo`, `bar`'): Agent('test', foo='value1', bar='value2') # type: ignore[call-arg] def test_deprecated_kwargs_still_work(): """Test that valid deprecated kwargs still work with warnings.""" import warnings try: from pydantic_ai.mcp import MCPServerStdio with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) assert '`mcp_servers` is deprecated' in str(w[0].message) except ImportError: pass def test_override_toolsets(): foo_toolset = FunctionToolset() @foo_toolset.tool def foo() -> str: return 'Hello from foo' available_tools: list[list[str]] = [] async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: nonlocal available_tools available_tools.append([tool_def.name for tool_def in tool_defs]) return tool_defs agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools) @agent.tool_plain def baz() -> str: return 'Hello from baz' result = agent.run_sync('Hello') assert available_tools[-1] == snapshot(['baz', 'foo']) assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}') bar_toolset = FunctionToolset() @bar_toolset.tool def bar() -> str: return 'Hello from bar' with agent.override(toolsets=[bar_toolset]): result = agent.run_sync('Hello') assert available_tools[-1] == snapshot(['baz', 'bar']) assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}') with agent.override(toolsets=[]): result = agent.run_sync('Hello') assert available_tools[-1] == snapshot(['baz']) assert result.output == snapshot('{"baz":"Hello from baz"}') result = agent.run_sync('Hello', toolsets=[bar_toolset]) assert available_tools[-1] == snapshot(['baz', 'foo', 'bar']) assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo","bar":"Hello from bar"}') with agent.override(toolsets=[]): result = agent.run_sync('Hello', toolsets=[bar_toolset]) assert available_tools[-1] == snapshot(['baz']) assert result.output == snapshot('{"baz":"Hello from baz"}') def test_override_tools(): def foo() -> str: return 'Hello from foo' def bar() -> str: return 'Hello from bar' available_tools: list[list[str]] = [] async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: nonlocal available_tools available_tools.append([tool_def.name for tool_def in tool_defs]) return tool_defs agent = Agent('test', tools=[foo], prepare_tools=prepare_tools) result = agent.run_sync('Hello') assert available_tools[-1] == snapshot(['foo']) assert result.output == snapshot('{"foo":"Hello from foo"}') with agent.override(tools=[bar]): result = agent.run_sync('Hello') assert available_tools[-1] == snapshot(['bar']) assert result.output == snapshot('{"bar":"Hello from bar"}') with agent.override(tools=[]): result = agent.run_sync('Hello') assert available_tools[-1] == snapshot([]) assert result.output == snapshot('success (no tool calls)') def test_toolset_factory(): toolset = FunctionToolset() @toolset.tool def foo() -> str: return 'Hello from foo' available_tools: list[str] = [] async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: nonlocal available_tools available_tools = [tool_def.name for tool_def in tool_defs] return tool_defs toolset_creation_counts: dict[str, int] = defaultdict(int) def via_toolsets_arg(ctx: RunContext[None]) -> AbstractToolset[None]: nonlocal toolset_creation_counts toolset_creation_counts['via_toolsets_arg'] += 1 return toolset.prefixed('via_toolsets_arg') def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('via_toolsets_arg_foo')]) elif len(messages) == 3: return ModelResponse(parts=[ToolCallPart('via_toolset_decorator_foo')]) else: return ModelResponse(parts=[TextPart('Done')]) agent = Agent(FunctionModel(respond), toolsets=[via_toolsets_arg], prepare_tools=prepare_tools) @agent.toolset def via_toolset_decorator(ctx: RunContext[None]) -> AbstractToolset[None]: nonlocal toolset_creation_counts toolset_creation_counts['via_toolset_decorator'] += 1 return toolset.prefixed('via_toolset_decorator') @agent.toolset(per_run_step=False) async def via_toolset_decorator_for_entire_run(ctx: RunContext[None]) -> AbstractToolset[None]: nonlocal toolset_creation_counts toolset_creation_counts['via_toolset_decorator_for_entire_run'] += 1 return toolset.prefixed('via_toolset_decorator_for_entire_run') run_result = agent.run_sync('Hello') assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage] assert len(available_tools) == 3 assert toolset_creation_counts == snapshot( defaultdict(int, {'via_toolsets_arg': 3, 'via_toolset_decorator': 3, 'via_toolset_decorator_for_entire_run': 1}) ) def test_adding_tools_during_run(): toolset = FunctionToolset() def foo() -> str: return 'Hello from foo' @toolset.tool def add_foo_tool() -> str: toolset.add_function(foo) return 'foo tool added' def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse(parts=[ToolCallPart('add_foo_tool')]) elif len(messages) == 3: return ModelResponse(parts=[ToolCallPart('foo')]) else: return ModelResponse(parts=[TextPart('Done')]) agent = Agent(FunctionModel(respond), toolsets=[toolset]) result = agent.run_sync('Add the foo tool and run it') assert result.output == snapshot('Done') assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Add the foo tool and run it', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='add_foo_tool', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=57, output_tokens=2), model_name='function:respond:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='add_foo_tool', content='foo tool added', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='foo', tool_call_id=IsStr())], usage=RequestUsage(input_tokens=60, output_tokens=4), model_name='function:respond:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='foo', content='Hello from foo', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Done')], usage=RequestUsage(input_tokens=63, output_tokens=5), model_name='function:respond:', timestamp=IsDatetime(), ), ] ) def test_prepare_output_tools(): @dataclass class AgentDeps: plan_presented: bool = False async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str: """ Present the plan to the user. """ ctx.deps.plan_presented = True return plan async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str: """ Run an SQL query. """ return 'SQL query executed successfully' async def only_if_plan_presented( ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition] ) -> list[ToolDefinition]: return tool_defs if ctx.deps.plan_presented else [] agent = Agent( model='test', deps_type=AgentDeps, tools=[present_plan], output_type=[ToolOutput(run_sql, name='run_sql')], prepare_output_tools=only_if_plan_presented, ) result = agent.run_sync('Hello', deps=AgentDeps()) assert result.output == snapshot('SQL query executed successfully') assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='present_plan', args={'plan': 'a'}, tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=51, output_tokens=5), model_name='test', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='present_plan', content='a', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='run_sql', args={'purpose': 'a', 'query': 'a'}, tool_call_id=IsStr(), ) ], usage=RequestUsage(input_tokens=52, output_tokens=12), model_name='test', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='run_sql', content='Final result processed.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ] ) async def test_explicit_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: # pragma: lax no cover pytest.skip('mcp is not installed') server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) agent = Agent('test', toolsets=[toolset]) async with agent: assert server1.is_running assert server2.is_running async with agent: assert server1.is_running assert server2.is_running async def test_implicit_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: # pragma: lax no cover pytest.skip('mcp is not installed') server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) agent = Agent('test', toolsets=[toolset]) async with agent.iter(user_prompt='Hello'): assert server1.is_running assert server2.is_running def test_parallel_mcp_calls(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: # pragma: lax no cover pytest.skip('mcp is not installed') async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse( parts=[ ToolCallPart(tool_name='get_none'), ToolCallPart(tool_name='get_multiple_items'), ] ) else: return ModelResponse(parts=[TextPart('finished')]) server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) agent = Agent(FunctionModel(call_tools_parallel), toolsets=[server]) result = agent.run_sync() assert result.output == snapshot('finished') @pytest.mark.parametrize('mode', ['argument', 'contextmanager']) def test_sequential_calls(mode: Literal['argument', 'contextmanager']): """Test that tool calls are executed correctly when a `sequential` tool is present in the call.""" async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse( parts=[ ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='call_first'), ToolCallPart(tool_name='increment_integer_holder'), ToolCallPart(tool_name='requires_approval'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ToolCallPart(tool_name='call_second'), ] ) sequential_toolset = FunctionToolset() integer_holder: int = 1 @sequential_toolset.tool def call_first(): nonlocal integer_holder assert integer_holder == 1 @sequential_toolset.tool(sequential=mode == 'argument') def increment_integer_holder(): nonlocal integer_holder integer_holder = 2 @sequential_toolset.tool def requires_approval(): from pydantic_ai.exceptions import ApprovalRequired raise ApprovalRequired() @sequential_toolset.tool def call_second(): nonlocal integer_holder assert integer_holder == 2 agent = Agent( FunctionModel(call_tools_sequential), toolsets=[sequential_toolset], output_type=[str, DeferredToolRequests] ) if mode == 'contextmanager': with agent.sequential_tool_calls(): result = agent.run_sync() else: result = agent.run_sync() assert result.output == snapshot( DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())]) ) assert integer_holder == 2 def test_set_mcp_sampling_model(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: # pragma: lax no cover pytest.skip('mcp is not installed') test_model = TestModel() server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) agent = Agent(None, toolsets=[toolset]) with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'): agent.set_mcp_sampling_model() assert server1.sampling_model is None assert server2.sampling_model is test_model agent.model = test_model agent.set_mcp_sampling_model() assert server1.sampling_model is test_model assert server2.sampling_model is test_model function_model = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Hello')])) with agent.override(model=function_model): agent.set_mcp_sampling_model() assert server1.sampling_model is function_model assert server2.sampling_model is function_model function_model2 = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Goodbye')])) agent.set_mcp_sampling_model(function_model2) assert server1.sampling_model is function_model2 assert server2.sampling_model is function_model2 def test_toolsets(): toolset = FunctionToolset() @toolset.tool def foo() -> str: return 'Hello from foo' # pragma: no cover agent = Agent('test', toolsets=[toolset]) assert toolset in agent.toolsets other_toolset = FunctionToolset() with agent.override(toolsets=[other_toolset]): assert other_toolset in agent.toolsets assert toolset not in agent.toolsets async def test_wrapper_agent(): async def event_stream_handler(ctx: RunContext[None], events: AsyncIterable[AgentStreamEvent]): pass # pragma: no cover foo_toolset = FunctionToolset() @foo_toolset.tool def foo() -> str: return 'Hello from foo' # pragma: no cover test_model = TestModel() agent = Agent(test_model, toolsets=[foo_toolset], output_type=Foo, event_stream_handler=event_stream_handler) wrapper_agent = WrapperAgent(agent) assert wrapper_agent.toolsets == agent.toolsets assert wrapper_agent.model == agent.model assert wrapper_agent.name == agent.name wrapper_agent.name = 'wrapped' assert wrapper_agent.name == 'wrapped' assert wrapper_agent.output_type == agent.output_type assert wrapper_agent.event_stream_handler == agent.event_stream_handler bar_toolset = FunctionToolset() @bar_toolset.tool def bar() -> str: return 'Hello from bar' with wrapper_agent.override(toolsets=[bar_toolset]): async with wrapper_agent: async with wrapper_agent.iter(user_prompt='Hello') as run: async for _ in run: pass assert run.result is not None assert run.result.output == snapshot(Foo(a=0, b='a')) assert test_model.last_model_request_parameters is not None assert [t.name for t in test_model.last_model_request_parameters.function_tools] == snapshot(['bar']) async def test_thinking_only_response_retry(): """Test that thinking-only responses trigger a retry mechanism.""" from pydantic_ai import ThinkingPart from pydantic_ai.models.function import FunctionModel call_count = 0 def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: nonlocal call_count call_count += 1 if call_count == 1: # First call: return thinking-only response return ModelResponse( parts=[ThinkingPart(content='Let me think about this...')], model_name='thinking-test-model', ) else: # Second call: return proper response return ModelResponse( parts=[TextPart(content='Final answer')], model_name='thinking-test-model', ) model = FunctionModel(model_function) agent = Agent(model, system_prompt='You are a helpful assistant.') result = await agent.run('Hello') assert result.all_messages() == snapshot( [ ModelRequest( parts=[ SystemPromptPart( content='You are a helpful assistant.', timestamp=IsDatetime(), ), UserPromptPart( content='Hello', timestamp=IsDatetime(), ), ] ), ModelResponse( parts=[ThinkingPart(content='Let me think about this...')], usage=RequestUsage(input_tokens=57, output_tokens=6), model_name='function:model_function:', timestamp=IsDatetime(), ), ModelRequest( parts=[ RetryPromptPart( content='Please return text or call a tool.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Final answer')], usage=RequestUsage(input_tokens=73, output_tokens=8), model_name='function:model_function:', timestamp=IsDatetime(), ), ] ) async def test_hitl_tool_approval(): def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if len(messages) == 1: return ModelResponse( parts=[ ToolCallPart( tool_name='create_file', args={'path': 'new_file.py', 'content': 'print("Hello, world!")'}, tool_call_id='create_file', ), ToolCallPart( tool_name='delete_file', args={'path': 'ok_to_delete.py'}, tool_call_id='ok_to_delete' ), ToolCallPart( tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ] ) else: return ModelResponse(parts=[TextPart('Done!')]) model = FunctionModel(model_function) agent = Agent(model, output_type=[str, DeferredToolRequests]) @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> str: return f'File {path!r} deleted' @agent.tool_plain def create_file(path: str, content: str) -> str: return f'File {path!r} created with content: {content}' result = await agent.run('Create new_file.py and delete ok_to_delete.py and never_delete.py') messages = result.all_messages() assert messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Create new_file.py and delete ok_to_delete.py and never_delete.py', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='create_file', args={'path': 'new_file.py', 'content': 'print("Hello, world!")'}, tool_call_id='create_file', ), ToolCallPart( tool_name='delete_file', args={'path': 'ok_to_delete.py'}, tool_call_id='ok_to_delete' ), ToolCallPart( tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], usage=RequestUsage(input_tokens=60, output_tokens=23), model_name='function:model_function:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='create_file', content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), ) ] ), ] ) assert result.output == snapshot( DeferredToolRequests( approvals=[ ToolCallPart(tool_name='delete_file', args={'path': 'ok_to_delete.py'}, tool_call_id='ok_to_delete'), ToolCallPart(tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete'), ] ) ) result = await agent.run( message_history=messages, deferred_tool_results=DeferredToolResults( approvals={'ok_to_delete': True, 'never_delete': ToolDenied('File cannot be deleted')}, ), ) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Create new_file.py and delete ok_to_delete.py and never_delete.py', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='create_file', args={'path': 'new_file.py', 'content': 'print("Hello, world!")'}, tool_call_id='create_file', ), ToolCallPart( tool_name='delete_file', args={'path': 'ok_to_delete.py'}, tool_call_id='ok_to_delete' ), ToolCallPart( tool_name='delete_file', args={'path': 'never_delete.py'}, tool_call_id='never_delete' ), ], usage=RequestUsage(input_tokens=60, output_tokens=23), model_name='function:model_function:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='create_file', content='File \'new_file.py\' created with content: print("Hello, world!")', tool_call_id='create_file', timestamp=IsDatetime(), ) ] ), ModelRequest( parts=[ ToolReturnPart( tool_name='delete_file', content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), ), ToolReturnPart( tool_name='delete_file', content='File cannot be deleted', tool_call_id='never_delete', timestamp=IsDatetime(), ), ] ), ModelResponse( parts=[TextPart(content='Done!')], usage=RequestUsage(input_tokens=78, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), ), ] ) assert result.output == snapshot('Done!') assert result.new_messages() == snapshot( [ ModelRequest( parts=[ ToolReturnPart( tool_name='delete_file', content="File 'ok_to_delete.py' deleted", tool_call_id='ok_to_delete', timestamp=IsDatetime(), ), ToolReturnPart( tool_name='delete_file', content='File cannot be deleted', tool_call_id='never_delete', timestamp=IsDatetime(), ), ] ), ModelResponse( parts=[TextPart(content='Done!')], usage=RequestUsage(input_tokens=78, output_tokens=24), model_name='function:model_function:', timestamp=IsDatetime(), ), ] ) async def test_run_with_deferred_tool_results_errors(): agent = Agent('test') message_history: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=['Hello', 'world'])])] with pytest.raises( UserError, match='Tool call results were provided, but the message history does not contain a `ModelResponse`.', ): await agent.run( 'Hello again', message_history=message_history, deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), ) message_history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse(parts=[TextPart(content='Hello to you too!')]), ] with pytest.raises( UserError, match='Tool call results were provided, but the message history does not contain any unprocessed tool calls.', ): await agent.run( 'Hello again', message_history=message_history, deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), ) message_history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse(parts=[ToolCallPart(tool_name='say_hello')]), ] with pytest.raises( UserError, match='Cannot provide a new user prompt when the message history contains unprocessed tool calls.' ): await agent.run('Hello', message_history=message_history) with pytest.raises(UserError, match='Tool call results need to be provided for all deferred tool calls.'): await agent.run( message_history=message_history, deferred_tool_results=DeferredToolResults(), ) with pytest.raises(UserError, match='Tool call results were provided, but the message history is empty.'): await agent.run( 'Hello again', deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), ) with pytest.raises( UserError, match='Cannot provide a new user prompt when the message history contains unprocessed tool calls.' ): await agent.run( 'Hello again', message_history=message_history, deferred_tool_results=DeferredToolResults(approvals={'create_file': True}), ) message_history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Hello')]), ModelResponse( parts=[ ToolCallPart(tool_name='run_me', tool_call_id='run_me'), ToolCallPart(tool_name='run_me_too', tool_call_id='run_me_too'), ToolCallPart(tool_name='defer_me', tool_call_id='defer_me'), ] ), ModelRequest( parts=[ ToolReturnPart(tool_name='run_me', tool_call_id='run_me', content='Success'), RetryPromptPart(tool_name='run_me_too', tool_call_id='run_me_too', content='Failure'), ] ), ] with pytest.raises(UserError, match="Tool call 'run_me' was already executed and its result cannot be overridden."): await agent.run( message_history=message_history, deferred_tool_results=DeferredToolResults( calls={'run_me': 'Failure', 'defer_me': 'Failure'}, ), ) with pytest.raises( UserError, match="Tool call 'run_me_too' was already executed and its result cannot be overridden." ): await agent.run( message_history=message_history, deferred_tool_results=DeferredToolResults( calls={'run_me_too': 'Success', 'defer_me': 'Failure'}, ), ) def test_tool_requires_approval_error(): agent = Agent('test') with pytest.raises( UserError, match='To use tools that require approval, add `DeferredToolRequests` to the list of output types for this agent.', ): @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> None: pass async def test_consecutive_model_responses_in_history(): received_messages: list[ModelMessage] | None = None def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: nonlocal received_messages received_messages = messages return ModelResponse( parts=[ TextPart('All right then, goodbye!'), ] ) history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Hello...')]), ModelResponse(parts=[TextPart(content='...world!')]), ModelResponse(parts=[TextPart(content='Anything else I can help with?')]), ] m = FunctionModel(llm) agent = Agent(m) result = await agent.run('No thanks', message_history=history) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello...', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='...world!'), TextPart(content='Anything else I can help with?')], timestamp=IsDatetime(), ), ModelRequest( parts=[ UserPromptPart( content='No thanks', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='All right then, goodbye!')], usage=RequestUsage(input_tokens=54, output_tokens=12), model_name='function:llm:', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='No thanks', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='All right then, goodbye!')], usage=RequestUsage(input_tokens=54, output_tokens=12), model_name='function:llm:', timestamp=IsDatetime(), ), ] ) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Hello...', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='...world!'), TextPart(content='Anything else I can help with?')], timestamp=IsDatetime(), ), ModelRequest( parts=[ UserPromptPart( content='No thanks', timestamp=IsDatetime(), ) ] ), ] ) def test_override_instructions_basic(): """Test that override can override instructions.""" agent = Agent('test') @agent.instructions def instr_fn() -> str: return 'SHOULD_BE_IGNORED' with capture_run_messages() as base_messages: agent.run_sync('Hello', model=TestModel(custom_output_text='baseline')) base_req = base_messages[0] assert isinstance(base_req, ModelRequest) assert base_req.instructions == 'SHOULD_BE_IGNORED' with agent.override(instructions='OVERRIDE'): with capture_run_messages() as messages: agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'OVERRIDE' def test_override_reset_after_context(): """Test that instructions are reset after exiting the override context.""" agent = Agent('test', instructions='ORIG') with agent.override(instructions='NEW'): with capture_run_messages() as messages_new: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) with capture_run_messages() as messages_orig: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) req_new = messages_new[0] assert isinstance(req_new, ModelRequest) req_orig = messages_orig[0] assert isinstance(req_orig, ModelRequest) assert req_new.instructions == 'NEW' assert req_orig.instructions == 'ORIG' def test_override_none_clears_instructions(): """Test that passing None for instructions clears all instructions.""" agent = Agent('test', instructions='BASE') @agent.instructions def instr_fn() -> str: # pragma: no cover - ignored under override return 'ALSO_BASE' with agent.override(instructions=None): with capture_run_messages() as messages: agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions is None def test_override_instructions_callable_replaces_functions(): """Override with a callable should replace existing instruction functions.""" agent = Agent('test') @agent.instructions def base_fn() -> str: return 'BASE_FN' def override_fn() -> str: return 'OVERRIDE_FN' with capture_run_messages() as base_messages: agent.run_sync('Hello', model=TestModel(custom_output_text='baseline')) base_req = base_messages[0] assert isinstance(base_req, ModelRequest) assert base_req.instructions is not None assert 'BASE_FN' in base_req.instructions with agent.override(instructions=override_fn): with capture_run_messages() as messages: agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'OVERRIDE_FN' assert 'BASE_FN' not in req.instructions async def test_override_instructions_async_callable(): """Override with an async callable should be awaited.""" agent = Agent('test') async def override_fn() -> str: await asyncio.sleep(0) return 'ASYNC_FN' with agent.override(instructions=override_fn): with capture_run_messages() as messages: await agent.run('Hi', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'ASYNC_FN' def test_override_instructions_sequence_mixed_types(): """Override can mix literal strings and functions.""" agent = Agent('test', instructions='BASE') def override_fn() -> str: return 'FUNC_PART' def override_fn_2() -> str: return 'FUNC_PART_2' with agent.override(instructions=['OVERRIDE1', override_fn, 'OVERRIDE2', override_fn_2]): with capture_run_messages() as messages: agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'OVERRIDE1\nOVERRIDE2\n\nFUNC_PART\n\nFUNC_PART_2' assert 'BASE' not in req.instructions async def test_override_concurrent_isolation(): """Test that concurrent overrides are isolated from each other.""" agent = Agent('test', instructions='ORIG') async def run_with(instr: str) -> str | None: with agent.override(instructions=instr): with capture_run_messages() as messages: await agent.run('Hi', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) return req.instructions a, b = await asyncio.gather( run_with('A'), run_with('B'), ) assert a == 'A' assert b == 'B' def test_override_replaces_instructions(): """Test overriding instructions replaces the base instructions.""" agent = Agent('test', instructions='ORIG_INSTR') with agent.override(instructions='NEW_INSTR'): with capture_run_messages() as messages: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'NEW_INSTR' def test_override_nested_contexts(): """Test nested override contexts.""" agent = Agent('test', instructions='ORIG') with agent.override(instructions='OUTER'): with capture_run_messages() as outer_messages: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) with agent.override(instructions='INNER'): with capture_run_messages() as inner_messages: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) outer_req = outer_messages[0] assert isinstance(outer_req, ModelRequest) inner_req = inner_messages[0] assert isinstance(inner_req, ModelRequest) assert outer_req.instructions == 'OUTER' assert inner_req.instructions == 'INNER' async def test_override_async_run(): """Test override with async run method.""" agent = Agent('test', instructions='ORIG') with agent.override(instructions='ASYNC_OVERRIDE'): with capture_run_messages() as messages: await agent.run('Hi', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'ASYNC_OVERRIDE' def test_override_with_dynamic_prompts(): """Test override interacting with dynamic prompts.""" agent = Agent('test') dynamic_value = 'DYNAMIC' @agent.system_prompt def dynamic_sys() -> str: return dynamic_value @agent.instructions def dynamic_instr() -> str: return 'DYNAMIC_INSTR' with capture_run_messages() as base_messages: agent.run_sync('Hi', model=TestModel(custom_output_text='baseline')) base_req = base_messages[0] assert isinstance(base_req, ModelRequest) assert base_req.instructions == 'DYNAMIC_INSTR' # Override should take precedence over dynamic instructions but leave system prompts intact with agent.override(instructions='OVERRIDE_INSTR'): with capture_run_messages() as messages: agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) req = messages[0] assert isinstance(req, ModelRequest) assert req.instructions == 'OVERRIDE_INSTR' sys_texts = [p.content for p in req.parts if isinstance(p, SystemPromptPart)] # The dynamic system prompt should still be present since overrides target instructions only assert dynamic_value in sys_texts def test_continue_conversation_that_ended_in_output_tool_call(allow_model_requests: None): def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: if any(isinstance(p, ToolReturnPart) and p.tool_name == 'roll_dice' for p in messages[-1].parts): return ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args={'dice_roll': 4}, tool_call_id='pyd_ai_tool_call_id__final_result', ) ] ) return ModelResponse( parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')] ) class Result(BaseModel): dice_roll: int agent = Agent(FunctionModel(llm), output_type=Result) @agent.tool_plain def roll_dice() -> int: return 4 result = agent.run_sync('Roll me a dice.') messages = result.all_messages() assert messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Roll me a dice.', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')], usage=RequestUsage(input_tokens=55, output_tokens=2), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='roll_dice', content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args={'dice_roll': 4}, tool_call_id='pyd_ai_tool_call_id__final_result', ) ], usage=RequestUsage(input_tokens=56, output_tokens=6), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), ) ] ), ] ) result = agent.run_sync('Roll me a dice again.', message_history=messages) new_messages = result.new_messages() assert new_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Roll me a dice again.', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')], usage=RequestUsage(input_tokens=66, output_tokens=8), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='roll_dice', content=4, tool_call_id='pyd_ai_tool_call_id__roll_dice', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ ToolCallPart( tool_name='final_result', args={'dice_roll': 4}, tool_call_id='pyd_ai_tool_call_id__final_result', ) ], usage=RequestUsage(input_tokens=67, output_tokens=12), model_name='function:llm:', timestamp=IsDatetime(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', tool_call_id='pyd_ai_tool_call_id__final_result', timestamp=IsDatetime(), ) ] ), ] ) assert not any(isinstance(p, ToolReturnPart) and p.tool_name == 'final_result' for p in new_messages[0].parts) def test_agent_builtin_tools_runtime_parameter(): """Test that Agent.run_sync accepts builtin_tools parameter.""" model = TestModel() agent = Agent(model=model, builtin_tools=[]) # Should work with empty builtin_tools result = agent.run_sync('Hello', builtin_tools=[]) assert result.output == 'success (no tool calls)' assert model.last_model_request_parameters is not None assert model.last_model_request_parameters.builtin_tools == [] async def test_agent_builtin_tools_runtime_parameter_async(): """Test that Agent.run and Agent.run_stream accept builtin_tools parameter.""" model = TestModel() agent = Agent(model=model, builtin_tools=[]) # Test async run result = await agent.run('Hello', builtin_tools=[]) assert result.output == 'success (no tool calls)' assert model.last_model_request_parameters is not None assert model.last_model_request_parameters.builtin_tools == [] # Test run_stream async with agent.run_stream('Hello', builtin_tools=[]) as stream: output = await stream.get_output() assert output == 'success (no tool calls)' assert model.last_model_request_parameters is not None assert model.last_model_request_parameters.builtin_tools == [] def test_agent_builtin_tools_testmodel_rejection(): """Test that TestModel rejects builtin tools as expected.""" model = TestModel() agent = Agent(model=model, builtin_tools=[]) # Should raise error when builtin_tools contains actual tools web_search_tool = WebSearchTool() with pytest.raises(Exception, match='TestModel does not support built-in tools'): agent.run_sync('Hello', builtin_tools=[web_search_tool]) assert model.last_model_request_parameters is not None assert len(model.last_model_request_parameters.builtin_tools) == 1 assert model.last_model_request_parameters.builtin_tools[0] == web_search_tool def test_agent_builtin_tools_runtime_vs_agent_level(): """Test that runtime builtin_tools parameter is merged with agent-level builtin_tools.""" model = TestModel() web_search_tool = WebSearchTool() # Agent has builtin tools, and we provide same type at runtime agent = Agent(model=model, builtin_tools=[web_search_tool]) # Runtime tool of same type should override agent-level tool different_web_search = WebSearchTool(search_context_size='high') with pytest.raises(Exception, match='TestModel does not support built-in tools'): agent.run_sync('Hello', builtin_tools=[different_web_search]) assert model.last_model_request_parameters is not None assert len(model.last_model_request_parameters.builtin_tools) == 1 runtime_tool = model.last_model_request_parameters.builtin_tools[0] assert isinstance(runtime_tool, WebSearchTool) assert runtime_tool.search_context_size == 'high' def test_agent_builtin_tools_runtime_additional(): """Test that runtime builtin_tools can add to agent-level builtin_tools when different types.""" model = TestModel() web_search_tool = WebSearchTool() agent = Agent(model=model, builtin_tools=[]) with pytest.raises(Exception, match='TestModel does not support built-in tools'): agent.run_sync('Hello', builtin_tools=[web_search_tool]) assert model.last_model_request_parameters is not None assert len(model.last_model_request_parameters.builtin_tools) == 1 assert model.last_model_request_parameters.builtin_tools[0] == web_search_tool async def test_agent_builtin_tools_run_stream_events(): """Test that Agent.run_stream_events accepts builtin_tools parameter.""" model = TestModel() agent = Agent(model=model, builtin_tools=[]) # Test with empty builtin_tools events = [event async for event in agent.run_stream_events('Hello', builtin_tools=[])] assert len(events) >= 2 assert isinstance(events[-1], AgentRunResultEvent) assert events[-1].result.output == 'success (no tool calls)' assert model.last_model_request_parameters is not None assert model.last_model_request_parameters.builtin_tools == [] # Test with builtin tool web_search_tool = WebSearchTool() with pytest.raises(Exception, match='TestModel does not support built-in tools'): events = [event async for event in agent.run_stream_events('Hello', builtin_tools=[web_search_tool])] assert model.last_model_request_parameters is not None assert len(model.last_model_request_parameters.builtin_tools) == 1 assert model.last_model_request_parameters.builtin_tools[0] == web_search_tool

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server