Skip to main content
Glama
kaman05010

MCP Wikipedia Server

by kaman05010
main.py129 kB
from __future__ import annotations import asyncio import concurrent import concurrent.futures import queue import warnings import weakref from collections import defaultdict, deque from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from dataclasses import is_dataclass from functools import partial from inspect import isclass from typing import Any, Callable, Generic, Optional, Union, cast, get_type_hints from uuid import UUID, uuid5 from langchain_core.globals import get_debug from langchain_core.runnables import ( RunnableSequence, ) from langchain_core.runnables.base import Input, Output from langchain_core.runnables.config import ( RunnableConfig, get_async_callback_manager_for_config, get_callback_manager_for_config, ) from langchain_core.runnables.graph import Graph from pydantic import BaseModel, TypeAdapter from typing_extensions import Self, Unpack, deprecated, is_typeddict from langgraph._internal._config import ( ensure_config, merge_configs, patch_checkpoint_map, patch_config, patch_configurable, recast_checkpoint_ns, ) from langgraph._internal._constants import ( CACHE_NS_WRITES, CONF, CONFIG_KEY_CACHE, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_DURABILITY, CONFIG_KEY_NODE_FINISHED, CONFIG_KEY_READ, CONFIG_KEY_RUNNER_SUBMIT, CONFIG_KEY_RUNTIME, CONFIG_KEY_SEND, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, CONFIG_KEY_THREAD_ID, ERROR, INPUT, INTERRUPT, NS_END, NS_SEP, NULL_TASK_ID, PUSH, TASKS, ) from langgraph._internal._pydantic import create_model from langgraph._internal._queue import ( # type: ignore[attr-defined] AsyncQueue, SyncQueue, ) from langgraph._internal._runnable import ( Runnable, RunnableLike, RunnableSeq, coerce_to_runnable, ) from langgraph._internal._typing import MISSING, DeprecatedKwargs from langgraph.cache.base import BaseCache from langgraph.channels.base import BaseChannel from langgraph.channels.topic import Topic from langgraph.checkpoint.base import ( BaseCheckpointSaver, Checkpoint, CheckpointTuple, ) from langgraph.config import get_config from langgraph.constants import END from langgraph.errors import ( ErrorCode, GraphRecursionError, InvalidUpdateError, create_error_message, ) from langgraph.managed.base import ManagedValueSpec from langgraph.pregel._algo import ( PregelTaskWrites, _scratchpad, apply_writes, local_read, prepare_next_tasks, ) from langgraph.pregel._call import identifier from langgraph.pregel._checkpoint import ( channels_from_checkpoint, copy_checkpoint, create_checkpoint, empty_checkpoint, ) from langgraph.pregel._draw import draw_graph from langgraph.pregel._io import map_input, read_channels from langgraph.pregel._loop import AsyncPregelLoop, SyncPregelLoop from langgraph.pregel._messages import StreamMessagesHandler from langgraph.pregel._read import DEFAULT_BOUND, PregelNode from langgraph.pregel._retry import RetryPolicy from langgraph.pregel._runner import PregelRunner from langgraph.pregel._utils import get_new_channel_versions from langgraph.pregel._validate import validate_graph, validate_keys from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry from langgraph.pregel.debug import get_bolded_text, get_colored_text, tasks_w_writes from langgraph.pregel.protocol import PregelProtocol, StreamChunk, StreamProtocol from langgraph.runtime import DEFAULT_RUNTIME, Runtime from langgraph.store.base import BaseStore from langgraph.types import ( All, CachePolicy, Checkpointer, Command, Durability, Interrupt, Send, StateSnapshot, StateUpdate, StreamMode, ) from langgraph.typing import ContextT, InputT, OutputT, StateT from langgraph.warnings import LangGraphDeprecatedSinceV10 try: from langchain_core.tracers._streaming import _StreamingCallbackHandler except ImportError: _StreamingCallbackHandler = None # type: ignore __all__ = ("NodeBuilder", "Pregel") _WriteValue = Union[Callable[[Input], Output], Any] class NodeBuilder: __slots__ = ( "_channels", "_triggers", "_tags", "_metadata", "_writes", "_bound", "_retry_policy", "_cache_policy", ) _channels: str | list[str] _triggers: list[str] _tags: list[str] _metadata: dict[str, Any] _writes: list[ChannelWriteEntry] _bound: Runnable _retry_policy: list[RetryPolicy] _cache_policy: CachePolicy | None def __init__( self, ) -> None: self._channels = [] self._triggers = [] self._tags = [] self._metadata = {} self._writes = [] self._bound = DEFAULT_BOUND self._retry_policy = [] self._cache_policy = None def subscribe_only( self, channel: str, ) -> Self: """Subscribe to a single channel.""" if not self._channels: self._channels = channel else: raise ValueError( "Cannot subscribe to single channels when other channels are already subscribed to" ) self._triggers.append(channel) return self def subscribe_to( self, *channels: str, read: bool = True, ) -> Self: """Add channels to subscribe to. Node will be invoked when any of these channels are updated, with a dict of the channel values as input. Args: channels: Channel name(s) to subscribe to read: If True, the channels will be included in the input to the node. Otherwise, they will trigger the node without being sent in input. Returns: Self for chaining """ if isinstance(self._channels, str): raise ValueError( "Cannot subscribe to channels when subscribed to a single channel" ) if read: if not self._channels: self._channels = list(channels) else: self._channels.extend(channels) if isinstance(channels, str): self._triggers.append(channels) else: self._triggers.extend(channels) return self def read_from( self, *channels: str, ) -> Self: """Adds the specified channels to read from, without subscribing to them.""" assert isinstance(self._channels, list), ( "Cannot read additional channels when subscribed to single channels" ) self._channels.extend(channels) return self def do( self, node: RunnableLike, ) -> Self: """Adds the specified node.""" if self._bound is not DEFAULT_BOUND: self._bound = RunnableSeq( self._bound, coerce_to_runnable(node, name=None, trace=True) ) else: self._bound = coerce_to_runnable(node, name=None, trace=True) return self def write_to( self, *channels: str | ChannelWriteEntry, **kwargs: _WriteValue, ) -> Self: """Add channel writes. Args: *channels: Channel names to write to **kwargs: Channel name and value mappings Returns: Self for chaining """ self._writes.extend( ChannelWriteEntry(c) if isinstance(c, str) else c for c in channels ) self._writes.extend( ChannelWriteEntry(k, mapper=v) if callable(v) else ChannelWriteEntry(k, value=v) for k, v in kwargs.items() ) return self def meta(self, *tags: str, **metadata: Any) -> Self: """Add tags or metadata to the node.""" self._tags.extend(tags) self._metadata.update(metadata) return self def add_retry_policies(self, *policies: RetryPolicy) -> Self: """Adds retry policies to the node.""" self._retry_policy.extend(policies) return self def add_cache_policy(self, policy: CachePolicy) -> Self: """Adds cache policies to the node.""" self._cache_policy = policy return self def build(self) -> PregelNode: """Builds the node.""" return PregelNode( channels=self._channels, triggers=self._triggers, tags=self._tags, metadata=self._metadata, writers=[ChannelWrite(self._writes)], bound=self._bound, retry_policy=self._retry_policy, cache_policy=self._cache_policy, ) class Pregel( PregelProtocol[StateT, ContextT, InputT, OutputT], Generic[StateT, ContextT, InputT, OutputT], ): """Pregel manages the runtime behavior for LangGraph applications. ## Overview Pregel combines [**actors**](https://en.wikipedia.org/wiki/Actor_model) and **channels** into a single application. **Actors** read data from channels and write data to channels. Pregel organizes the execution of the application into multiple steps, following the **Pregel Algorithm**/**Bulk Synchronous Parallel** model. Each step consists of three phases: - **Plan**: Determine which **actors** to execute in this step. For example, in the first step, select the **actors** that subscribe to the special **input** channels; in subsequent steps, select the **actors** that subscribe to channels updated in the previous step. - **Execution**: Execute all selected **actors** in parallel, until all complete, or one fails, or a timeout is reached. During this phase, channel updates are invisible to actors until the next step. - **Update**: Update the channels with the values written by the **actors** in this step. Repeat until no **actors** are selected for execution, or a maximum number of steps is reached. ## Actors An **actor** is a `PregelNode`. It subscribes to channels, reads data from them, and writes data to them. It can be thought of as an **actor** in the Pregel algorithm. `PregelNodes` implement LangChain's Runnable interface. ## Channels Channels are used to communicate between actors (`PregelNodes`). Each channel has a value type, an update type, and an update function – which takes a sequence of updates and modifies the stored value. Channels can be used to send data from one chain to another, or to send data from a chain to itself in a future step. LangGraph provides a number of built-in channels: ### Basic channels: LastValue and Topic - `LastValue`: The default channel, stores the last value sent to the channel, useful for input and output values, or for sending data from one step to the next - `Topic`: A configurable PubSub Topic, useful for sending multiple values between *actors*, or for accumulating output. Can be configured to deduplicate values, and/or to accumulate values over the course of multiple steps. ### Advanced channels: Context and BinaryOperatorAggregate - `Context`: exposes the value of a context manager, managing its lifecycle. Useful for accessing external resources that require setup and/or teardown. eg. `client = Context(httpx.Client)` - `BinaryOperatorAggregate`: stores a persistent value, updated by applying a binary operator to the current value and each update sent to the channel, useful for computing aggregates over multiple steps. eg. `total = BinaryOperatorAggregate(int, operator.add)` ## Examples Most users will interact with Pregel via a [StateGraph (Graph API)][langgraph.graph.StateGraph] or via an [entrypoint (Functional API)][langgraph.func.entrypoint]. However, for **advanced** use cases, Pregel can be used directly. If you're not sure whether you need to use Pregel directly, then the answer is probably no – you should use the Graph API or Functional API instead. These are higher-level interfaces that will compile down to Pregel under the hood. Here are some examples to give you a sense of how it works: Example: Single node application ```python from langgraph.channels import EphemeralValue from langgraph.pregel import Pregel, NodeBuilder node1 = ( NodeBuilder().subscribe_only("a") .do(lambda x: x + x) .write_to("b") ) app = Pregel( nodes={"node1": node1}, channels={ "a": EphemeralValue(str), "b": EphemeralValue(str), }, input_channels=["a"], output_channels=["b"], ) app.invoke({"a": "foo"}) ``` ```con {'b': 'foofoo'} ``` Example: Using multiple nodes and multiple output channels ```python from langgraph.channels import LastValue, EphemeralValue from langgraph.pregel import Pregel, NodeBuilder node1 = ( NodeBuilder().subscribe_only("a") .do(lambda x: x + x) .write_to("b") ) node2 = ( NodeBuilder().subscribe_to("b") .do(lambda x: x["b"] + x["b"]) .write_to("c") ) app = Pregel( nodes={"node1": node1, "node2": node2}, channels={ "a": EphemeralValue(str), "b": LastValue(str), "c": EphemeralValue(str), }, input_channels=["a"], output_channels=["b", "c"], ) app.invoke({"a": "foo"}) ``` ```con {'b': 'foofoo', 'c': 'foofoofoofoo'} ``` Example: Using a Topic channel ```python from langgraph.channels import LastValue, EphemeralValue, Topic from langgraph.pregel import Pregel, NodeBuilder node1 = ( NodeBuilder().subscribe_only("a") .do(lambda x: x + x) .write_to("b", "c") ) node2 = ( NodeBuilder().subscribe_only("b") .do(lambda x: x + x) .write_to("c") ) app = Pregel( nodes={"node1": node1, "node2": node2}, channels={ "a": EphemeralValue(str), "b": EphemeralValue(str), "c": Topic(str, accumulate=True), }, input_channels=["a"], output_channels=["c"], ) app.invoke({"a": "foo"}) ``` ```pycon {'c': ['foofoo', 'foofoofoofoo']} ``` Example: Using a BinaryOperatorAggregate channel ```python from langgraph.channels import EphemeralValue, BinaryOperatorAggregate from langgraph.pregel import Pregel, NodeBuilder node1 = ( NodeBuilder().subscribe_only("a") .do(lambda x: x + x) .write_to("b", "c") ) node2 = ( NodeBuilder().subscribe_only("b") .do(lambda x: x + x) .write_to("c") ) def reducer(current, update): if current: return current + " | " + update else: return update app = Pregel( nodes={"node1": node1, "node2": node2}, channels={ "a": EphemeralValue(str), "b": EphemeralValue(str), "c": BinaryOperatorAggregate(str, operator=reducer), }, input_channels=["a"], output_channels=["c"] ) app.invoke({"a": "foo"}) ``` ```con {'c': 'foofoo | foofoofoofoo'} ``` Example: Introducing a cycle This example demonstrates how to introduce a cycle in the graph, by having a chain write to a channel it subscribes to. Execution will continue until a None value is written to the channel. ```python from langgraph.channels import EphemeralValue from langgraph.pregel import Pregel, NodeBuilder, ChannelWriteEntry example_node = ( NodeBuilder().subscribe_only("value") .do(lambda x: x + x if len(x) < 10 else None) .write_to(ChannelWriteEntry(channel="value", skip_none=True)) ) app = Pregel( nodes={"example_node": example_node}, channels={ "value": EphemeralValue(str), }, input_channels=["value"], output_channels=["value"] ) app.invoke({"value": "a"}) ``` ```con {'value': 'aaaaaaaaaaaaaaaa'} ``` """ nodes: dict[str, PregelNode] channels: dict[str, BaseChannel | ManagedValueSpec] stream_mode: StreamMode = "values" """Mode to stream output, defaults to 'values'.""" stream_eager: bool = False """Whether to force emitting stream events eagerly, automatically turned on for stream_mode "messages" and "custom".""" output_channels: str | Sequence[str] stream_channels: str | Sequence[str] | None = None """Channels to stream, defaults to all channels not in reserved channels""" interrupt_after_nodes: All | Sequence[str] interrupt_before_nodes: All | Sequence[str] input_channels: str | Sequence[str] step_timeout: float | None = None """Maximum time to wait for a step to complete, in seconds. Defaults to None.""" debug: bool """Whether to print debug information during execution. Defaults to False.""" checkpointer: Checkpointer = None """Checkpointer used to save and load graph state. Defaults to None.""" store: BaseStore | None = None """Memory store to use for SharedValues. Defaults to None.""" cache: BaseCache | None = None """Cache to use for storing node results. Defaults to None.""" retry_policy: Sequence[RetryPolicy] = () """Retry policies to use when running tasks. Empty set disables retries.""" cache_policy: CachePolicy | None = None """Cache policy to use for all nodes. Can be overridden by individual nodes. Defaults to None.""" context_schema: type[ContextT] | None = None """Specifies the schema for the context object that will be passed to the workflow.""" config: RunnableConfig | None = None name: str = "LangGraph" trigger_to_nodes: Mapping[str, Sequence[str]] def __init__( self, *, nodes: dict[str, PregelNode | NodeBuilder], channels: dict[str, BaseChannel | ManagedValueSpec] | None, auto_validate: bool = True, stream_mode: StreamMode = "values", stream_eager: bool = False, output_channels: str | Sequence[str], stream_channels: str | Sequence[str] | None = None, interrupt_after_nodes: All | Sequence[str] = (), interrupt_before_nodes: All | Sequence[str] = (), input_channels: str | Sequence[str], step_timeout: float | None = None, debug: bool | None = None, checkpointer: BaseCheckpointSaver | None = None, store: BaseStore | None = None, cache: BaseCache | None = None, retry_policy: RetryPolicy | Sequence[RetryPolicy] = (), cache_policy: CachePolicy | None = None, context_schema: type[ContextT] | None = None, config: RunnableConfig | None = None, trigger_to_nodes: Mapping[str, Sequence[str]] | None = None, name: str = "LangGraph", **deprecated_kwargs: Unpack[DeprecatedKwargs], ) -> None: if ( config_type := deprecated_kwargs.get("config_type"), MISSING, ) is not MISSING: warnings.warn( "`config_type` is deprecated and will be removed. Please use `context_schema` instead.", category=LangGraphDeprecatedSinceV10, stacklevel=2, ) if context_schema is None: context_schema = cast(type[ContextT], config_type) self.nodes = { k: v.build() if isinstance(v, NodeBuilder) else v for k, v in nodes.items() } self.channels = channels or {} if TASKS in self.channels and not isinstance(self.channels[TASKS], Topic): raise ValueError( f"Channel '{TASKS}' is reserved and cannot be used in the graph." ) else: self.channels[TASKS] = Topic(Send, accumulate=False) self.stream_mode = stream_mode self.stream_eager = stream_eager self.output_channels = output_channels self.stream_channels = stream_channels self.interrupt_after_nodes = interrupt_after_nodes self.interrupt_before_nodes = interrupt_before_nodes self.input_channels = input_channels self.step_timeout = step_timeout self.debug = debug if debug is not None else get_debug() self.checkpointer = checkpointer self.store = store self.cache = cache self.retry_policy = ( (retry_policy,) if isinstance(retry_policy, RetryPolicy) else retry_policy ) self.cache_policy = cache_policy self.context_schema = context_schema self.config = config self.trigger_to_nodes = trigger_to_nodes or {} self.name = name if auto_validate: self.validate() def get_graph( self, config: RunnableConfig | None = None, *, xray: int | bool = False ) -> Graph: """Return a drawable representation of the computation graph.""" # gather subgraphs if xray: subgraphs = { k: v.get_graph( config, xray=xray if isinstance(xray, bool) or xray <= 0 else xray - 1, ) for k, v in self.get_subgraphs() } else: subgraphs = {} return draw_graph( merge_configs(self.config, config), nodes=self.nodes, specs=self.channels, input_channels=self.input_channels, interrupt_after_nodes=self.interrupt_after_nodes, interrupt_before_nodes=self.interrupt_before_nodes, trigger_to_nodes=self.trigger_to_nodes, checkpointer=self.checkpointer, subgraphs=subgraphs, ) async def aget_graph( self, config: RunnableConfig | None = None, *, xray: int | bool = False ) -> Graph: """Return a drawable representation of the computation graph.""" # gather subgraphs if xray: subpregels: dict[str, PregelProtocol] = { k: v async for k, v in self.aget_subgraphs() } subgraphs = { k: v for k, v in zip( subpregels, await asyncio.gather( *( p.aget_graph( config, xray=xray if isinstance(xray, bool) or xray <= 0 else xray - 1, ) for p in subpregels.values() ) ), ) } else: subgraphs = {} return draw_graph( merge_configs(self.config, config), nodes=self.nodes, specs=self.channels, input_channels=self.input_channels, interrupt_after_nodes=self.interrupt_after_nodes, interrupt_before_nodes=self.interrupt_before_nodes, trigger_to_nodes=self.trigger_to_nodes, checkpointer=self.checkpointer, subgraphs=subgraphs, ) def _repr_mimebundle_(self, **kwargs: Any) -> dict[str, Any]: """Mime bundle used by Jupyter to display the graph""" return { "text/plain": repr(self), "image/png": self.get_graph().draw_mermaid_png(), } def copy(self, update: dict[str, Any] | None = None) -> Self: attrs = {k: v for k, v in self.__dict__.items() if k != "__orig_class__"} attrs.update(update or {}) return self.__class__(**attrs) def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self: """Create a copy of the Pregel object with an updated config.""" return self.copy( {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))} ) def validate(self) -> Self: validate_graph( self.nodes, {k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)}, {k: v for k, v in self.channels.items() if not isinstance(v, BaseChannel)}, self.input_channels, self.output_channels, self.stream_channels, self.interrupt_after_nodes, self.interrupt_before_nodes, ) self.trigger_to_nodes = _trigger_to_nodes(self.nodes) return self @deprecated( "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead." ) def config_schema(self, *, include: Sequence[str] | None = None) -> type[BaseModel]: warnings.warn( "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead.", category=LangGraphDeprecatedSinceV10, stacklevel=2, ) include = include or [] fields = { **( {"configurable": (self.context_schema, None)} if self.context_schema else {} ), **{ field_name: (field_type, None) for field_name, field_type in get_type_hints(RunnableConfig).items() if field_name in [i for i in include if i != "configurable"] }, } return create_model(self.get_name("Config"), field_definitions=fields) @deprecated( "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead." ) def get_config_jsonschema( self, *, include: Sequence[str] | None = None ) -> dict[str, Any]: warnings.warn( "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead.", category=LangGraphDeprecatedSinceV10, stacklevel=2, ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=LangGraphDeprecatedSinceV10) schema = self.config_schema(include=include) return schema.model_json_schema() def get_context_jsonschema(self) -> dict[str, Any] | None: if (context_schema := self.context_schema) is None: return None if isclass(context_schema) and issubclass(context_schema, BaseModel): return context_schema.model_json_schema() elif is_typeddict(context_schema) or is_dataclass(context_schema): return TypeAdapter(context_schema).json_schema() else: raise ValueError( f"Invalid context schema type: {context_schema}. Must be a BaseModel, TypedDict or dataclass." ) @property def InputType(self) -> Any: if isinstance(self.input_channels, str): channel = self.channels[self.input_channels] if isinstance(channel, BaseChannel): return channel.UpdateType def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: config = merge_configs(self.config, config) if isinstance(self.input_channels, str): return super().get_input_schema(config) else: return create_model( self.get_name("Input"), field_definitions={ k: (c.UpdateType, None) for k in self.input_channels or self.channels.keys() if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) def get_input_jsonschema( self, config: RunnableConfig | None = None ) -> dict[str, Any]: schema = self.get_input_schema(config) return schema.model_json_schema() @property def OutputType(self) -> Any: if isinstance(self.output_channels, str): channel = self.channels[self.output_channels] if isinstance(channel, BaseChannel): return channel.ValueType def get_output_schema( self, config: RunnableConfig | None = None ) -> type[BaseModel]: config = merge_configs(self.config, config) if isinstance(self.output_channels, str): return super().get_output_schema(config) else: return create_model( self.get_name("Output"), field_definitions={ k: (c.ValueType, None) for k in self.output_channels if (c := self.channels[k]) and isinstance(c, BaseChannel) }, ) def get_output_jsonschema( self, config: RunnableConfig | None = None ) -> dict[str, Any]: schema = self.get_output_schema(config) return schema.model_json_schema() @property def stream_channels_list(self) -> Sequence[str]: stream_channels = self.stream_channels_asis return ( [stream_channels] if isinstance(stream_channels, str) else stream_channels ) @property def stream_channels_asis(self) -> str | Sequence[str]: return self.stream_channels or [ k for k in self.channels if isinstance(self.channels[k], BaseChannel) ] def get_subgraphs( self, *, namespace: str | None = None, recurse: bool = False ) -> Iterator[tuple[str, PregelProtocol]]: """Get the subgraphs of the graph. Args: namespace: The namespace to filter the subgraphs by. recurse: Whether to recurse into the subgraphs. If False, only the immediate subgraphs will be returned. Returns: Iterator[tuple[str, PregelProtocol]]: An iterator of the (namespace, subgraph) pairs. """ for name, node in self.nodes.items(): # filter by prefix if namespace is not None: if not namespace.startswith(name): continue # find the subgraph, if any graph = node.subgraphs[0] if node.subgraphs else None # if found, yield recursively if graph: if name == namespace: yield name, graph return # we found it, stop searching if namespace is None: yield name, graph if recurse and isinstance(graph, Pregel): if namespace is not None: namespace = namespace[len(name) + 1 :] yield from ( (f"{name}{NS_SEP}{n}", s) for n, s in graph.get_subgraphs( namespace=namespace, recurse=recurse ) ) async def aget_subgraphs( self, *, namespace: str | None = None, recurse: bool = False ) -> AsyncIterator[tuple[str, PregelProtocol]]: """Get the subgraphs of the graph. Args: namespace: The namespace to filter the subgraphs by. recurse: Whether to recurse into the subgraphs. If False, only the immediate subgraphs will be returned. Returns: AsyncIterator[tuple[str, PregelProtocol]]: An iterator of the (namespace, subgraph) pairs. """ for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse): yield name, node def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None: """Migrate a saved checkpoint to new channel layout.""" if checkpoint["v"] < 4 and checkpoint.get("pending_sends"): pending_sends: list[Send] = checkpoint.pop("pending_sends") checkpoint["channel_values"][TASKS] = pending_sends checkpoint["channel_versions"][TASKS] = max( checkpoint["channel_versions"].values() ) def _prepare_state_snapshot( self, config: RunnableConfig, saved: CheckpointTuple | None, recurse: BaseCheckpointSaver | None = None, apply_pending_writes: bool = False, ) -> StateSnapshot: if not saved: return StateSnapshot( values={}, next=(), config=config, metadata=None, created_at=None, parent_config=None, tasks=(), interrupts=(), ) # migrate checkpoint if needed self._migrate_checkpoint(saved.checkpoint) step = saved.metadata.get("step", -1) + 1 stop = step + 2 channels, managed = channels_from_checkpoint( self.channels, saved.checkpoint, ) # tasks for this checkpoint next_tasks = prepare_next_tasks( saved.checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, step, stop, for_execution=True, store=self.store, checkpointer=( self.checkpointer if isinstance(self.checkpointer, BaseCheckpointSaver) else None ), manager=None, ) # get the subgraphs subgraphs = dict(self.get_subgraphs()) parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, RunnableConfig | StateSnapshot] = {} for task in next_tasks.values(): if task.name not in subgraphs: continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" if not recurse: # set config as signal that subgraph checkpoints exist config = { CONF: { "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { CONF: { CONFIG_KEY_CHECKPOINTER: recurse, "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = subgraphs[task.name].get_state( config, subgraphs=True ) # apply pending writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, self.trigger_to_nodes, ) if apply_pending_writes and saved.pending_writes: for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes( saved.checkpoint, channels, tasks, None, self.trigger_to_nodes ) tasks_with_writes = tasks_w_writes( next_tasks.values(), saved.pending_writes, task_states, self.stream_channels_asis, ) # assemble the state snapshot return StateSnapshot( read_channels(channels, self.stream_channels_asis), tuple(t.name for t in next_tasks.values() if not t.writes), patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_with_writes, tuple([i for task in tasks_with_writes for i in task.interrupts]), ) async def _aprepare_state_snapshot( self, config: RunnableConfig, saved: CheckpointTuple | None, recurse: BaseCheckpointSaver | None = None, apply_pending_writes: bool = False, ) -> StateSnapshot: if not saved: return StateSnapshot( values={}, next=(), config=config, metadata=None, created_at=None, parent_config=None, tasks=(), interrupts=(), ) # migrate checkpoint if needed self._migrate_checkpoint(saved.checkpoint) step = saved.metadata.get("step", -1) + 1 stop = step + 2 channels, managed = channels_from_checkpoint( self.channels, saved.checkpoint, ) # tasks for this checkpoint next_tasks = prepare_next_tasks( saved.checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, step, stop, for_execution=True, store=self.store, checkpointer=( self.checkpointer if isinstance(self.checkpointer, BaseCheckpointSaver) else None ), manager=None, ) # get the subgraphs subgraphs = {n: g async for n, g in self.aget_subgraphs()} parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, RunnableConfig | StateSnapshot] = {} for task in next_tasks.values(): if task.name not in subgraphs: continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" if not recurse: # set config as signal that subgraph checkpoints exist config = { CONF: { "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { CONF: { CONFIG_KEY_CHECKPOINTER: recurse, "thread_id": saved.config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = await subgraphs[task.name].aget_state( config, subgraphs=True ) # apply pending writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( saved.checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], None, self.trigger_to_nodes, ) if apply_pending_writes and saved.pending_writes: for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes( saved.checkpoint, channels, tasks, None, self.trigger_to_nodes ) tasks_with_writes = tasks_w_writes( next_tasks.values(), saved.pending_writes, task_states, self.stream_channels_asis, ) # assemble the state snapshot return StateSnapshot( read_channels(channels, self.stream_channels_asis), tuple(t.name for t in next_tasks.values() if not t.writes), patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_with_writes, tuple([i for task in tasks_with_writes for i in task.interrupts]), ) def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name for _, pregel in self.get_subgraphs(namespace=recast, recurse=True): return pregel.get_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), subgraphs=subgraphs, ) else: raise ValueError(f"Subgraph {recast} not found") config = merge_configs(self.config, config) if self.config else config if self.checkpointer is True: ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS]) config = merge_configs( config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}} ) thread_id = config[CONF][CONFIG_KEY_THREAD_ID] if not isinstance(thread_id, str): config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id) saved = checkpointer.get_tuple(config) return self._prepare_state_snapshot( config, saved, recurse=checkpointer if subgraphs else None, apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF], ) async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True): return await pregel.aget_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), subgraphs=subgraphs, ) else: raise ValueError(f"Subgraph {recast} not found") config = merge_configs(self.config, config) if self.config else config if self.checkpointer is True: ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS]) config = merge_configs( config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}} ) thread_id = config[CONF][CONFIG_KEY_THREAD_ID] if not isinstance(thread_id, str): config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id) saved = await checkpointer.aget_tuple(config) return await self._aprepare_state_snapshot( config, saved, recurse=checkpointer if subgraphs else None, apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF], ) def get_state_history( self, config: RunnableConfig, *, filter: dict[str, Any] | None = None, before: RunnableConfig | None = None, limit: int | None = None, ) -> Iterator[StateSnapshot]: """Get the history of the state of the graph.""" config = ensure_config(config) checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name for _, pregel in self.get_subgraphs(namespace=recast, recurse=True): yield from pregel.get_state_history( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), filter=filter, before=before, limit=limit, ) return else: raise ValueError(f"Subgraph {recast} not found") config = merge_configs( self.config, config, { CONF: { CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns, CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]), } }, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in list( checkpointer.list(config, before=before, limit=limit, filter=filter) ): yield self._prepare_state_snapshot( checkpoint_tuple.config, checkpoint_tuple ) async def aget_state_history( self, config: RunnableConfig, *, filter: dict[str, Any] | None = None, before: RunnableConfig | None = None, limit: int | None = None, ) -> AsyncIterator[StateSnapshot]: """Asynchronously get the history of the state of the graph.""" config = ensure_config(config) checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True): async for state in pregel.aget_state_history( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), filter=filter, before=before, limit=limit, ): yield state return else: raise ValueError(f"Subgraph {recast} not found") config = merge_configs( self.config, config, { CONF: { CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns, CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]), } }, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in [ c async for c in checkpointer.alist( config, before=before, limit=limit, filter=filter ) ]: yield await self._aprepare_state_snapshot( checkpoint_tuple.config, checkpoint_tuple ) def bulk_update_state( self, config: RunnableConfig, supersteps: Sequence[Sequence[StateUpdate]], ) -> RunnableConfig: """Apply updates to the graph state in bulk. Requires a checkpointer to be set. Args: config: The config to apply the updates to. supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state. Each update is a tuple of the form `(values, as_node, task_id)` where task_id is optional. Raises: ValueError: If no checkpointer is set or no updates are provided. InvalidUpdateError: If an invalid update is provided. Returns: RunnableConfig: The updated config. """ checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if len(supersteps) == 0: raise ValueError("No supersteps provided") if any(len(u) == 0 for u in supersteps): raise ValueError("No updates provided") # delegate to subgraph if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name for _, pregel in self.get_subgraphs(namespace=recast, recurse=True): return pregel.bulk_update_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), supersteps, ) else: raise ValueError(f"Subgraph {recast} not found") def perform_superstep( input_config: RunnableConfig, updates: Sequence[StateUpdate] ) -> RunnableConfig: # get last checkpoint config = ensure_config(self.config, input_config) saved = checkpointer.get_tuple(config) if saved is not None: self._migrate_checkpoint(saved.checkpoint) checkpoint = ( copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint() ) checkpoint_previous_versions = ( saved.checkpoint["channel_versions"].copy() if saved else {} ) step = saved.metadata.get("step", -1) if saved else -1 # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, { CONFIG_KEY_CHECKPOINT_NS: config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ) }, ) if saved: checkpoint_config = patch_configurable(config, saved.config[CONF]) channels, managed = channels_from_checkpoint( self.channels, checkpoint, ) values, as_node = updates[0][:2] # no values as END, just clear all tasks if values is None and as_node == END: if len(updates) > 1: raise InvalidUpdateError( "Cannot apply multiple updates when clearing state" ) if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, step + 1, step + 3, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) # apply writes from tasks that already ran for tid, k, v in saved.pending_writes or []: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) # clear all current tasks apply_writes( checkpoint, channels, next_tasks.values(), checkpointer.get_next_version, self.trigger_to_nodes, ) # save checkpoint next_config = checkpointer.put( checkpoint_config, create_checkpoint(checkpoint, channels, step), { "source": "update", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"], ), ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # act as an input if as_node == INPUT: if len(updates) > 1: raise InvalidUpdateError( "Cannot apply multiple updates when updating as input" ) if input_writes := deque(map_input(self.input_channels, values)): apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, input_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) # apply input write to channels next_step = ( step + 1 if saved and saved.metadata.get("step") is not None else -1 ) next_config = checkpointer.put( checkpoint_config, create_checkpoint(checkpoint, channels, next_step), { "source": "input", "step": next_step, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"], ), ) # store the writes checkpointer.put_writes( next_config, input_writes, str(uuid5(UUID(checkpoint["id"]), INPUT)), ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) else: raise InvalidUpdateError( f"Received no input writes for {self.input_channels}" ) # copy checkpoint if as_node == "__copy__": if len(updates) > 1: raise InvalidUpdateError( "Cannot copy checkpoint with multiple updates" ) if saved is None: raise InvalidUpdateError("Cannot copy a non-existent checkpoint") next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = checkpointer.put( saved.parent_config or patch_configurable( saved.config, {CONFIG_KEY_CHECKPOINT_ID: None} ), next_checkpoint, { "source": "fork", "step": step + 1, "parents": saved.metadata.get("parents", {}), }, {}, ) # we want to both clone a checkpoint and update state in one go. # reuse the same task ID if possible. if isinstance(values, list) and len(values) > 0: # figure out the task IDs for the next update checkpoint next_tasks = prepare_next_tasks( next_checkpoint, saved.pending_writes or [], self.nodes, channels, managed, next_config, step + 2, step + 4, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) tasks_group_by = defaultdict(list) user_group_by: dict[str, list[StateUpdate]] = defaultdict(list) for task in next_tasks.values(): tasks_group_by[task.name].append(task.id) for item in values: if not isinstance(item, Sequence): raise InvalidUpdateError( f"Invalid update item: {item} when copying checkpoint" ) values, as_node = item[:2] user_group = user_group_by[as_node] tasks_group = tasks_group_by[as_node] target_idx = len(user_group) task_id = ( tasks_group[target_idx] if target_idx < len(tasks_group) else None ) user_group_by[as_node].append( StateUpdate(values=values, as_node=as_node, task_id=task_id) ) return perform_superstep( patch_checkpoint_map(next_config, saved.metadata), [item for lst in user_group_by.values() for item in lst], ) return patch_checkpoint_map(next_config, saved.metadata) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] and saved is not None and saved.pending_writes ): # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes, self.nodes, channels, managed, saved.config, step + 1, step + 3, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) # apply writes for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes( checkpoint, channels, tasks, checkpointer.get_next_version, self.trigger_to_nodes, ) valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = [] if len(updates) == 1: values, as_node, task_id = updates[0] # find last node that updated the state, if not provided if as_node is None and len(self.nodes) == 1: as_node = tuple(self.nodes)[0] elif as_node is None and not any( v for vv in checkpoint["versions_seen"].values() for v in vv.values() ): if ( isinstance(self.input_channels, str) and self.input_channels in self.nodes ): as_node = self.input_channels elif as_node is None: last_seen_by_node = sorted( (v, n) for n, seen in checkpoint["versions_seen"].items() if n in self.nodes for v in seen.values() ) # if two nodes updated the state at the same time, it's ambiguous if last_seen_by_node: if len(last_seen_by_node) == 1: as_node = last_seen_by_node[0][1] elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]: as_node = last_seen_by_node[-1][1] if as_node is None: raise InvalidUpdateError("Ambiguous update, specify as_node") if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") valid_updates.append((as_node, values, task_id)) else: for values, as_node, task_id in updates: if as_node is None: raise InvalidUpdateError( "as_node is required when applying multiple updates" ) if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") valid_updates.append((as_node, values, task_id)) run_tasks: list[PregelTaskWrites] = [] run_task_ids: list[str] = [] for as_node, values, provided_task_id in valid_updates: # create task to run all writers of the chosen node writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = provided_task_id or str( uuid5(UUID(checkpoint["id"]), INTERRUPT) ) run_tasks.append(task) run_task_ids.append(task_id) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task run.invoke( values, patch_config( config, run_name=self.name + "UpdateState", configurable={ # deque.extend is thread-safe CONFIG_KEY_SEND: writes.extend, CONFIG_KEY_TASK_ID: task_id, CONFIG_KEY_READ: partial( local_read, _scratchpad( None, [], task_id, "", None, step, step + 2, ), channels, managed, task, ), }, ), ) # save task writes for task_id, task in zip(run_task_ids, run_tasks): # channel writes are saved to current checkpoint channel_writes = [w for w in task.writes if w[0] != PUSH] if saved and channel_writes: checkpointer.put_writes(checkpoint_config, channel_writes, task_id) # apply to checkpoint and save apply_writes( checkpoint, channels, run_tasks, checkpointer.get_next_version, self.trigger_to_nodes, ) checkpoint = create_checkpoint(checkpoint, channels, step + 1) next_config = checkpointer.put( checkpoint_config, checkpoint, { "source": "update", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"] ), ) for task_id, task in zip(run_task_ids, run_tasks): # save push writes if push_writes := [w for w in task.writes if w[0] == PUSH]: checkpointer.put_writes(next_config, push_writes, task_id) return patch_checkpoint_map(next_config, saved.metadata if saved else None) current_config = patch_configurable( config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])} ) for superstep in supersteps: current_config = perform_superstep(current_config, superstep) return current_config async def abulk_update_state( self, config: RunnableConfig, supersteps: Sequence[Sequence[StateUpdate]], ) -> RunnableConfig: """Asynchronously apply updates to the graph state in bulk. Requires a checkpointer to be set. Args: config: The config to apply the updates to. supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state. Each update is a tuple of the form `(values, as_node, task_id)` where task_id is optional. Raises: ValueError: If no checkpointer is set or no updates are provided. InvalidUpdateError: If an invalid update is provided. Returns: RunnableConfig: The updated config. """ checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if len(supersteps) == 0: raise ValueError("No supersteps provided") if any(len(u) == 0 for u in supersteps): raise ValueError("No updates provided") # delegate to subgraph if ( checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast = recast_checkpoint_ns(checkpoint_ns) # find the subgraph with the matching name async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True): return await pregel.abulk_update_state( patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}), supersteps, ) else: raise ValueError(f"Subgraph {recast} not found") async def aperform_superstep( input_config: RunnableConfig, updates: Sequence[StateUpdate] ) -> RunnableConfig: # get last checkpoint config = ensure_config(self.config, input_config) saved = await checkpointer.aget_tuple(config) if saved is not None: self._migrate_checkpoint(saved.checkpoint) checkpoint = ( copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint() ) checkpoint_previous_versions = ( saved.checkpoint["channel_versions"].copy() if saved else {} ) step = saved.metadata.get("step", -1) if saved else -1 # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, { CONFIG_KEY_CHECKPOINT_NS: config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ) }, ) if saved: checkpoint_config = patch_configurable(config, saved.config[CONF]) channels, managed = channels_from_checkpoint( self.channels, checkpoint, ) values, as_node = updates[0][:2] # no values, just clear all tasks if values is None and as_node == END: if len(updates) > 1: raise InvalidUpdateError( "Cannot apply multiple updates when clearing state" ) if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes or [], self.nodes, channels, managed, saved.config, step + 1, step + 3, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) # apply writes from tasks that already ran for tid, k, v in saved.pending_writes or []: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) # clear all current tasks apply_writes( checkpoint, channels, next_tasks.values(), checkpointer.get_next_version, self.trigger_to_nodes, ) # save checkpoint next_config = await checkpointer.aput( checkpoint_config, create_checkpoint(checkpoint, channels, step), { "source": "update", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"] ), ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # act as an input if as_node == INPUT: if len(updates) > 1: raise InvalidUpdateError( "Cannot apply multiple updates when updating as input" ) if input_writes := deque(map_input(self.input_channels, values)): apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, input_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) # apply input write to channels next_step = ( step + 1 if saved and saved.metadata.get("step") is not None else -1 ) next_config = await checkpointer.aput( checkpoint_config, create_checkpoint(checkpoint, channels, next_step), { "source": "input", "step": next_step, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"], ), ) # store the writes await checkpointer.aput_writes( next_config, input_writes, str(uuid5(UUID(checkpoint["id"]), INPUT)), ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) else: raise InvalidUpdateError( f"Received no input writes for {self.input_channels}" ) # no values, copy checkpoint if as_node == "__copy__": if len(updates) > 1: raise InvalidUpdateError( "Cannot copy checkpoint with multiple updates" ) if saved is None: raise InvalidUpdateError("Cannot copy a non-existent checkpoint") next_checkpoint = create_checkpoint(checkpoint, None, step) # copy checkpoint next_config = await checkpointer.aput( saved.parent_config or patch_configurable( saved.config, {CONFIG_KEY_CHECKPOINT_ID: None} ), next_checkpoint, { "source": "fork", "step": step + 1, "parents": saved.metadata.get("parents", {}), }, {}, ) # we want to both clone a checkpoint and update state in one go. # reuse the same task ID if possible. if isinstance(values, list) and len(values) > 0: # figure out the task IDs for the next update checkpoint next_tasks = prepare_next_tasks( next_checkpoint, saved.pending_writes or [], self.nodes, channels, managed, next_config, step + 2, step + 4, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) tasks_group_by = defaultdict(list) user_group_by: dict[str, list[StateUpdate]] = defaultdict(list) for task in next_tasks.values(): tasks_group_by[task.name].append(task.id) for item in values: if not isinstance(item, Sequence): raise InvalidUpdateError( f"Invalid update item: {item} when copying checkpoint" ) values, as_node = item[:2] user_group = user_group_by[as_node] tasks_group = tasks_group_by[as_node] target_idx = len(user_group) task_id = ( tasks_group[target_idx] if target_idx < len(tasks_group) else None ) user_group_by[as_node].append( StateUpdate(values=values, as_node=as_node, task_id=task_id) ) return await aperform_superstep( patch_checkpoint_map(next_config, saved.metadata), [item for lst in user_group_by.values() for item in lst], ) return patch_checkpoint_map( next_config, saved.metadata if saved else None ) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] and saved is not None and saved.pending_writes ): # tasks for this checkpoint next_tasks = prepare_next_tasks( checkpoint, saved.pending_writes, self.nodes, channels, managed, saved.config, step + 1, step + 3, for_execution=True, store=self.store, checkpointer=checkpointer, manager=None, ) # apply null writes if null_writes := [ w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID ]: apply_writes( checkpoint, channels, [PregelTaskWrites((), INPUT, null_writes, [])], checkpointer.get_next_version, self.trigger_to_nodes, ) for tid, k, v in saved.pending_writes: if k in (ERROR, INTERRUPT): continue if tid not in next_tasks: continue next_tasks[tid].writes.append((k, v)) if tasks := [t for t in next_tasks.values() if t.writes]: apply_writes( checkpoint, channels, tasks, checkpointer.get_next_version, self.trigger_to_nodes, ) valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = [] if len(updates) == 1: values, as_node, task_id = updates[0] # find last node that updated the state, if not provided if as_node is None and len(self.nodes) == 1: as_node = tuple(self.nodes)[0] elif as_node is None and not saved: if ( isinstance(self.input_channels, str) and self.input_channels in self.nodes ): as_node = self.input_channels elif as_node is None: last_seen_by_node = sorted( (v, n) for n, seen in checkpoint["versions_seen"].items() if n in self.nodes for v in seen.values() ) # if two nodes updated the state at the same time, it's ambiguous if last_seen_by_node: if len(last_seen_by_node) == 1: as_node = last_seen_by_node[0][1] elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]: as_node = last_seen_by_node[-1][1] if as_node is None: raise InvalidUpdateError("Ambiguous update, specify as_node") if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") valid_updates.append((as_node, values, task_id)) else: for values, as_node, task_id in updates: if as_node is None: raise InvalidUpdateError( "as_node is required when applying multiple updates" ) if as_node not in self.nodes: raise InvalidUpdateError(f"Node {as_node} does not exist") valid_updates.append((as_node, values, task_id)) run_tasks: list[PregelTaskWrites] = [] run_task_ids: list[str] = [] for as_node, values, provided_task_id in valid_updates: # create task to run all writers of the chosen node writers = self.nodes[as_node].flat_writers if not writers: raise InvalidUpdateError(f"Node {as_node} has no writers") writes: deque[tuple[str, Any]] = deque() task = PregelTaskWrites((), as_node, writes, [INTERRUPT]) task_id = provided_task_id or str( uuid5(UUID(checkpoint["id"]), INTERRUPT) ) run_tasks.append(task) run_task_ids.append(task_id) run = RunnableSequence(*writers) if len(writers) > 1 else writers[0] # execute task await run.ainvoke( values, patch_config( config, run_name=self.name + "UpdateState", configurable={ # deque.extend is thread-safe CONFIG_KEY_SEND: writes.extend, CONFIG_KEY_TASK_ID: task_id, CONFIG_KEY_READ: partial( local_read, _scratchpad( None, [], task_id, "", None, step, step + 2, ), channels, managed, task, ), }, ), ) # save task writes for task_id, task in zip(run_task_ids, run_tasks): # channel writes are saved to current checkpoint channel_writes = [w for w in task.writes if w[0] != PUSH] if saved and channel_writes: await checkpointer.aput_writes( checkpoint_config, channel_writes, task_id ) # apply to checkpoint and save apply_writes( checkpoint, channels, run_tasks, checkpointer.get_next_version, self.trigger_to_nodes, ) checkpoint = create_checkpoint(checkpoint, channels, step + 1) # save checkpoint, after applying writes next_config = await checkpointer.aput( checkpoint_config, checkpoint, { "source": "update", "step": step + 1, "parents": saved.metadata.get("parents", {}) if saved else {}, }, get_new_channel_versions( checkpoint_previous_versions, checkpoint["channel_versions"] ), ) for task_id, task in zip(run_task_ids, run_tasks): # save push writes if push_writes := [w for w in task.writes if w[0] == PUSH]: await checkpointer.aput_writes(next_config, push_writes, task_id) return patch_checkpoint_map(next_config, saved.metadata if saved else None) current_config = patch_configurable( config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])} ) for superstep in supersteps: current_config = await aperform_superstep(current_config, superstep) return current_config def update_state( self, config: RunnableConfig, values: dict[str, Any] | Any | None, as_node: str | None = None, task_id: str | None = None, ) -> RunnableConfig: """Update the state of the graph with the given values, as if they came from node `as_node`. If `as_node` is not provided, it will be set to the last node that updated the state, if not ambiguous. """ return self.bulk_update_state(config, [[StateUpdate(values, as_node, task_id)]]) async def aupdate_state( self, config: RunnableConfig, values: dict[str, Any] | Any, as_node: str | None = None, task_id: str | None = None, ) -> RunnableConfig: """Asynchronously update the state of the graph with the given values, as if they came from node `as_node`. If `as_node` is not provided, it will be set to the last node that updated the state, if not ambiguous. """ return await self.abulk_update_state( config, [[StateUpdate(values, as_node, task_id)]] ) def _defaults( self, config: RunnableConfig, *, stream_mode: StreamMode | Sequence[StreamMode], print_mode: StreamMode | Sequence[StreamMode], output_keys: str | Sequence[str] | None, interrupt_before: All | Sequence[str] | None, interrupt_after: All | Sequence[str] | None, durability: Durability | None = None, checkpoint_during: bool | None = None, ) -> tuple[ set[StreamMode], str | Sequence[str], All | Sequence[str], All | Sequence[str], BaseCheckpointSaver | None, BaseStore | None, BaseCache | None, Durability, ]: if config["recursion_limit"] < 1: raise ValueError("recursion_limit must be at least 1") if output_keys is None: output_keys = self.stream_channels_asis else: validate_keys(output_keys, self.channels) interrupt_before = interrupt_before or self.interrupt_before_nodes interrupt_after = interrupt_after or self.interrupt_after_nodes if not isinstance(stream_mode, list): stream_modes = {stream_mode} else: stream_modes = set(stream_mode) if isinstance(print_mode, str): stream_modes.add(print_mode) else: stream_modes.update(print_mode) if self.checkpointer is False: checkpointer: BaseCheckpointSaver | None = None elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}): checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER] elif self.checkpointer is True: raise RuntimeError("checkpointer=True cannot be used for root graphs.") else: checkpointer = self.checkpointer if checkpointer and not config.get(CONF): raise ValueError( "Checkpointer requires one or more of the following 'configurable' " "keys: thread_id, checkpoint_ns, checkpoint_id" ) if CONFIG_KEY_RUNTIME in config.get(CONF, {}): store: BaseStore | None = config[CONF][CONFIG_KEY_RUNTIME].store else: store = self.store if CONFIG_KEY_CACHE in config.get(CONF, {}): cache: BaseCache | None = config[CONF][CONFIG_KEY_CACHE] else: cache = self.cache if checkpoint_during is not None: if durability is not None: raise ValueError( "Cannot use both `checkpoint_during` and `durability` parameters." ) elif checkpoint_during: durability = "async" else: durability = "exit" if durability is None: durability = config.get(CONF, {}).get(CONFIG_KEY_DURABILITY, "async") return ( stream_modes, output_keys, interrupt_before, interrupt_after, checkpointer, store, cache, durability, ) def stream( self, input: InputT | Command | None, config: RunnableConfig | None = None, *, context: ContextT | None = None, stream_mode: StreamMode | Sequence[StreamMode] | None = None, print_mode: StreamMode | Sequence[StreamMode] = (), output_keys: str | Sequence[str] | None = None, interrupt_before: All | Sequence[str] | None = None, interrupt_after: All | Sequence[str] | None = None, durability: Durability | None = None, subgraphs: bool = False, debug: bool | None = None, **kwargs: Unpack[DeprecatedKwargs], ) -> Iterator[dict[str, Any] | Any]: """Stream graph steps for a single input. Args: input: The input to the graph. config: The configuration to use for the run. context: The static context to use for the run. !!! version-added "Added in version 0.6.0." stream_mode: The mode to stream output, defaults to `self.stream_mode`. Options are: - `"values"`: Emit all values in the state after each step, including interrupts. When used with functional API, values are emitted once at the end of the workflow. - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step. If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately. - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`. - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks. Will be emitted as 2-tuples `(LLM token, metadata)`. - `"checkpoints"`: Emit an event when a checkpoint is created, in the same format as returned by get_state(). - `"tasks"`: Emit events when tasks start and finish, including their results and errors. You can pass a list as the `stream_mode` parameter to stream multiple modes at once. The streamed outputs will be tuples of `(mode, data)`. See [LangGraph streaming guide](https://langchain-ai.github.io/langgraph/how-tos/streaming/) for more details. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes. Does not affect the output of the graph in any way. output_keys: The keys to stream, defaults to all non-context channels. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph. durability: The durability mode for the graph execution, defaults to "async". Options are: - `"sync"`: Changes are persisted synchronously before the next step starts. - `"async"`: Changes are persisted asynchronously while the next step executes. - `"exit"`: Changes are persisted only when the graph exits. subgraphs: Whether to stream events from inside subgraphs, defaults to False. If True, the events will be emitted as tuples `(namespace, data)`, or `(namespace, mode, data)` if `stream_mode` is a list, where `namespace` is a tuple with the path to the node where a subgraph is invoked, e.g. `("parent_node:<task_id>", "child_node:<task_id>")`. See [LangGraph streaming guide](https://langchain-ai.github.io/langgraph/how-tos/streaming/) for more details. Yields: The output of each step in the graph. The output shape depends on the stream_mode. """ if stream_mode is None: # if being called as a node in another graph, default to values mode # but don't overwrite stream_mode arg if provided stream_mode = ( "values" if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {}) else self.stream_mode ) if debug or self.debug: print_mode = ["updates", "values"] stream = SyncQueue() config = ensure_config(self.config, config) callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), ) try: deprecated_checkpoint_during = cast( Optional[bool], kwargs.get("checkpoint_during") ) if deprecated_checkpoint_during is not None: warnings.warn( "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.", category=LangGraphDeprecatedSinceV10, ) # assign defaults ( stream_modes, output_keys, interrupt_before_, interrupt_after_, checkpointer, store, cache, durability_, ) = self._defaults( config, stream_mode=stream_mode, print_mode=print_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, durability=durability, checkpoint_during=deprecated_checkpoint_during, ) if checkpointer is None and ( durability is not None or deprecated_checkpoint_during is not None ): warnings.warn( "`durability` has no effect when no checkpointer is present.", ) # set up subgraph checkpointing if self.checkpointer is True: ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS]) config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns) # set up messages stream mode if "messages" in stream_modes: run_manager.inheritable_handlers.append( StreamMessagesHandler(stream.put, subgraphs) ) # set up custom stream mode if "custom" in stream_modes: def stream_writer(c: Any) -> None: stream.put( ( tuple( get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split( NS_SEP )[:-1] ), "custom", c, ) ) elif CONFIG_KEY_STREAM in config[CONF]: stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer else: def stream_writer(c: Any) -> None: pass # set durability mode for subgraphs if durability is not None or deprecated_checkpoint_during is not None: config[CONF][CONFIG_KEY_DURABILITY] = durability_ runtime = Runtime( context=_coerce_context(self.context_schema, context), store=store, stream_writer=stream_writer, previous=None, ) parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME) runtime = parent_runtime.merge(runtime) config[CONF][CONFIG_KEY_RUNTIME] = runtime with SyncPregelLoop( input, stream=StreamProtocol(stream.put, stream_modes), config=config, store=store, cache=cache, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, output_keys=output_keys, input_keys=self.input_channels, stream_keys=self.stream_channels_asis, interrupt_before=interrupt_before_, interrupt_after=interrupt_after_, manager=run_manager, durability=durability_, trigger_to_nodes=self.trigger_to_nodes, migrate_checkpoint=self._migrate_checkpoint, retry_policy=self.retry_policy, cache_policy=self.cache_policy, ) as loop: # create runner runner = PregelRunner( submit=config[CONF].get( CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit) ), put_writes=weakref.WeakMethod(loop.put_writes), node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED), ) # enable subgraph streaming if subgraphs: loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream # enable concurrent streaming if ( self.stream_eager or subgraphs or "messages" in stream_modes or "custom" in stream_modes ): # we are careful to have a single waiter live at any one time # because on exit we increment semaphore count by exactly 1 waiter: concurrent.futures.Future | None = None # because sync futures cannot be cancelled, we instead # release the stream semaphore on exit, which will cause # a pending waiter to return immediately loop.stack.callback(stream._count.release) def get_waiter() -> concurrent.futures.Future[None]: nonlocal waiter if waiter is None or waiter.done(): waiter = loop.submit(stream.wait) return waiter else: return waiter else: get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates. # Channel updates from step N are only visible in step N+1 # channels are guaranteed to be immutable for the duration of the step, # with channel updates applied only at the transition between steps. while loop.tick(): for task in loop.match_cached_writes(): loop.output_writes(task.id, task.writes, cached=True) for _ in runner.tick( [t for t in loop.tasks.values() if not t.writes], timeout=self.step_timeout, get_waiter=get_waiter, schedule_task=loop.accept_push, ): # emit output yield from _output( stream_mode, print_mode, subgraphs, stream.get, queue.Empty ) loop.after_tick() # wait for checkpoint if durability_ == "sync": loop._put_checkpoint_fut.result() # emit output yield from _output( stream_mode, print_mode, subgraphs, stream.get, queue.Empty ) # handle exit if loop.status == "out_of_steps": msg = create_error_message( message=( f"Recursion limit of {config['recursion_limit']} reached " "without hitting a stop condition. You can increase the " "limit by setting the `recursion_limit` config key." ), error_code=ErrorCode.GRAPH_RECURSION_LIMIT, ) raise GraphRecursionError(msg) # set final channel values as run output run_manager.on_chain_end(loop.output) except BaseException as e: run_manager.on_chain_error(e) raise async def astream( self, input: InputT | Command | None, config: RunnableConfig | None = None, *, context: ContextT | None = None, stream_mode: StreamMode | Sequence[StreamMode] | None = None, print_mode: StreamMode | Sequence[StreamMode] = (), output_keys: str | Sequence[str] | None = None, interrupt_before: All | Sequence[str] | None = None, interrupt_after: All | Sequence[str] | None = None, durability: Durability | None = None, subgraphs: bool = False, debug: bool | None = None, **kwargs: Unpack[DeprecatedKwargs], ) -> AsyncIterator[dict[str, Any] | Any]: """Asynchronously stream graph steps for a single input. Args: input: The input to the graph. config: The configuration to use for the run. context: The static context to use for the run. !!! version-added "Added in version 0.6.0." stream_mode: The mode to stream output, defaults to `self.stream_mode`. Options are: - `"values"`: Emit all values in the state after each step, including interrupts. When used with functional API, values are emitted once at the end of the workflow. - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step. If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately. - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`. - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks. Will be emitted as 2-tuples `(LLM token, metadata)`. - `"debug"`: Emit debug events with as much information as possible for each step. You can pass a list as the `stream_mode` parameter to stream multiple modes at once. The streamed outputs will be tuples of `(mode, data)`. See [LangGraph streaming guide](https://langchain-ai.github.io/langgraph/how-tos/streaming/) for more details. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes. Does not affect the output of the graph in any way. output_keys: The keys to stream, defaults to all non-context channels. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph. durability: The durability mode for the graph execution, defaults to "async". Options are: - `"sync"`: Changes are persisted synchronously before the next step starts. - `"async"`: Changes are persisted asynchronously while the next step executes. - `"exit"`: Changes are persisted only when the graph exits. subgraphs: Whether to stream events from inside subgraphs, defaults to False. If True, the events will be emitted as tuples `(namespace, data)`, or `(namespace, mode, data)` if `stream_mode` is a list, where `namespace` is a tuple with the path to the node where a subgraph is invoked, e.g. `("parent_node:<task_id>", "child_node:<task_id>")`. See [LangGraph streaming guide](https://langchain-ai.github.io/langgraph/how-tos/streaming/) for more details. Yields: The output of each step in the graph. The output shape depends on the stream_mode. """ if stream_mode is None: # if being called as a node in another graph, default to values mode # but don't overwrite stream_mode arg if provided stream_mode = ( "values" if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {}) else self.stream_mode ) if debug or self.debug: print_mode = ["updates", "values"] stream = AsyncQueue() aioloop = asyncio.get_running_loop() stream_put = cast( Callable[[StreamChunk], None], partial(aioloop.call_soon_threadsafe, stream.put_nowait), ) config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), run_id=config.get("run_id"), ) # if running from astream_log() run each proc with streaming do_stream = ( next( ( True for h in run_manager.handlers if isinstance(h, _StreamingCallbackHandler) and not isinstance(h, StreamMessagesHandler) ), False, ) if _StreamingCallbackHandler is not None else False ) try: deprecated_checkpoint_during = cast( Optional[bool], kwargs.get("checkpoint_during") ) if deprecated_checkpoint_during is not None: warnings.warn( "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.", category=LangGraphDeprecatedSinceV10, ) # assign defaults ( stream_modes, output_keys, interrupt_before_, interrupt_after_, checkpointer, store, cache, durability_, ) = self._defaults( config, stream_mode=stream_mode, print_mode=print_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, durability=durability, checkpoint_during=deprecated_checkpoint_during, ) if checkpointer is None and ( durability is not None or deprecated_checkpoint_during is not None ): warnings.warn( "`durability` has no effect when no checkpointer is present.", ) # set up subgraph checkpointing if self.checkpointer is True: ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS]) config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns) # set up messages stream mode if "messages" in stream_modes: run_manager.inheritable_handlers.append( StreamMessagesHandler(stream_put, subgraphs) ) # set up custom stream mode def stream_writer(c: Any) -> None: aioloop.call_soon_threadsafe( stream.put_nowait, ( tuple( get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(NS_SEP)[ :-1 ] ), "custom", c, ), ) if "custom" in stream_modes: def stream_writer(c: Any) -> None: aioloop.call_soon_threadsafe( stream.put_nowait, ( tuple( get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split( NS_SEP )[:-1] ), "custom", c, ), ) elif CONFIG_KEY_STREAM in config[CONF]: stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer else: def stream_writer(c: Any) -> None: pass # set durability mode for subgraphs if durability is not None or deprecated_checkpoint_during is not None: config[CONF][CONFIG_KEY_DURABILITY] = durability_ runtime = Runtime( context=_coerce_context(self.context_schema, context), store=store, stream_writer=stream_writer, previous=None, ) parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME) runtime = parent_runtime.merge(runtime) config[CONF][CONFIG_KEY_RUNTIME] = runtime async with AsyncPregelLoop( input, stream=StreamProtocol(stream.put_nowait, stream_modes), config=config, store=store, cache=cache, checkpointer=checkpointer, nodes=self.nodes, specs=self.channels, output_keys=output_keys, input_keys=self.input_channels, stream_keys=self.stream_channels_asis, interrupt_before=interrupt_before_, interrupt_after=interrupt_after_, manager=run_manager, durability=durability_, trigger_to_nodes=self.trigger_to_nodes, migrate_checkpoint=self._migrate_checkpoint, retry_policy=self.retry_policy, cache_policy=self.cache_policy, ) as loop: # create runner runner = PregelRunner( submit=config[CONF].get( CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit) ), put_writes=weakref.WeakMethod(loop.put_writes), use_astream=do_stream, node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED), ) # enable subgraph streaming if subgraphs: loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol( stream_put, stream_modes ) # enable concurrent streaming if ( self.stream_eager or subgraphs or "messages" in stream_modes or "custom" in stream_modes ): def get_waiter() -> asyncio.Task[None]: return aioloop.create_task(stream.wait()) else: get_waiter = None # type: ignore[assignment] # Similarly to Bulk Synchronous Parallel / Pregel model # computation proceeds in steps, while there are channel updates # channel updates from step N are only visible in step N+1 # channels are guaranteed to be immutable for the duration of the step, # with channel updates applied only at the transition between steps while loop.tick(): for task in await loop.amatch_cached_writes(): loop.output_writes(task.id, task.writes, cached=True) async for _ in runner.atick( [t for t in loop.tasks.values() if not t.writes], timeout=self.step_timeout, get_waiter=get_waiter, schedule_task=loop.aaccept_push, ): # emit output for o in _output( stream_mode, print_mode, subgraphs, stream.get_nowait, asyncio.QueueEmpty, ): yield o loop.after_tick() # wait for checkpoint if durability_ == "sync": await cast(asyncio.Future, loop._put_checkpoint_fut) # emit output for o in _output( stream_mode, print_mode, subgraphs, stream.get_nowait, asyncio.QueueEmpty, ): yield o # handle exit if loop.status == "out_of_steps": msg = create_error_message( message=( f"Recursion limit of {config['recursion_limit']} reached " "without hitting a stop condition. You can increase the " "limit by setting the `recursion_limit` config key." ), error_code=ErrorCode.GRAPH_RECURSION_LIMIT, ) raise GraphRecursionError(msg) # set final channel values as run output await run_manager.on_chain_end(loop.output) except BaseException as e: await asyncio.shield(run_manager.on_chain_error(e)) raise def invoke( self, input: InputT | Command | None, config: RunnableConfig | None = None, *, context: ContextT | None = None, stream_mode: StreamMode = "values", print_mode: StreamMode | Sequence[StreamMode] = (), output_keys: str | Sequence[str] | None = None, interrupt_before: All | Sequence[str] | None = None, interrupt_after: All | Sequence[str] | None = None, **kwargs: Any, ) -> dict[str, Any] | Any: """Run the graph with a single input and config. Args: input: The input data for the graph. It can be a dictionary or any other type. config: Optional. The configuration for the graph run. context: The static context to use for the run. !!! version-added "Added in version 0.6.0." stream_mode: Optional[str]. The stream mode for the graph run. Default is "values". print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes. Does not affect the output of the graph in any way. output_keys: Optional. The output keys to retrieve from the graph run. interrupt_before: Optional. The nodes to interrupt the graph run before. interrupt_after: Optional. The nodes to interrupt the graph run after. **kwargs: Additional keyword arguments to pass to the graph run. Returns: The output of the graph run. If stream_mode is "values", it returns the latest output. If stream_mode is not "values", it returns a list of output chunks. """ output_keys = output_keys if output_keys is not None else self.output_channels latest: dict[str, Any] | Any = None chunks: list[dict[str, Any] | Any] = [] interrupts: list[Interrupt] = [] for chunk in self.stream( input, config, context=context, stream_mode=["updates", "values"] if stream_mode == "values" else stream_mode, print_mode=print_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, **kwargs, ): if stream_mode == "values": if len(chunk) == 2: mode, payload = cast(tuple[StreamMode, Any], chunk) else: _, mode, payload = cast( tuple[tuple[str, ...], StreamMode, Any], chunk ) if ( mode == "updates" and isinstance(payload, dict) and (ints := payload.get(INTERRUPT)) is not None ): interrupts.extend(ints) elif mode == "values": latest = payload else: chunks.append(chunk) if stream_mode == "values": if interrupts: return ( {**latest, INTERRUPT: interrupts} if isinstance(latest, dict) else {INTERRUPT: interrupts} ) return latest else: return chunks async def ainvoke( self, input: InputT | Command | None, config: RunnableConfig | None = None, *, context: ContextT | None = None, stream_mode: StreamMode = "values", print_mode: StreamMode | Sequence[StreamMode] = (), output_keys: str | Sequence[str] | None = None, interrupt_before: All | Sequence[str] | None = None, interrupt_after: All | Sequence[str] | None = None, **kwargs: Any, ) -> dict[str, Any] | Any: """Asynchronously invoke the graph on a single input. Args: input: The input data for the computation. It can be a dictionary or any other type. config: Optional. The configuration for the computation. context: The static context to use for the run. !!! version-added "Added in version 0.6.0." stream_mode: Optional. The stream mode for the computation. Default is "values". print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes. Does not affect the output of the graph in any way. output_keys: Optional. The output keys to include in the result. Default is None. interrupt_before: Optional. The nodes to interrupt before. Default is None. interrupt_after: Optional. The nodes to interrupt after. Default is None. **kwargs: Additional keyword arguments. Returns: The result of the computation. If stream_mode is "values", it returns the latest value. If stream_mode is "chunks", it returns a list of chunks. """ output_keys = output_keys if output_keys is not None else self.output_channels latest: dict[str, Any] | Any = None chunks: list[dict[str, Any] | Any] = [] interrupts: list[Interrupt] = [] async for chunk in self.astream( input, config, context=context, stream_mode=["updates", "values"] if stream_mode == "values" else stream_mode, print_mode=print_mode, output_keys=output_keys, interrupt_before=interrupt_before, interrupt_after=interrupt_after, **kwargs, ): if stream_mode == "values": if len(chunk) == 2: mode, payload = cast(tuple[StreamMode, Any], chunk) else: _, mode, payload = cast( tuple[tuple[str, ...], StreamMode, Any], chunk ) if ( mode == "updates" and isinstance(payload, dict) and (ints := payload.get(INTERRUPT)) is not None ): interrupts.extend(ints) elif mode == "values": latest = payload else: chunks.append(chunk) if stream_mode == "values": if interrupts: return ( {**latest, INTERRUPT: interrupts} if isinstance(latest, dict) else {INTERRUPT: interrupts} ) return latest else: return chunks def clear_cache(self, nodes: Sequence[str] | None = None) -> None: """Clear the cache for the given nodes.""" if not self.cache: raise ValueError("No cache is set for this graph. Cannot clear cache.") nodes = nodes or self.nodes.keys() # collect namespaces to clear namespaces: list[tuple[str, ...]] = [] for node in nodes: if node in self.nodes: namespaces.append( ( CACHE_NS_WRITES, (identifier(self.nodes[node]) or "__dynamic__"), node, ), ) # clear cache self.cache.clear(namespaces) async def aclear_cache(self, nodes: Sequence[str] | None = None) -> None: """Asynchronously clear the cache for the given nodes.""" if not self.cache: raise ValueError("No cache is set for this graph. Cannot clear cache.") nodes = nodes or self.nodes.keys() # collect namespaces to clear namespaces: list[tuple[str, ...]] = [] for node in nodes: if node in self.nodes: namespaces.append( ( CACHE_NS_WRITES, (identifier(self.nodes[node]) or "__dynamic__"), node, ), ) # clear cache await self.cache.aclear(namespaces) def _trigger_to_nodes(nodes: dict[str, PregelNode]) -> Mapping[str, Sequence[str]]: """Index from a trigger to nodes that depend on it.""" trigger_to_nodes: defaultdict[str, list[str]] = defaultdict(list) for name, node in nodes.items(): for trigger in node.triggers: trigger_to_nodes[trigger].append(name) return dict(trigger_to_nodes) def _output( stream_mode: StreamMode | Sequence[StreamMode], print_mode: StreamMode | Sequence[StreamMode], stream_subgraphs: bool, getter: Callable[[], tuple[tuple[str, ...], str, Any]], empty_exc: type[Exception], ) -> Iterator: while True: try: ns, mode, payload = getter() except empty_exc: break if mode in print_mode: if stream_subgraphs and ns: print( " ".join( ( get_bolded_text(f"[{mode}]"), get_colored_text(f"[graph={ns}]", color="yellow"), repr(payload), ) ) ) else: print( " ".join( ( get_bolded_text(f"[{mode}]"), repr(payload), ) ) ) if mode in stream_mode: if stream_subgraphs and isinstance(stream_mode, list): yield (ns, mode, payload) elif isinstance(stream_mode, list): yield (mode, payload) elif stream_subgraphs: yield (ns, payload) else: yield payload def _coerce_context( context_schema: type[ContextT] | None, context: Any ) -> ContextT | None: """Coerce context input to the appropriate schema type. If context is a dict and context_schema is a dataclass or pydantic model, we coerce. Else, we return the context as-is. Args: context_schema: The schema type to coerce to (BaseModel, dataclass, or TypedDict) context: The context value to coerce Returns: The coerced context value or None if context is None """ if context is None: return None if context_schema is None: return context schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass( context_schema ) if isinstance(context, dict) and schema_is_class: return context_schema(**context) # type: ignore[misc] return cast(ContextT, context)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kaman05010/MCPClientServer'

If you have feedback or need assistance with the MCP directory API, please join our Discord server