from __future__ import annotations
import asyncio
import logging
import random
import sys
import time
from collections.abc import Awaitable, Sequence
from dataclasses import replace
from typing import Any, Callable
from langgraph._internal._config import patch_configurable
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUMING,
NS_SEP,
)
from langgraph.errors import GraphBubbleUp, ParentCommand
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
logger = logging.getLogger(__name__)
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
def run_with_retry(
task: PregelExecutableTask,
retry_policy: Sequence[RetryPolicy] | None,
configurable: dict[str, Any] | None = None,
) -> None:
"""Run a task with retries."""
retry_policy = task.retry_policy or retry_policy
attempts = 0
config = task.config
if configurable is not None:
config = patch_configurable(config, configurable)
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
return task.proc.invoke(task.input, config)
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph in (ns, task.name):
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parts = ns.split(NS_SEP)
if parts[-1].isdigit():
parts.pop()
parent_ns = NS_SEP.join(parts[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if not retry_policy:
raise
# Check which retry policy applies to this exception
matching_policy = None
for policy in retry_policy:
if _should_retry_on(policy, exc):
matching_policy = policy
break
if not matching_policy:
raise
# increment attempts
attempts += 1
# check if we should give up
if attempts >= matching_policy.max_attempts:
raise
# sleep before retrying
interval = matching_policy.initial_interval
# Apply backoff factor based on attempt count
interval = min(
matching_policy.max_interval,
interval * (matching_policy.backoff_factor ** (attempts - 1)),
)
# Apply jitter if configured
sleep_time = (
interval + random.uniform(0, 1) if matching_policy.jitter else interval
)
time.sleep(sleep_time)
# log the retry
logger.info(
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
async def arun_with_retry(
task: PregelExecutableTask,
retry_policy: Sequence[RetryPolicy] | None,
stream: bool = False,
match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
| None = None,
configurable: dict[str, Any] | None = None,
) -> None:
"""Run a task asynchronously with retries."""
retry_policy = task.retry_policy or retry_policy
attempts = 0
config = task.config
if configurable is not None:
config = patch_configurable(config, configurable)
if match_cached_writes is not None and task.cache_key is not None:
for t in await match_cached_writes():
if t is task:
# if the task is already cached, return
return
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
if stream:
async for _ in task.proc.astream(task.input, config):
pass
# if successful, end
break
else:
return await task.proc.ainvoke(task.input, config)
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
if cmd.graph in (ns, task.name):
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent
parts = ns.split(NS_SEP)
if parts[-1].isdigit():
parts.pop()
parent_ns = NS_SEP.join(parts[:-1])
exc.args = (replace(cmd, graph=parent_ns),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if not retry_policy:
raise
# Check which retry policy applies to this exception
matching_policy = None
for policy in retry_policy:
if _should_retry_on(policy, exc):
matching_policy = policy
break
if not matching_policy:
raise
# increment attempts
attempts += 1
# check if we should give up
if attempts >= matching_policy.max_attempts:
raise
# sleep before retrying
interval = matching_policy.initial_interval
# Apply backoff factor based on attempt count
interval = min(
matching_policy.max_interval,
interval * (matching_policy.backoff_factor ** (attempts - 1)),
)
# Apply jitter if configured
sleep_time = (
interval + random.uniform(0, 1) if matching_policy.jitter else interval
)
await asyncio.sleep(sleep_time)
# log the retry
logger.info(
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
"""Check if the given exception should be retried based on the retry policy."""
if isinstance(retry_policy.retry_on, Sequence):
return isinstance(exc, tuple(retry_policy.retry_on))
elif isinstance(retry_policy.retry_on, type) and issubclass(
retry_policy.retry_on, Exception
):
return isinstance(exc, retry_policy.retry_on)
elif callable(retry_policy.retry_on):
return retry_policy.retry_on(exc) # type: ignore[call-arg]
else:
raise TypeError(
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
)