Skip to main content
Glama

mcp-run-python

Official
by pydantic
test_graph_builder.py13.8 kB
"""Tests for the GraphBuilder API and basic graph construction.""" from __future__ import annotations from dataclasses import dataclass import pytest from pydantic_graph.beta import GraphBuilder, StepContext from pydantic_graph.beta.graph_builder import GraphBuildingError from pydantic_graph.beta.join import reduce_list_append, reduce_sum from pydantic_graph.beta.node import Fork from pydantic_graph.exceptions import GraphValidationError pytestmark = pytest.mark.anyio @dataclass class SimpleState: counter: int = 0 result: str | None = None async def test_basic_graph_builder(): """Test basic graph builder construction and execution.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step async def increment(ctx: StepContext[SimpleState, None, None]) -> int: ctx.state.counter += 1 return ctx.state.counter g.add( g.edge_from(g.start_node).to(increment), g.edge_from(increment).to(g.end_node), ) graph = g.build() state = SimpleState() result = await graph.run(state=state) assert result == 1 assert state.counter == 1 async def test_sequential_steps(): """Test multiple sequential steps in a graph.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step async def step_one(ctx: StepContext[SimpleState, None, None]) -> None: ctx.state.counter += 1 @g.step async def step_two(ctx: StepContext[SimpleState, None, None]) -> None: ctx.state.counter *= 2 @g.step async def step_three(ctx: StepContext[SimpleState, None, None]) -> int: ctx.state.counter += 10 return ctx.state.counter g.add( g.edge_from(g.start_node).to(step_one), g.edge_from(step_one).to(step_two), g.edge_from(step_two).to(step_three), g.edge_from(step_three).to(g.end_node), ) graph = g.build() state = SimpleState(counter=5) result = await graph.run(state=state) # (5 + 1) * 2 + 10 = 22 assert result == 22 async def test_step_with_inputs(): """Test steps that receive and transform input data.""" g = GraphBuilder(state_type=SimpleState, input_type=int, output_type=str) @g.step async def double_it(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 2 @g.step async def stringify(ctx: StepContext[SimpleState, None, int]) -> str: return f'Result: {ctx.inputs}' g.add( g.edge_from(g.start_node).to(double_it), g.edge_from(double_it).to(stringify), g.edge_from(stringify).to(g.end_node), ) graph = g.build() state = SimpleState() result = await graph.run(state=state, inputs=21) assert result == 'Result: 42' async def test_step_with_custom_id(): """Test creating steps with custom IDs.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step(node_id='custom_step_id') async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: return 42 # pragma: no cover g.add( g.edge_from(g.start_node).to(my_step), g.edge_from(my_step).to(g.end_node), ) graph = g.build() assert 'custom_step_id' in graph.nodes async def test_step_with_label(): """Test creating steps with human-readable labels.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step(label='My Custom Label') async def my_step(ctx: StepContext[SimpleState, None, None]) -> int: return 42 assert my_step.label == 'My Custom Label' g.add( g.edge_from(g.start_node).to(my_step), g.edge_from(my_step).to(g.end_node), ) graph = g.build() result = await graph.run(state=SimpleState()) assert result == 42 async def test_add_edge_convenience(): """Test the add_edge convenience method.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step async def step_a(ctx: StepContext[SimpleState, None, None]) -> int: return 42 @g.step async def step_b(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs + 1 g.add_edge(g.start_node, step_a) g.add_edge(step_a, step_b, label='from a to b') g.add_edge(step_b, g.end_node) graph = g.build() result = await graph.run(state=SimpleState()) assert result == 43 async def test_graph_with_dependencies(): """Test graph execution with dependency injection.""" @dataclass class MyDeps: multiplier: int g = GraphBuilder(state_type=SimpleState, deps_type=MyDeps, output_type=int) @g.step async def multiply(ctx: StepContext[SimpleState, MyDeps, None]) -> int: return ctx.deps.multiplier * 10 g.add( g.edge_from(g.start_node).to(multiply), g.edge_from(multiply).to(g.end_node), ) graph = g.build() state = SimpleState() deps = MyDeps(multiplier=5) result = await graph.run(state=state, deps=deps) assert result == 50 async def test_empty_graph(): """Test that a minimal graph can be built and run.""" g = GraphBuilder(input_type=int, output_type=int) g.add(g.edge_from(g.start_node).to(g.end_node)) graph = g.build() result = await graph.run(inputs=42) assert result == 42 async def test_graph_name_inference(): """Test that graph names are properly inferred from variable names.""" my_graph_builder = GraphBuilder(output_type=int) @my_graph_builder.step async def return_value(ctx: StepContext[None, None, None]) -> int: return 100 my_graph_builder.add( my_graph_builder.edge_from(my_graph_builder.start_node).to(return_value), my_graph_builder.edge_from(return_value).to(my_graph_builder.end_node), ) my_custom_graph = my_graph_builder.build() result = await my_custom_graph.run() assert result == 100 assert my_custom_graph.name == 'my_custom_graph' async def test_explicit_graph_name(): """Test setting an explicit graph name.""" g = GraphBuilder(name='ExplicitName', input_type=int, output_type=int) g.add(g.edge_from(g.start_node).to(g.end_node)) graph = g.build() assert graph.name == 'ExplicitName' async def test_state_mutation(): """Test that state mutations persist across steps.""" g = GraphBuilder(state_type=SimpleState, output_type=str) @g.step async def set_counter(ctx: StepContext[SimpleState, None, None]) -> None: ctx.state.counter = 10 @g.step async def set_result(ctx: StepContext[SimpleState, None, None]) -> None: ctx.state.result = f'counter={ctx.state.counter}' @g.step async def get_result(ctx: StepContext[SimpleState, None, None]) -> str: assert ctx.state.result is not None return ctx.state.result g.add( g.edge_from(g.start_node).to(set_counter), g.edge_from(set_counter).to(set_result), g.edge_from(set_result).to(get_result), g.edge_from(get_result).to(g.end_node), ) graph = g.build() state = SimpleState() result = await graph.run(state=state) assert result == 'counter=10' assert state.counter == 10 assert state.result == 'counter=10' async def test_duplicate_node_ids_error(): """Test that duplicate node IDs raise a ValueError.""" g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step(node_id='duplicate_id') async def step_one(ctx: StepContext[SimpleState, None, None]) -> int: return 1 # pragma: no cover @g.step(node_id='duplicate_id') async def step_two(ctx: StepContext[SimpleState, None, None]) -> int: return 2 # pragma: no cover with pytest.raises(GraphBuildingError, match='All nodes must have unique node IDs'): g.add( g.edge_from(g.start_node).to(step_one), g.edge_from(g.start_node).to(step_two), ) async def test_multiple_destinations_creates_broadcast_fork(): """Test that using .to() with multiple arguments creates a broadcast fork.""" g = GraphBuilder(state_type=SimpleState, output_type=list[int]) @g.step async def source(ctx: StepContext[SimpleState, None, None]) -> int: return 10 @g.step async def dest_a(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 2 @g.step async def dest_b(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 3 collect = g.join(reduce_list_append, initial_factory=list[int]) g.add( g.edge_from(g.start_node).to(source), g.edge_from(source).to(dest_a, dest_b), # Multiple destinations trigger broadcast fork creation g.edge_from(dest_a, dest_b).to(collect), g.edge_from(collect).to(g.end_node), ) graph = g.build() # Verify a broadcast fork was created broadcast_forks = [node for node in graph.nodes.values() if isinstance(node, Fork) and not node.is_map] assert len(broadcast_forks) > 0, 'Expected a broadcast fork to be created' result = await graph.run(state=SimpleState()) assert sorted(result) == [20, 30] async def test_join_without_dominating_fork_error(): """Test that a join without a dominating fork raises ValueError.""" g = GraphBuilder(output_type=int, input_type=int) @g.step async def source(ctx: StepContext[None, None, int]) -> list[int]: return [ctx.inputs, 1] # pragma: no cover join_sum = g.join(reduce_sum, initial=0) g.add( g.edge_from(g.start_node).to(source), g.edge_from(source).map().to(join_sum), g.edge_from(join_sum).to(join_sum), g.edge_from(join_sum).to(g.end_node), ) with pytest.raises( GraphBuildingError, match='For every Join J in the graph, there must be a Fork F between the StartNode and J satisfying', ): g.build() async def test_validation_no_edges_from_start(): """Test that validation catches graphs with no edges from start node.""" g = GraphBuilder(output_type=int) @g.step async def orphan_step(ctx: StepContext[None, None, None]) -> int: return 42 # pragma: no cover # Add the step to the graph but don't connect it to start g.add(g.edge_from(orphan_step).to(g.end_node)) with pytest.raises(GraphValidationError, match='The graph has no edges from the start node'): g.build() async def test_validation_no_edges_to_end(): """Test that validation catches graphs with no edges to end node.""" g = GraphBuilder(output_type=int) @g.step async def dead_end_step(ctx: StepContext[None, None, None]) -> int: return 42 # pragma: no cover # Connect start to step but don't connect step to end g.add(g.edge_from(g.start_node).to(dead_end_step)) with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'): g.build() async def test_validation_node_with_no_outgoing_edges(): """Test that validation catches nodes with no outgoing edges.""" g = GraphBuilder(output_type=int) @g.step async def first_step(ctx: StepContext[None, None, None]) -> int: return 42 # pragma: no cover @g.step async def dead_end_step(ctx: StepContext[None, None, int]) -> int: return ctx.inputs # pragma: no cover # first_step connects to both dead_end_step and end_node # But dead_end_step has no outgoing edges g.add( g.edge_from(g.start_node).to(first_step), g.edge_from(first_step).to(dead_end_step, g.end_node), ) with pytest.raises(GraphValidationError, match='The following nodes have no outgoing edges'): g.build() async def test_validation_end_node_unreachable(): """Test that validation catches when end node is unreachable from start.""" g = GraphBuilder(input_type=int, output_type=int) @g.step async def first_step(ctx: StepContext[None, None, int]) -> int: return 42 # pragma: no cover @g.step async def second_step(ctx: StepContext[None, None, int]) -> int: return ctx.inputs # pragma: no cover # Create a cycle that doesn't reach the end node g.add( g.edge_from(g.start_node).to(first_step), g.edge_from(first_step).to(second_step), g.edge_from(second_step).to(first_step), ) with pytest.raises(GraphValidationError, match='The graph has no edges to the end node'): g.build() async def test_validation_unreachable_nodes(): """Test that validation catches nodes that are not reachable from start.""" g = GraphBuilder(output_type=int) @g.step async def reachable_step(ctx: StepContext[None, None, None]) -> int: return 10 # pragma: no cover @g.step async def unreachable_step(ctx: StepContext[None, None, int]) -> int: return ctx.inputs * 2 # pragma: no cover # unreachable_step is in the graph but not connected from start g.add( g.edge_from(g.start_node).to(reachable_step), g.edge_from(reachable_step).to(g.end_node), g.edge_from(unreachable_step).to(g.end_node), ) with pytest.raises(GraphValidationError, match='The following nodes are not reachable from the start node'): g.build() async def test_validation_can_be_disabled(): """Test that validation can be disabled with validate_graph_structure=False.""" g = GraphBuilder(output_type=int) @g.step async def orphan_step(ctx: StepContext[None, None, None]) -> int: return 42 # pragma: no cover # Add the step to the graph but don't connect it to start # This would normally fail validation g.add(g.edge_from(orphan_step).to(g.end_node)) # Should not raise an error when validation is disabled g.build(validate_graph_structure=False)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server