initial commit
This commit is contained in:
238
venv/Lib/site-packages/langgraph/pregel/_retry.py
Normal file
238
venv/Lib/site-packages/langgraph/pregel/_retry.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from langgraph._internal._config import patch_configurable, recast_checkpoint_ns
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
CONFIG_KEY_RESUMING,
|
||||
NS_SEP,
|
||||
)
|
||||
from langgraph.errors import GraphBubbleUp, ParentCommand
|
||||
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
|
||||
|
||||
|
||||
def _checkpoint_ns_for_parent_command(ns: str) -> str:
|
||||
"""Return the checkpoint namespace for the parent graph.
|
||||
|
||||
The checkpoint namespace is a `|`-separated path. Each segment is usually
|
||||
of the form `name:task_id` (e.g. `parent_first:<uuid>|node:<uuid>`), but the
|
||||
runtime may also insert a purely-numeric segment (e.g. `|1`) to disambiguate
|
||||
concurrent tasks (e.g. `parent_first:<uuid>|1|node:<uuid>`).
|
||||
|
||||
Numeric segments are not real path levels, so we drop them before computing
|
||||
the parent namespace.
|
||||
"""
|
||||
|
||||
parts = ns.split(NS_SEP)
|
||||
|
||||
# Drop any trailing numeric selectors for the current frame (e.g. `...|node:<id>|1`).
|
||||
while parts and parts[-1].isdigit():
|
||||
parts.pop()
|
||||
|
||||
# Drop the current frame segment itself (e.g. the `node:<id>`).
|
||||
if parts:
|
||||
parts.pop()
|
||||
|
||||
# Drop any trailing numeric selectors for the parent frame (e.g. `...|1|node:<id>`).
|
||||
while parts and parts[-1].isdigit():
|
||||
parts.pop()
|
||||
|
||||
return NS_SEP.join(parts)
|
||||
|
||||
|
||||
def run_with_retry(
|
||||
task: PregelExecutableTask,
|
||||
retry_policy: Sequence[RetryPolicy] | None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Run a task with retries."""
|
||||
retry_policy = task.retry_policy or retry_policy
|
||||
attempts = 0
|
||||
config = task.config
|
||||
if configurable is not None:
|
||||
config = patch_configurable(config, configurable)
|
||||
while True:
|
||||
try:
|
||||
# clear any writes from previous attempts
|
||||
task.writes.clear()
|
||||
# run the task
|
||||
return task.proc.invoke(task.input, config)
|
||||
except ParentCommand as exc:
|
||||
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
|
||||
cmd = exc.args[0]
|
||||
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
|
||||
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
|
||||
# this command is for the current graph, handle it
|
||||
for w in task.writers:
|
||||
w.invoke(cmd, config)
|
||||
break
|
||||
elif cmd.graph == Command.PARENT:
|
||||
# this command is for the parent graph, assign it to the parent.
|
||||
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
|
||||
# bubble up
|
||||
raise
|
||||
except GraphBubbleUp:
|
||||
# if interrupted, end
|
||||
raise
|
||||
except Exception as exc:
|
||||
if SUPPORTS_EXC_NOTES:
|
||||
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
|
||||
if not retry_policy:
|
||||
raise
|
||||
|
||||
# Check which retry policy applies to this exception
|
||||
matching_policy = None
|
||||
for policy in retry_policy:
|
||||
if _should_retry_on(policy, exc):
|
||||
matching_policy = policy
|
||||
break
|
||||
|
||||
if not matching_policy:
|
||||
raise
|
||||
|
||||
# increment attempts
|
||||
attempts += 1
|
||||
# check if we should give up
|
||||
if attempts >= matching_policy.max_attempts:
|
||||
raise
|
||||
# sleep before retrying
|
||||
interval = matching_policy.initial_interval
|
||||
# Apply backoff factor based on attempt count
|
||||
interval = min(
|
||||
matching_policy.max_interval,
|
||||
interval * (matching_policy.backoff_factor ** (attempts - 1)),
|
||||
)
|
||||
|
||||
# Apply jitter if configured
|
||||
sleep_time = (
|
||||
interval + random.uniform(0, 1) if matching_policy.jitter else interval
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# log the retry
|
||||
logger.info(
|
||||
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
# signal subgraphs to resume (if available)
|
||||
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
|
||||
|
||||
|
||||
async def arun_with_retry(
|
||||
task: PregelExecutableTask,
|
||||
retry_policy: Sequence[RetryPolicy] | None,
|
||||
stream: bool = False,
|
||||
match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
|
||||
| None = None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Run a task asynchronously with retries."""
|
||||
retry_policy = task.retry_policy or retry_policy
|
||||
attempts = 0
|
||||
config = task.config
|
||||
if configurable is not None:
|
||||
config = patch_configurable(config, configurable)
|
||||
if match_cached_writes is not None and task.cache_key is not None:
|
||||
for t in await match_cached_writes():
|
||||
if t is task:
|
||||
# if the task is already cached, return
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
# clear any writes from previous attempts
|
||||
task.writes.clear()
|
||||
# run the task
|
||||
if stream:
|
||||
async for _ in task.proc.astream(task.input, config):
|
||||
pass
|
||||
# if successful, end
|
||||
break
|
||||
else:
|
||||
return await task.proc.ainvoke(task.input, config)
|
||||
except ParentCommand as exc:
|
||||
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
|
||||
cmd = exc.args[0]
|
||||
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
|
||||
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
|
||||
# this command is for the current graph, handle it
|
||||
for w in task.writers:
|
||||
w.invoke(cmd, config)
|
||||
break
|
||||
elif cmd.graph == Command.PARENT:
|
||||
# this command is for the parent graph, assign it to the parent.
|
||||
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
|
||||
# bubble up
|
||||
raise
|
||||
except GraphBubbleUp:
|
||||
# if interrupted, end
|
||||
raise
|
||||
except Exception as exc:
|
||||
if SUPPORTS_EXC_NOTES:
|
||||
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
|
||||
if not retry_policy:
|
||||
raise
|
||||
|
||||
# Check which retry policy applies to this exception
|
||||
matching_policy = None
|
||||
for policy in retry_policy:
|
||||
if _should_retry_on(policy, exc):
|
||||
matching_policy = policy
|
||||
break
|
||||
|
||||
if not matching_policy:
|
||||
raise
|
||||
|
||||
# increment attempts
|
||||
attempts += 1
|
||||
# check if we should give up
|
||||
if attempts >= matching_policy.max_attempts:
|
||||
raise
|
||||
# sleep before retrying
|
||||
interval = matching_policy.initial_interval
|
||||
# Apply backoff factor based on attempt count
|
||||
interval = min(
|
||||
matching_policy.max_interval,
|
||||
interval * (matching_policy.backoff_factor ** (attempts - 1)),
|
||||
)
|
||||
|
||||
# Apply jitter if configured
|
||||
sleep_time = (
|
||||
interval + random.uniform(0, 1) if matching_policy.jitter else interval
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
# log the retry
|
||||
logger.info(
|
||||
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
# signal subgraphs to resume (if available)
|
||||
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
|
||||
|
||||
|
||||
def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
|
||||
"""Check if the given exception should be retried based on the retry policy."""
|
||||
if isinstance(retry_policy.retry_on, Sequence):
|
||||
return isinstance(exc, tuple(retry_policy.retry_on))
|
||||
elif isinstance(retry_policy.retry_on, type) and issubclass(
|
||||
retry_policy.retry_on, Exception
|
||||
):
|
||||
return isinstance(exc, retry_policy.retry_on)
|
||||
elif callable(retry_policy.retry_on):
|
||||
return retry_policy.retry_on(exc) # type: ignore[call-arg]
|
||||
else:
|
||||
raise TypeError(
|
||||
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
|
||||
)
|
||||
Reference in New Issue
Block a user