from __future__ import annotations
from collections import Counter
from collections.abc import Iterator, Mapping, Sequence
from typing import Any, Literal
from langgraph._internal._constants import (
ERROR,
INTERRUPT,
NULL_TASK_ID,
RESUME,
RETURN,
TASKS,
)
from langgraph._internal._typing import EMPTY_SEQ, MISSING
from langgraph.channels.base import BaseChannel, EmptyChannelError
from langgraph.constants import START, TAG_HIDDEN
from langgraph.errors import InvalidUpdateError
from langgraph.pregel._log import logger
from langgraph.types import Command, PregelExecutableTask, Send
def read_channel(
channels: Mapping[str, BaseChannel],
chan: str,
*,
catch: bool = True,
) -> Any:
try:
return channels[chan].get()
except EmptyChannelError:
if catch:
return None
else:
raise
def read_channels(
channels: Mapping[str, BaseChannel],
select: Sequence[str] | str,
*,
skip_empty: bool = True,
) -> dict[str, Any] | Any:
if isinstance(select, str):
return read_channel(channels, select)
else:
values: dict[str, Any] = {}
for k in select:
try:
values[k] = read_channel(channels, k, catch=not skip_empty)
except EmptyChannelError:
pass
return values
def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is no parent graph")
if cmd.goto:
if isinstance(cmd.goto, (tuple, list)):
sends = cmd.goto
else:
sends = [cmd.goto]
for send in sends:
if isinstance(send, Send):
yield (NULL_TASK_ID, TASKS, send)
elif isinstance(send, str):
yield (NULL_TASK_ID, f"branch:to:{send}", START)
else:
raise TypeError(
f"In Command.goto, expected Send/str, got {type(send).__name__}"
)
if cmd.resume is not None:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
for k, v in cmd._update_as_tuples():
yield (NULL_TASK_ID, k, v)
def map_input(
input_channels: str | Sequence[str],
chunk: dict[str, Any] | Any | None,
) -> Iterator[tuple[str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if chunk is None:
return
elif isinstance(input_channels, str):
yield (input_channels, chunk)
else:
if not isinstance(chunk, dict):
raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
for k in chunk:
if k in input_channels:
yield (k, chunk[k])
else:
logger.warning(f"Input channel {k} not found in {input_channels}")
def map_output_values(
output_channels: str | Sequence[str],
pending_writes: Literal[True] | Sequence[tuple[str, Any]],
channels: Mapping[str, BaseChannel],
) -> Iterator[dict[str, Any] | Any]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
if isinstance(output_channels, str):
if pending_writes is True or any(
chan == output_channels for chan, _ in pending_writes
):
yield read_channel(channels, output_channels)
else:
if pending_writes is True or {
c for c, _ in pending_writes if c in output_channels
}:
yield read_channels(channels, output_channels)
def map_output_updates(
output_channels: str | Sequence[str],
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
cached: bool = False,
) -> Iterator[dict[str, Any | dict[str, Any]]]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
output_tasks = [
(t, ww)
for t, ww in tasks
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
and ww[0][0] != ERROR
and ww[0][0] != INTERRUPT
]
if not output_tasks:
return
updated: list[tuple[str, Any]] = []
for task, writes in output_tasks:
rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
if rtn is not MISSING:
updated.append((task.name, rtn))
elif isinstance(output_channels, str):
updated.extend(
(task.name, value) for chan, value in writes if chan == output_channels
)
elif any(chan in output_channels for chan, _ in writes):
counts = Counter(chan for chan, _ in writes)
if any(counts[chan] > 1 for chan in output_channels):
updated.extend(
(
task.name,
{chan: value},
)
for chan, value in writes
if chan in output_channels
)
else:
updated.append(
(
task.name,
{
chan: value
for chan, value in writes
if chan in output_channels
},
)
)
grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
for node, value in updated:
grouped[node].append(value)
for node, value in grouped.items():
if len(value) == 0:
grouped[node] = None
if len(value) == 1:
grouped[node] = value[0]
if cached:
grouped["__metadata__"] = {"cached": cached}
yield grouped