from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from langgraph._internal._constants import RESERVED
from langgraph.channels.base import BaseChannel
from langgraph.managed.base import ManagedValueMapping
from langgraph.pregel._read import PregelNode
from langgraph.types import All
def validate_graph(
nodes: Mapping[str, PregelNode],
channels: dict[str, BaseChannel],
managed: ManagedValueMapping,
input_channels: str | Sequence[str],
output_channels: str | Sequence[str],
stream_channels: str | Sequence[str] | None,
interrupt_after_nodes: All | Sequence[str],
interrupt_before_nodes: All | Sequence[str],
) -> None:
for chan in channels:
if chan in RESERVED:
raise ValueError(f"Channel name '{chan}' is reserved")
for name in managed:
if name in RESERVED:
raise ValueError(f"Managed name '{name}' is reserved")
subscribed_channels = set[str]()
for name, node in nodes.items():
if name in RESERVED:
raise ValueError(f"Node name '{name}' is reserved")
if isinstance(node, PregelNode):
subscribed_channels.update(node.triggers)
if isinstance(node.channels, str):
if node.channels not in channels:
raise ValueError(
f"Node {name} reads channel '{node.channels}' "
f"not in known channels: '{repr(sorted(channels))[:100]}'"
)
else:
for chan in node.channels:
if chan not in channels and chan not in managed:
raise ValueError(
f"Node {name} reads channel '{chan}' "
f"not in known channels: '{repr(sorted(channels))[:100]}'"
)
else:
raise TypeError(
f"Invalid node type {type(node)}, expected PregelNode or NodeBuilder"
)
for chan in subscribed_channels:
if chan not in channels:
raise ValueError(
f"Subscribed channel '{chan}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if isinstance(input_channels, str):
if input_channels not in channels:
raise ValueError(
f"Input channel '{input_channels}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if input_channels not in subscribed_channels:
raise ValueError(
f"Input channel {input_channels} is not subscribed to by any node"
)
else:
for chan in input_channels:
if chan not in channels:
raise ValueError(
f"Input channel '{chan}' not in '{repr(sorted(channels))[:100]}'"
)
if all(chan not in subscribed_channels for chan in input_channels):
raise ValueError(
f"None of the input channels {input_channels} "
f"are subscribed to by any node"
)
all_output_channels = set[str]()
if isinstance(output_channels, str):
all_output_channels.add(output_channels)
else:
all_output_channels.update(output_channels)
if isinstance(stream_channels, str):
all_output_channels.add(stream_channels)
elif stream_channels is not None:
all_output_channels.update(stream_channels)
for chan in all_output_channels:
if chan not in channels:
raise ValueError(
f"Output channel '{chan}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if interrupt_after_nodes != "*":
for n in interrupt_after_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
if interrupt_before_nodes != "*":
for n in interrupt_before_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
def validate_keys(
keys: str | Sequence[str] | None,
channels: Mapping[str, Any],
) -> None:
if isinstance(keys, str):
if keys not in channels:
raise ValueError(f"Key {keys} not in channels")
elif keys is not None:
for chan in keys:
if chan not in channels:
raise ValueError(f"Key {chan} not in channels")