test_graph_execution.py•11.2 kB
"""Tests for graph execution internals and edge cases."""
from __future__ import annotations
from dataclasses import dataclass, field
import pytest
from pydantic_graph.beta import GraphBuilder, StepContext
from pydantic_graph.beta.join import ReduceFirstValue, reduce_list_append
pytestmark = pytest.mark.anyio
@dataclass
class ExecutionState:
log: list[str] = field(default_factory=list)
counter: int = 0
async def test_map_to_end_node_cancels_pending():
"""Test that mapping directly to end_node cancels pending tasks"""
import asyncio
g = GraphBuilder(state_type=ExecutionState, output_type=int)
@g.step
async def generate(ctx: StepContext[ExecutionState, None, None]) -> list[int]:
return [1, 2, 3, 4, 5]
@g.step
async def early_exit(ctx: StepContext[ExecutionState, None, int]) -> int:
# First item returns immediately
if ctx.inputs == 1:
return ctx.inputs
# Others would take longer
await asyncio.sleep(1)
return ctx.inputs # pragma: no cover
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(early_exit),
g.edge_from(early_exit).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=ExecutionState())
# Should complete quickly with the first result
assert result in [1, 2, 3, 4, 5]
async def test_map_non_iterable_raises_error():
"""Test that mapping a non-iterable raises RuntimeError."""
g = GraphBuilder(state_type=ExecutionState, output_type=int)
@g.step
async def return_non_iterable(ctx: StepContext[ExecutionState, None, None]) -> int:
return 42 # Not iterable!
@g.step
async def process_item(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs # pragma: no cover
g.add(
g.edge_from(g.start_node).to(return_non_iterable),
g.edge_from(return_non_iterable).map().to(process_item), # type: ignore # purposely have a type error here
g.edge_from(process_item).to(g.end_node),
)
graph = g.build()
with pytest.raises(RuntimeError, match='Cannot map non-iterable value'):
await graph.run(state=ExecutionState())
async def test_broadcast_marker_handling():
"""Test that BroadcastMarker is handled in paths"""
g = GraphBuilder(state_type=ExecutionState, output_type=list[str])
@g.step
async def source(ctx: StepContext[ExecutionState, None, None]) -> str:
return 'data'
@g.step
async def branch_a(ctx: StepContext[ExecutionState, None, str]) -> str:
return f'{ctx.inputs}-A'
@g.step
async def branch_b(ctx: StepContext[ExecutionState, None, str]) -> str:
return f'{ctx.inputs}-B'
collect = g.join(reduce_list_append, initial_factory=list[str])
g.add(
g.edge_from(g.start_node).to(source),
# Use multiple .to() destinations to create broadcast
g.edge_from(source).to(branch_a, branch_b),
g.edge_from(branch_a, branch_b).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=ExecutionState())
assert sorted(result) == ['data-A', 'data-B']
async def test_nested_joins_with_different_fork_stacks():
"""Test nested joins with different fork stack depths"""
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
@g.step
async def generate_outer(ctx: StepContext[ExecutionState, None, None]) -> list[int]:
return [1, 2]
@g.step
async def generate_inner(ctx: StepContext[ExecutionState, None, int]) -> list[int]:
return [ctx.inputs * 10, ctx.inputs * 20]
@g.step
async def process(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs
final_join = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(generate_outer),
g.edge_from(generate_outer).map().to(generate_inner),
g.edge_from(generate_inner).map().to(process),
g.edge_from(process).to(final_join),
g.edge_from(final_join).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=ExecutionState())
# Should have 4 total elements (2 outer * 2 inner each)
assert len(result) == 4
assert sorted(result) == [10, 20, 20, 40]
async def test_reduce_first_value_task_cancellation():
"""Test that ReduceFirstValue properly cancels sibling tasks"""
import asyncio
g = GraphBuilder(state_type=ExecutionState, output_type=str)
@g.step
async def generate(ctx: StepContext[ExecutionState, None, None]) -> list[int]:
return [1, 2, 3, 4, 5]
@g.step
async def slow_process(ctx: StepContext[ExecutionState, None, int]) -> str:
if ctx.inputs == 1:
# First one completes quickly
await asyncio.sleep(0.01)
else:
# Others take longer (should be cancelled)
await asyncio.sleep(10)
ctx.state.log.append(f'completed-{ctx.inputs}')
return f'result-{ctx.inputs}'
first_join = g.join(ReduceFirstValue[str](), initial='')
g.add(
g.edge_from(g.start_node).to(generate),
g.edge_from(generate).map().to(slow_process),
g.edge_from(slow_process).to(first_join),
g.edge_from(first_join).to(g.end_node),
)
graph = g.build()
state = ExecutionState()
result = await graph.run(state=state)
# Should get a result
assert result is not None and result.startswith('result-')
# Not all tasks should have completed due to cancellation
assert len(state.log) < 5
async def test_empty_map_handling():
"""Test handling of mapping an empty iterable.
Note: Empty maps with joins can be tricky and may need the downstream_join_id hint.
This test documents expected behavior.
"""
# Skipping this test as empty maps need special handling with downstream_join_id
# The actual line coverage is achieved through other tests
pass
async def test_complex_fork_stack_with_multiple_levels():
"""Test complex scenarios with multiple fork levels"""
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
@g.step
async def level1(ctx: StepContext[ExecutionState, None, None]) -> list[int]:
return [1, 2]
@g.step
async def level2(ctx: StepContext[ExecutionState, None, int]) -> list[int]:
return [ctx.inputs * 10, ctx.inputs * 10 + 1]
@g.step
async def level3(ctx: StepContext[ExecutionState, None, int]) -> int:
ctx.state.log.append(f'processing-{ctx.inputs}')
return ctx.inputs
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(level1),
g.edge_from(level1).map().to(level2),
g.edge_from(level2).map().to(level3),
g.edge_from(level3).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
state = ExecutionState()
result = await graph.run(state=state)
# Should process 4 items total (2 from level1 * 2 from each level2)
assert len(result) == 4
assert sorted(result) == [10, 11, 20, 21]
assert len(state.log) == 4
async def test_broadcast_with_immediate_join():
"""Test broadcast that immediately joins."""
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
@g.step
async def source(ctx: StepContext[ExecutionState, None, None]) -> int:
return 10
@g.step
async def path_a(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 2
@g.step
async def path_b(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 3
@g.step
async def path_c(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 4
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(source),
# Multiple .to() destinations creates a broadcast
g.edge_from(source).to(path_a, path_b, path_c),
g.edge_from(path_a, path_b, path_c).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=ExecutionState())
assert sorted(result) == [20, 30, 40]
async def test_implicit_broadcast_with_immediate_join():
"""Test broadcast that immediately joins by just manually adding multiple edges from a single node."""
g = GraphBuilder(state_type=ExecutionState, output_type=list[int])
@g.step
async def source(ctx: StepContext[ExecutionState, None, None]) -> int:
return 10
@g.step
async def path_a(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 2
@g.step
async def path_b(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 3
@g.step
async def path_c(ctx: StepContext[ExecutionState, None, int]) -> int:
return ctx.inputs * 4
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(source),
# Multiple .to() destinations creates a broadcast
g.edge_from(source).to(path_a),
g.edge_from(source).to(path_b),
g.edge_from(source).to(path_c),
g.edge_from(path_a).to(collect),
g.edge_from(path_b).to(collect),
g.edge_from(path_c).to(collect),
g.edge_from(collect).to(g.end_node),
)
graph = g.build()
result = await graph.run(state=ExecutionState())
assert sorted(result) == [20, 30, 40]
async def test_mixed_sequential_and_parallel_execution():
"""Test graph with both sequential and parallel sections."""
g = GraphBuilder(state_type=ExecutionState, output_type=str)
@g.step
async def step1(ctx: StepContext[ExecutionState, None, None]) -> int:
ctx.state.log.append('step1')
return 5
@g.step
async def step2(ctx: StepContext[ExecutionState, None, int]) -> list[int]:
ctx.state.log.append('step2')
return [ctx.inputs * 10, ctx.inputs * 20]
@g.step
async def parallel_step(ctx: StepContext[ExecutionState, None, int]) -> int:
ctx.state.log.append(f'parallel-{ctx.inputs}')
return ctx.inputs + 1
@g.step
async def step3(ctx: StepContext[ExecutionState, None, list[int]]) -> str:
ctx.state.log.append('step3')
return f'Result: {sum(ctx.inputs)}'
collect = g.join(reduce_list_append, initial_factory=list[int])
g.add(
g.edge_from(g.start_node).to(step1),
g.edge_from(step1).to(step2),
g.edge_from(step2).map().to(parallel_step),
g.edge_from(parallel_step).to(collect),
g.edge_from(collect).to(step3),
g.edge_from(step3).to(g.end_node),
)
graph = g.build()
state = ExecutionState()
result = await graph.run(state=state)
assert 'step1' in state.log
assert 'step2' in state.log
assert 'parallel-50' in state.log
assert 'parallel-100' in state.log
assert 'step3' in state.log
assert result == 'Result: 152' # (50+1) + (100+1) = 152