Skip to main content
Glama
kaman05010

MCP Wikipedia Server

by kaman05010
_loop.py45.3 kB
from __future__ import annotations import asyncio import binascii import concurrent.futures from collections import defaultdict, deque from collections.abc import Iterator, Mapping, Sequence from contextlib import ( AbstractAsyncContextManager, AbstractContextManager, AsyncExitStack, ExitStack, ) from datetime import datetime, timezone from inspect import signature from types import TracebackType from typing import ( Any, Callable, Literal, Optional, TypeVar, cast, ) from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager from langchain_core.runnables import RunnableConfig from typing_extensions import ParamSpec, Self from langgraph._internal._config import patch_configurable from langgraph._internal._constants import ( CONF, CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_RESUME_MAP, CONFIG_KEY_RESUMING, CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, CONFIG_KEY_THREAD_ID, ERROR, INPUT, INTERRUPT, NS_END, NS_SEP, NULL_TASK_ID, PUSH, RESUME, ) from langgraph._internal._typing import EMPTY_SEQ, MISSING from langgraph.cache.base import BaseCache from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import ( WRITES_IDX_MAP, BaseCheckpointSaver, ChannelVersions, Checkpoint, CheckpointMetadata, CheckpointTuple, PendingWrite, ) from langgraph.constants import TAG_HIDDEN from langgraph.errors import ( EmptyInputError, GraphInterrupt, ) from langgraph.managed.base import ( ManagedValueMapping, ManagedValueSpec, ) from langgraph.pregel._algo import ( Call, GetNextVersion, PregelTaskWrites, apply_writes, checkpoint_null_version, increment, prepare_next_tasks, prepare_single_task, should_interrupt, task_path_str, ) from langgraph.pregel._checkpoint import ( channels_from_checkpoint, copy_checkpoint, create_checkpoint, empty_checkpoint, ) from langgraph.pregel._executor import ( AsyncBackgroundExecutor, BackgroundExecutor, Submit, ) from langgraph.pregel._io import ( map_command, map_input, map_output_updates, map_output_values, read_channels, ) from langgraph.pregel._read import PregelNode from langgraph.pregel._scratchpad import PregelScratchpad from langgraph.pregel._utils import get_new_channel_versions, is_xxh3_128_hexdigest from langgraph.pregel.debug import ( map_debug_checkpoint, map_debug_task_results, map_debug_tasks, ) from langgraph.pregel.protocol import StreamChunk, StreamProtocol from langgraph.store.base import BaseStore from langgraph.types import ( All, CachePolicy, Command, Durability, PregelExecutableTask, RetryPolicy, StreamMode, ) V = TypeVar("V") P = ParamSpec("P") WritesT = Sequence[tuple[str, Any]] def DuplexStream(*streams: StreamProtocol) -> StreamProtocol: def __call__(value: StreamChunk) -> None: for stream in streams: if value[1] in stream.modes: stream(value) return StreamProtocol(__call__, {mode for s in streams for mode in s.modes}) class PregelLoop: config: RunnableConfig store: BaseStore | None stream: StreamProtocol | None step: int stop: int input: Any | None cache: BaseCache[WritesT] | None checkpointer: BaseCheckpointSaver | None nodes: Mapping[str, PregelNode] specs: Mapping[str, BaseChannel | ManagedValueSpec] input_keys: str | Sequence[str] output_keys: str | Sequence[str] stream_keys: str | Sequence[str] skip_done_tasks: bool is_nested: bool manager: None | AsyncParentRunManager | ParentRunManager interrupt_after: All | Sequence[str] interrupt_before: All | Sequence[str] durability: Durability retry_policy: Sequence[RetryPolicy] cache_policy: CachePolicy | None checkpointer_get_next_version: GetNextVersion checkpointer_put_writes: Callable[[RunnableConfig, WritesT, str], Any] | None checkpointer_put_writes_accepts_task_path: bool _checkpointer_put_after_previous: ( Callable[ [ concurrent.futures.Future | None, RunnableConfig, Checkpoint, str, ChannelVersions, ], Any, ] | None ) _migrate_checkpoint: Callable[[Checkpoint], None] | None submit: Submit channels: Mapping[str, BaseChannel] managed: ManagedValueMapping checkpoint: Checkpoint checkpoint_id_saved: str checkpoint_ns: tuple[str, ...] checkpoint_config: RunnableConfig checkpoint_metadata: CheckpointMetadata checkpoint_pending_writes: list[PendingWrite] checkpoint_previous_versions: dict[str, str | float | int] prev_checkpoint_config: RunnableConfig | None status: Literal[ "input", "pending", "done", "interrupt_before", "interrupt_after", "out_of_steps", ] tasks: dict[str, PregelExecutableTask] output: None | dict[str, Any] | Any = None updated_channels: set[str] | None = None # public def __init__( self, input: Any | None, *, stream: StreamProtocol | None, config: RunnableConfig, store: BaseStore | None, cache: BaseCache | None, checkpointer: BaseCheckpointSaver | None, nodes: Mapping[str, PregelNode], specs: Mapping[str, BaseChannel | ManagedValueSpec], input_keys: str | Sequence[str], output_keys: str | Sequence[str], stream_keys: str | Sequence[str], trigger_to_nodes: Mapping[str, Sequence[str]], durability: Durability, interrupt_after: All | Sequence[str] = EMPTY_SEQ, interrupt_before: All | Sequence[str] = EMPTY_SEQ, manager: None | AsyncParentRunManager | ParentRunManager = None, migrate_checkpoint: Callable[[Checkpoint], None] | None = None, retry_policy: Sequence[RetryPolicy] = (), cache_policy: CachePolicy | None = None, ) -> None: self.stream = stream self.config = config self.store = store self.step = 0 self.stop = 0 self.input = input self.checkpointer = checkpointer self.cache = cache self.nodes = nodes self.specs = specs self.input_keys = input_keys self.output_keys = output_keys self.stream_keys = stream_keys self.interrupt_after = interrupt_after self.interrupt_before = interrupt_before self.manager = manager self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {}) self.skip_done_tasks = CONFIG_KEY_CHECKPOINT_ID not in config[CONF] self._migrate_checkpoint = migrate_checkpoint self.trigger_to_nodes = trigger_to_nodes self.retry_policy = retry_policy self.cache_policy = cache_policy self.durability = durability if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]: self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM]) scratchpad: PregelScratchpad | None = config[CONF].get(CONFIG_KEY_SCRATCHPAD) if isinstance(scratchpad, PregelScratchpad): # if count is > 0, append to checkpoint_ns # if count is 0, leave as is if cnt := scratchpad.subgraph_counter(): self.config = patch_configurable( self.config, { CONFIG_KEY_CHECKPOINT_NS: NS_SEP.join( ( config[CONF][CONFIG_KEY_CHECKPOINT_NS], str(cnt), ) ) }, ) if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): self.config = patch_configurable( self.config, {CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None}, ) if ( CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF] and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP] ): self.checkpoint_config = patch_configurable( self.config, { CONFIG_KEY_CHECKPOINT_ID: self.config[CONF][ CONFIG_KEY_CHECKPOINT_MAP ][self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]] }, ) else: self.checkpoint_config = self.config if thread_id := self.checkpoint_config[CONF].get(CONFIG_KEY_THREAD_ID): if not isinstance(thread_id, str): self.checkpoint_config = patch_configurable( self.checkpoint_config, {CONFIG_KEY_THREAD_ID: str(thread_id)}, ) self.checkpoint_ns = ( tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP)) if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) else () ) self.prev_checkpoint_config = None def put_writes(self, task_id: str, writes: WritesT) -> None: """Put writes for a task, to be read by the next tick.""" if not writes: return # deduplicate writes to special channels, last write wins if all(w[0] in WRITES_IDX_MAP for w in writes): writes = list({w[0]: w for w in writes}.values()) if task_id == NULL_TASK_ID: # writes for the null task are accumulated self.checkpoint_pending_writes = [ w for w in self.checkpoint_pending_writes if w[0] != task_id or w[1] not in WRITES_IDX_MAP ] writes_to_save: WritesT = [ w[1:] for w in self.checkpoint_pending_writes if w[0] == task_id ] + list(writes) else: # remove existing writes for this task self.checkpoint_pending_writes = [ w for w in self.checkpoint_pending_writes if w[0] != task_id ] writes_to_save = writes # save writes self.checkpoint_pending_writes.extend((task_id, c, v) for c, v in writes) if self.durability != "exit" and self.checkpointer_put_writes is not None: config = patch_configurable( self.checkpoint_config, { CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ), CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, ) if self.checkpointer_put_writes_accepts_task_path: if hasattr(self, "tasks"): task = self.tasks.get(task_id) else: task = None self.submit( self.checkpointer_put_writes, config, writes_to_save, task_id, task_path_str(task.path) if task else "", ) else: self.submit( self.checkpointer_put_writes, config, writes_to_save, task_id, ) # output writes if hasattr(self, "tasks"): self.output_writes(task_id, writes) def _put_pending_writes(self) -> None: if self.checkpointer_put_writes is None: return if not self.checkpoint_pending_writes: return # patch config config = patch_configurable( self.checkpoint_config, { CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ), CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, ) # group by task id by_task = defaultdict(list) for task_id, channel, value in self.checkpoint_pending_writes: by_task[task_id].append((channel, value)) # submit writes to checkpointer for task_id, writes in by_task.items(): if self.checkpointer_put_writes_accepts_task_path and hasattr( self, "tasks" ): task = self.tasks.get(task_id) self.submit( self.checkpointer_put_writes, config, writes, task_id, task_path_str(task.path) if task else "", ) else: self.submit( self.checkpointer_put_writes, config, writes, task_id, ) def accept_push( self, task: PregelExecutableTask, write_idx: int, call: Call | None = None ) -> PregelExecutableTask | None: """Accept a PUSH from a task, potentially returning a new task to start.""" checkpoint_id_bytes = binascii.unhexlify(self.checkpoint["id"].replace("-", "")) null_version = checkpoint_null_version(self.checkpoint) if pushed := cast( Optional[PregelExecutableTask], prepare_single_task( (PUSH, task.path, write_idx, task.id, call), None, checkpoint=self.checkpoint, checkpoint_id_bytes=checkpoint_id_bytes, checkpoint_null_version=null_version, pending_writes=self.checkpoint_pending_writes, processes=self.nodes, channels=self.channels, managed=self.managed, config=task.config, step=self.step, stop=self.stop, for_execution=True, store=self.store, checkpointer=self.checkpointer, manager=self.manager, retry_policy=self.retry_policy, cache_policy=self.cache_policy, ), ): # produce debug output self._emit("tasks", map_debug_tasks, [pushed]) # save the new task self.tasks[pushed.id] = pushed # match any pending writes to the new task if self.skip_done_tasks: self._match_writes({pushed.id: pushed}) # return the new task, to be started if not run before return pushed def tick(self) -> bool: """Execute a single iteration of the Pregel loop. Args: input_keys: The key(s) to read input from. Returns: True if more iterations are needed. """ # check if iteration limit is reached if self.step > self.stop: self.status = "out_of_steps" return False # prepare next tasks self.tasks = prepare_next_tasks( self.checkpoint, self.checkpoint_pending_writes, self.nodes, self.channels, self.managed, self.config, self.step, self.stop, for_execution=True, manager=self.manager, store=self.store, checkpointer=self.checkpointer, trigger_to_nodes=self.trigger_to_nodes, updated_channels=self.updated_channels, retry_policy=self.retry_policy, cache_policy=self.cache_policy, ) # produce debug output if self._checkpointer_put_after_previous is not None: self._emit( "checkpoints", map_debug_checkpoint, { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, }, self.channels, self.stream_keys, self.checkpoint_metadata, self.tasks.values(), self.checkpoint_pending_writes, self.prev_checkpoint_config, self.output_keys, ) # if no more tasks, we're done if not self.tasks: self.status = "done" return False # if there are pending writes from a previous loop, apply them if self.skip_done_tasks and self.checkpoint_pending_writes: self._match_writes(self.tasks) # before execution, check if we should interrupt if self.interrupt_before and should_interrupt( self.checkpoint, self.interrupt_before, self.tasks.values() ): self.status = "interrupt_before" raise GraphInterrupt() # produce debug output self._emit("tasks", map_debug_tasks, self.tasks.values()) # print output for any tasks we applied previous writes to for task in self.tasks.values(): if task.writes: self.output_writes(task.id, task.writes, cached=True) return True def after_tick(self) -> None: # finish superstep writes = [w for t in self.tasks.values() for w in t.writes] # all tasks have finished self.updated_channels = apply_writes( self.checkpoint, self.channels, self.tasks.values(), self.checkpointer_get_next_version, self.trigger_to_nodes, ) # produce values output if not self.updated_channels.isdisjoint( (self.output_keys,) if isinstance(self.output_keys, str) else self.output_keys ): self._emit( "values", map_output_values, self.output_keys, writes, self.channels ) # clear pending writes self.checkpoint_pending_writes.clear() # "not skip_done_tasks" only applies to first tick after resuming self.skip_done_tasks = True # save checkpoint self._put_checkpoint({"source": "loop"}) # after execution, check if we should interrupt if self.interrupt_after and should_interrupt( self.checkpoint, self.interrupt_after, self.tasks.values() ): self.status = "interrupt_after" raise GraphInterrupt() # unset resuming flag self.config[CONF].pop(CONFIG_KEY_RESUMING, None) def match_cached_writes(self) -> Sequence[PregelExecutableTask]: raise NotImplementedError async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]: raise NotImplementedError # private def _match_writes(self, tasks: Mapping[str, PregelExecutableTask]) -> None: for tid, k, v in self.checkpoint_pending_writes: if k in (ERROR, INTERRUPT, RESUME): continue if task := tasks.get(tid): task.writes.append((k, v)) def _first(self, *, input_keys: str | Sequence[str]) -> set[str] | None: # resuming from previous checkpoint requires # - finding a previous checkpoint # - receiving None input (outer graph) or RESUMING flag (subgraph) configurable = self.config.get(CONF, {}) is_resuming = bool(self.checkpoint["channel_versions"]) and bool( configurable.get( CONFIG_KEY_RESUMING, self.input is None or isinstance(self.input, Command) or ( not self.is_nested and self.config.get("metadata", {}).get("run_id") == self.checkpoint_metadata.get("run_id", MISSING) ), ) ) # this can be set only when there are input_writes updated_channels: set[str] | None = None # map command to writes if isinstance(self.input, Command): if resume_is_map := ( (resume := self.input.resume) is not None and isinstance(resume, dict) and all(is_xxh3_128_hexdigest(k) for k in resume) ): self.config[CONF][CONFIG_KEY_RESUME_MAP] = self.input.resume if resume is not None and not self.checkpointer: raise RuntimeError( "Cannot use Command(resume=...) without checkpointer" ) writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list) # group writes by task ID for tid, c, v in map_command(cmd=self.input): if not (c == RESUME and resume_is_map): writes[tid].append((c, v)) if not writes and not resume_is_map: raise EmptyInputError("Received empty Command input") # save writes for tid, ws in writes.items(): self.put_writes(tid, ws) # apply NULL writes if null_writes := [ w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID ]: apply_writes( self.checkpoint, self.channels, [PregelTaskWrites((), INPUT, null_writes, [])], self.checkpointer_get_next_version, self.trigger_to_nodes, ) # proceed past previous checkpoint if is_resuming: self.checkpoint["versions_seen"].setdefault(INTERRUPT, {}) for k in self.channels: if k in self.checkpoint["channel_versions"]: version = self.checkpoint["channel_versions"][k] self.checkpoint["versions_seen"][INTERRUPT][k] = version # produce values output self._emit( "values", map_output_values, self.output_keys, True, self.channels ) # map inputs to channel updates elif input_writes := deque(map_input(input_keys, self.input)): # discard any unfinished tasks from previous checkpoint discard_tasks = prepare_next_tasks( self.checkpoint, self.checkpoint_pending_writes, self.nodes, self.channels, self.managed, self.config, self.step, self.stop, for_execution=True, store=None, checkpointer=None, manager=None, ) # apply input writes updated_channels = apply_writes( self.checkpoint, self.channels, [ *discard_tasks.values(), PregelTaskWrites((), INPUT, input_writes, []), ], self.checkpointer_get_next_version, self.trigger_to_nodes, ) # save input checkpoint self._put_checkpoint({"source": "input"}) elif CONFIG_KEY_RESUMING not in configurable: raise EmptyInputError(f"Received no input for {input_keys}") # update config if not self.is_nested: self.config = patch_configurable( self.config, {CONFIG_KEY_RESUMING: is_resuming} ) # set flag self.status = "pending" return updated_channels def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: # assign step and parents exiting = metadata is self.checkpoint_metadata if exiting and self.checkpoint["id"] == self.checkpoint_id_saved: # checkpoint already saved return if not exiting: metadata["step"] = self.step metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {}) self.checkpoint_metadata = metadata # do checkpoint? do_checkpoint = self._checkpointer_put_after_previous is not None and ( exiting or self.durability != "exit" ) # create new checkpoint self.checkpoint = create_checkpoint( self.checkpoint, self.channels if do_checkpoint else None, self.step, id=self.checkpoint["id"] if exiting else None, ) # bail if no checkpointer if do_checkpoint and self._checkpointer_put_after_previous is not None: self.prev_checkpoint_config = ( self.checkpoint_config if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF] and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] else None ) self.checkpoint_config = { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( CONFIG_KEY_CHECKPOINT_NS, "" ), }, } channel_versions = self.checkpoint["channel_versions"].copy() new_versions = get_new_channel_versions( self.checkpoint_previous_versions, channel_versions ) self.checkpoint_previous_versions = channel_versions # save it, without blocking # if there's a previous checkpoint save in progress, wait for it # ensuring checkpointers receive checkpoints in order self._put_checkpoint_fut = self.submit( self._checkpointer_put_after_previous, getattr(self, "_put_checkpoint_fut", None), self.checkpoint_config, copy_checkpoint(self.checkpoint), self.checkpoint_metadata, new_versions, ) self.checkpoint_config = { **self.checkpoint_config, CONF: { **self.checkpoint_config[CONF], CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, } if not exiting: # increment step self.step += 1 def _suppress_interrupt( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool | None: # persist current checkpoint and writes if self.durability == "exit" and ( # if it's a top graph not self.is_nested # or a nested graph with error or interrupt or exc_value is not None # or a nested graph with checkpointer=True or all(NS_END not in part for part in self.checkpoint_ns) ): self._put_checkpoint(self.checkpoint_metadata) self._put_pending_writes() # suppress interrupt suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested if suppress: # emit one last "values" event, with pending writes applied if ( hasattr(self, "tasks") and self.checkpoint_pending_writes and any(task.writes for task in self.tasks.values()) ): updated_channels = apply_writes( self.checkpoint, self.channels, self.tasks.values(), self.checkpointer_get_next_version, self.trigger_to_nodes, ) if not updated_channels.isdisjoint( (self.output_keys,) if isinstance(self.output_keys, str) else self.output_keys ): self._emit( "values", map_output_values, self.output_keys, [w for t in self.tasks.values() for w in t.writes], self.channels, ) # emit INTERRUPT if exception is empty (otherwise emitted by put_writes) if exc_value is not None and (not exc_value.args or not exc_value.args[0]): self._emit( "updates", lambda: iter( [{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}] ), ) # save final output self.output = read_channels(self.channels, self.output_keys) # suppress interrupt return True elif exc_type is None: # save final output self.output = read_channels(self.channels, self.output_keys) def _emit( self, mode: StreamMode, values: Callable[P, Iterator[Any]], *args: P.args, **kwargs: P.kwargs, ) -> None: if self.stream is None: return debug_remap = mode in ("checkpoints", "tasks") and "debug" in self.stream.modes if mode not in self.stream.modes and not debug_remap: return for v in values(*args, **kwargs): if mode in self.stream.modes: self.stream((self.checkpoint_ns, mode, v)) # "debug" mode is "checkpoints" or "tasks" with a wrapper dict if debug_remap: self.stream( ( self.checkpoint_ns, "debug", { "step": self.step - 1 if mode == "checkpoints" else self.step, "timestamp": datetime.now(timezone.utc).isoformat(), "type": "checkpoint" if mode == "checkpoints" else "task_result" if "result" in v else "task", "payload": v, }, ) ) def output_writes( self, task_id: str, writes: WritesT, *, cached: bool = False ) -> None: if task := self.tasks.get(task_id): if task.config is not None and TAG_HIDDEN in task.config.get( "tags", EMPTY_SEQ ): return if writes[0][0] == INTERRUPT: # in loop.py we append a bool to the PUSH task paths to indicate # whether or not a call was present. If so, # we don't emit the interrupt as it'll be emitted by the parent if task.path[0] == PUSH and task.path[-1] is True: return interrupts = [ { INTERRUPT: tuple( v for w in writes if w[0] == INTERRUPT for v in (w[1] if isinstance(w[1], Sequence) else (w[1],)) ) } ] self._emit("updates", lambda: iter(interrupts)) elif writes[0][0] != ERROR: self._emit( "updates", map_output_updates, self.output_keys, [(task, writes)], cached, ) if not cached: self._emit( "tasks", map_debug_task_results, (task, writes), self.stream_keys, ) class SyncPregelLoop(PregelLoop, AbstractContextManager): def __init__( self, input: Any | None, *, stream: StreamProtocol | None, config: RunnableConfig, store: BaseStore | None, cache: BaseCache | None, checkpointer: BaseCheckpointSaver | None, nodes: Mapping[str, PregelNode], specs: Mapping[str, BaseChannel | ManagedValueSpec], trigger_to_nodes: Mapping[str, Sequence[str]], durability: Durability, manager: None | AsyncParentRunManager | ParentRunManager = None, interrupt_after: All | Sequence[str] = EMPTY_SEQ, interrupt_before: All | Sequence[str] = EMPTY_SEQ, input_keys: str | Sequence[str] = EMPTY_SEQ, output_keys: str | Sequence[str] = EMPTY_SEQ, stream_keys: str | Sequence[str] = EMPTY_SEQ, migrate_checkpoint: Callable[[Checkpoint], None] | None = None, retry_policy: Sequence[RetryPolicy] = (), cache_policy: CachePolicy | None = None, ) -> None: super().__init__( input, stream=stream, config=config, checkpointer=checkpointer, cache=cache, store=store, nodes=nodes, specs=specs, input_keys=input_keys, output_keys=output_keys, stream_keys=stream_keys, interrupt_after=interrupt_after, interrupt_before=interrupt_before, manager=manager, migrate_checkpoint=migrate_checkpoint, trigger_to_nodes=trigger_to_nodes, retry_policy=retry_policy, cache_policy=cache_policy, durability=durability, ) self.stack = ExitStack() if checkpointer: self.checkpointer_get_next_version = checkpointer.get_next_version self.checkpointer_put_writes = checkpointer.put_writes self.checkpointer_put_writes_accepts_task_path = ( signature(checkpointer.put_writes).parameters.get("task_path") is not None ) else: self.checkpointer_get_next_version = increment self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None self.checkpointer_put_writes_accepts_task_path = False def _checkpointer_put_after_previous( self, prev: concurrent.futures.Future | None, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: prev.result() finally: cast(BaseCheckpointSaver, self.checkpointer).put( config, checkpoint, metadata, new_versions ) def match_cached_writes(self) -> Sequence[PregelExecutableTask]: if self.cache is None: return () matched: list[PregelExecutableTask] = [] if cached := { (t.cache_key.ns, t.cache_key.key): t for t in self.tasks.values() if t.cache_key and not t.writes }: for key, values in self.cache.get(tuple(cached)).items(): task = cached[key] task.writes.extend(values) matched.append(task) return matched def accept_push( self, task: PregelExecutableTask, write_idx: int, call: Call | None = None ) -> PregelExecutableTask | None: if pushed := super().accept_push(task, write_idx, call): for task in self.match_cached_writes(): self.output_writes(task.id, task.writes, cached=True) return pushed def put_writes(self, task_id: str, writes: WritesT) -> None: """Put writes for a task, to be read by the next tick.""" super().put_writes(task_id, writes) if not writes or self.cache is None or not hasattr(self, "tasks"): return task = self.tasks.get(task_id) if task is None or task.cache_key is None: return self.submit( self.cache.set, { (task.cache_key.ns, task.cache_key.key): ( task.writes, task.cache_key.ttl, ) }, ) # context manager def __enter__(self) -> Self: if self.checkpointer: saved = self.checkpointer.get_tuple(self.checkpoint_config) else: saved = None if saved is None: saved = CheckpointTuple( self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, [] ) elif self._migrate_checkpoint is not None: self._migrate_checkpoint(saved.checkpoint) self.checkpoint_config = { **self.checkpoint_config, **saved.config, CONF: { CONFIG_KEY_CHECKPOINT_NS: "", **self.checkpoint_config.get(CONF, {}), **saved.config.get(CONF, {}), }, } self.prev_checkpoint_config = saved.parent_config self.checkpoint_id_saved = saved.checkpoint["id"] self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes] if saved.pending_writes is not None else [] ) self.submit = self.stack.enter_context(BackgroundExecutor(self.config)) self.channels, self.managed = channels_from_checkpoint( self.specs, self.checkpoint ) self.stack.push(self._suppress_interrupt) self.status = "input" self.step = self.checkpoint_metadata["step"] + 1 self.stop = self.step + self.config["recursion_limit"] + 1 self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy() self.updated_channels = self._first(input_keys=self.input_keys) return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool | None: # unwind stack return self.stack.__exit__(exc_type, exc_value, traceback) class AsyncPregelLoop(PregelLoop, AbstractAsyncContextManager): def __init__( self, input: Any | None, *, stream: StreamProtocol | None, config: RunnableConfig, store: BaseStore | None, cache: BaseCache | None, checkpointer: BaseCheckpointSaver | None, nodes: Mapping[str, PregelNode], specs: Mapping[str, BaseChannel | ManagedValueSpec], trigger_to_nodes: Mapping[str, Sequence[str]], durability: Durability, interrupt_after: All | Sequence[str] = EMPTY_SEQ, interrupt_before: All | Sequence[str] = EMPTY_SEQ, manager: None | AsyncParentRunManager | ParentRunManager = None, input_keys: str | Sequence[str] = EMPTY_SEQ, output_keys: str | Sequence[str] = EMPTY_SEQ, stream_keys: str | Sequence[str] = EMPTY_SEQ, migrate_checkpoint: Callable[[Checkpoint], None] | None = None, retry_policy: Sequence[RetryPolicy] = (), cache_policy: CachePolicy | None = None, ) -> None: super().__init__( input, stream=stream, config=config, checkpointer=checkpointer, cache=cache, store=store, nodes=nodes, specs=specs, input_keys=input_keys, output_keys=output_keys, stream_keys=stream_keys, interrupt_after=interrupt_after, interrupt_before=interrupt_before, manager=manager, migrate_checkpoint=migrate_checkpoint, trigger_to_nodes=trigger_to_nodes, retry_policy=retry_policy, cache_policy=cache_policy, durability=durability, ) self.stack = AsyncExitStack() if checkpointer: self.checkpointer_get_next_version = checkpointer.get_next_version self.checkpointer_put_writes = checkpointer.aput_writes self.checkpointer_put_writes_accepts_task_path = ( signature(checkpointer.aput_writes).parameters.get("task_path") is not None ) else: self.checkpointer_get_next_version = increment self._checkpointer_put_after_previous = None # type: ignore[assignment] self.checkpointer_put_writes = None self.checkpointer_put_writes_accepts_task_path = False async def _checkpointer_put_after_previous( self, prev: asyncio.Task | None, config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: ChannelVersions, ) -> RunnableConfig: try: if prev is not None: await prev finally: await cast(BaseCheckpointSaver, self.checkpointer).aput( config, checkpoint, metadata, new_versions ) async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]: if self.cache is None: return [] matched: list[PregelExecutableTask] = [] if cached := { (t.cache_key.ns, t.cache_key.key): t for t in self.tasks.values() if t.cache_key and not t.writes }: for key, values in (await self.cache.aget(tuple(cached))).items(): task = cached[key] task.writes.extend(values) matched.append(task) return matched async def aaccept_push( self, task: PregelExecutableTask, write_idx: int, call: Call | None = None ) -> PregelExecutableTask | None: if pushed := super().accept_push(task, write_idx, call): for task in await self.amatch_cached_writes(): self.output_writes(task.id, task.writes, cached=True) return pushed def put_writes(self, task_id: str, writes: WritesT) -> None: """Put writes for a task, to be read by the next tick.""" super().put_writes(task_id, writes) if not writes or self.cache is None or not hasattr(self, "tasks"): return task = self.tasks.get(task_id) if task is None or task.cache_key is None: return if writes[0][0] in (INTERRUPT, ERROR): # only cache successful tasks return self.submit( self.cache.aset, { (task.cache_key.ns, task.cache_key.key): ( task.writes, task.cache_key.ttl, ) }, ) # context manager async def __aenter__(self) -> Self: if self.checkpointer: saved = await self.checkpointer.aget_tuple(self.checkpoint_config) else: saved = None if saved is None: saved = CheckpointTuple( self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, [] ) elif self._migrate_checkpoint is not None: self._migrate_checkpoint(saved.checkpoint) self.checkpoint_config = { **self.checkpoint_config, **saved.config, CONF: { CONFIG_KEY_CHECKPOINT_NS: "", **self.checkpoint_config.get(CONF, {}), **saved.config.get(CONF, {}), }, } self.prev_checkpoint_config = saved.parent_config self.checkpoint_id_saved = saved.checkpoint["id"] self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( [(str(tid), k, v) for tid, k, v in saved.pending_writes] if saved.pending_writes is not None else [] ) self.submit = await self.stack.enter_async_context( AsyncBackgroundExecutor(self.config) ) self.channels, self.managed = channels_from_checkpoint( self.specs, self.checkpoint ) self.stack.push(self._suppress_interrupt) self.status = "input" self.step = self.checkpoint_metadata["step"] + 1 self.stop = self.step + self.config["recursion_limit"] + 1 self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy() self.updated_channels = self._first(input_keys=self.input_keys) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> bool | None: # unwind stack exit_task = asyncio.create_task( self.stack.__aexit__(exc_type, exc_value, traceback) ) try: return await exit_task except asyncio.CancelledError as e: # Bubble up the exit task upon cancellation to permit the API # consumer to await it before e.g., reusing the DB connection. e.args = (*e.args, exit_task) raise

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/kaman05010/MCPClientServer'

If you have feedback or need assistance with the MCP directory API, please join our Discord server