Skip to main content
Glama

mcp-run-python

Official
by pydantic
_agent_graph.py56 kB
from __future__ import annotations as _annotations import asyncio import dataclasses import inspect from asyncio import Task from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy from dataclasses import field, replace from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast from opentelemetry.trace import Tracer from typing_extensions import TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings from .tools import ( DeferredToolCallResult, DeferredToolResult, DeferredToolResults, RunContext, ToolApproved, ToolDefinition, ToolDenied, ToolKind, ) if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings __all__ = ( 'GraphAgentState', 'GraphAgentDeps', 'UserPromptNode', 'ModelRequestNode', 'CallToolsNode', 'build_run_context', 'capture_run_messages', 'HistoryProcessor', ) T = TypeVar('T') S = TypeVar('S') NoneType = type(None) EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. - `'early'`: Stop processing other tool calls once a final result is found - `'exhaustive'`: Process all tool calls even after finding a final result """ DepsT = TypeVar('DepsT') OutputT = TypeVar('OutputT') _HistoryProcessorSync = Callable[[list[_messages.ModelMessage]], list[_messages.ModelMessage]] _HistoryProcessorAsync = Callable[[list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]] _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.ModelMessage]], list[_messages.ModelMessage]] _HistoryProcessorAsyncWithCtx = Callable[ [RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]] ] HistoryProcessor = ( _HistoryProcessorSync | _HistoryProcessorAsync | _HistoryProcessorSyncWithCtx[DepsT] | _HistoryProcessorAsyncWithCtx[DepsT] ) """A function that processes a list of model messages and returns a list of model messages. Can optionally accept a `RunContext` as a parameter. """ @dataclasses.dataclass(kw_only=True) class GraphAgentState: """State kept across the execution of the agent graph.""" message_history: list[_messages.ModelMessage] = dataclasses.field(default_factory=list) usage: _usage.RunUsage = dataclasses.field(default_factory=_usage.RunUsage) retries: int = 0 run_step: int = 0 def increment_retries( self, max_result_retries: int, error: BaseException | None = None, model_settings: ModelSettings | None = None, ) -> None: self.retries += 1 if self.retries > max_result_retries: if ( self.message_history and isinstance(model_response := self.message_history[-1], _messages.ModelResponse) and model_response.finish_reason == 'length' and model_response.parts and isinstance(tool_call := model_response.parts[-1], _messages.ToolCallPart) ): try: tool_call.args_as_dict() except Exception: max_tokens = (model_settings or {}).get('max_tokens') if model_settings else None raise exceptions.IncompleteToolCall( f'Model token limit ({max_tokens if max_tokens is not None else "provider default"}) exceeded while emitting a tool call, resulting in incomplete arguments. Increase max tokens or simplify tool call arguments to fit within limit.' ) message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: error = error.__cause__ raise exceptions.UnexpectedModelBehavior(message) from error else: raise exceptions.UnexpectedModelBehavior(message) @dataclasses.dataclass(kw_only=True) class GraphAgentDeps(Generic[DepsT, OutputDataT]): """Dependencies/config passed to the agent graph.""" user_deps: DepsT prompt: str | Sequence[_messages.UserContent] | None new_message_index: int model: models.Model model_settings: ModelSettings | None usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] history_processors: Sequence[HistoryProcessor[DepsT]] builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False) tool_manager: ToolManager[DepsT] tracer: Tracer instrumentation_settings: InstrumentationSettings | None class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """The base class for all agent nodes. Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere """ def is_agent_node( node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]], ) -> TypeGuard[AgentNode[T, S]]: """Check if the provided node is an instance of `AgentNode`. Usage: if is_agent_node(node): # `node` is an AgentNode ... This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`. """ return isinstance(node, AgentNode) @dataclasses.dataclass class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): """The node that handles the user prompt and instructions.""" user_prompt: str | Sequence[_messages.UserContent] | None _: dataclasses.KW_ONLY deferred_tool_results: DeferredToolResults | None = None instructions: str | None = None instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list) system_prompts: tuple[str, ...] = dataclasses.field(default_factory=tuple) system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list) system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field( default_factory=dict ) async def run( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]: try: ctx_messages = get_captured_run_messages() except LookupError: messages: list[_messages.ModelMessage] = [] else: if ctx_messages.used: messages = [] else: messages = ctx_messages.messages ctx_messages.used = True # Replace the `capture_run_messages` list with the message history messages[:] = _clean_message_history(ctx.state.message_history) # Use the `capture_run_messages` list as the message history so that new messages are added to it ctx.state.message_history = messages ctx.deps.new_message_index = len(messages) if self.deferred_tool_results is not None: return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx) next_message: _messages.ModelRequest | None = None if messages and (last_message := messages[-1]): if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None: # Drop last message from history and reuse its parts messages.pop() next_message = _messages.ModelRequest(parts=last_message.parts) # Extract `UserPromptPart` content from the popped message and add to `ctx.deps.prompt` user_prompt_parts = [part for part in last_message.parts if isinstance(part, _messages.UserPromptPart)] if user_prompt_parts: if len(user_prompt_parts) == 1: ctx.deps.prompt = user_prompt_parts[0].content else: combined_content: list[_messages.UserContent] = [] for part in user_prompt_parts: if isinstance(part.content, str): combined_content.append(part.content) else: combined_content.extend(part.content) ctx.deps.prompt = combined_content elif isinstance(last_message, _messages.ModelResponse): if self.user_prompt is None: # Skip ModelRequestNode and go directly to CallToolsNode return CallToolsNode[DepsT, NodeRunEndT](last_message) elif last_message.tool_calls: raise exceptions.UserError( 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' ) # Build the run context after `ctx.deps.prompt` has been updated run_context = build_run_context(ctx) if messages: await self._reevaluate_dynamic_prompts(messages, run_context) if next_message: await self._reevaluate_dynamic_prompts([next_message], run_context) else: parts: list[_messages.ModelRequestPart] = [] if not messages: parts.extend(await self._sys_parts(run_context)) if self.user_prompt is not None: parts.append(_messages.UserPromptPart(self.user_prompt)) next_message = _messages.ModelRequest(parts=parts) next_message.instructions = await ctx.deps.get_instructions(run_context) return ModelRequestNode[DepsT, NodeRunEndT](request=next_message) async def _handle_deferred_tool_results( # noqa: C901 self, deferred_tool_results: DeferredToolResults, messages: list[_messages.ModelMessage], ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> CallToolsNode[DepsT, NodeRunEndT]: if not messages: raise exceptions.UserError('Tool call results were provided, but the message history is empty.') last_model_request: _messages.ModelRequest | None = None last_model_response: _messages.ModelResponse | None = None for message in reversed(messages): if isinstance(message, _messages.ModelRequest): last_model_request = message elif isinstance(message, _messages.ModelResponse): # pragma: no branch last_model_response = message break if not last_model_response: raise exceptions.UserError( 'Tool call results were provided, but the message history does not contain a `ModelResponse`.' ) if not last_model_response.tool_calls: raise exceptions.UserError( 'Tool call results were provided, but the message history does not contain any unprocessed tool calls.' ) if self.user_prompt is not None: raise exceptions.UserError( 'Cannot provide a new user prompt when the message history contains unprocessed tool calls.' ) tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None tool_call_results = {} for tool_call_id, approval in deferred_tool_results.approvals.items(): if approval is True: approval = ToolApproved() elif approval is False: approval = ToolDenied() tool_call_results[tool_call_id] = approval if calls := deferred_tool_results.calls: call_result_types = get_union_args(DeferredToolCallResult) for tool_call_id, result in calls.items(): if not isinstance(result, call_result_types): result = _messages.ToolReturn(result) tool_call_results[tool_call_id] = result if last_model_request: for part in last_model_request.parts: if isinstance(part, _messages.ToolReturnPart | _messages.RetryPromptPart): if part.tool_call_id in tool_call_results: raise exceptions.UserError( f'Tool call {part.tool_call_id!r} was already executed and its result cannot be overridden.' ) tool_call_results[part.tool_call_id] = 'skip' # Skip ModelRequestNode and go directly to CallToolsNode return CallToolsNode[DepsT, NodeRunEndT](last_model_response, tool_call_results=tool_call_results) async def _reevaluate_dynamic_prompts( self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] ) -> None: """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function.""" # Only proceed if there's at least one dynamic runner. if self.system_prompt_dynamic_functions: for msg in messages: if isinstance(msg, _messages.ModelRequest): reevaluated_message_parts: list[_messages.ModelRequestPart] = [] for part in msg.parts: if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref: # Look up the runner by its ref if runner := self.system_prompt_dynamic_functions.get( # pragma: lax no cover part.dynamic_ref ): updated_part_content = await runner.run(run_context) part = _messages.SystemPromptPart(updated_part_content, dynamic_ref=part.dynamic_ref) reevaluated_message_parts.append(part) # Replace message parts with reevaluated ones to prevent mutating parts list if reevaluated_message_parts != msg.parts: msg.parts = reevaluated_message_parts async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]: """Build the initial messages for the conversation.""" messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts] for sys_prompt_runner in self.system_prompt_functions: prompt = await sys_prompt_runner.run(run_context) if sys_prompt_runner.dynamic: messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)) else: messages.append(_messages.SystemPromptPart(prompt)) return messages __repr__ = dataclasses_no_defaults_repr async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" output_schema = ctx.deps.output_schema output_object = None if isinstance(output_schema, _output.NativeOutputSchema): output_object = output_schema.object_def function_tools: list[ToolDefinition] = [] output_tools: list[ToolDefinition] = [] for tool_def in ctx.deps.tool_manager.tool_defs: if tool_def.kind == 'output': output_tools.append(tool_def) else: function_tools.append(tool_def) return models.ModelRequestParameters( function_tools=function_tools, builtin_tools=ctx.deps.builtin_tools, output_mode=output_schema.mode, output_tools=output_tools, output_object=output_object, allow_text_output=output_schema.allows_text, allow_image_output=output_schema.allows_image, ) @dataclasses.dataclass class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): """The node that makes a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None) _did_stream: bool = field(repr=False, init=False, default=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> CallToolsNode[DepsT, NodeRunEndT]: if self._result is not None: return self._result if self._did_stream: # `self._result` gets set when exiting the `stream` contextmanager, so hitting this # means that the stream was started but not finished before `run()` was called raise exceptions.AgentRunError('You must finish streaming before calling run()') # pragma: no cover return await self._make_request(ctx) @asynccontextmanager async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[result.AgentStream[DepsT, T]]: assert not self._did_stream, 'stream() should only be called once per node' model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx) async with ctx.deps.model.request_stream( message_history, model_settings, model_request_parameters, run_context ) as streamed_response: self._did_stream = True ctx.state.usage.requests += 1 agent_stream = result.AgentStream[DepsT, T]( _raw_stream_response=streamed_response, _output_schema=ctx.deps.output_schema, _model_request_parameters=model_request_parameters, _output_validators=ctx.deps.output_validators, _run_ctx=build_run_context(ctx), _usage_limits=ctx.deps.usage_limits, _tool_manager=ctx.deps.tool_manager, ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, # otherwise usage won't be properly counted: async for _ in agent_stream: pass model_response = streamed_response.get() self._finish_handling(ctx, model_response) assert self._result is not None # this should be set by the previous line async def _make_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> CallToolsNode[DepsT, NodeRunEndT]: if self._result is not None: return self._result # pragma: no cover model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.requests += 1 return self._finish_handling(ctx, model_response) async def _prepare_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> tuple[ModelSettings | None, models.ModelRequestParameters, list[_messages.ModelMessage], RunContext[DepsT]]: ctx.state.message_history.append(self.request) ctx.state.run_step += 1 run_context = build_run_context(ctx) # This will raise errors for any tool name conflicts ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) original_history = ctx.state.message_history[:] message_history = await _process_message_history(original_history, ctx.deps.history_processors, run_context) # `ctx.state.message_history` is the same list used by `capture_run_messages`, so we should replace its contents, not the reference ctx.state.message_history[:] = message_history # Update the new message index to ensure `result.new_messages()` returns the correct messages ctx.deps.new_message_index -= len(original_history) - len(message_history) # Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts, # but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries. # See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary message_history = _clean_message_history(message_history) model_request_parameters = await _prepare_request_parameters(ctx) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) model_settings = ctx.deps.model_settings usage = ctx.state.usage if ctx.deps.usage_limits.count_tokens_before_request: # Copy to avoid modifying the original usage object with the counted usage usage = deepcopy(usage) counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters) usage.incr(counted_usage) ctx.deps.usage_limits.check_before_request(usage) return model_settings, model_request_parameters, message_history, run_context def _finish_handling( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], response: _messages.ModelResponse, ) -> CallToolsNode[DepsT, NodeRunEndT]: # Update usage ctx.state.usage.incr(response.usage) if ctx.deps.usage_limits: # pragma: no branch ctx.deps.usage_limits.check_tokens(ctx.state.usage) # Append the model response to state.message_history ctx.state.message_history.append(response) # Set the `_result` attribute since we can't use `return` in an async iterator self._result = CallToolsNode(response) return self._result __repr__ = dataclasses_no_defaults_repr @dataclasses.dataclass class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): """The node that processes a model response, and decides whether to end the run or make a new request.""" model_response: _messages.ModelResponse tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None = None _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False) _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, init=False, repr=False ) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: async with self.stream(ctx): pass assert self._next_node is not None, 'the stream should set `self._next_node` before it ends' return self._next_node @asynccontextmanager async def stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: """Process the model response and yield events for the start and end of each function tool call.""" stream = self._run_stream(ctx) yield stream # Run the stream to completion if it was not finished: async for _event in stream: pass async def _run_stream( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> AsyncIterator[_messages.HandleResponseEvent]: if self._events_iterator is None: # Ensure that the stream is only run once output_schema = ctx.deps.output_schema async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901 if not self.model_response.parts: # we got an empty response. # this sometimes happens with anthropic (and perhaps other models) # when the model has already returned text along side tool calls if text_processor := output_schema.text_processor: # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): text = '' for part in message.parts: if isinstance(part, _messages.TextPart): text += part.content elif isinstance(part, _messages.BuiltinToolCallPart): # Text parts before a built-in tool call are essentially thoughts, # not part of the final result output, so we reset the accumulated text text = '' # pragma: no cover if text: self._next_node = await self._handle_text_response(ctx, text, text_processor) return # Go back to the model request node with an empty request, which means we'll essentially # resubmit the most recent request that resulted in an empty response, # as the empty response and request will not create any items in the API payload, # in the hope the model will return a non-empty response this time. ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings) self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[])) return text = '' tool_calls: list[_messages.ToolCallPart] = [] files: list[_messages.BinaryContent] = [] for part in self.model_response.parts: if isinstance(part, _messages.TextPart): text += part.content elif isinstance(part, _messages.ToolCallPart): tool_calls.append(part) elif isinstance(part, _messages.FilePart): files.append(part.content) elif isinstance(part, _messages.BuiltinToolCallPart): # Text parts before a built-in tool call are essentially thoughts, # not part of the final result output, so we reset the accumulated text text = '' yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated] elif isinstance(part, _messages.BuiltinToolReturnPart): yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated] elif isinstance(part, _messages.ThinkingPart): pass else: assert_never(part) # At the moment, we prioritize at least executing tool calls if they are present. # In the future, we'd consider making this configurable at the agent or run level. # This accounts for cases like anthropic returns that might contain a text response # and a tool call response, where the text response just indicates the tool call will happen. try: alternatives: list[str] = [] if tool_calls: async for event in self._handle_tool_calls(ctx, tool_calls): yield event return elif output_schema.toolset: alternatives.append('include your response in a tool call') else: alternatives.append('call a tool') if output_schema.allows_image: if image := next((file for file in files if isinstance(file, _messages.BinaryImage)), None): self._next_node = await self._handle_image_response(ctx, image) return alternatives.append('return an image') if text_processor := output_schema.text_processor: if text: self._next_node = await self._handle_text_response(ctx, text, text_processor) return alternatives.insert(0, 'return text') # handle responses with only parts that don't constitute output. # This can happen with models that support thinking mode when they don't provide # actionable output alongside their thinking content. so we tell the model to try again. m = _messages.RetryPromptPart( content=f'Please {" or ".join(alternatives)}.', ) raise ToolRetryError(m) except ToolRetryError as e: ctx.state.increment_retries( ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings ) self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) self._events_iterator = _run_stream() async for event in self._events_iterator: yield event async def _handle_tool_calls( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: run_context = build_run_context(ctx) # This will raise errors for any tool name conflicts ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) output_parts: list[_messages.ModelRequestPart] = [] output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) async for event in process_tool_calls( tool_manager=ctx.deps.tool_manager, tool_calls=tool_calls, tool_call_results=self.tool_call_results, final_result=None, ctx=ctx, output_parts=output_parts, output_final_result=output_final_result, ): yield event if output_final_result: final_result = output_final_result[0] self._next_node = self._handle_final_result(ctx, final_result, output_parts) else: instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( _messages.ModelRequest(parts=output_parts, instructions=instructions) ) async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], text: str, text_processor: _output.BaseOutputProcessor[NodeRunEndT], ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) result_data = await text_processor.process(text, run_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) return self._handle_final_result(ctx, result.FinalResult(result_data), []) async def _handle_image_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], image: _messages.BinaryImage, ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: result_data = cast(NodeRunEndT, image) return self._handle_final_result(ctx, result.FinalResult(result_data), []) def _handle_final_result( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], final_result: result.FinalResult[NodeRunEndT], tool_responses: list[_messages.ModelRequestPart], ) -> End[result.FinalResult[NodeRunEndT]]: messages = ctx.state.message_history # For backwards compatibility, append a new ModelRequest using the tool returns and retries if tool_responses: messages.append(_messages.ModelRequest(parts=tool_responses)) return End(final_result) __repr__ = dataclasses_no_defaults_repr @dataclasses.dataclass class SetFinalResult(AgentNode[DepsT, NodeRunEndT]): """A node that immediately ends the graph run after a streaming response produced a final result.""" final_result: result.FinalResult[NodeRunEndT] async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> End[result.FinalResult[NodeRunEndT]]: return End(self.final_result) def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" return RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, tracer=ctx.deps.tracer, trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, instrumentation_version=ctx.deps.instrumentation_settings.version if ctx.deps.instrumentation_settings else DEFAULT_INSTRUMENTATION_VERSION, run_step=ctx.state.run_step, tool_call_approved=ctx.state.run_step == 0, ) async def process_tool_calls( # noqa: C901 tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], tool_call_results: dict[str, DeferredToolResult | Literal['skip']] | None, final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1), ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments. """ tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list) for call in tool_calls: tool_def = tool_manager.get_tool_def(call.tool_name) if tool_def: kind = tool_def.kind else: kind = 'unknown' tool_calls_by_kind[kind].append(call) # First, we handle output tool calls for call in tool_calls_by_kind['output']: if final_result: if final_result.tool_call_id == call.tool_call_id: part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', tool_call_id=call.tool_call_id, ) else: yield _messages.FunctionToolCallEvent(call) part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Output tool not used - a final result was already processed.', tool_call_id=call.tool_call_id, ) yield _messages.FunctionToolResultEvent(part) output_parts.append(part) else: try: result_data = await tool_manager.handle_call(call) except exceptions.UnexpectedModelBehavior as e: ctx.state.increment_retries( ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings ) raise e # pragma: lax no cover except ToolRetryError as e: ctx.state.increment_retries( ctx.deps.max_result_retries, error=e, model_settings=ctx.deps.model_settings ) yield _messages.FunctionToolCallEvent(call) output_parts.append(e.tool_retry) yield _messages.FunctionToolResultEvent(e.tool_retry) else: part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', tool_call_id=call.tool_call_id, ) output_parts.append(part) final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) # Then, we handle function tool calls calls_to_run: list[_messages.ToolCallPart] = [] if final_result and ctx.deps.end_strategy == 'early': for call in tool_calls_by_kind['function']: output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', tool_call_id=call.tool_call_id, ) ) else: calls_to_run.extend(tool_calls_by_kind['function']) # Then, we handle unknown tool calls if tool_calls_by_kind['unknown']: ctx.state.increment_retries(ctx.deps.max_result_retries, model_settings=ctx.deps.model_settings) calls_to_run.extend(tool_calls_by_kind['unknown']) calls_to_run_results: dict[str, DeferredToolResult] = {} if tool_call_results is not None: # Deferred tool calls are "run" as well, by reading their value from the tool call results calls_to_run.extend(tool_calls_by_kind['external']) calls_to_run.extend(tool_calls_by_kind['unapproved']) result_tool_call_ids = set(tool_call_results.keys()) tool_call_ids_to_run = {call.tool_call_id for call in calls_to_run} if tool_call_ids_to_run != result_tool_call_ids: raise exceptions.UserError( 'Tool call results need to be provided for all deferred tool calls. ' f'Expected: {tool_call_ids_to_run}, got: {result_tool_call_ids}' ) # Filter out calls that were already executed before and should now be skipped calls_to_run_results = {call_id: result for call_id, result in tool_call_results.items() if result != 'skip'} calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results] deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list) if calls_to_run: async for event in _call_tools( tool_manager=tool_manager, tool_calls=calls_to_run, tool_call_results=calls_to_run_results, tracer=ctx.deps.tracer, usage=ctx.state.usage, usage_limits=ctx.deps.usage_limits, output_parts=output_parts, output_deferred_calls=deferred_calls, ): yield event # Finally, we handle deferred tool calls (unless they were already included in the run because results were provided) if tool_call_results is None: calls = [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']] if final_result: # If the run was already determined to end on deferred tool calls, # we shouldn't insert return parts as the deferred tools will still get a real result. if not isinstance(final_result.output, _output.DeferredToolRequests): for call in calls: output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', tool_call_id=call.tool_call_id, ) ) elif calls: deferred_calls['external'].extend(tool_calls_by_kind['external']) deferred_calls['unapproved'].extend(tool_calls_by_kind['unapproved']) for call in calls: yield _messages.FunctionToolCallEvent(call) if not final_result and deferred_calls: if not ctx.deps.output_schema.allows_deferred_tools: raise exceptions.UserError( 'A deferred tool call was present, but `DeferredToolRequests` is not among output types. To resolve this, add `DeferredToolRequests` to the list of output types for this agent.' ) deferred_tool_requests = _output.DeferredToolRequests( calls=deferred_calls['external'], approvals=deferred_calls['unapproved'], ) final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None) if final_result: output_final_result.append(final_result) async def _call_tools( tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], tool_call_results: dict[str, DeferredToolResult], tracer: Tracer, usage: _usage.RunUsage, usage_limits: _usage.UsageLimits, output_parts: list[_messages.ModelRequestPart], output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], ) -> AsyncIterator[_messages.HandleResponseEvent]: tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} user_parts_by_index: dict[int, _messages.UserPromptPart] = {} deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {} if usage_limits.tool_calls_limit is not None: projected_usage = deepcopy(usage) projected_usage.tool_calls += len(tool_calls) usage_limits.check_before_tool_call(projected_usage) for call in tool_calls: yield _messages.FunctionToolCallEvent(call) with tracer.start_as_current_span( 'running tools', attributes={ 'tools': [call.tool_name for call in tool_calls], 'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}', }, ): async def handle_call_or_result( coro_or_task: Awaitable[ tuple[ _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None ] ] | Task[ tuple[ _messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None ] ], index: int, ) -> _messages.HandleResponseEvent | None: try: tool_part, tool_user_content = ( (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() ) except exceptions.CallDeferred: deferred_calls_by_index[index] = 'external' except exceptions.ApprovalRequired: deferred_calls_by_index[index] = 'unapproved' else: tool_parts_by_index[index] = tool_part if tool_user_content: user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content) return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) if tool_manager.should_call_sequentially(tool_calls): for index, call in enumerate(tool_calls): if event := await handle_call_or_result( _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), index, ): yield event else: tasks = [ asyncio.create_task( _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), name=call.tool_name, ) for call in tool_calls ] pending = tasks while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) if event := await handle_call_or_result(coro_or_task=task, index=index): yield event # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) for k in sorted(deferred_calls_by_index): output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k]) async def _call_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, tool_call_result: DeferredToolResult | None, ) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]: try: if tool_call_result is None: tool_result = await tool_manager.handle_call(tool_call) elif isinstance(tool_call_result, ToolApproved): if tool_call_result.override_args is not None: tool_call = dataclasses.replace(tool_call, args=tool_call_result.override_args) tool_result = await tool_manager.handle_call(tool_call) elif isinstance(tool_call_result, ToolDenied): return _messages.ToolReturnPart( tool_name=tool_call.tool_name, content=tool_call_result.message, tool_call_id=tool_call.tool_call_id, ), None elif isinstance(tool_call_result, exceptions.ModelRetry): m = _messages.RetryPromptPart( content=tool_call_result.message, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, ) raise ToolRetryError(m) elif isinstance(tool_call_result, _messages.RetryPromptPart): tool_call_result.tool_name = tool_call.tool_name tool_call_result.tool_call_id = tool_call.tool_call_id raise ToolRetryError(tool_call_result) else: tool_result = tool_call_result except ToolRetryError as e: return e.tool_retry, None if isinstance(tool_result, _messages.ToolReturn): tool_return = tool_result else: result_is_list = isinstance(tool_result, list) contents = cast(list[Any], tool_result) if result_is_list else [tool_result] return_values: list[Any] = [] user_contents: list[str | _messages.UserContent] = [] for content in contents: if isinstance(content, _messages.ToolReturn): raise exceptions.UserError( f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. ' f'`ToolReturn` should be used directly.' ) elif isinstance(content, _messages.MultiModalContent): identifier = content.identifier return_values.append(f'See file {identifier}') user_contents.extend([f'This is file {identifier}:', content]) else: return_values.append(content) tool_return = _messages.ToolReturn( return_value=return_values[0] if len(return_values) == 1 and not result_is_list else return_values, content=user_contents, ) if ( isinstance(tool_return.return_value, _messages.MultiModalContent) or isinstance(tool_return.return_value, list) and any( isinstance(content, _messages.MultiModalContent) for content in tool_return.return_value # type: ignore ) ): raise exceptions.UserError( f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContent` objects. ' f'Please use `content` instead.' ) return_part = _messages.ToolReturnPart( tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, content=tool_return.return_value, # type: ignore metadata=tool_return.metadata, ) return return_part, tool_return.content or None @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] used: bool = False _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var') @contextmanager def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: """Context manager to access the messages used in a [`run`][pydantic_ai.agent.AbstractAgent.run], [`run_sync`][pydantic_ai.agent.AbstractAgent.run_sync], or [`run_stream`][pydantic_ai.agent.AbstractAgent.run_stream] call. Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. Examples: ```python from pydantic_ai import Agent, capture_run_messages agent = Agent('test') with capture_run_messages() as messages: try: result = agent.run_sync('foobar') except Exception: print(messages) raise ``` !!! note If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. """ token = None messages: list[_messages.ModelMessage] = [] # Try to reuse existing message context if available try: messages = _messages_ctx_var.get().messages except LookupError: # No existing context, create a new one token = _messages_ctx_var.set(_RunMessages(messages)) try: yield messages finally: # Clean up context if we created it if token is not None: _messages_ctx_var.reset(token) def get_captured_run_messages() -> _RunMessages: return _messages_ctx_var.get() def build_agent_graph( name: str | None, deps_type: type[DepsT], output_type: OutputSpec[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], CallToolsNode[DepsT], SetFinalResult[DepsT], ) graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]]( nodes=nodes, name=name or 'Agent', state_type=GraphAgentState, run_end_type=result.FinalResult[OutputT], auto_instrument=False, ) return graph async def _process_message_history( messages: list[_messages.ModelMessage], processors: Sequence[HistoryProcessor[DepsT]], run_context: RunContext[DepsT], ) -> list[_messages.ModelMessage]: """Process message history through a sequence of processors.""" for processor in processors: takes_ctx = is_takes_ctx(processor) if is_async_callable(processor): if takes_ctx: messages = await processor(run_context, messages) else: async_processor = cast(_HistoryProcessorAsync, processor) messages = await async_processor(messages) else: if takes_ctx: sync_processor_with_ctx = cast(_HistoryProcessorSyncWithCtx[DepsT], processor) messages = await run_in_executor(sync_processor_with_ctx, run_context, messages) else: sync_processor = cast(_HistoryProcessorSync, processor) messages = await run_in_executor(sync_processor, messages) if len(messages) == 0: raise exceptions.UserError('Processed history cannot be empty.') if not isinstance(messages[-1], _messages.ModelRequest): raise exceptions.UserError('Processed history must end with a `ModelRequest`.') return messages def _clean_message_history(messages: list[_messages.ModelMessage]) -> list[_messages.ModelMessage]: """Clean the message history by merging consecutive messages of the same type.""" clean_messages: list[_messages.ModelMessage] = [] for message in messages: last_message = clean_messages[-1] if len(clean_messages) > 0 else None if isinstance(message, _messages.ModelRequest): if ( last_message and isinstance(last_message, _messages.ModelRequest) # Requests can only be merged if they have the same instructions and ( not last_message.instructions or not message.instructions or last_message.instructions == message.instructions ) ): parts = [*last_message.parts, *message.parts] parts.sort( # Tool return parts always need to be at the start key=lambda x: 0 if isinstance(x, _messages.ToolReturnPart | _messages.RetryPromptPart) else 1 ) merged_message = _messages.ModelRequest( parts=parts, instructions=last_message.instructions or message.instructions, ) clean_messages[-1] = merged_message else: clean_messages.append(message) elif isinstance(message, _messages.ModelResponse): # pragma: no branch if ( last_message and isinstance(last_message, _messages.ModelResponse) # Responses can only be merged if they didn't really come from an API and last_message.provider_response_id is None and last_message.provider_name is None and last_message.model_name is None and message.provider_response_id is None and message.provider_name is None and message.model_name is None ): merged_message = replace(last_message, parts=[*last_message.parts, *message.parts]) clean_messages[-1] = merged_message else: clean_messages.append(message) return clean_messages

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server