Skip to main content
Glama

mcp-run-python

Official
by pydantic
test_joins_and_reducers.py11.7 kB
"""Tests for join nodes and reducer types.""" from __future__ import annotations import asyncio from dataclasses import dataclass, field import pytest from pydantic_graph.beta import GraphBuilder, StepContext from pydantic_graph.beta.join import ( ReduceFirstValue, ReducerContext, reduce_dict_update, reduce_list_append, reduce_null, ) pytestmark = pytest.mark.anyio @dataclass class SimpleState: value: int = 0 async def test_null_reducer(): """Test NullReducer that discards all inputs.""" g = GraphBuilder(state_type=SimpleState) @g.step async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: return [1, 2, 3] @g.step async def process(ctx: StepContext[SimpleState, None, int]) -> int: ctx.state.value += ctx.inputs return ctx.inputs null_join = g.join(reduce_null, initial=None) g.add( g.edge_from(g.start_node).to(source), g.edge_from(source).map().to(process), g.edge_from(process).to(null_join), g.edge_from(null_join).to(g.end_node), ) graph = g.build() state = SimpleState() result = await graph.run(state=state) assert result is None # But side effects should still happen assert state.value == 6 async def test_reduce_list_append(): """Test reduce_list_append that collects all inputs into a list.""" g = GraphBuilder(state_type=SimpleState, output_type=list[str]) @g.step async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]: return [1, 2, 3, 4] @g.step async def to_string(ctx: StepContext[SimpleState, None, int]) -> str: return f'item-{ctx.inputs}' list_join = g.join(reduce_list_append, initial_factory=list[str]) g.add( g.edge_from(g.start_node).to(generate_numbers), g.edge_from(generate_numbers).map().to(to_string), g.edge_from(to_string).to(list_join), g.edge_from(list_join).to(g.end_node), ) graph = g.build() result = await graph.run(state=SimpleState()) # Order may vary due to parallel execution assert sorted(result) == ['item-1', 'item-2', 'item-3', 'item-4'] async def test_reduce_dict_update(): """Test reduce_dict_update that merges dictionaries.""" g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) @g.step async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]: return ['a', 'b', 'c'] @g.step async def create_dict(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]: return {ctx.inputs: len(ctx.inputs)} dict_join = g.join(reduce_dict_update, initial_factory=dict[str, int]) g.add( g.edge_from(g.start_node).to(generate_keys), g.edge_from(generate_keys).map().to(create_dict), g.edge_from(create_dict).to(dict_join), g.edge_from(dict_join).to(g.end_node), ) graph = g.build() result = await graph.run(state=SimpleState()) assert result == {'a': 1, 'b': 1, 'c': 1} async def test_reducer_with_state_access(): """Test that reducers can access and modify graph state.""" # Note: doing this means the `count` will be shared across _all_ fork runs. def get_state_aware_reducer(): count = 0 def reduce_state_aware(ctx: ReducerContext[SimpleState, None], current: int, inputs: int) -> int: nonlocal count ctx.state.value += inputs count += 1 return count return reduce_state_aware g = GraphBuilder(state_type=SimpleState, output_type=int) @g.step async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: return [1, 2, 3] @g.step async def process(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs * 10 aware_join = g.join(get_state_aware_reducer(), initial=0) g.add( g.edge_from(g.start_node).to(generate), g.edge_from(generate).map().to(process), g.edge_from(process).to(aware_join), g.edge_from(aware_join).to(g.end_node), ) graph = g.build() state = SimpleState() result = await graph.run(state=state) assert result == 3 # Three items were reduced assert state.value == 60 # 10 + 20 + 30 async def test_join_with_custom_id(): """Test creating a join with a custom node ID.""" g = GraphBuilder(state_type=SimpleState, output_type=list[int]) @g.step async def source(ctx: StepContext[SimpleState, None, None]) -> list[int]: return [1, 2, 3] # pragma: no cover @g.step async def process(ctx: StepContext[SimpleState, None, int]) -> int: return ctx.inputs # pragma: no cover custom_join = g.join(reduce_list_append, initial_factory=list[int], node_id='my_custom_join') g.add( g.edge_from(g.start_node).to(source), g.edge_from(source).map().to(process), g.edge_from(process).to(custom_join), g.edge_from(custom_join).to(g.end_node), ) graph = g.build() assert 'my_custom_join' in graph.nodes async def test_multiple_joins(): """Test a graph with multiple independent joins.""" @dataclass class MultiState: results: dict[str, list[int]] = field(default_factory=dict) g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]]) @g.step async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]: return [1, 2] @g.step async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]: return [10, 20] @g.step async def process_a(ctx: StepContext[MultiState, None, int]) -> int: return ctx.inputs * 2 @g.step async def process_b(ctx: StepContext[MultiState, None, int]) -> int: return ctx.inputs * 3 join_a = g.join(reduce_list_append, initial_factory=list[int], node_id='join_a') join_b = g.join(reduce_list_append, initial_factory=list[int], node_id='join_b') @g.step async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]: return ctx.state.results @g.step async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None: ctx.state.results['a'] = ctx.inputs @g.step async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None: ctx.state.results['b'] = ctx.inputs g.add( g.edge_from(g.start_node).to(source_a, source_b), g.edge_from(source_a).map().to(process_a), g.edge_from(source_b).map().to(process_b), g.edge_from(process_a).to(join_a), g.edge_from(process_b).to(join_b), g.edge_from(join_a).to(store_a), g.edge_from(join_b).to(store_b), g.edge_from(store_a, store_b).to(combine), g.edge_from(combine).to(g.end_node), ) graph = g.build() state = MultiState() result = await graph.run(state=state) assert sorted(result['a']) == [2, 4] assert sorted(result['b']) == [30, 60] async def test_reduce_dict_update_with_overlapping_keys(): """Test that reduce_dict_update properly handles overlapping keys (later values win).""" g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int]) @g.step async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]: return [1, 2, 3] @g.step async def create_dict(ctx: StepContext[SimpleState, None, int]) -> dict[str, int]: # All create the same key return {'key': ctx.inputs} dict_join = g.join(reduce_dict_update, initial_factory=dict[str, int]) g.add( g.edge_from(g.start_node).to(generate), g.edge_from(generate).map().to(create_dict), g.edge_from(create_dict).to(dict_join), g.edge_from(dict_join).to(g.end_node), ) graph = g.build() result = await graph.run(state=SimpleState()) # One of the values should win (1, 2, or 3) assert 'key' in result assert result['key'] in [1, 2, 3] async def test_reducer_with_deps_access(): """Test that reducer context can access deps""" @dataclass class DepsWithMultiplier: multiplier: int def reducer_using_deps(ctx: ReducerContext[SimpleState, DepsWithMultiplier], current: int, inputs: int) -> int: # Access deps through the context return current + (inputs * ctx.deps.multiplier) g = GraphBuilder(state_type=SimpleState, deps_type=DepsWithMultiplier, output_type=int) @g.step async def generate(ctx: StepContext[SimpleState, DepsWithMultiplier, None]) -> list[int]: return [1, 2, 3] @g.step async def process(ctx: StepContext[SimpleState, DepsWithMultiplier, int]) -> int: return ctx.inputs deps_join = g.join(reducer_using_deps, initial=0) g.add( g.edge_from(g.start_node).to(generate), g.edge_from(generate).map().to(process), g.edge_from(process).to(deps_join), g.edge_from(deps_join).to(g.end_node), ) graph = g.build() state = SimpleState() deps = DepsWithMultiplier(multiplier=10) result = await graph.run(state=state, deps=deps) # (0 + 1*10) + (2*10) + (3*10) = 60 assert result == 60 async def test_reduce_list_extend(): """Test reduce_list_extend that extends a list with iterables""" from pydantic_graph.beta.join import reduce_list_extend g = GraphBuilder(state_type=SimpleState, output_type=list[int]) @g.step async def generate_iterables(ctx: StepContext[SimpleState, None, None]) -> list[list[int]]: return [[1, 2], [3, 4], [5, 6]] @g.step async def pass_through(ctx: StepContext[SimpleState, None, list[int]]) -> list[int]: return ctx.inputs extend_join = g.join(reduce_list_extend, initial_factory=list[int]) g.add( g.edge_from(g.start_node).to(generate_iterables), g.edge_from(generate_iterables).map().to(pass_through), g.edge_from(pass_through).to(extend_join), g.edge_from(extend_join).to(g.end_node), ) graph = g.build() result = await graph.run(state=SimpleState()) # All sublists are extended into one flat list assert sorted(result) == [1, 2, 3, 4, 5, 6] async def test_reduce_first_value(): """Test ReduceFirstValue cancels sibling tasks""" @dataclass class StateWithResults: results: list[str] = field(default_factory=list) g = GraphBuilder(state_type=StateWithResults, output_type=str) @g.step async def generate(ctx: StepContext[StateWithResults, None, None]) -> list[int]: return [1, 2, 3, 4, 5] @g.step async def slow_process(ctx: StepContext[StateWithResults, None, int]) -> str: # First task finishes quickly if ctx.inputs == 1: await asyncio.sleep(0.001) else: # Others take longer (should be cancelled) await asyncio.sleep(10) ctx.state.results.append(f'completed-{ctx.inputs}') return f'result-{ctx.inputs}' first_join = g.join(ReduceFirstValue[str](), initial='') g.add( g.edge_from(g.start_node).to(generate), g.edge_from(generate).map().to(slow_process), g.edge_from(slow_process).to(first_join), g.edge_from(first_join).to(g.end_node), ) graph = g.build() state = StateWithResults() result = await graph.run(state=state) # Only the first value should be returned assert result.startswith('result-') # Due to cancellation, not all 5 tasks should complete # (though timing can be tricky, so we just verify we got a result) assert 'completed-1' in state.results

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server