initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from langgraph.pregel.main import NodeBuilder, Pregel
__all__ = ("Pregel", "NodeBuilder")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,269 @@
"""Utility to convert a user provided function into a Runnable with a ChannelWrite."""
from __future__ import annotations
import concurrent.futures
import functools
import inspect
import sys
import types
from collections.abc import Awaitable, Callable, Generator, Sequence
from typing import Any, Generic, TypeVar, cast
from langchain_core.runnables import Runnable
from typing_extensions import ParamSpec
from langgraph._internal._constants import CONF, CONFIG_KEY_CALL, RETURN
from langgraph._internal._runnable import (
RunnableCallable,
RunnableSeq,
is_async_callable,
run_in_executor,
)
from langgraph.config import get_config
from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
from langgraph.types import CachePolicy, RetryPolicy
##
# Utilities borrowed from cloudpickle.
# https://github.com/cloudpipe/cloudpickle/blob/6220b0ce83ffee5e47e06770a1ee38ca9e47c850/cloudpickle/cloudpickle.py#L265
def _getattribute(obj: Any, name: str) -> Any:
parent = None
for subpath in name.split("."):
if subpath == "<locals>":
raise AttributeError(f"Can't get local attribute {name!r} on {obj!r}")
try:
parent = obj
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError(f"Can't get attribute {name!r} on {obj!r}") from None
return obj, parent
def _whichmodule(obj: Any, name: str) -> str | None:
"""Find the module an object belongs to.
This function differs from ``pickle.whichmodule`` in two ways:
- it does not mangle the cases where obj's module is __main__ and obj was
not found in any module.
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
module_name = getattr(obj, "__module__", None)
if module_name is not None:
return module_name
# Protect the iteration by using a copy of sys.modules against dynamic
# modules that trigger imports of other modules upon calls to getattr or
# other threads importing at the same time.
for module_name, module in sys.modules.copy().items():
# Some modules such as coverage can inject non-module objects inside
# sys.modules
if (
module_name == "__main__"
or module_name == "__mp_main__"
or module is None
or not isinstance(module, types.ModuleType)
):
continue
try:
if _getattribute(module, name)[0] is obj:
return module_name
except Exception:
pass
return None
def identifier(obj: Any, name: str | None = None) -> str | None:
"""Return the module and name of an object."""
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
from langgraph.pregel._read import PregelNode
if isinstance(obj, PregelNode):
obj = obj.bound
if isinstance(obj, RunnableSeq):
obj = obj.steps[0]
if isinstance(obj, RunnableCallable):
obj = obj.func
if name is None:
name = getattr(obj, "__qualname__", None)
if name is None: # pragma: no cover
# This used to be needed for Python 2.7 support but is probably not
# needed anymore. However we keep the __name__ introspection in case
# users of cloudpickle rely on this old behavior for unknown reasons.
name = getattr(obj, "__name__", None)
if name is None:
return None
module_name = getattr(obj, "__module__", None)
if module_name is None:
# In this case, obj.__module__ is None. obj is thus treated as dynamic.
return None
return f"{module_name}.{name}"
def _lookup_module_and_qualname(
obj: Any, name: str | None = None
) -> tuple[types.ModuleType, str] | None:
if name is None:
name = getattr(obj, "__qualname__", None)
if name is None: # pragma: no cover
# This used to be needed for Python 2.7 support but is probably not
# needed anymore. However we keep the __name__ introspection in case
# users of cloudpickle rely on this old behavior for unknown reasons.
name = getattr(obj, "__name__", None)
if name is None:
return None
module_name = _whichmodule(obj, name)
if module_name is None:
# In this case, obj.__module__ is None AND obj was not found in any
# imported module. obj is thus treated as dynamic.
return None
if module_name == "__main__":
return None
# Note: if module_name is in sys.modules, the corresponding module is
# assumed importable at unpickling time. See #357
module = sys.modules.get(module_name, None)
if module is None:
# The main reason why obj's module would not be imported is that this
# module has been dynamically created, using for example
# types.ModuleType. The other possibility is that module was removed
# from sys.modules after obj was created/imported. But this case is not
# supported, as the standard pickle does not support it either.
return None
try:
obj2, parent = _getattribute(module, name)
except AttributeError:
# obj was not found inside the module it points to
return None
if obj2 is not obj:
return None
return module, name
def _explode_args_trace_inputs(
sig: inspect.Signature, input: tuple[tuple[Any, ...], dict[str, Any]]
) -> dict[str, Any]:
args, kwargs = input
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
arguments = dict(bound.arguments)
arguments.pop("self", None)
arguments.pop("cls", None)
for param_name, param in sig.parameters.items():
if param.kind == inspect.Parameter.VAR_KEYWORD:
# Update with the **kwargs, and remove the original entry
# This is to help flatten out keyword arguments
if param_name in arguments:
arguments.update(arguments.pop(param_name))
return arguments
def get_runnable_for_entrypoint(func: Callable[..., Any]) -> Runnable:
key = (func, False)
if key in CACHE:
return CACHE[key]
else:
if is_async_callable(func):
run = RunnableCallable(
None, func, name=func.__name__, trace=False, recurse=False
)
else:
afunc = functools.update_wrapper(
functools.partial(run_in_executor, None, func), func
)
run = RunnableCallable(
func,
afunc,
name=func.__name__,
trace=False,
recurse=False,
)
if not _lookup_module_and_qualname(func):
return run
return CACHE.setdefault(key, run)
def get_runnable_for_task(func: Callable[..., Any]) -> Runnable:
key = (func, True)
if key in CACHE:
return CACHE[key]
else:
if hasattr(func, "__name__"):
name = func.__name__
elif hasattr(func, "func"):
name = func.func.__name__
elif hasattr(func, "__class__"):
name = func.__class__.__name__
else:
name = str(func)
if is_async_callable(func):
run = RunnableCallable(
None,
func,
explode_args=True,
name=name,
trace=False,
recurse=False,
)
else:
run = RunnableCallable(
func,
functools.wraps(func)(functools.partial(run_in_executor, None, func)),
explode_args=True,
name=name,
trace=False,
recurse=False,
)
seq = RunnableSeq(
run,
ChannelWrite([ChannelWriteEntry(RETURN)]),
name=name,
trace_inputs=functools.partial(
_explode_args_trace_inputs, inspect.signature(func)
),
)
if not _lookup_module_and_qualname(func):
return seq
return CACHE.setdefault(key, seq)
CACHE: dict[tuple[Callable[..., Any], bool], Runnable] = {}
P = ParamSpec("P")
P1 = TypeVar("P1")
T = TypeVar("T")
class SyncAsyncFuture(Generic[T], concurrent.futures.Future[T]):
def __await__(self) -> Generator[T, None, T]:
yield cast(T, ...)
def call(
func: Callable[P, Awaitable[T]] | Callable[P, T],
*args: Any,
retry_policy: Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
**kwargs: Any,
) -> SyncAsyncFuture[T]:
config = get_config()
impl = config[CONF][CONFIG_KEY_CALL]
fut = impl(
func,
(args, kwargs),
retry_policy=retry_policy,
cache_policy=cache_policy,
callbacks=config["callbacks"],
)
return fut

View File

@@ -0,0 +1,88 @@
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timezone
from langgraph.checkpoint.base import Checkpoint
from langgraph.checkpoint.base.id import uuid6
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel
from langgraph.managed.base import ManagedValueMapping, ManagedValueSpec
LATEST_VERSION = 4
def empty_checkpoint() -> Checkpoint:
return Checkpoint(
v=LATEST_VERSION,
id=str(uuid6(clock_seq=-2)),
ts=datetime.now(timezone.utc).isoformat(),
channel_values={},
channel_versions={},
versions_seen={},
)
def create_checkpoint(
checkpoint: Checkpoint,
channels: Mapping[str, BaseChannel] | None,
step: int,
*,
id: str | None = None,
updated_channels: set[str] | None = None,
) -> Checkpoint:
"""Create a checkpoint for the given channels."""
ts = datetime.now(timezone.utc).isoformat()
if channels is None:
values = checkpoint["channel_values"]
else:
values = {}
for k in channels:
if k not in checkpoint["channel_versions"]:
continue
v = channels[k].checkpoint()
if v is not MISSING:
values[k] = v
return Checkpoint(
v=LATEST_VERSION,
ts=ts,
id=id or str(uuid6(clock_seq=step)),
channel_values=values,
channel_versions=checkpoint["channel_versions"],
versions_seen=checkpoint["versions_seen"],
updated_channels=None if updated_channels is None else sorted(updated_channels),
)
def channels_from_checkpoint(
specs: Mapping[str, BaseChannel | ManagedValueSpec],
checkpoint: Checkpoint,
) -> tuple[Mapping[str, BaseChannel], ManagedValueMapping]:
"""Get channels from a checkpoint."""
channel_specs: dict[str, BaseChannel] = {}
managed_specs: dict[str, ManagedValueSpec] = {}
for k, v in specs.items():
if isinstance(v, BaseChannel):
channel_specs[k] = v
else:
managed_specs[k] = v
return (
{
k: v.from_checkpoint(checkpoint["channel_values"].get(k, MISSING))
for k, v in channel_specs.items()
},
managed_specs,
)
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
return Checkpoint(
v=checkpoint["v"],
ts=checkpoint["ts"],
id=checkpoint["id"],
channel_values=checkpoint["channel_values"].copy(),
channel_versions=checkpoint["channel_versions"].copy(),
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
updated_channels=checkpoint.get("updated_channels", None),
)

View File

@@ -0,0 +1,294 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, NamedTuple, cast
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.graph import Graph, Node
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, INPUT
from langgraph.channels.base import BaseChannel
from langgraph.channels.last_value import LastValueAfterFinish
from langgraph.constants import END, START
from langgraph.managed.base import ManagedValueSpec
from langgraph.pregel._algo import (
PregelTaskWrites,
apply_writes,
increment,
prepare_next_tasks,
)
from langgraph.pregel._checkpoint import channels_from_checkpoint, empty_checkpoint
from langgraph.pregel._io import map_input
from langgraph.pregel._read import PregelNode
from langgraph.pregel._write import ChannelWrite
from langgraph.types import All, Checkpointer
class Edge(NamedTuple):
source: str
target: str
conditional: bool
data: str | None
class TriggerEdge(NamedTuple):
source: str
conditional: bool
data: str | None
def draw_graph(
config: RunnableConfig,
*,
nodes: dict[str, PregelNode],
specs: dict[str, BaseChannel | ManagedValueSpec],
input_channels: str | Sequence[str],
interrupt_after_nodes: All | Sequence[str],
interrupt_before_nodes: All | Sequence[str],
trigger_to_nodes: Mapping[str, Sequence[str]],
checkpointer: Checkpointer,
subgraphs: dict[str, Graph],
limit: int = 250,
) -> Graph:
"""Get the graph for this Pregel instance.
Args:
config: The configuration to use for the graph.
subgraphs: The subgraphs to include in the graph.
checkpointer: The checkpointer to use for the graph.
Returns:
The graph for this Pregel instance.
"""
# (src, dest, is_conditional, label)
edges: set[Edge] = set()
step = -1
checkpoint = empty_checkpoint()
get_next_version = (
checkpointer.get_next_version
if isinstance(checkpointer, BaseCheckpointSaver)
else increment
)
channels, managed = channels_from_checkpoint(
specs,
checkpoint,
)
static_seen: set[Any] = set()
sources: dict[str, set[TriggerEdge]] = {}
step_sources: dict[str, set[TriggerEdge]] = {}
static_declared_writes: dict[str, set[TriggerEdge]] = defaultdict(set)
# remove node mappers
nodes = {
k: v.copy(update={"mapper": None}) if v.mapper is not None else v
for k, v in nodes.items()
}
# apply input writes
input_writes = list(map_input(input_channels, {}))
updated_channels = apply_writes(
checkpoint,
channels,
[
PregelTaskWrites((), INPUT, input_writes, []),
],
get_next_version,
trigger_to_nodes,
)
# prepare first tasks
tasks = prepare_next_tasks(
checkpoint,
[],
nodes,
channels,
managed,
config,
step,
-1,
for_execution=True,
store=None,
checkpointer=None,
manager=None,
trigger_to_nodes=trigger_to_nodes,
updated_channels=updated_channels,
)
start_tasks = tasks
# run the pregel loop
for step in range(step, limit):
if not tasks:
break
conditionals: dict[tuple[str, str, Any], str | None] = {}
# run task writers
for task in tasks.values():
for w in task.writers:
# apply regular writes
if isinstance(w, ChannelWrite):
empty_input = (
cast(BaseChannel, specs["__root__"]).ValueType()
if "__root__" in specs
else None
)
w.invoke(empty_input, task.config)
# apply conditional writes declared for static analysis, only once
if w not in static_seen:
static_seen.add(w)
# apply static writes
if writes := ChannelWrite.get_static_writes(w):
# END writes are not written, but become edges directly
for t in writes:
if t[0] == END:
edges.add(Edge(task.name, t[0], True, t[2]))
writes = [t for t in writes if t[0] != END]
conditionals.update(
{(task.name, t[0], t[1] or None): t[2] for t in writes}
)
# record static writes for edge creation
for t in writes:
static_declared_writes[task.name].add(
TriggerEdge(t[0], True, t[2])
)
task.config[CONF][CONFIG_KEY_SEND]([t[:2] for t in writes])
# collect sources
step_sources = {}
for task in tasks.values():
task_edges = {
TriggerEdge(
w[0],
(task.name, w[0], w[1] or None) in conditionals,
conditionals.get((task.name, w[0], w[1] or None)),
)
for w in task.writes
}
task_edges |= static_declared_writes.get(task.name, set())
step_sources[task.name] = task_edges
sources.update(step_sources)
# invert triggers
trigger_to_sources: dict[str, set[TriggerEdge]] = defaultdict(set)
for src, triggers in sources.items():
for trigger, cond, label in triggers:
trigger_to_sources[trigger].add(TriggerEdge(src, cond, label))
# apply writes
updated_channels = apply_writes(
checkpoint, channels, tasks.values(), get_next_version, trigger_to_nodes
)
# prepare next tasks
tasks = prepare_next_tasks(
checkpoint,
[],
nodes,
channels,
managed,
config,
step,
limit,
for_execution=True,
store=None,
checkpointer=None,
manager=None,
trigger_to_nodes=trigger_to_nodes,
updated_channels=updated_channels,
)
# collect deferred nodes
deferred_nodes: set[str] = set()
edges_to_deferred_nodes: set[Edge] = set()
for channel, item in channels.items():
if isinstance(item, LastValueAfterFinish):
deferred_node = channel.split(":", 2)[-1]
deferred_nodes.add(deferred_node)
# collect edges
for task in tasks.values():
added = False
for trigger in task.triggers:
for src, cond, label in sorted(trigger_to_sources[trigger]):
# record edge to be reviewed later
if task.name in deferred_nodes:
edges_to_deferred_nodes.add(Edge(src, task.name, cond, label))
edges.add(Edge(src, task.name, cond, label))
# if the edge is from this step, skip adding the implicit edges
if (trigger, cond, label) in step_sources.get(src, set()):
added = True
else:
sources[src].discard(TriggerEdge(trigger, cond, label))
# if no edges from this step, add implicit edges from all previous tasks
if not added:
for src in step_sources:
edges.add(Edge(src, task.name, True, None))
# assemble the graph
graph = Graph()
# add nodes
for name, node in nodes.items():
metadata = dict(node.metadata or {})
if name in deferred_nodes:
metadata["defer"] = True
if name in interrupt_before_nodes and name in interrupt_after_nodes:
metadata["__interrupt"] = "before,after"
elif name in interrupt_before_nodes:
metadata["__interrupt"] = "before"
elif name in interrupt_after_nodes:
metadata["__interrupt"] = "after"
graph.add_node(node.bound, name, metadata=metadata or None)
# add start node
if START not in nodes:
graph.add_node(None, START)
for task in start_tasks.values():
add_edge(graph, START, task.name)
# add discovered edges
for src, dest, is_conditional, label in sorted(edges):
add_edge(
graph,
src,
dest,
data=label if label != dest else None,
conditional=is_conditional,
)
# add end edges
termini = {d for _, d, _, _ in edges if d != END}.difference(
s for s, _, _, _ in edges
)
end_edge_exists = any(d == END for _, d, _, _ in edges)
if termini:
for src in sorted(termini):
add_edge(graph, src, END)
elif len(step_sources) == 1 and not end_edge_exists:
for src in sorted(step_sources):
add_edge(graph, src, END, conditional=True)
# replace subgraphs
for name, subgraph in subgraphs.items():
if (
len(subgraph.nodes) > 1
and name in graph.nodes
and subgraph.first_node()
and subgraph.last_node()
):
subgraph.trim_first_node()
subgraph.trim_last_node()
# replace the node with the subgraph
graph.nodes.pop(name)
first, last = graph.extend(subgraph, prefix=name)
for idx, edge in enumerate(graph.edges):
if edge.source == name:
edge = edge.copy(source=cast(Node, last).id)
if edge.target == name:
edge = edge.copy(target=cast(Node, first).id)
graph.edges[idx] = edge
return graph
def add_edge(
graph: Graph,
source: str,
target: str,
*,
data: Any | None = None,
conditional: bool = False,
) -> None:
"""Add an edge to the graph."""
for edge in graph.edges:
if edge.source == source and edge.target == target:
return
if target not in graph.nodes and target == END:
graph.add_node(None, END)
graph.add_edge(graph.nodes[source], graph.nodes[target], data, conditional)

View File

@@ -0,0 +1,223 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import time
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from contextvars import copy_context
from types import TracebackType
from typing import (
Protocol,
TypeVar,
cast,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from typing_extensions import ParamSpec
from langgraph._internal._future import CONTEXT_NOT_SUPPORTED, run_coroutine_threadsafe
from langgraph.errors import GraphBubbleUp
P = ParamSpec("P")
T = TypeVar("T")
class Submit(Protocol[P, T]):
def __call__( # type: ignore[valid-type]
self,
fn: Callable[P, T],
*args: P.args,
__name__: str | None = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]: ...
class BackgroundExecutor(AbstractContextManager):
"""A context manager that runs sync tasks in the background.
Uses a thread pool executor to delegate tasks to separate threads.
On exit,
- cancels any (not yet started) tasks with `__cancel_on_exit__=True`
- waits for all tasks to finish
- re-raises the first exception from tasks with `__reraise_on_exit__=True`"""
def __init__(self, config: RunnableConfig) -> None:
self.stack = ExitStack()
self.executor = self.stack.enter_context(get_executor_for_config(config))
# mapping of Future to (__cancel_on_exit__, __reraise_on_exit__) flags
self.tasks: dict[concurrent.futures.Future, tuple[bool, bool]] = {}
def submit( # type: ignore[valid-type]
self,
fn: Callable[P, T],
*args: P.args,
__name__: str | None = None, # currently not used in sync version
__cancel_on_exit__: bool = False, # for sync, can cancel only if not started
__reraise_on_exit__: bool = True,
__next_tick__: bool = False,
**kwargs: P.kwargs,
) -> concurrent.futures.Future[T]:
ctx = copy_context()
if __next_tick__:
task = cast(
concurrent.futures.Future[T],
self.executor.submit(next_tick, ctx.run, fn, *args, **kwargs), # type: ignore[arg-type]
)
else:
task = self.executor.submit(ctx.run, fn, *args, **kwargs)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
# add a callback to remove the task from the tasks dict when it's done
task.add_done_callback(self.done)
return task
def done(self, task: concurrent.futures.Future) -> None:
"""Remove the task from the tasks dict when it's done."""
try:
task.result()
except GraphBubbleUp:
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
self.tasks.pop(task)
except BaseException:
pass
else:
self.tasks.pop(task)
def __enter__(self) -> Submit:
return self.submit
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
# copy the tasks as done() callback may modify the dict
tasks = self.tasks.copy()
# cancel all tasks that should be cancelled
for task, (cancel, _) in tasks.items():
if cancel:
task.cancel()
# wait for all tasks to finish
if pending := {t for t in tasks if not t.done()}:
concurrent.futures.wait(pending)
# shutdown the executor
self.stack.__exit__(exc_type, exc_value, traceback)
# if there's already an exception being raised, don't raise another one
if exc_type is None:
# re-raise the first exception that occurred in a task
for task, (_, reraise) in tasks.items():
if not reraise:
continue
try:
task.result()
except concurrent.futures.CancelledError:
pass
class AsyncBackgroundExecutor(AbstractAsyncContextManager):
"""A context manager that runs async tasks in the background.
Uses the current event loop to delegate tasks to asyncio tasks.
On exit,
- cancels any tasks with `__cancel_on_exit__=True`
- waits for all tasks to finish
- re-raises the first exception from tasks with `__reraise_on_exit__=True`
ignoring CancelledError"""
def __init__(self, config: RunnableConfig) -> None:
self.tasks: dict[asyncio.Future, tuple[bool, bool]] = {}
self.sentinel = object()
self.loop = asyncio.get_running_loop()
if max_concurrency := config.get("max_concurrency"):
self.semaphore: asyncio.Semaphore | None = asyncio.Semaphore(
max_concurrency
)
else:
self.semaphore = None
def submit( # type: ignore[valid-type]
self,
fn: Callable[P, Awaitable[T]],
*args: P.args,
__name__: str | None = None,
__cancel_on_exit__: bool = False,
__reraise_on_exit__: bool = True,
__next_tick__: bool = False, # noop in async (always True)
**kwargs: P.kwargs,
) -> asyncio.Future[T]:
coro = cast(Coroutine[None, None, T], fn(*args, **kwargs))
if self.semaphore:
coro = gated(self.semaphore, coro)
if CONTEXT_NOT_SUPPORTED:
task = run_coroutine_threadsafe(
coro, self.loop, name=__name__, lazy=__next_tick__
)
else:
task = run_coroutine_threadsafe(
coro,
self.loop,
name=__name__,
context=copy_context(),
lazy=__next_tick__,
)
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
task.add_done_callback(self.done)
return task
def done(self, task: asyncio.Future) -> None:
try:
if exc := task.exception():
# This exception is an interruption signal, not an error
# so we don't want to re-raise it on exit
if isinstance(exc, GraphBubbleUp):
self.tasks.pop(task)
else:
self.tasks.pop(task)
except asyncio.CancelledError:
self.tasks.pop(task)
async def __aenter__(self) -> Submit:
return self.submit
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
# copy the tasks as done() callback may modify the dict
tasks = self.tasks.copy()
# cancel all tasks that should be cancelled
for task, (cancel, _) in tasks.items():
if cancel:
task.cancel(self.sentinel)
# wait for all tasks to finish
if tasks:
await asyncio.wait(tasks)
# if there's already an exception being raised, don't raise another one
if exc_type is None:
# re-raise the first exception that occurred in a task
for task, (_, reraise) in tasks.items():
if not reraise:
continue
try:
if exc := task.exception():
raise exc
except asyncio.CancelledError:
pass
async def gated(semaphore: asyncio.Semaphore, coro: Coroutine[None, None, T]) -> T:
"""A coroutine that waits for a semaphore before running another coroutine."""
async with semaphore:
return await coro
def next_tick(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""A function that yields control to other threads before running another function."""
time.sleep(0)
return fn(*args, **kwargs)

View File

@@ -0,0 +1,174 @@
from __future__ import annotations
from collections import Counter
from collections.abc import Iterator, Mapping, Sequence
from typing import Any, Literal
from langgraph._internal._constants import (
ERROR,
INTERRUPT,
NULL_TASK_ID,
RESUME,
RETURN,
TASKS,
)
from langgraph._internal._typing import EMPTY_SEQ, MISSING
from langgraph.channels.base import BaseChannel, EmptyChannelError
from langgraph.constants import START, TAG_HIDDEN
from langgraph.errors import InvalidUpdateError
from langgraph.pregel._log import logger
from langgraph.types import Command, PregelExecutableTask, Send
def read_channel(
channels: Mapping[str, BaseChannel],
chan: str,
*,
catch: bool = True,
) -> Any:
try:
return channels[chan].get()
except EmptyChannelError:
if catch:
return None
else:
raise
def read_channels(
channels: Mapping[str, BaseChannel],
select: Sequence[str] | str,
*,
skip_empty: bool = True,
) -> dict[str, Any] | Any:
if isinstance(select, str):
return read_channel(channels, select)
else:
values: dict[str, Any] = {}
for k in select:
try:
values[k] = read_channel(channels, k, catch=not skip_empty)
except EmptyChannelError:
pass
return values
def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if cmd.graph == Command.PARENT:
raise InvalidUpdateError("There is no parent graph")
if cmd.goto:
if isinstance(cmd.goto, (tuple, list)):
sends = cmd.goto
else:
sends = [cmd.goto]
for send in sends:
if isinstance(send, Send):
yield (NULL_TASK_ID, TASKS, send)
elif isinstance(send, str):
yield (NULL_TASK_ID, f"branch:to:{send}", START)
else:
raise TypeError(
f"In Command.goto, expected Send/str, got {type(send).__name__}"
)
if cmd.resume is not None:
yield (NULL_TASK_ID, RESUME, cmd.resume)
if cmd.update:
for k, v in cmd._update_as_tuples():
yield (NULL_TASK_ID, k, v)
def map_input(
input_channels: str | Sequence[str],
chunk: dict[str, Any] | Any | None,
) -> Iterator[tuple[str, Any]]:
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
if chunk is None:
return
elif isinstance(input_channels, str):
yield (input_channels, chunk)
else:
if not isinstance(chunk, dict):
raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
for k in chunk:
if k in input_channels:
yield (k, chunk[k])
else:
logger.warning(f"Input channel {k} not found in {input_channels}")
def map_output_values(
output_channels: str | Sequence[str],
pending_writes: Literal[True] | Sequence[tuple[str, Any]],
channels: Mapping[str, BaseChannel],
) -> Iterator[dict[str, Any] | Any]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
if isinstance(output_channels, str):
if pending_writes is True or any(
chan == output_channels for chan, _ in pending_writes
):
yield read_channel(channels, output_channels)
else:
if pending_writes is True or {
c for c, _ in pending_writes if c in output_channels
}:
yield read_channels(channels, output_channels)
def map_output_updates(
output_channels: str | Sequence[str],
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
cached: bool = False,
) -> Iterator[dict[str, Any | dict[str, Any]]]:
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
output_tasks = [
(t, ww)
for t, ww in tasks
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
and ww[0][0] != ERROR
and ww[0][0] != INTERRUPT
]
if not output_tasks:
return
updated: list[tuple[str, Any]] = []
for task, writes in output_tasks:
rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
if rtn is not MISSING:
updated.append((task.name, rtn))
elif isinstance(output_channels, str):
updated.extend(
(task.name, value) for chan, value in writes if chan == output_channels
)
elif any(chan in output_channels for chan, _ in writes):
counts = Counter(chan for chan, _ in writes)
if any(counts[chan] > 1 for chan in output_channels):
updated.extend(
(
task.name,
{chan: value},
)
for chan, value in writes
if chan in output_channels
)
else:
updated.append(
(
task.name,
{
chan: value
for chan, value in writes
if chan in output_channels
},
)
)
grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
for node, value in updated:
grouped[node].append(value)
for node, value in grouped.items():
if len(value) == 0:
grouped[node] = None
if len(value) == 1:
grouped[node] = value[0]
if cached:
grouped["__metadata__"] = {"cached": cached}
yield grouped

View File

@@ -0,0 +1,3 @@
import logging
logger = logging.getLogger("langgraph")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,250 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from dataclasses import fields, is_dataclass
from typing import (
Any,
TypeVar,
cast,
)
from uuid import UUID, uuid4
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
from pydantic import BaseModel
from langgraph._internal._constants import NS_SEP
from langgraph.constants import TAG_HIDDEN, TAG_NOSTREAM
from langgraph.pregel.protocol import StreamChunk
from langgraph.types import Command
try:
from langchain_core.tracers._streaming import _StreamingCallbackHandler
except ImportError:
_StreamingCallbackHandler = object # type: ignore
T = TypeVar("T")
Meta = tuple[tuple[str, ...], dict[str, Any]]
def _state_values(obj: Any) -> Sequence[Any]:
"""Extract top-level field values from a state object (dict, BaseModel, or dataclass)."""
if isinstance(obj, dict):
return list(obj.values())
elif isinstance(obj, BaseModel):
return [getattr(obj, k) for k in type(obj).model_fields]
elif is_dataclass(obj) and not isinstance(obj, type):
return [getattr(obj, f.name) for f in fields(obj)]
return ()
class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
"""A callback handler that implements stream_mode=messages.
Collects messages from:
(1) chat model stream events; and
(2) node outputs.
"""
run_inline = True
"""We want this callback to run in the main thread to avoid order/locking issues."""
def __init__(
self,
stream: Callable[[StreamChunk], None],
subgraphs: bool,
*,
parent_ns: tuple[str, ...] | None = None,
) -> None:
"""Configure the handler to stream messages from LLMs and nodes.
Args:
stream: A callable that takes a StreamChunk and emits it.
subgraphs: Whether to emit messages from subgraphs.
parent_ns: The namespace where the handler was created.
We keep track of this namespace to allow calls to subgraphs that
were explicitly requested as a stream with `messages` mode
configured.
Example:
parent_ns is used to handle scenarios where the subgraph is explicitly
streamed with `stream_mode="messages"`.
```python
def parent_graph_node():
# This node is in the parent graph.
async for event in some_subgraph(..., stream_mode="messages"):
do something with event # <-- these events will be emitted
return ...
parent_graph.invoke(subgraphs=False)
```
"""
self.stream = stream
self.subgraphs = subgraphs
self.metadata: dict[UUID, Meta] = {}
self.seen: set[int | str] = set()
self.parent_ns = parent_ns
def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
if dedupe and message.id in self.seen:
return
else:
if message.id is None:
message.id = str(uuid4())
self.seen.add(message.id)
self.stream((meta[0], "messages", (message, meta[1])))
def _find_and_emit_messages(self, meta: Meta, response: Any) -> None:
if isinstance(response, BaseMessage):
self._emit(meta, response, dedupe=True)
elif isinstance(response, Sequence):
for value in response:
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
else:
for value in _state_values(response):
if isinstance(value, BaseMessage):
self._emit(meta, value, dedupe=True)
elif isinstance(value, Sequence):
for item in value:
if isinstance(item, BaseMessage):
self._emit(meta, item, dedupe=True)
def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
return output
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
return output
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
if metadata and (not tags or (TAG_NOSTREAM not in tags)):
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
:-1
]
if not self.subgraphs and len(ns) > 0 and ns != self.parent_ns:
return
if tags:
if filtered_tags := [t for t in tags if not t.startswith("seq:step")]:
metadata["tags"] = filtered_tags
self.metadata[run_id] = (ns, metadata)
def on_llm_new_token(
self,
token: str,
*,
chunk: ChatGenerationChunk | None = None,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
if not isinstance(chunk, ChatGenerationChunk):
return
if meta := self.metadata.get(run_id):
self._emit(meta, chunk.message)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
if meta := self.metadata.get(run_id):
if response.generations and response.generations[0]:
gen = response.generations[0][0]
if isinstance(gen, ChatGeneration):
self._emit(meta, gen.message, dedupe=True)
self.metadata.pop(run_id, None)
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
if (
metadata
and kwargs.get("name") == metadata.get("langgraph_node")
and (not tags or TAG_HIDDEN not in tags)
):
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
:-1
]
if not self.subgraphs and len(ns) > 0:
return
self.metadata[run_id] = (ns, metadata)
for value in _state_values(inputs):
if isinstance(value, BaseMessage):
if value.id is not None:
self.seen.add(value.id)
elif isinstance(value, Sequence) and not isinstance(value, str):
for item in value:
if isinstance(item, BaseMessage):
if item.id is not None:
self.seen.add(item.id)
def on_chain_end(
self,
response: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
if meta := self.metadata.pop(run_id, None):
# Handle Command node updates
if isinstance(response, Command):
self._find_and_emit_messages(meta, response.update)
# Handle list of Command updates
elif isinstance(response, Sequence) and any(
isinstance(value, Command) for value in response
):
for value in response:
if isinstance(value, Command):
self._find_and_emit_messages(meta, value.update)
else:
self._find_and_emit_messages(meta, value)
# Handle basic updates / streaming
else:
self._find_and_emit_messages(meta, response)
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
self.metadata.pop(run_id, None)

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
from functools import cached_property
from typing import (
Any,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph._internal._config import merge_configs
from langgraph._internal._constants import CONF, CONFIG_KEY_READ
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
from langgraph.pregel._utils import find_subgraph_pregel
from langgraph.pregel._write import ChannelWrite
from langgraph.pregel.protocol import PregelProtocol
from langgraph.types import CachePolicy, RetryPolicy
READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
class ChannelRead(RunnableCallable):
"""Implements the logic for reading state from CONFIG_KEY_READ.
Usable both as a runnable as well as a static method to call imperatively."""
channel: str | list[str]
fresh: bool = False
mapper: Callable[[Any], Any] | None = None
def __init__(
self,
channel: str | list[str],
*,
fresh: bool = False,
mapper: Callable[[Any], Any] | None = None,
tags: list[str] | None = None,
) -> None:
super().__init__(
func=self._read,
afunc=self._aread,
tags=tags,
name=None,
trace=False,
)
self.fresh = fresh
self.mapper = mapper
self.channel = channel
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
if name:
pass
elif isinstance(self.channel, str):
name = f"ChannelRead<{self.channel}>"
else:
name = f"ChannelRead<{','.join(self.channel)}>"
return super().get_name(suffix, name=name)
def _read(self, _: Any, config: RunnableConfig) -> Any:
return self.do_read(
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
)
async def _aread(self, _: Any, config: RunnableConfig) -> Any:
return self.do_read(
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
)
@staticmethod
def do_read(
config: RunnableConfig,
*,
select: str | list[str],
fresh: bool = False,
mapper: Callable[[Any], Any] | None = None,
) -> Any:
try:
read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
except KeyError:
raise RuntimeError(
"Not configured with a read function"
"Make sure to call in the context of a Pregel process"
)
if mapper:
return mapper(read(select, fresh))
else:
return read(select, fresh)
DEFAULT_BOUND = RunnableCallable(lambda input: input)
class PregelNode:
"""A node in a Pregel graph. This won't be invoked as a runnable by the graph
itself, but instead acts as a container for the components necessary to make
a PregelExecutableTask for a node."""
channels: str | list[str]
"""The channels that will be passed as input to `bound`.
If a str, the node will be invoked with its value if it isn't empty.
If a list, the node will be invoked with a dict of those channels' values."""
triggers: list[str]
"""If any of these channels is written to, this node will be triggered in
the next step."""
mapper: Callable[[Any], Any] | None
"""A function to transform the input before passing it to `bound`."""
writers: list[Runnable]
"""A list of writers that will be executed after `bound`, responsible for
taking the output of `bound` and writing it to the appropriate channels."""
bound: Runnable[Any, Any]
"""The main logic of the node. This will be invoked with the input from
`channels`."""
retry_policy: Sequence[RetryPolicy] | None
"""The retry policies to use when invoking the node."""
cache_policy: CachePolicy | None
"""The cache policy to use when invoking the node."""
tags: Sequence[str] | None
"""Tags to attach to the node for tracing."""
metadata: Mapping[str, Any] | None
"""Metadata to attach to the node for tracing."""
subgraphs: Sequence[PregelProtocol]
"""Subgraphs used by the node."""
def __init__(
self,
*,
channels: str | list[str],
triggers: Sequence[str],
mapper: Callable[[Any], Any] | None = None,
writers: list[Runnable] | None = None,
tags: list[str] | None = None,
metadata: Mapping[str, Any] | None = None,
bound: Runnable[Any, Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
subgraphs: Sequence[PregelProtocol] | None = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
self.mapper = mapper
self.writers = writers or []
self.bound = bound if bound is not None else DEFAULT_BOUND
self.cache_policy = cache_policy
if isinstance(retry_policy, RetryPolicy):
self.retry_policy = (retry_policy,)
else:
self.retry_policy = retry_policy
self.tags = tags
self.metadata = metadata
if subgraphs is not None:
self.subgraphs = subgraphs
elif self.bound is not DEFAULT_BOUND:
try:
subgraph = find_subgraph_pregel(self.bound)
except Exception:
subgraph = None
if subgraph:
self.subgraphs = [subgraph]
else:
self.subgraphs = []
else:
self.subgraphs = []
def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
# Drop the cached properties
attrs.pop("flat_writers", None)
attrs.pop("node", None)
attrs.pop("input_cache_key", None)
return PregelNode(**attrs)
@cached_property
def flat_writers(self) -> list[Runnable]:
"""Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
writers = self.writers.copy()
while (
len(writers) > 1
and isinstance(writers[-1], ChannelWrite)
and isinstance(writers[-2], ChannelWrite)
):
# we can combine writes if they are consecutive
# careful to not modify the original writers list or ChannelWrite
writers[-2] = ChannelWrite(
writes=writers[-2].writes + writers[-1].writes,
)
writers.pop()
return writers
@cached_property
def node(self) -> Runnable[Any, Any] | None:
"""Get a runnable that combines `bound` and `writers`."""
writers = self.flat_writers
if self.bound is DEFAULT_BOUND and not writers:
return None
elif self.bound is DEFAULT_BOUND and len(writers) == 1:
return writers[0]
elif self.bound is DEFAULT_BOUND:
return RunnableSeq(*writers)
elif writers:
return RunnableSeq(self.bound, *writers)
else:
return self.bound
@cached_property
def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
"""Get a cache key for the input to the node.
This is used to avoid calculating the same input multiple times."""
return (
self.mapper,
tuple(self.channels)
if isinstance(self.channels, list)
else (self.channels,),
)
def invoke(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Any:
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
return self.bound.invoke(
input,
merge_configs(self_config, config),
**kwargs,
)
async def ainvoke(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Any:
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
return await self.bound.ainvoke(
input,
merge_configs(self_config, config),
**kwargs,
)
def stream(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
yield from self.bound.stream(
input,
merge_configs(self_config, config),
**kwargs,
)
async def astream(
self,
input: Any,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
async for item in self.bound.astream(
input,
merge_configs(self_config, config),
**kwargs,
):
yield item

View File

@@ -0,0 +1,238 @@
from __future__ import annotations
import asyncio
import logging
import random
import sys
import time
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import replace
from typing import Any
from langgraph._internal._config import patch_configurable, recast_checkpoint_ns
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUMING,
NS_SEP,
)
from langgraph.errors import GraphBubbleUp, ParentCommand
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
logger = logging.getLogger(__name__)
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
def _checkpoint_ns_for_parent_command(ns: str) -> str:
"""Return the checkpoint namespace for the parent graph.
The checkpoint namespace is a `|`-separated path. Each segment is usually
of the form `name:task_id` (e.g. `parent_first:<uuid>|node:<uuid>`), but the
runtime may also insert a purely-numeric segment (e.g. `|1`) to disambiguate
concurrent tasks (e.g. `parent_first:<uuid>|1|node:<uuid>`).
Numeric segments are not real path levels, so we drop them before computing
the parent namespace.
"""
parts = ns.split(NS_SEP)
# Drop any trailing numeric selectors for the current frame (e.g. `...|node:<id>|1`).
while parts and parts[-1].isdigit():
parts.pop()
# Drop the current frame segment itself (e.g. the `node:<id>`).
if parts:
parts.pop()
# Drop any trailing numeric selectors for the parent frame (e.g. `...|1|node:<id>`).
while parts and parts[-1].isdigit():
parts.pop()
return NS_SEP.join(parts)
def run_with_retry(
task: PregelExecutableTask,
retry_policy: Sequence[RetryPolicy] | None,
configurable: dict[str, Any] | None = None,
) -> None:
"""Run a task with retries."""
retry_policy = task.retry_policy or retry_policy
attempts = 0
config = task.config
if configurable is not None:
config = patch_configurable(config, configurable)
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
return task.proc.invoke(task.input, config)
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent.
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if not retry_policy:
raise
# Check which retry policy applies to this exception
matching_policy = None
for policy in retry_policy:
if _should_retry_on(policy, exc):
matching_policy = policy
break
if not matching_policy:
raise
# increment attempts
attempts += 1
# check if we should give up
if attempts >= matching_policy.max_attempts:
raise
# sleep before retrying
interval = matching_policy.initial_interval
# Apply backoff factor based on attempt count
interval = min(
matching_policy.max_interval,
interval * (matching_policy.backoff_factor ** (attempts - 1)),
)
# Apply jitter if configured
sleep_time = (
interval + random.uniform(0, 1) if matching_policy.jitter else interval
)
time.sleep(sleep_time)
# log the retry
logger.info(
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
async def arun_with_retry(
task: PregelExecutableTask,
retry_policy: Sequence[RetryPolicy] | None,
stream: bool = False,
match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
| None = None,
configurable: dict[str, Any] | None = None,
) -> None:
"""Run a task asynchronously with retries."""
retry_policy = task.retry_policy or retry_policy
attempts = 0
config = task.config
if configurable is not None:
config = patch_configurable(config, configurable)
if match_cached_writes is not None and task.cache_key is not None:
for t in await match_cached_writes():
if t is task:
# if the task is already cached, return
return
while True:
try:
# clear any writes from previous attempts
task.writes.clear()
# run the task
if stream:
async for _ in task.proc.astream(task.input, config):
pass
# if successful, end
break
else:
return await task.proc.ainvoke(task.input, config)
except ParentCommand as exc:
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
cmd = exc.args[0]
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
# this command is for the current graph, handle it
for w in task.writers:
w.invoke(cmd, config)
break
elif cmd.graph == Command.PARENT:
# this command is for the parent graph, assign it to the parent.
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
# bubble up
raise
except GraphBubbleUp:
# if interrupted, end
raise
except Exception as exc:
if SUPPORTS_EXC_NOTES:
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
if not retry_policy:
raise
# Check which retry policy applies to this exception
matching_policy = None
for policy in retry_policy:
if _should_retry_on(policy, exc):
matching_policy = policy
break
if not matching_policy:
raise
# increment attempts
attempts += 1
# check if we should give up
if attempts >= matching_policy.max_attempts:
raise
# sleep before retrying
interval = matching_policy.initial_interval
# Apply backoff factor based on attempt count
interval = min(
matching_policy.max_interval,
interval * (matching_policy.backoff_factor ** (attempts - 1)),
)
# Apply jitter if configured
sleep_time = (
interval + random.uniform(0, 1) if matching_policy.jitter else interval
)
await asyncio.sleep(sleep_time)
# log the retry
logger.info(
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
exc_info=exc,
)
# signal subgraphs to resume (if available)
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
"""Check if the given exception should be retried based on the retry policy."""
if isinstance(retry_policy.retry_on, Sequence):
return isinstance(exc, tuple(retry_policy.retry_on))
elif isinstance(retry_policy.retry_on, type) and issubclass(
retry_policy.retry_on, Exception
):
return isinstance(exc, retry_policy.retry_on)
elif callable(retry_policy.retry_on):
return retry_policy.retry_on(exc) # type: ignore[call-arg]
else:
raise TypeError(
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
)

View File

@@ -0,0 +1,768 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import inspect
import threading
import time
import weakref
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Iterable,
Iterator,
Sequence,
)
from functools import partial
from typing import (
Any,
Generic,
TypeVar,
cast,
)
from langchain_core.callbacks import Callbacks
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CALL,
CONFIG_KEY_SCRATCHPAD,
ERROR,
INTERRUPT,
NO_WRITES,
RESUME,
RETURN,
)
from langgraph._internal._future import chain_future, run_coroutine_threadsafe
from langgraph._internal._scratchpad import PregelScratchpad
from langgraph._internal._typing import MISSING
from langgraph.constants import TAG_HIDDEN
from langgraph.errors import GraphBubbleUp, GraphInterrupt
from langgraph.pregel._algo import Call
from langgraph.pregel._executor import Submit
from langgraph.pregel._retry import arun_with_retry, run_with_retry
from langgraph.types import (
CachePolicy,
PregelExecutableTask,
RetryPolicy,
)
F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
E = TypeVar("E", threading.Event, asyncio.Event)
# List of filenames to exclude from exception traceback
# Note: Frames will be removed if they are the last frame in traceback, recursively
EXCLUDED_FRAME_FNAMES = (
"langgraph/pregel/retry.py",
"langgraph/pregel/runner.py",
"langgraph/pregel/executor.py",
"langgraph/utils/runnable.py",
"langchain_core/runnables/config.py",
"concurrent/futures/thread.py",
"concurrent/futures/_base.py",
)
SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
weakref.WeakSet()
)
class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
event: E
callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
counter: int
done: set[F]
lock: threading.Lock
def __init__(
self,
event: E,
callback: weakref.ref[
Callable[[PregelExecutableTask, BaseException | None], None]
],
future_type: type[F],
# used for generic typing, newer py supports FutureDict[...](...)
) -> None:
super().__init__()
self.lock = threading.Lock()
self.event = event
self.callback = callback
self.counter = 0
self.done: set[F] = set()
def __setitem__(
self,
key: F,
value: PregelExecutableTask | None,
) -> None:
super().__setitem__(key, value) # type: ignore[index]
if value is not None:
with self.lock:
self.event.clear()
self.counter += 1
key.add_done_callback(partial(self.on_done, value))
def on_done(
self,
task: PregelExecutableTask,
fut: F,
) -> None:
try:
if cb := self.callback():
cb(task, _exception(fut))
finally:
with self.lock:
self.done.add(fut)
self.counter -= 1
if self.counter == 0 or _should_stop_others(self.done):
self.event.set()
class PregelRunner:
"""Responsible for executing a set of Pregel tasks concurrently, committing
their writes, yielding control to caller when there is output to emit, and
interrupting other tasks if appropriate."""
def __init__(
self,
*,
submit: weakref.ref[Submit],
put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
use_astream: bool = False,
node_finished: Callable[[str], None] | None = None,
) -> None:
self.submit = submit
self.put_writes = put_writes
self.use_astream = use_astream
self.node_finished = node_finished
def tick(
self,
tasks: Iterable[PregelExecutableTask],
*,
reraise: bool = True,
timeout: float | None = None,
retry_policy: Sequence[RetryPolicy] | None = None,
get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
schedule_task: Callable[
[PregelExecutableTask, int, Call | None],
PregelExecutableTask | None,
],
) -> Iterator[None]:
tasks = tuple(tasks)
futures = FuturesDict(
callback=weakref.WeakMethod(self.commit),
event=threading.Event(),
future_type=concurrent.futures.Future,
)
# give control back to the caller
yield
# fast path if single task with no timeout and no waiter
if len(tasks) == 0:
return
elif len(tasks) == 1 and timeout is None and get_waiter is None:
t = tasks[0]
try:
run_with_retry(
t,
retry_policy,
configurable={
CONFIG_KEY_CALL: partial(
_call,
weakref.ref(t),
retry_policy=retry_policy,
futures=weakref.ref(futures),
schedule_task=schedule_task,
submit=self.submit,
),
},
)
self.commit(t, None)
except Exception as exc:
self.commit(t, exc)
if reraise and futures:
# will be re-raised after futures are done
fut: concurrent.futures.Future = concurrent.futures.Future()
fut.set_exception(exc)
futures.done.add(fut)
elif reraise:
if tb := exc.__traceback__:
while tb.tb_next is not None and any(
tb.tb_frame.f_code.co_filename.endswith(name)
for name in EXCLUDED_FRAME_FNAMES
):
tb = tb.tb_next
exc.__traceback__ = tb
raise
if not futures: # maybe `t` scheduled another task
return
else:
tasks = () # don't reschedule this task
# add waiter task if requested
if get_waiter is not None:
futures[get_waiter()] = None
# schedule tasks
for t in tasks:
fut = self.submit()( # type: ignore[misc]
run_with_retry,
t,
retry_policy,
configurable={
CONFIG_KEY_CALL: partial(
_call,
weakref.ref(t),
retry_policy=retry_policy,
futures=weakref.ref(futures),
schedule_task=schedule_task,
submit=self.submit,
),
},
__reraise_on_exit__=reraise,
)
futures[fut] = t
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
end_time = timeout + time.monotonic() if timeout else None
while len(futures) > (1 if get_waiter is not None else 0):
done, inflight = concurrent.futures.wait(
futures,
return_when=concurrent.futures.FIRST_COMPLETED,
timeout=(max(0, end_time - time.monotonic()) if end_time else None),
)
if not done:
break # timed out
for fut in done:
task = futures.pop(fut)
if task is None:
# waiter task finished, schedule another
if inflight and get_waiter is not None:
futures[get_waiter()] = None
else:
# remove references to loop vars
del fut, task
# maybe stop other tasks
if _should_stop_others(done):
break
# give control back to the caller
yield
# wait for done callbacks
futures.event.wait(
timeout=(max(0, end_time - time.monotonic()) if end_time else None)
)
# give control back to the caller
yield
# panic on failure or timeout
try:
_panic_or_proceed(
futures.done.union(f for f, t in futures.items() if t is not None),
panic=reraise,
)
except Exception as exc:
if tb := exc.__traceback__:
while tb.tb_next is not None and any(
tb.tb_frame.f_code.co_filename.endswith(name)
for name in EXCLUDED_FRAME_FNAMES
):
tb = tb.tb_next
exc.__traceback__ = tb
raise
async def atick(
self,
tasks: Iterable[PregelExecutableTask],
*,
reraise: bool = True,
timeout: float | None = None,
retry_policy: Sequence[RetryPolicy] | None = None,
get_waiter: Callable[[], asyncio.Future[None]] | None = None,
schedule_task: Callable[
[PregelExecutableTask, int, Call | None],
Awaitable[PregelExecutableTask | None],
],
) -> AsyncIterator[None]:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
tasks = tuple(tasks)
futures = FuturesDict(
callback=weakref.WeakMethod(self.commit),
event=asyncio.Event(),
future_type=asyncio.Future,
)
# give control back to the caller
yield
# fast path if single task with no waiter and no timeout
if len(tasks) == 0:
return
elif len(tasks) == 1 and get_waiter is None and timeout is None:
t = tasks[0]
try:
await arun_with_retry(
t,
retry_policy,
stream=self.use_astream,
configurable={
CONFIG_KEY_CALL: partial(
_acall,
weakref.ref(t),
stream=self.use_astream,
retry_policy=retry_policy,
futures=weakref.ref(futures),
schedule_task=schedule_task,
submit=self.submit,
loop=loop,
),
},
)
self.commit(t, None)
except Exception as exc:
self.commit(t, exc)
if reraise and futures:
# will be re-raised after futures are done
fut: asyncio.Future = loop.create_future()
fut.set_exception(exc)
futures.done.add(fut)
elif reraise:
if tb := exc.__traceback__:
while tb.tb_next is not None and any(
tb.tb_frame.f_code.co_filename.endswith(name)
for name in EXCLUDED_FRAME_FNAMES
):
tb = tb.tb_next
exc.__traceback__ = tb
raise
if not futures: # maybe `t` scheduled another task
return
else:
tasks = () # don't reschedule this task
# add waiter task if requested
if get_waiter is not None:
futures[get_waiter()] = None
# schedule tasks
for t in tasks:
fut = cast(
asyncio.Future,
self.submit()( # type: ignore[misc]
arun_with_retry,
t,
retry_policy,
stream=self.use_astream,
configurable={
CONFIG_KEY_CALL: partial(
_acall,
weakref.ref(t),
retry_policy=retry_policy,
stream=self.use_astream,
futures=weakref.ref(futures),
schedule_task=schedule_task,
submit=self.submit,
loop=loop,
),
},
__name__=t.name,
__cancel_on_exit__=True,
__reraise_on_exit__=reraise,
),
)
futures[fut] = t
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
end_time = timeout + loop.time() if timeout else None
while len(futures) > (1 if get_waiter is not None else 0):
done, inflight = await asyncio.wait(
futures,
return_when=asyncio.FIRST_COMPLETED,
timeout=(max(0, end_time - loop.time()) if end_time else None),
)
if not done:
break # timed out
for fut in done:
task = futures.pop(fut)
if task is None:
# waiter task finished, schedule another
if inflight and get_waiter is not None:
futures[get_waiter()] = None
else:
# remove references to loop vars
del fut, task
# maybe stop other tasks
if _should_stop_others(done):
break
# give control back to the caller
yield
# wait for done callbacks
await asyncio.wait_for(
futures.event.wait(),
timeout=(max(0, end_time - loop.time()) if end_time else None),
)
# give control back to the caller
yield
# cancel waiter task
for fut in futures:
fut.cancel()
# panic on failure or timeout
try:
_panic_or_proceed(
futures.done.union(f for f, t in futures.items() if t is not None),
timeout_exc_cls=asyncio.TimeoutError,
panic=reraise,
)
except Exception as exc:
if tb := exc.__traceback__:
while tb.tb_next is not None and any(
tb.tb_frame.f_code.co_filename.endswith(name)
for name in EXCLUDED_FRAME_FNAMES
):
tb = tb.tb_next
exc.__traceback__ = tb
raise
def commit(
self,
task: PregelExecutableTask,
exception: BaseException | None,
) -> None:
if isinstance(exception, asyncio.CancelledError):
# for cancelled tasks, also save error in task,
# so loop can finish super-step
task.writes.append((ERROR, exception))
self.put_writes()(task.id, task.writes) # type: ignore[misc]
elif exception:
if isinstance(exception, GraphInterrupt):
# save interrupt to checkpointer
if exception.args[0]:
writes = [(INTERRUPT, exception.args[0])]
if resumes := [w for w in task.writes if w[0] == RESUME]:
writes.extend(resumes)
self.put_writes()(task.id, writes) # type: ignore[misc]
elif isinstance(exception, GraphBubbleUp):
# exception will be raised in _panic_or_proceed
pass
else:
# save error to checkpointer
task.writes.append((ERROR, exception))
self.put_writes()(task.id, task.writes) # type: ignore[misc]
else:
if self.node_finished and (
task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
):
self.node_finished(task.name)
if not task.writes:
# add no writes marker
task.writes.append((NO_WRITES, None))
# save task writes to checkpointer
self.put_writes()(task.id, task.writes) # type: ignore[misc]
def _should_stop_others(
done: set[F],
) -> bool:
"""Check if any task failed, if so, cancel all other tasks.
GraphInterrupts are not considered failures."""
for fut in done:
if fut.cancelled():
continue
elif exc := fut.exception():
if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
return True
return False
def _exception(
fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
) -> BaseException | None:
"""Return the exception from a future, without raising CancelledError."""
if fut.cancelled():
if isinstance(fut, asyncio.Future):
return asyncio.CancelledError()
else:
return concurrent.futures.CancelledError()
else:
return fut.exception()
def _panic_or_proceed(
futs: set[concurrent.futures.Future] | set[asyncio.Future],
*,
timeout_exc_cls: type[Exception] = TimeoutError,
panic: bool = True,
) -> None:
"""Cancel remaining tasks if any failed, re-raise exception if panic is True."""
done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
for fut in futs:
if fut.cancelled():
continue
elif fut.done():
done.add(fut)
else:
inflight.add(fut)
interrupts: list[GraphInterrupt] = []
while done:
# if any task failed
fut = done.pop()
if exc := _exception(fut):
# cancel all pending tasks
while inflight:
inflight.pop().cancel()
# raise the exception
if panic:
if isinstance(exc, GraphInterrupt):
# collect interrupts
interrupts.append(exc)
elif fut not in SKIP_RERAISE_SET:
raise exc
# raise combined interrupts
if interrupts:
raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
if inflight:
# if we got here means we timed out
while inflight:
# cancel all pending tasks
inflight.pop().cancel()
# raise timeout error
raise timeout_exc_cls("Timed out")
def _call(
task: weakref.ref[PregelExecutableTask],
func: Callable[[Any], Awaitable[Any] | Any],
input: Any,
*,
retry_policy: Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
callbacks: Callbacks = None,
futures: weakref.ref[FuturesDict],
schedule_task: Callable[
[PregelExecutableTask, int, Call | None], PregelExecutableTask | None
],
submit: weakref.ref[Submit],
) -> concurrent.futures.Future[Any]:
if inspect.iscoroutinefunction(func):
raise RuntimeError("In an sync context async tasks cannot be called")
fut: concurrent.futures.Future | None = None
# schedule PUSH tasks, collect futures
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
# schedule the next task, if the callback returns one
if next_task := schedule_task(
task(), # type: ignore[arg-type]
scratchpad.call_counter(),
Call(
func,
input,
retry_policy=retry_policy,
cache_policy=cache_policy,
callbacks=callbacks,
),
):
if fut := next(
(
f
for f, t in list(futures().items()) # type: ignore[union-attr]
if t is not None and t == next_task.id
),
None,
):
# if the parent task was retried,
# the next task might already be running
pass
elif next_task.writes:
# if it already ran, return the result
fut = concurrent.futures.Future()
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
if ret is not MISSING:
fut.set_result(ret)
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
fut.set_exception(
exc if isinstance(exc, BaseException) else Exception(exc)
)
else:
fut.set_result(None)
else:
# schedule the next task
fut = submit()( # type: ignore[misc]
run_with_retry,
next_task,
retry_policy,
configurable={
CONFIG_KEY_CALL: partial(
_call,
weakref.ref(next_task),
futures=futures,
retry_policy=retry_policy,
callbacks=callbacks,
schedule_task=schedule_task,
submit=submit,
),
},
__reraise_on_exit__=False,
# starting a new task in the next tick ensures
# updates from this tick are committed/streamed first
__next_tick__=True,
)
# exceptions for call() tasks are raised into the parent task
# so we should not re-raise at the end of the tick
SKIP_RERAISE_SET.add(fut)
futures()[fut] = next_task # type: ignore[index]
fut = cast(asyncio.Future | concurrent.futures.Future, fut)
# return a chained future to ensure commit() callback is called
# before the returned future is resolved, to ensure stream order etc
return chain_future(fut, concurrent.futures.Future())
def _acall(
task: weakref.ref[PregelExecutableTask],
func: Callable[[Any], Awaitable[Any] | Any],
input: Any,
*,
retry_policy: Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
callbacks: Callbacks = None,
# injected dependencies
futures: weakref.ref[FuturesDict],
schedule_task: Callable[
[PregelExecutableTask, int, Call | None],
Awaitable[PregelExecutableTask | None],
],
submit: weakref.ref[Submit],
loop: asyncio.AbstractEventLoop,
stream: bool = False,
) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
# return a chained future to ensure commit() callback is called
# before the returned future is resolved, to ensure stream order etc
try:
in_async = asyncio.current_task() is not None
except RuntimeError:
in_async = False
# if in async context return an async future, otherwise return a sync future
if in_async:
fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
loop=loop
)
else:
fut = concurrent.futures.Future()
# schedule the next task
run_coroutine_threadsafe(
_acall_impl(
fut,
task,
func,
input,
retry_policy=retry_policy,
cache_policy=cache_policy,
callbacks=callbacks,
futures=futures,
schedule_task=schedule_task,
submit=submit,
loop=loop,
stream=stream,
),
loop,
lazy=False,
)
return fut
async def _acall_impl(
destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
task: weakref.ref[PregelExecutableTask],
func: Callable[[Any], Awaitable[Any] | Any],
input: Any,
*,
retry_policy: Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
callbacks: Callbacks = None,
# injected dependencies
futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
schedule_task: Callable[
[PregelExecutableTask, int, Call | None],
Awaitable[PregelExecutableTask | None],
],
submit: weakref.ref[Submit],
loop: asyncio.AbstractEventLoop,
stream: bool = False,
) -> None:
try:
fut: asyncio.Future | None = None
# schedule PUSH tasks, collect futures
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
# schedule the next task, if the callback returns one
if next_task := await schedule_task(
task(), # type: ignore[arg-type]
scratchpad.call_counter(),
Call(
func,
input,
retry_policy=retry_policy,
cache_policy=cache_policy,
callbacks=callbacks,
),
):
if fut := next(
(
f
for f, t in list(futures().items()) # type: ignore[union-attr]
if t is not None and t == next_task.id
),
None,
):
# if the parent task was retried,
# the next task might already be running
pass
elif next_task.writes:
# if it already ran, return the result
fut = asyncio.Future(loop=loop)
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
if ret is not MISSING:
fut.set_result(ret)
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
fut.set_exception(
exc if isinstance(exc, BaseException) else Exception(exc)
)
else:
fut.set_result(None)
else:
# schedule the next task
fut = cast(
asyncio.Future,
submit()( # type: ignore[misc]
arun_with_retry,
next_task,
retry_policy,
stream=stream,
configurable={
CONFIG_KEY_CALL: partial(
_acall,
weakref.ref(next_task),
stream=stream,
futures=futures,
schedule_task=schedule_task,
submit=submit,
loop=loop,
),
},
__name__=next_task.name,
__cancel_on_exit__=True,
__reraise_on_exit__=False,
# starting a new task in the next tick ensures
# updates from this tick are committed/streamed first
__next_tick__=True,
),
)
# exceptions for call() tasks are raised into the parent task
# so we should not re-raise at the end of the tick
SKIP_RERAISE_SET.add(fut)
futures()[fut] = next_task # type: ignore[index]
if fut is not None:
chain_future(fut, destination)
else:
destination.set_exception(RuntimeError("Task not scheduled"))
except Exception as exc:
destination.set_exception(exc)

View File

@@ -0,0 +1,218 @@
from __future__ import annotations
import ast
import inspect
import re
import textwrap
from collections.abc import Callable
from typing import Any
from langchain_core.runnables import Runnable, RunnableLambda, RunnableSequence
from langgraph.checkpoint.base import ChannelVersions
from typing_extensions import override
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
from langgraph.pregel.protocol import PregelProtocol
def get_new_channel_versions(
previous_versions: ChannelVersions, current_versions: ChannelVersions
) -> ChannelVersions:
"""Get subset of current_versions that are newer than previous_versions."""
if previous_versions:
version_type = type(next(iter(current_versions.values()), None))
null_version = version_type() # type: ignore[misc]
new_versions = {
k: v
for k, v in current_versions.items()
if v > previous_versions.get(k, null_version) # type: ignore[operator]
}
else:
new_versions = current_versions
return new_versions
def find_subgraph_pregel(candidate: Runnable) -> PregelProtocol | None:
from langgraph.pregel import Pregel
candidates: list[Runnable] = [candidate]
for c in candidates:
if (
isinstance(c, PregelProtocol)
# subgraphs that disabled checkpointing are not considered
and (not isinstance(c, Pregel) or c.checkpointer is not False)
):
return c
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
candidates.extend(c.steps)
elif isinstance(c, RunnableLambda):
candidates.extend(c.deps)
elif isinstance(c, RunnableCallable):
if c.func is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.func)
)
elif c.afunc is not None:
candidates.extend(
nl.__self__ if hasattr(nl, "__self__") else nl
for nl in get_function_nonlocals(c.afunc)
)
return None
def get_function_nonlocals(func: Callable) -> list[Any]:
"""Get the nonlocal variables accessed by a function.
Args:
func: The function to check.
Returns:
List[Any]: The nonlocal variables accessed by the function.
"""
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = FunctionNonLocals()
visitor.visit(tree)
values: list[Any] = []
closure = (
inspect.getclosurevars(func.__wrapped__)
if hasattr(func, "__wrapped__") and callable(func.__wrapped__)
else inspect.getclosurevars(func)
)
candidates = {**closure.globals, **closure.nonlocals}
for k, v in candidates.items():
if k in visitor.nonlocals:
values.append(v)
for kk in visitor.nonlocals:
if "." in kk and kk.startswith(k):
vv = v
for part in kk.split(".")[1:]:
if vv is None:
break
else:
try:
vv = getattr(vv, part)
except AttributeError:
break
else:
values.append(vv)
except (SyntaxError, TypeError, OSError, SystemError):
return []
return values
class FunctionNonLocals(ast.NodeVisitor):
"""Get the nonlocal variables accessed of a function."""
def __init__(self) -> None:
self.nonlocals: set[str] = set()
@override
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
"""Visit an async function definition.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
@override
def visit_Lambda(self, node: ast.Lambda) -> Any:
"""Visit a lambda function.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
visitor = NonLocals()
visitor.visit(node)
self.nonlocals.update(visitor.loads - visitor.stores)
class NonLocals(ast.NodeVisitor):
"""Get nonlocal variables accessed."""
def __init__(self) -> None:
self.loads: set[str] = set()
self.stores: set[str] = set()
@override
def visit_Name(self, node: ast.Name) -> Any:
"""Visit a name node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if isinstance(node.ctx, ast.Load):
self.loads.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.stores.add(node.id)
@override
def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Visit an attribute node.
Args:
node: The node to visit.
Returns:
Any: The result of the visit.
"""
if isinstance(node.ctx, ast.Load):
parent = node.value
attr_expr = node.attr
while isinstance(parent, ast.Attribute):
attr_expr = parent.attr + "." + attr_expr
parent = parent.value
if isinstance(parent, ast.Name):
self.loads.add(parent.id + "." + attr_expr)
self.loads.discard(parent.id)
elif isinstance(parent, ast.Call):
if isinstance(parent.func, ast.Name):
self.loads.add(parent.func.id)
else:
parent = parent.func
attr_expr = ""
while isinstance(parent, ast.Attribute):
if attr_expr:
attr_expr = parent.attr + "." + attr_expr
else:
attr_expr = parent.attr
parent = parent.value
if isinstance(parent, ast.Name):
self.loads.add(parent.id + "." + attr_expr)
def is_xxh3_128_hexdigest(value: str) -> bool:
"""Check if the given string matches the format of xxh3_128_hexdigest."""
return bool(re.fullmatch(r"[0-9a-f]{32}", value))

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from langgraph._internal._constants import RESERVED
from langgraph.channels.base import BaseChannel
from langgraph.managed.base import ManagedValueMapping
from langgraph.pregel._read import PregelNode
from langgraph.types import All
def validate_graph(
nodes: Mapping[str, PregelNode],
channels: dict[str, BaseChannel],
managed: ManagedValueMapping,
input_channels: str | Sequence[str],
output_channels: str | Sequence[str],
stream_channels: str | Sequence[str] | None,
interrupt_after_nodes: All | Sequence[str],
interrupt_before_nodes: All | Sequence[str],
) -> None:
for chan in channels:
if chan in RESERVED:
raise ValueError(f"Channel name '{chan}' is reserved")
for name in managed:
if name in RESERVED:
raise ValueError(f"Managed name '{name}' is reserved")
subscribed_channels = set[str]()
for name, node in nodes.items():
if name in RESERVED:
raise ValueError(f"Node name '{name}' is reserved")
if isinstance(node, PregelNode):
subscribed_channels.update(node.triggers)
if isinstance(node.channels, str):
if node.channels not in channels:
raise ValueError(
f"Node {name} reads channel '{node.channels}' "
f"not in known channels: '{repr(sorted(channels))[:100]}'"
)
else:
for chan in node.channels:
if chan not in channels and chan not in managed:
raise ValueError(
f"Node {name} reads channel '{chan}' "
f"not in known channels: '{repr(sorted(channels))[:100]}'"
)
else:
raise TypeError(
f"Invalid node type {type(node)}, expected PregelNode or NodeBuilder"
)
for chan in subscribed_channels:
if chan not in channels:
raise ValueError(
f"Subscribed channel '{chan}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if isinstance(input_channels, str):
if input_channels not in channels:
raise ValueError(
f"Input channel '{input_channels}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if input_channels not in subscribed_channels:
raise ValueError(
f"Input channel {input_channels} is not subscribed to by any node"
)
else:
for chan in input_channels:
if chan not in channels:
raise ValueError(
f"Input channel '{chan}' not in '{repr(sorted(channels))[:100]}'"
)
if all(chan not in subscribed_channels for chan in input_channels):
raise ValueError(
f"None of the input channels {input_channels} "
f"are subscribed to by any node"
)
all_output_channels = set[str]()
if isinstance(output_channels, str):
all_output_channels.add(output_channels)
else:
all_output_channels.update(output_channels)
if isinstance(stream_channels, str):
all_output_channels.add(stream_channels)
elif stream_channels is not None:
all_output_channels.update(stream_channels)
for chan in all_output_channels:
if chan not in channels:
raise ValueError(
f"Output channel '{chan}' not "
f"in known channels: '{repr(sorted(channels))[:100]}'"
)
if interrupt_after_nodes != "*":
for n in interrupt_after_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
if interrupt_before_nodes != "*":
for n in interrupt_before_nodes:
if n not in nodes:
raise ValueError(f"Node {n} not in nodes")
def validate_keys(
keys: str | Sequence[str] | None,
channels: Mapping[str, Any],
) -> None:
if isinstance(keys, str):
if keys not in channels:
raise ValueError(f"Key {keys} not in channels")
elif keys is not None:
for chan in keys:
if chan not in channels:
raise ValueError(f"Key {chan} not in channels")

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import (
Any,
NamedTuple,
TypeVar,
cast,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, TASKS
from langgraph._internal._runnable import RunnableCallable
from langgraph._internal._typing import MISSING
from langgraph.errors import InvalidUpdateError
from langgraph.types import Send
TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None]
R = TypeVar("R", bound=Runnable)
SKIP_WRITE = object()
PASSTHROUGH = object()
class ChannelWriteEntry(NamedTuple):
channel: str
"""Channel name to write to."""
value: Any = PASSTHROUGH
"""Value to write, or PASSTHROUGH to use the input."""
skip_none: bool = False
"""Whether to skip writing if the value is None."""
mapper: Callable | None = None
"""Function to transform the value before writing."""
class ChannelWriteTupleEntry(NamedTuple):
mapper: Callable[[Any], Sequence[tuple[str, Any]] | None]
"""Function to extract tuples from value."""
value: Any = PASSTHROUGH
"""Value to write, or PASSTHROUGH to use the input."""
static: Sequence[tuple[str, Any, str | None]] | None = None
"""Optional, declared writes for static analysis."""
class ChannelWrite(RunnableCallable):
"""Implements the logic for sending writes to CONFIG_KEY_SEND.
Can be used as a runnable or as a static method to call imperatively."""
writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]
"""Sequence of write entries or Send objects to write."""
def __init__(
self,
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
*,
tags: Sequence[str] | None = None,
):
super().__init__(
func=self._write,
afunc=self._awrite,
name=None,
tags=tags,
trace=False,
)
self.writes = cast(
list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
)
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
if not name:
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
return super().get_name(suffix, name=name)
def _write(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
)
return input
async def _awrite(self, input: Any, config: RunnableConfig) -> None:
writes = [
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
else ChannelWriteTupleEntry(write.mapper, input)
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
else write
for write in self.writes
]
self.do_write(
config,
writes,
)
return input
@staticmethod
def do_write(
config: RunnableConfig,
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
allow_passthrough: bool = True,
) -> None:
# validate
for w in writes:
if isinstance(w, ChannelWriteEntry):
if w.channel == TASKS:
raise InvalidUpdateError(
"Cannot write to the reserved channel TASKS"
)
if w.value is PASSTHROUGH and not allow_passthrough:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
if isinstance(w, ChannelWriteTupleEntry):
if w.value is PASSTHROUGH and not allow_passthrough:
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
# if we want to persist writes found before hitting a ParentCommand
# can move this to a finally block
write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
write(_assemble_writes(writes))
@staticmethod
def is_writer(runnable: Runnable) -> bool:
"""Used by PregelNode to distinguish between writers and other runnables."""
return (
isinstance(runnable, ChannelWrite)
or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
)
@staticmethod
def get_static_writes(
runnable: Runnable,
) -> Sequence[tuple[str, Any, str | None]] | None:
"""Used to get conditional writes a writer declares for static analysis."""
if isinstance(runnable, ChannelWrite):
return [
w
for entry in runnable.writes
if isinstance(entry, ChannelWriteTupleEntry) and entry.static
for w in entry.static
] or None
elif writes := getattr(runnable, "_is_channel_writer", MISSING):
if writes is not MISSING:
writes = cast(
Sequence[tuple[ChannelWriteEntry | Send, str | None]],
writes,
)
entries = [e for e, _ in writes]
labels = [la for _, la in writes]
return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]
@staticmethod
def register_writer(
runnable: R,
static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
) -> R:
"""Used to mark a runnable as a writer, so that it can be detected by is_writer.
Instances of ChannelWrite are automatically marked as writers.
Optionally, a list of declared writes can be passed for static analysis."""
# using object.__setattr__ to work around objects that override __setattr__
# eg. pydantic models and dataclasses
object.__setattr__(runnable, "_is_channel_writer", static)
return runnable
def _assemble_writes(
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
) -> list[tuple[str, Any]]:
"""Assembles the writes into a list of tuples."""
tuples: list[tuple[str, Any]] = []
for w in writes:
if isinstance(w, Send):
tuples.append((TASKS, w))
elif isinstance(w, ChannelWriteTupleEntry):
if ww := w.mapper(w.value):
tuples.extend(ww)
elif isinstance(w, ChannelWriteEntry):
value = w.mapper(w.value) if w.mapper is not None else w.value
if value is SKIP_WRITE:
continue
if w.skip_none and value is None:
continue
tuples.append((w.channel, value))
else:
raise ValueError(f"Invalid write entry: {w}")
return tuples

View File

@@ -0,0 +1,308 @@
from __future__ import annotations
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import asdict
from typing import Any
from uuid import UUID
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import CheckpointMetadata, PendingWrite
from typing_extensions import TypedDict
from langgraph._internal._config import patch_checkpoint_map
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINT_NS,
ERROR,
INTERRUPT,
NS_END,
NS_SEP,
RETURN,
)
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel
from langgraph.constants import TAG_HIDDEN
from langgraph.pregel._io import read_channels
from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot
__all__ = ("TaskPayload", "TaskResultPayload", "CheckpointTask", "CheckpointPayload")
class TaskPayload(TypedDict):
id: str
name: str
input: Any
triggers: list[str]
class TaskResultPayload(TypedDict):
id: str
name: str
error: str | None
interrupts: list[dict]
result: dict[str, Any]
class CheckpointTask(TypedDict):
id: str
name: str
error: str | None
interrupts: list[dict]
state: StateSnapshot | RunnableConfig | None
class CheckpointPayload(TypedDict):
config: RunnableConfig | None
metadata: CheckpointMetadata
values: dict[str, Any]
next: list[str]
parent_config: RunnableConfig | None
tasks: list[CheckpointTask]
TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8")
def map_debug_tasks(tasks: Iterable[PregelExecutableTask]) -> Iterator[TaskPayload]:
"""Produce "task" events for stream_mode=debug."""
for task in tasks:
if task.config is not None and TAG_HIDDEN in task.config.get("tags", []):
continue
yield {
"id": task.id,
"name": task.name,
"input": task.input,
"triggers": task.triggers,
}
def is_multiple_channel_write(value: Any) -> bool:
"""Return True if the payload already wraps multiple writes from the same channel."""
return (
isinstance(value, dict)
and "$writes" in value
and isinstance(value["$writes"], list)
)
def map_task_result_writes(writes: Sequence[tuple[str, Any]]) -> dict[str, Any]:
"""Folds task writes into a result dict and aggregates multiple writes to the same channel.
If the channel contains a single write, we record the write in the result dict as `{channel: write}`
If the channel contains multiple writes, we record the writes in the result dict as `{channel: {'$writes': [write1, write2, ...]}}`"""
result: dict[str, Any] = {}
for channel, value in writes:
existing = result.get(channel)
if existing is not None:
channel_writes = (
existing["$writes"]
if is_multiple_channel_write(existing)
else [existing]
)
channel_writes.append(value)
result[channel] = {"$writes": channel_writes}
else:
result[channel] = value
return result
def map_debug_task_results(
task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]],
stream_keys: str | Sequence[str],
) -> Iterator[TaskResultPayload]:
"""Produce "task_result" events for stream_mode=debug."""
stream_channels_list = (
[stream_keys] if isinstance(stream_keys, str) else stream_keys
)
task, writes = task_tup
yield {
"id": task.id,
"name": task.name,
"error": next((w[1] for w in writes if w[0] == ERROR), None),
"result": map_task_result_writes(
[w for w in writes if w[0] in stream_channels_list or w[0] == RETURN]
),
"interrupts": [
asdict(v)
for w in writes
if w[0] == INTERRUPT
for v in (w[1] if isinstance(w[1], Sequence) else [w[1]])
],
}
def rm_pregel_keys(config: RunnableConfig | None) -> RunnableConfig | None:
"""Remove pregel-specific keys from the config."""
if config is None:
return config
return {
"configurable": {
k: v
for k, v in config.get("configurable", {}).items()
if not k.startswith("__pregel_")
}
}
def map_debug_checkpoint(
config: RunnableConfig,
channels: Mapping[str, BaseChannel],
stream_channels: str | Sequence[str],
metadata: CheckpointMetadata,
tasks: Iterable[PregelExecutableTask],
pending_writes: list[PendingWrite],
parent_config: RunnableConfig | None,
output_keys: str | Sequence[str],
) -> Iterator[CheckpointPayload]:
"""Produce "checkpoint" events for stream_mode=debug."""
parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
task_states: dict[str, RunnableConfig | StateSnapshot] = {}
for task in tasks:
if not task.subgraphs:
continue
# assemble checkpoint_ns for this task
task_ns = f"{task.name}{NS_END}{task.id}"
if parent_ns:
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
# set config as signal that subgraph checkpoints exist
task_states[task.id] = {
CONF: {
"thread_id": config[CONF]["thread_id"],
CONFIG_KEY_CHECKPOINT_NS: task_ns,
}
}
yield {
"config": rm_pregel_keys(patch_checkpoint_map(config, metadata)),
"parent_config": rm_pregel_keys(patch_checkpoint_map(parent_config, metadata)),
"values": read_channels(channels, stream_channels),
"metadata": metadata,
"next": [t.name for t in tasks],
"tasks": [
{
"id": t.id,
"name": t.name,
"error": t.error,
"state": t.state,
}
if t.error
else {
"id": t.id,
"name": t.name,
"result": t.result,
"interrupts": tuple(asdict(i) for i in t.interrupts),
"state": t.state,
}
if t.result
else {
"id": t.id,
"name": t.name,
"interrupts": tuple(asdict(i) for i in t.interrupts),
"state": t.state,
}
for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys)
],
}
def tasks_w_writes(
tasks: Iterable[PregelTask | PregelExecutableTask],
pending_writes: list[PendingWrite] | None,
states: dict[str, RunnableConfig | StateSnapshot] | None,
output_keys: str | Sequence[str],
) -> tuple[PregelTask, ...]:
"""Apply writes / subgraph states to tasks to be returned in a StateSnapshot."""
pending_writes = pending_writes or []
out: list[PregelTask] = []
for task in tasks:
rtn = next(
(
val
for tid, chan, val in pending_writes
if tid == task.id and chan == RETURN
),
MISSING,
)
task_error = next(
(exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR),
None,
)
task_interrupts = tuple(
v
for tid, n, vv in pending_writes
if tid == task.id and n == INTERRUPT
for v in (vv if isinstance(vv, Sequence) else [vv])
)
task_writes = [
(chan, val)
for tid, chan, val in pending_writes
if tid == task.id and chan not in (ERROR, INTERRUPT, RETURN)
]
if rtn is not MISSING:
task_result = rtn
elif isinstance(output_keys, str):
# unwrap single channel writes to just the write value
filtered_writes = [
(chan, val) for chan, val in task_writes if chan == output_keys
]
mapped_writes = map_task_result_writes(filtered_writes)
task_result = mapped_writes.get(str(output_keys)) if mapped_writes else None
else:
if isinstance(output_keys, str):
output_keys = [output_keys]
# map task result writes to the desired output channels
# repeateed writes to the same channel are aggregated into: {'$writes': [write1, write2, ...]}
filtered_writes = [
(chan, val) for chan, val in task_writes if chan in output_keys
]
mapped_writes = map_task_result_writes(filtered_writes)
task_result = mapped_writes if filtered_writes else {}
has_writes = rtn is not MISSING or any(
w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes
)
out.append(
PregelTask(
task.id,
task.name,
task.path,
task_error,
task_interrupts,
states.get(task.id) if states else None,
task_result if has_writes else None,
)
)
return tuple(out)
COLOR_MAPPING = {
"black": "0;30",
"red": "0;31",
"green": "0;32",
"yellow": "0;33",
"blue": "0;34",
"magenta": "0;35",
"cyan": "0;36",
"white": "0;37",
"gray": "1;30",
}
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
return f"\033[1;3{COLOR_MAPPING[color]}m{text}\033[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from abc import abstractmethod
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from typing import Any, Generic, cast
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph as DrawableGraph
from typing_extensions import Self
from langgraph.types import All, Command, StateSnapshot, StateUpdate, StreamMode
from langgraph.typing import ContextT, InputT, OutputT, StateT
__all__ = ("PregelProtocol", "StreamProtocol")
class PregelProtocol(Runnable[InputT, Any], Generic[StateT, ContextT, InputT, OutputT]):
@abstractmethod
def with_config(
self, config: RunnableConfig | None = None, **kwargs: Any
) -> Self: ...
@abstractmethod
def get_graph(
self,
config: RunnableConfig | None = None,
*,
xray: int | bool = False,
) -> DrawableGraph: ...
@abstractmethod
async def aget_graph(
self,
config: RunnableConfig | None = None,
*,
xray: int | bool = False,
) -> DrawableGraph: ...
@abstractmethod
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...
@abstractmethod
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot: ...
@abstractmethod
def get_state_history(
self,
config: RunnableConfig,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[StateSnapshot]: ...
@abstractmethod
def aget_state_history(
self,
config: RunnableConfig,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[StateSnapshot]: ...
@abstractmethod
def bulk_update_state(
self,
config: RunnableConfig,
updates: Sequence[Sequence[StateUpdate]],
) -> RunnableConfig: ...
@abstractmethod
async def abulk_update_state(
self,
config: RunnableConfig,
updates: Sequence[Sequence[StateUpdate]],
) -> RunnableConfig: ...
@abstractmethod
def update_state(
self,
config: RunnableConfig,
values: dict[str, Any] | Any | None,
as_node: str | None = None,
) -> RunnableConfig: ...
@abstractmethod
async def aupdate_state(
self,
config: RunnableConfig,
values: dict[str, Any] | Any | None,
as_node: str | None = None,
) -> RunnableConfig: ...
@abstractmethod
def stream(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
stream_mode: StreamMode | list[StreamMode] | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
subgraphs: bool = False,
) -> Iterator[dict[str, Any] | Any]: ...
@abstractmethod
def astream(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
stream_mode: StreamMode | list[StreamMode] | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
subgraphs: bool = False,
) -> AsyncIterator[dict[str, Any] | Any]: ...
@abstractmethod
def invoke(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
) -> dict[str, Any] | Any: ...
@abstractmethod
async def ainvoke(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
) -> dict[str, Any] | Any: ...
StreamChunk = tuple[tuple[str, ...], str, Any]
class StreamProtocol:
__slots__ = ("modes", "__call__")
modes: set[StreamMode]
__call__: Callable[[Self, StreamChunk], None]
def __init__(
self,
__call__: Callable[[StreamChunk], None],
modes: set[StreamMode],
) -> None:
self.__call__ = cast(Callable[[Self, StreamChunk], None], __call__)
self.modes = modes

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
"""Re-export types moved to langgraph.types"""
from langgraph.types import (
All,
CachePolicy,
PregelExecutableTask,
PregelTask,
RetryPolicy,
StateSnapshot,
StateUpdate,
StreamMode,
StreamWriter,
default_retry_on,
)
__all__ = [
"All",
"StateUpdate",
"CachePolicy",
"PregelExecutableTask",
"PregelTask",
"RetryPolicy",
"StateSnapshot",
"StreamMode",
"StreamWriter",
"default_retry_on",
]
from warnings import warn
from langgraph.warnings import LangGraphDeprecatedSinceV10
warn(
"Importing from langgraph.pregel.types is deprecated. "
"Please use 'from langgraph.types import ...' instead.",
LangGraphDeprecatedSinceV10,
stacklevel=2,
)