test_broadcast_and_spread.py•8.35 kB
"""Tests for broadcast (parallel) and map (fan-out) operations."""
from __future__ import annotations
from dataclasses import dataclass, field
import pytest
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import reduce_list_append
pytestmark = pytest.mark.anyio
@dataclass
class CounterState:
values: list[int] = field(default_factory=list)
async def test_broadcast_to_multiple_steps():
"""Test broadcasting the same data to multiple parallel steps."""
g = GraphBuilder(state_type=CounterState, output_type=list[int])
@g.step
async def source(ctx: StepContext[CounterState, None, None]) -> int:
return 10
@g.step
async def add_one(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 1
@g.step
async def add_two(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 2
@g.step
async def add_three(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 3
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(source),
g.edge_from(source).to(add_one, add_two, add_three),
g.edge_from(add_one, add_two, add_three).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
# Results can be in any order due to parallel execution
assert sorted(result) == [11, 12, 13]
async def test_map_over_list():
"""Test mapping a list to process items in parallel."""
g = GraphBuilder(state_type=CounterState, output_type=list[int])
@g.step
async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]:
return [1, 2, 3, 4, 5]
@g.step
async def square(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs * ctx.inputs
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add_mapping_edge(generate_list, square)
g.add(
g.edge_from(g.start_node).to(generate_list),
g.edge_from(square).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
assert sorted(result) == [1, 4, 9, 16, 25]
async def test_map_with_labels():
"""Test map operation with labeled edges."""
g = GraphBuilder(state_type=CounterState, output_type=list[str])
@g.step
async def generate_numbers(ctx: StepContext[CounterState, None, None]) -> list[int]:
return [10, 20, 30]
@g.step
async def stringify(ctx: StepContext[CounterState, None, int]) -> str:
return f'Value: {ctx.inputs}'
collect = g.join(reduce_list_append, initial_factory=list[str])
g.add_mapping_edge(
generate_numbers,
stringify,
pre_map_label='before map',
post_map_label='after map',
)
g.add(
g.edge_from(g.start_node).to(generate_numbers),
g.edge_from(stringify).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
assert sorted(result) == ['Value: 10', 'Value: 20', 'Value: 30']
async def test_map_empty_list():
"""Test mapping an empty list."""
g = GraphBuilder(state_type=CounterState, output_type=list[int])
@g.step
async def generate_empty(ctx: StepContext[CounterState, None, None]) -> list[int]:
return []
@g.step
async def double(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs * 2 # pragma: no cover
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add_mapping_edge(generate_empty, double, downstream_join_id=collect.id)
g.add(
g.edge_from(g.start_node).to(generate_empty),
g.edge_from(double).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
assert result == []
async def test_nested_broadcasts():
"""Test nested broadcast operations."""
g = GraphBuilder(state_type=CounterState, output_type=list[int])
@g.step
async def start_value(ctx: StepContext[CounterState, None, None]) -> int:
return 5
@g.step
async def path_a1(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 1
@g.step
async def path_a2(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 10
@g.step
async def path_b1(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs * 2
@g.step
async def path_b2(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs * 3
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(start_value),
g.edge_from(start_value).to(path_a1, path_b1),
g.edge_from(path_a1).to(path_a2),
g.edge_from(path_b1).to(path_b2),
g.edge_from(path_a2, path_b2).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
# path_a: 5 + 1 + 10 = 16
# path_b: 5 * 2 * 3 = 30
assert sorted(result) == [16, 30]
async def test_map_then_broadcast():
"""Test mapping followed by broadcasting from each map item."""
g = GraphBuilder(state_type=CounterState, output_type=list[int])
@g.step
async def generate_list(ctx: StepContext[CounterState, None, None]) -> list[int]:
return [10, 20]
@g.step
async def add_one(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 1
@g.step
async def add_two(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs + 2
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(generate_list),
g.edge_from(generate_list).map().to(add_one, add_two),
g.edge_from(add_one, add_two).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
# From 10: 11, 12
# From 20: 21, 22
assert sorted(result) == [11, 12, 21, 22]
async def test_multiple_sequential_maps():
"""Test multiple sequential map operations."""
g = GraphBuilder(state_type=CounterState, output_type=list[str])
@g.step
async def generate_pairs(ctx: StepContext[CounterState, None, None]) -> list[tuple[int, int]]:
return [(1, 2), (3, 4)]
@g.step
async def unpack_pair(ctx: StepContext[CounterState, None, tuple[int, int]]) -> list[int]:
return [ctx.inputs[0], ctx.inputs[1]]
@g.step
async def stringify(ctx: StepContext[CounterState, None, int]) -> str:
return f'num:{ctx.inputs}'
collect = g.join(reduce_list_append, initial_factory=list[str])
g.add(
g.edge_from(g.start_node).to(generate_pairs),
g.edge_from(generate_pairs).map().to(unpack_pair),
g.edge_from(unpack_pair).map().to(stringify),
g.edge_from(stringify).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
assert sorted(result) == ['num:1', 'num:2', 'num:3', 'num:4']
async def test_broadcast_with_different_outputs():
"""Test that broadcasts can produce different types of outputs."""
g = GraphBuilder(state_type=CounterState, output_type=list[int | str])
@g.step
async def source(ctx: StepContext[CounterState, None, None]) -> int:
return 42
@g.step
async def return_int(ctx: StepContext[CounterState, None, int]) -> int:
return ctx.inputs
@g.step
async def return_str(ctx: StepContext[CounterState, None, int]) -> str:
return str(ctx.inputs)
collect = g.join(reduce_list_append, initial_factory=list[int | str])
g.add(
g.edge_from(g.start_node).to(source),
g.edge_from(source).to(return_int, return_str),
g.edge_from(return_int, return_str).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=CounterState())
# Order may vary
assert set(result) == {42, '42'}