Skip to main content
Glama

mcp-run-python

Official
by pydantic
test_history_processor.py31.2 kB
from collections.abc import AsyncIterator from typing import Any import pytest from inline_snapshot import snapshot from pydantic_ai import ( Agent, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, SystemPromptPart, TextPart, UserPromptPart, capture_run_messages, ) from pydantic_ai.exceptions import UserError from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.tools import RunContext from pydantic_ai.usage import RequestUsage from .conftest import IsDatetime pytestmark = [pytest.mark.anyio] @pytest.fixture def received_messages() -> list[ModelMessage]: return [] @pytest.fixture def function_model(received_messages: list[ModelMessage]) -> FunctionModel: def capture_model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Capture the messages that the provider actually receives received_messages.clear() received_messages.extend(messages) return ModelResponse(parts=[TextPart(content='Provider response')]) async def capture_model_stream_function(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: received_messages.clear() received_messages.extend(messages) yield 'hello' return FunctionModel(capture_model_function, stream_function=capture_model_stream_function) async def test_history_processor_no_op(function_model: FunctionModel, received_messages: list[ModelMessage]): def no_op_history_processor(messages: list[ModelMessage]) -> list[ModelMessage]: return messages agent = Agent(function_model, history_processors=[no_op_history_processor]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Previous question')]), ModelResponse(parts=[TextPart(content='Previous answer')]), ] with capture_run_messages() as captured_messages: result = await agent.run('New question', message_history=message_history) assert received_messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_run_replaces_message_history( function_model: FunctionModel, received_messages: list[ModelMessage] ): """Test that the history processor replaces the message history in the state.""" def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage]: # Keep the last message (last question) and add a new system prompt return messages[-1:] + [ModelRequest(parts=[SystemPromptPart(content='Processed answer')])] agent = Agent(function_model, history_processors=[process_previous_answers]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ModelRequest(parts=[UserPromptPart(content='Question 2')]), ModelResponse(parts=[TextPart(content='Answer 2')]), ] with capture_run_messages() as captured_messages: result = await agent.run('Question 3', message_history=message_history) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 3', timestamp=IsDatetime(), ), SystemPromptPart( content='Processed answer', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Question 3', timestamp=IsDatetime())]), ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_streaming_replaces_message_history( function_model: FunctionModel, received_messages: list[ModelMessage] ): """Test that the history processor replaces the message history in the state.""" def process_previous_answers(messages: list[ModelMessage]) -> list[ModelMessage]: # Keep the last message (last question) and add a new system prompt return messages[-1:] + [ModelRequest(parts=[SystemPromptPart(content='Processed answer')])] agent = Agent(function_model, history_processors=[process_previous_answers]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ModelRequest(parts=[UserPromptPart(content='Question 2')]), ModelResponse(parts=[TextPart(content='Answer 2')]), ] with capture_run_messages() as captured_messages: async with agent.run_stream('Question 3', message_history=message_history) as result: async for _ in result.stream_text(): pass assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 3', timestamp=IsDatetime(), ), SystemPromptPart( content='Processed answer', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Question 3', timestamp=IsDatetime())]), ModelRequest(parts=[SystemPromptPart(content='Processed answer', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='hello')], usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_messages_sent_to_provider( function_model: FunctionModel, received_messages: list[ModelMessage] ): """Test what messages are actually sent to the provider after processing.""" def capture_messages_processor(messages: list[ModelMessage]) -> list[ModelMessage]: # Filter out ModelResponse messages return [msg for msg in messages if isinstance(msg, ModelRequest)] agent = Agent(function_model, history_processors=[capture_messages_processor]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Previous question')]), ModelResponse(parts=[TextPart(content='Previous answer')]), # This should be filtered out ] with capture_run_messages() as captured_messages: result = await agent.run('New question', message_history=message_history) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Previous question', timestamp=IsDatetime(), ), UserPromptPart( content='New question', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_multiple_history_processors(function_model: FunctionModel, received_messages: list[ModelMessage]): """Test that multiple processors are applied in sequence.""" def first_processor(messages: list[ModelMessage]) -> list[ModelMessage]: # Add a prefix to user prompts processed: list[ModelMessage] = [] for msg in messages: if isinstance(msg, ModelRequest): new_parts: list[ModelRequestPart] = [] for part in msg.parts: if isinstance(part, UserPromptPart): # pragma: no branch new_parts.append(UserPromptPart(content=f'[FIRST] {part.content}')) processed.append(ModelRequest(parts=new_parts)) else: processed.append(msg) return processed def second_processor(messages: list[ModelMessage]) -> list[ModelMessage]: # Add another prefix to user prompts processed: list[ModelMessage] = [] for msg in messages: if isinstance(msg, ModelRequest): new_parts: list[ModelRequestPart] = [] for part in msg.parts: if isinstance(part, UserPromptPart): # pragma: no branch new_parts.append(UserPromptPart(content=f'[SECOND] {part.content}')) processed.append(ModelRequest(parts=new_parts)) else: processed.append(msg) return processed agent = Agent(function_model, history_processors=[first_processor, second_processor]) message_history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Question')]), ModelResponse(parts=[TextPart(content='Answer')]), ] with capture_run_messages() as captured_messages: result = await agent.run('New question', message_history=message_history) assert received_messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='[SECOND] [FIRST] Question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='[SECOND] [FIRST] New question', timestamp=IsDatetime())]), ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='[SECOND] [FIRST] Question', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Answer')], timestamp=IsDatetime(), ), ModelRequest( parts=[ UserPromptPart( content='[SECOND] [FIRST] New question', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=57, output_tokens=3), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_async_history_processor(function_model: FunctionModel, received_messages: list[ModelMessage]): """Test that async processors work.""" async def async_processor(messages: list[ModelMessage]) -> list[ModelMessage]: return [msg for msg in messages if isinstance(msg, ModelRequest)] agent = Agent(function_model, history_processors=[async_processor]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), # Should be filtered out ] with capture_run_messages() as captured_messages: result = await agent.run('Question 2', message_history=message_history) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 1', timestamp=IsDatetime(), ), UserPromptPart( content='Question 2', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 1', timestamp=IsDatetime(), ) ] ), ModelRequest( parts=[ UserPromptPart( content='Question 2', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_on_streamed_run(function_model: FunctionModel, received_messages: list[ModelMessage]): """Test that history processors work on streamed runs.""" async def async_processor(messages: list[ModelMessage]) -> list[ModelMessage]: return [msg for msg in messages if isinstance(msg, ModelRequest)] message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ] agent = Agent(function_model, history_processors=[async_processor]) with capture_run_messages() as captured_messages: async with agent.iter('Question 2', message_history=message_history) as run: async for node in run: if agent.is_model_request_node(node): async with node.stream(run.ctx) as stream: async for _ in stream.stream_responses(debounce_by=None): ... result = run.result assert result is not None assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 1', timestamp=IsDatetime(), ), UserPromptPart( content='Question 2', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 1', timestamp=IsDatetime(), ) ] ), ModelRequest( parts=[ UserPromptPart( content='Question 2', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='hello')], usage=RequestUsage(input_tokens=50, output_tokens=1), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_with_context(function_model: FunctionModel, received_messages: list[ModelMessage]): """Test history processor that takes RunContext.""" def context_processor(ctx: RunContext[str], messages: list[ModelMessage]) -> list[ModelMessage]: # Access deps from context prefix = ctx.deps processed: list[ModelMessage] = [] for msg in messages: if isinstance(msg, ModelRequest): new_parts: list[ModelRequestPart] = [] for part in msg.parts: if isinstance(part, UserPromptPart): new_parts.append(UserPromptPart(content=f'{prefix}: {part.content}')) else: new_parts.append(part) # pragma: no cover processed.append(ModelRequest(parts=new_parts)) else: processed.append(msg) # pragma: no cover return processed agent = Agent(function_model, history_processors=[context_processor], deps_type=str) with capture_run_messages() as captured_messages: result = await agent.run('test', deps='PREFIX') assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='PREFIX: test', timestamp=IsDatetime(), ) ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='PREFIX: test', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=52, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_with_context_async( function_model: FunctionModel, received_messages: list[ModelMessage] ): """Test async history processor that takes RunContext.""" async def async_context_processor(ctx: RunContext[Any], messages: list[ModelMessage]) -> list[ModelMessage]: return messages[-1:] # Keep only the last message message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ModelRequest(parts=[UserPromptPart(content='Question 2')]), ModelResponse(parts=[TextPart(content='Answer 2')]), ] agent = Agent(function_model, history_processors=[async_context_processor]) with capture_run_messages() as captured_messages: result = await agent.run('Question 3', message_history=message_history) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 3', timestamp=IsDatetime(), ) ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Question 3', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=52, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_mixed_signatures(function_model: FunctionModel, received_messages: list[ModelMessage]): """Test mixing processors with and without context.""" def simple_processor(messages: list[ModelMessage]) -> list[ModelMessage]: # Filter out responses return [msg for msg in messages if isinstance(msg, ModelRequest)] def context_processor(ctx: RunContext[Any], messages: list[ModelMessage]) -> list[ModelMessage]: # Add prefix based on deps prefix = getattr(ctx.deps, 'prefix', 'DEFAULT') processed: list[ModelMessage] = [] for msg in messages: if isinstance(msg, ModelRequest): new_parts: list[ModelRequestPart] = [] for part in msg.parts: if isinstance(part, UserPromptPart): new_parts.append(UserPromptPart(content=f'{prefix}: {part.content}')) else: new_parts.append(part) # pragma: no cover processed.append(ModelRequest(parts=new_parts)) else: processed.append(msg) # pragma: no cover return processed message_history = [ ModelRequest(parts=[UserPromptPart(content='Question 1')]), ModelResponse(parts=[TextPart(content='Answer 1')]), ] # Create deps with prefix attribute class Deps: prefix = 'TEST' agent = Agent(function_model, history_processors=[simple_processor, context_processor], deps_type=Deps) with capture_run_messages() as captured_messages: result = await agent.run('Question 2', message_history=message_history, deps=Deps()) # Should have filtered responses and added prefix assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='TEST: Question 1', timestamp=IsDatetime(), ), UserPromptPart( content='TEST: Question 2', timestamp=IsDatetime(), ), ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='TEST: Question 1', timestamp=IsDatetime(), ) ] ), ModelRequest( parts=[ UserPromptPart( content='TEST: Question 2', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=56, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_replace_messages(function_model: FunctionModel, received_messages: list[ModelMessage]): history: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='Original message')]), ModelResponse(parts=[TextPart(content='Original response')]), ModelRequest(parts=[UserPromptPart(content='Original followup')]), ] def return_new_history(messages: list[ModelMessage]) -> list[ModelMessage]: return [ ModelRequest(parts=[UserPromptPart(content='Modified message')]), ] agent = Agent(function_model, history_processors=[return_new_history]) with capture_run_messages() as captured_messages: result = await agent.run('foobar', message_history=history) assert received_messages == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Modified message', timestamp=IsDatetime(), ) ] ) ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( content='Modified message', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=52, output_tokens=2), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_history_processor_empty_history(function_model: FunctionModel, received_messages: list[ModelMessage]): def return_new_history(messages: list[ModelMessage]) -> list[ModelMessage]: return [] agent = Agent(function_model, history_processors=[return_new_history]) with pytest.raises(UserError, match='Processed history cannot be empty.'): await agent.run('foobar') async def test_history_processor_history_ending_in_response( function_model: FunctionModel, received_messages: list[ModelMessage] ): def return_new_history(messages: list[ModelMessage]) -> list[ModelMessage]: return [ModelResponse(parts=[TextPart(content='Provider response')])] agent = Agent(function_model, history_processors=[return_new_history]) with pytest.raises(UserError, match='Processed history must end with a `ModelRequest`.'): await agent.run('foobar') async def test_callable_class_history_processor_no_op( function_model: FunctionModel, received_messages: list[ModelMessage] ): class NoOpHistoryProcessor: def __call__(self, messages: list[ModelMessage]) -> list[ModelMessage]: return messages agent = Agent(function_model, history_processors=[NoOpHistoryProcessor()]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Previous question')]), ModelResponse(parts=[TextPart(content='Previous answer')]), ] with capture_run_messages() as captured_messages: result = await agent.run('New question', message_history=message_history) assert received_messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:] async def test_callable_class_history_processor_with_ctx_no_op( function_model: FunctionModel, received_messages: list[ModelMessage] ): class NoOpHistoryProcessorWithCtx: def __call__(self, _: RunContext, messages: list[ModelMessage]) -> list[ModelMessage]: return messages agent = Agent(function_model, history_processors=[NoOpHistoryProcessorWithCtx()]) message_history = [ ModelRequest(parts=[UserPromptPart(content='Previous question')]), ModelResponse(parts=[TextPart(content='Previous answer')]), ] with capture_run_messages() as captured_messages: result = await agent.run('New question', message_history=message_history) assert received_messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ] ) assert captured_messages == result.all_messages() assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Previous question', timestamp=IsDatetime())]), ModelResponse(parts=[TextPart(content='Previous answer')], timestamp=IsDatetime()), ModelRequest(parts=[UserPromptPart(content='New question', timestamp=IsDatetime())]), ModelResponse( parts=[TextPart(content='Provider response')], usage=RequestUsage(input_tokens=54, output_tokens=4), model_name='function:capture_model_function:capture_model_stream_function', timestamp=IsDatetime(), ), ] ) assert result.new_messages() == result.all_messages()[-2:]

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server