initial commit
This commit is contained in:
3
venv/Lib/site-packages/langgraph/pregel/__init__.py
Normal file
3
venv/Lib/site-packages/langgraph/pregel/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from langgraph.pregel.main import NodeBuilder, Pregel
|
||||
|
||||
__all__ = ("Pregel", "NodeBuilder")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1233
venv/Lib/site-packages/langgraph/pregel/_algo.py
Normal file
1233
venv/Lib/site-packages/langgraph/pregel/_algo.py
Normal file
File diff suppressed because it is too large
Load Diff
269
venv/Lib/site-packages/langgraph/pregel/_call.py
Normal file
269
venv/Lib/site-packages/langgraph/pregel/_call.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Utility to convert a user provided function into a Runnable with a ChannelWrite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Awaitable, Callable, Generator, Sequence
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
from langchain_core.runnables import Runnable
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_CALL, RETURN
|
||||
from langgraph._internal._runnable import (
|
||||
RunnableCallable,
|
||||
RunnableSeq,
|
||||
is_async_callable,
|
||||
run_in_executor,
|
||||
)
|
||||
from langgraph.config import get_config
|
||||
from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
|
||||
from langgraph.types import CachePolicy, RetryPolicy
|
||||
|
||||
##
|
||||
# Utilities borrowed from cloudpickle.
|
||||
# https://github.com/cloudpipe/cloudpickle/blob/6220b0ce83ffee5e47e06770a1ee38ca9e47c850/cloudpickle/cloudpickle.py#L265
|
||||
|
||||
|
||||
def _getattribute(obj: Any, name: str) -> Any:
|
||||
parent = None
|
||||
for subpath in name.split("."):
|
||||
if subpath == "<locals>":
|
||||
raise AttributeError(f"Can't get local attribute {name!r} on {obj!r}")
|
||||
try:
|
||||
parent = obj
|
||||
obj = getattr(obj, subpath)
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Can't get attribute {name!r} on {obj!r}") from None
|
||||
return obj, parent
|
||||
|
||||
|
||||
def _whichmodule(obj: Any, name: str) -> str | None:
|
||||
"""Find the module an object belongs to.
|
||||
|
||||
This function differs from ``pickle.whichmodule`` in two ways:
|
||||
- it does not mangle the cases where obj's module is __main__ and obj was
|
||||
not found in any module.
|
||||
- Errors arising during module introspection are ignored, as those errors
|
||||
are considered unwanted side effects.
|
||||
"""
|
||||
module_name = getattr(obj, "__module__", None)
|
||||
|
||||
if module_name is not None:
|
||||
return module_name
|
||||
# Protect the iteration by using a copy of sys.modules against dynamic
|
||||
# modules that trigger imports of other modules upon calls to getattr or
|
||||
# other threads importing at the same time.
|
||||
for module_name, module in sys.modules.copy().items():
|
||||
# Some modules such as coverage can inject non-module objects inside
|
||||
# sys.modules
|
||||
if (
|
||||
module_name == "__main__"
|
||||
or module_name == "__mp_main__"
|
||||
or module is None
|
||||
or not isinstance(module, types.ModuleType)
|
||||
):
|
||||
continue
|
||||
try:
|
||||
if _getattribute(module, name)[0] is obj:
|
||||
return module_name
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def identifier(obj: Any, name: str | None = None) -> str | None:
|
||||
"""Return the module and name of an object."""
|
||||
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
|
||||
from langgraph.pregel._read import PregelNode
|
||||
|
||||
if isinstance(obj, PregelNode):
|
||||
obj = obj.bound
|
||||
if isinstance(obj, RunnableSeq):
|
||||
obj = obj.steps[0]
|
||||
if isinstance(obj, RunnableCallable):
|
||||
obj = obj.func
|
||||
if name is None:
|
||||
name = getattr(obj, "__qualname__", None)
|
||||
if name is None: # pragma: no cover
|
||||
# This used to be needed for Python 2.7 support but is probably not
|
||||
# needed anymore. However we keep the __name__ introspection in case
|
||||
# users of cloudpickle rely on this old behavior for unknown reasons.
|
||||
name = getattr(obj, "__name__", None)
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
module_name = getattr(obj, "__module__", None)
|
||||
if module_name is None:
|
||||
# In this case, obj.__module__ is None. obj is thus treated as dynamic.
|
||||
return None
|
||||
|
||||
return f"{module_name}.{name}"
|
||||
|
||||
|
||||
def _lookup_module_and_qualname(
|
||||
obj: Any, name: str | None = None
|
||||
) -> tuple[types.ModuleType, str] | None:
|
||||
if name is None:
|
||||
name = getattr(obj, "__qualname__", None)
|
||||
if name is None: # pragma: no cover
|
||||
# This used to be needed for Python 2.7 support but is probably not
|
||||
# needed anymore. However we keep the __name__ introspection in case
|
||||
# users of cloudpickle rely on this old behavior for unknown reasons.
|
||||
name = getattr(obj, "__name__", None)
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
module_name = _whichmodule(obj, name)
|
||||
|
||||
if module_name is None:
|
||||
# In this case, obj.__module__ is None AND obj was not found in any
|
||||
# imported module. obj is thus treated as dynamic.
|
||||
return None
|
||||
|
||||
if module_name == "__main__":
|
||||
return None
|
||||
|
||||
# Note: if module_name is in sys.modules, the corresponding module is
|
||||
# assumed importable at unpickling time. See #357
|
||||
module = sys.modules.get(module_name, None)
|
||||
if module is None:
|
||||
# The main reason why obj's module would not be imported is that this
|
||||
# module has been dynamically created, using for example
|
||||
# types.ModuleType. The other possibility is that module was removed
|
||||
# from sys.modules after obj was created/imported. But this case is not
|
||||
# supported, as the standard pickle does not support it either.
|
||||
return None
|
||||
|
||||
try:
|
||||
obj2, parent = _getattribute(module, name)
|
||||
except AttributeError:
|
||||
# obj was not found inside the module it points to
|
||||
return None
|
||||
if obj2 is not obj:
|
||||
return None
|
||||
return module, name
|
||||
|
||||
|
||||
def _explode_args_trace_inputs(
|
||||
sig: inspect.Signature, input: tuple[tuple[Any, ...], dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
args, kwargs = input
|
||||
bound = sig.bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
arguments = dict(bound.arguments)
|
||||
arguments.pop("self", None)
|
||||
arguments.pop("cls", None)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
# Update with the **kwargs, and remove the original entry
|
||||
# This is to help flatten out keyword arguments
|
||||
if param_name in arguments:
|
||||
arguments.update(arguments.pop(param_name))
|
||||
return arguments
|
||||
|
||||
|
||||
def get_runnable_for_entrypoint(func: Callable[..., Any]) -> Runnable:
|
||||
key = (func, False)
|
||||
if key in CACHE:
|
||||
return CACHE[key]
|
||||
else:
|
||||
if is_async_callable(func):
|
||||
run = RunnableCallable(
|
||||
None, func, name=func.__name__, trace=False, recurse=False
|
||||
)
|
||||
else:
|
||||
afunc = functools.update_wrapper(
|
||||
functools.partial(run_in_executor, None, func), func
|
||||
)
|
||||
run = RunnableCallable(
|
||||
func,
|
||||
afunc,
|
||||
name=func.__name__,
|
||||
trace=False,
|
||||
recurse=False,
|
||||
)
|
||||
if not _lookup_module_and_qualname(func):
|
||||
return run
|
||||
return CACHE.setdefault(key, run)
|
||||
|
||||
|
||||
def get_runnable_for_task(func: Callable[..., Any]) -> Runnable:
|
||||
key = (func, True)
|
||||
if key in CACHE:
|
||||
return CACHE[key]
|
||||
else:
|
||||
if hasattr(func, "__name__"):
|
||||
name = func.__name__
|
||||
elif hasattr(func, "func"):
|
||||
name = func.func.__name__
|
||||
elif hasattr(func, "__class__"):
|
||||
name = func.__class__.__name__
|
||||
else:
|
||||
name = str(func)
|
||||
|
||||
if is_async_callable(func):
|
||||
run = RunnableCallable(
|
||||
None,
|
||||
func,
|
||||
explode_args=True,
|
||||
name=name,
|
||||
trace=False,
|
||||
recurse=False,
|
||||
)
|
||||
else:
|
||||
run = RunnableCallable(
|
||||
func,
|
||||
functools.wraps(func)(functools.partial(run_in_executor, None, func)),
|
||||
explode_args=True,
|
||||
name=name,
|
||||
trace=False,
|
||||
recurse=False,
|
||||
)
|
||||
seq = RunnableSeq(
|
||||
run,
|
||||
ChannelWrite([ChannelWriteEntry(RETURN)]),
|
||||
name=name,
|
||||
trace_inputs=functools.partial(
|
||||
_explode_args_trace_inputs, inspect.signature(func)
|
||||
),
|
||||
)
|
||||
if not _lookup_module_and_qualname(func):
|
||||
return seq
|
||||
return CACHE.setdefault(key, seq)
|
||||
|
||||
|
||||
CACHE: dict[tuple[Callable[..., Any], bool], Runnable] = {}
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
P1 = TypeVar("P1")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SyncAsyncFuture(Generic[T], concurrent.futures.Future[T]):
|
||||
def __await__(self) -> Generator[T, None, T]:
|
||||
yield cast(T, ...)
|
||||
|
||||
|
||||
def call(
|
||||
func: Callable[P, Awaitable[T]] | Callable[P, T],
|
||||
*args: Any,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
**kwargs: Any,
|
||||
) -> SyncAsyncFuture[T]:
|
||||
config = get_config()
|
||||
impl = config[CONF][CONFIG_KEY_CALL]
|
||||
fut = impl(
|
||||
func,
|
||||
(args, kwargs),
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=config["callbacks"],
|
||||
)
|
||||
return fut
|
||||
88
venv/Lib/site-packages/langgraph/pregel/_checkpoint.py
Normal file
88
venv/Lib/site-packages/langgraph/pregel/_checkpoint.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from langgraph.checkpoint.base import Checkpoint
|
||||
from langgraph.checkpoint.base.id import uuid6
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.managed.base import ManagedValueMapping, ManagedValueSpec
|
||||
|
||||
LATEST_VERSION = 4
|
||||
|
||||
|
||||
def empty_checkpoint() -> Checkpoint:
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
id=str(uuid6(clock_seq=-2)),
|
||||
ts=datetime.now(timezone.utc).isoformat(),
|
||||
channel_values={},
|
||||
channel_versions={},
|
||||
versions_seen={},
|
||||
)
|
||||
|
||||
|
||||
def create_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
channels: Mapping[str, BaseChannel] | None,
|
||||
step: int,
|
||||
*,
|
||||
id: str | None = None,
|
||||
updated_channels: set[str] | None = None,
|
||||
) -> Checkpoint:
|
||||
"""Create a checkpoint for the given channels."""
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
if channels is None:
|
||||
values = checkpoint["channel_values"]
|
||||
else:
|
||||
values = {}
|
||||
for k in channels:
|
||||
if k not in checkpoint["channel_versions"]:
|
||||
continue
|
||||
v = channels[k].checkpoint()
|
||||
if v is not MISSING:
|
||||
values[k] = v
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
ts=ts,
|
||||
id=id or str(uuid6(clock_seq=step)),
|
||||
channel_values=values,
|
||||
channel_versions=checkpoint["channel_versions"],
|
||||
versions_seen=checkpoint["versions_seen"],
|
||||
updated_channels=None if updated_channels is None else sorted(updated_channels),
|
||||
)
|
||||
|
||||
|
||||
def channels_from_checkpoint(
|
||||
specs: Mapping[str, BaseChannel | ManagedValueSpec],
|
||||
checkpoint: Checkpoint,
|
||||
) -> tuple[Mapping[str, BaseChannel], ManagedValueMapping]:
|
||||
"""Get channels from a checkpoint."""
|
||||
channel_specs: dict[str, BaseChannel] = {}
|
||||
managed_specs: dict[str, ManagedValueSpec] = {}
|
||||
for k, v in specs.items():
|
||||
if isinstance(v, BaseChannel):
|
||||
channel_specs[k] = v
|
||||
else:
|
||||
managed_specs[k] = v
|
||||
return (
|
||||
{
|
||||
k: v.from_checkpoint(checkpoint["channel_values"].get(k, MISSING))
|
||||
for k, v in channel_specs.items()
|
||||
},
|
||||
managed_specs,
|
||||
)
|
||||
|
||||
|
||||
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
|
||||
return Checkpoint(
|
||||
v=checkpoint["v"],
|
||||
ts=checkpoint["ts"],
|
||||
id=checkpoint["id"],
|
||||
channel_values=checkpoint["channel_values"].copy(),
|
||||
channel_versions=checkpoint["channel_versions"].copy(),
|
||||
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
|
||||
updated_channels=checkpoint.get("updated_channels", None),
|
||||
)
|
||||
0
venv/Lib/site-packages/langgraph/pregel/_config.py
Normal file
0
venv/Lib/site-packages/langgraph/pregel/_config.py
Normal file
294
venv/Lib/site-packages/langgraph/pregel/_draw.py
Normal file
294
venv/Lib/site-packages/langgraph/pregel/_draw.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph, Node
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, INPUT
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.channels.last_value import LastValueAfterFinish
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.managed.base import ManagedValueSpec
|
||||
from langgraph.pregel._algo import (
|
||||
PregelTaskWrites,
|
||||
apply_writes,
|
||||
increment,
|
||||
prepare_next_tasks,
|
||||
)
|
||||
from langgraph.pregel._checkpoint import channels_from_checkpoint, empty_checkpoint
|
||||
from langgraph.pregel._io import map_input
|
||||
from langgraph.pregel._read import PregelNode
|
||||
from langgraph.pregel._write import ChannelWrite
|
||||
from langgraph.types import All, Checkpointer
|
||||
|
||||
|
||||
class Edge(NamedTuple):
|
||||
source: str
|
||||
target: str
|
||||
conditional: bool
|
||||
data: str | None
|
||||
|
||||
|
||||
class TriggerEdge(NamedTuple):
|
||||
source: str
|
||||
conditional: bool
|
||||
data: str | None
|
||||
|
||||
|
||||
def draw_graph(
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
nodes: dict[str, PregelNode],
|
||||
specs: dict[str, BaseChannel | ManagedValueSpec],
|
||||
input_channels: str | Sequence[str],
|
||||
interrupt_after_nodes: All | Sequence[str],
|
||||
interrupt_before_nodes: All | Sequence[str],
|
||||
trigger_to_nodes: Mapping[str, Sequence[str]],
|
||||
checkpointer: Checkpointer,
|
||||
subgraphs: dict[str, Graph],
|
||||
limit: int = 250,
|
||||
) -> Graph:
|
||||
"""Get the graph for this Pregel instance.
|
||||
|
||||
Args:
|
||||
config: The configuration to use for the graph.
|
||||
subgraphs: The subgraphs to include in the graph.
|
||||
checkpointer: The checkpointer to use for the graph.
|
||||
|
||||
Returns:
|
||||
The graph for this Pregel instance.
|
||||
"""
|
||||
# (src, dest, is_conditional, label)
|
||||
edges: set[Edge] = set()
|
||||
|
||||
step = -1
|
||||
checkpoint = empty_checkpoint()
|
||||
get_next_version = (
|
||||
checkpointer.get_next_version
|
||||
if isinstance(checkpointer, BaseCheckpointSaver)
|
||||
else increment
|
||||
)
|
||||
channels, managed = channels_from_checkpoint(
|
||||
specs,
|
||||
checkpoint,
|
||||
)
|
||||
static_seen: set[Any] = set()
|
||||
sources: dict[str, set[TriggerEdge]] = {}
|
||||
step_sources: dict[str, set[TriggerEdge]] = {}
|
||||
static_declared_writes: dict[str, set[TriggerEdge]] = defaultdict(set)
|
||||
# remove node mappers
|
||||
nodes = {
|
||||
k: v.copy(update={"mapper": None}) if v.mapper is not None else v
|
||||
for k, v in nodes.items()
|
||||
}
|
||||
# apply input writes
|
||||
input_writes = list(map_input(input_channels, {}))
|
||||
updated_channels = apply_writes(
|
||||
checkpoint,
|
||||
channels,
|
||||
[
|
||||
PregelTaskWrites((), INPUT, input_writes, []),
|
||||
],
|
||||
get_next_version,
|
||||
trigger_to_nodes,
|
||||
)
|
||||
# prepare first tasks
|
||||
tasks = prepare_next_tasks(
|
||||
checkpoint,
|
||||
[],
|
||||
nodes,
|
||||
channels,
|
||||
managed,
|
||||
config,
|
||||
step,
|
||||
-1,
|
||||
for_execution=True,
|
||||
store=None,
|
||||
checkpointer=None,
|
||||
manager=None,
|
||||
trigger_to_nodes=trigger_to_nodes,
|
||||
updated_channels=updated_channels,
|
||||
)
|
||||
start_tasks = tasks
|
||||
# run the pregel loop
|
||||
for step in range(step, limit):
|
||||
if not tasks:
|
||||
break
|
||||
conditionals: dict[tuple[str, str, Any], str | None] = {}
|
||||
# run task writers
|
||||
for task in tasks.values():
|
||||
for w in task.writers:
|
||||
# apply regular writes
|
||||
if isinstance(w, ChannelWrite):
|
||||
empty_input = (
|
||||
cast(BaseChannel, specs["__root__"]).ValueType()
|
||||
if "__root__" in specs
|
||||
else None
|
||||
)
|
||||
w.invoke(empty_input, task.config)
|
||||
# apply conditional writes declared for static analysis, only once
|
||||
if w not in static_seen:
|
||||
static_seen.add(w)
|
||||
# apply static writes
|
||||
if writes := ChannelWrite.get_static_writes(w):
|
||||
# END writes are not written, but become edges directly
|
||||
for t in writes:
|
||||
if t[0] == END:
|
||||
edges.add(Edge(task.name, t[0], True, t[2]))
|
||||
writes = [t for t in writes if t[0] != END]
|
||||
conditionals.update(
|
||||
{(task.name, t[0], t[1] or None): t[2] for t in writes}
|
||||
)
|
||||
# record static writes for edge creation
|
||||
for t in writes:
|
||||
static_declared_writes[task.name].add(
|
||||
TriggerEdge(t[0], True, t[2])
|
||||
)
|
||||
task.config[CONF][CONFIG_KEY_SEND]([t[:2] for t in writes])
|
||||
# collect sources
|
||||
step_sources = {}
|
||||
for task in tasks.values():
|
||||
task_edges = {
|
||||
TriggerEdge(
|
||||
w[0],
|
||||
(task.name, w[0], w[1] or None) in conditionals,
|
||||
conditionals.get((task.name, w[0], w[1] or None)),
|
||||
)
|
||||
for w in task.writes
|
||||
}
|
||||
task_edges |= static_declared_writes.get(task.name, set())
|
||||
step_sources[task.name] = task_edges
|
||||
sources.update(step_sources)
|
||||
# invert triggers
|
||||
trigger_to_sources: dict[str, set[TriggerEdge]] = defaultdict(set)
|
||||
for src, triggers in sources.items():
|
||||
for trigger, cond, label in triggers:
|
||||
trigger_to_sources[trigger].add(TriggerEdge(src, cond, label))
|
||||
# apply writes
|
||||
updated_channels = apply_writes(
|
||||
checkpoint, channels, tasks.values(), get_next_version, trigger_to_nodes
|
||||
)
|
||||
# prepare next tasks
|
||||
tasks = prepare_next_tasks(
|
||||
checkpoint,
|
||||
[],
|
||||
nodes,
|
||||
channels,
|
||||
managed,
|
||||
config,
|
||||
step,
|
||||
limit,
|
||||
for_execution=True,
|
||||
store=None,
|
||||
checkpointer=None,
|
||||
manager=None,
|
||||
trigger_to_nodes=trigger_to_nodes,
|
||||
updated_channels=updated_channels,
|
||||
)
|
||||
# collect deferred nodes
|
||||
deferred_nodes: set[str] = set()
|
||||
edges_to_deferred_nodes: set[Edge] = set()
|
||||
for channel, item in channels.items():
|
||||
if isinstance(item, LastValueAfterFinish):
|
||||
deferred_node = channel.split(":", 2)[-1]
|
||||
deferred_nodes.add(deferred_node)
|
||||
# collect edges
|
||||
for task in tasks.values():
|
||||
added = False
|
||||
for trigger in task.triggers:
|
||||
for src, cond, label in sorted(trigger_to_sources[trigger]):
|
||||
# record edge to be reviewed later
|
||||
if task.name in deferred_nodes:
|
||||
edges_to_deferred_nodes.add(Edge(src, task.name, cond, label))
|
||||
edges.add(Edge(src, task.name, cond, label))
|
||||
# if the edge is from this step, skip adding the implicit edges
|
||||
if (trigger, cond, label) in step_sources.get(src, set()):
|
||||
added = True
|
||||
else:
|
||||
sources[src].discard(TriggerEdge(trigger, cond, label))
|
||||
# if no edges from this step, add implicit edges from all previous tasks
|
||||
if not added:
|
||||
for src in step_sources:
|
||||
edges.add(Edge(src, task.name, True, None))
|
||||
|
||||
# assemble the graph
|
||||
graph = Graph()
|
||||
# add nodes
|
||||
for name, node in nodes.items():
|
||||
metadata = dict(node.metadata or {})
|
||||
if name in deferred_nodes:
|
||||
metadata["defer"] = True
|
||||
if name in interrupt_before_nodes and name in interrupt_after_nodes:
|
||||
metadata["__interrupt"] = "before,after"
|
||||
elif name in interrupt_before_nodes:
|
||||
metadata["__interrupt"] = "before"
|
||||
elif name in interrupt_after_nodes:
|
||||
metadata["__interrupt"] = "after"
|
||||
graph.add_node(node.bound, name, metadata=metadata or None)
|
||||
# add start node
|
||||
if START not in nodes:
|
||||
graph.add_node(None, START)
|
||||
for task in start_tasks.values():
|
||||
add_edge(graph, START, task.name)
|
||||
# add discovered edges
|
||||
for src, dest, is_conditional, label in sorted(edges):
|
||||
add_edge(
|
||||
graph,
|
||||
src,
|
||||
dest,
|
||||
data=label if label != dest else None,
|
||||
conditional=is_conditional,
|
||||
)
|
||||
# add end edges
|
||||
termini = {d for _, d, _, _ in edges if d != END}.difference(
|
||||
s for s, _, _, _ in edges
|
||||
)
|
||||
end_edge_exists = any(d == END for _, d, _, _ in edges)
|
||||
if termini:
|
||||
for src in sorted(termini):
|
||||
add_edge(graph, src, END)
|
||||
elif len(step_sources) == 1 and not end_edge_exists:
|
||||
for src in sorted(step_sources):
|
||||
add_edge(graph, src, END, conditional=True)
|
||||
# replace subgraphs
|
||||
for name, subgraph in subgraphs.items():
|
||||
if (
|
||||
len(subgraph.nodes) > 1
|
||||
and name in graph.nodes
|
||||
and subgraph.first_node()
|
||||
and subgraph.last_node()
|
||||
):
|
||||
subgraph.trim_first_node()
|
||||
subgraph.trim_last_node()
|
||||
# replace the node with the subgraph
|
||||
graph.nodes.pop(name)
|
||||
first, last = graph.extend(subgraph, prefix=name)
|
||||
for idx, edge in enumerate(graph.edges):
|
||||
if edge.source == name:
|
||||
edge = edge.copy(source=cast(Node, last).id)
|
||||
if edge.target == name:
|
||||
edge = edge.copy(target=cast(Node, first).id)
|
||||
graph.edges[idx] = edge
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def add_edge(
|
||||
graph: Graph,
|
||||
source: str,
|
||||
target: str,
|
||||
*,
|
||||
data: Any | None = None,
|
||||
conditional: bool = False,
|
||||
) -> None:
|
||||
"""Add an edge to the graph."""
|
||||
for edge in graph.edges:
|
||||
if edge.source == source and edge.target == target:
|
||||
return
|
||||
if target not in graph.nodes and target == END:
|
||||
graph.add_node(None, END)
|
||||
graph.add_edge(graph.nodes[source], graph.nodes[target], data, conditional)
|
||||
223
venv/Lib/site-packages/langgraph/pregel/_executor.py
Normal file
223
venv/Lib/site-packages/langgraph/pregel/_executor.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
|
||||
from contextvars import copy_context
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import get_executor_for_config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from langgraph._internal._future import CONTEXT_NOT_SUPPORTED, run_coroutine_threadsafe
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Submit(Protocol[P, T]):
|
||||
def __call__( # type: ignore[valid-type]
|
||||
self,
|
||||
fn: Callable[P, T],
|
||||
*args: P.args,
|
||||
__name__: str | None = None,
|
||||
__cancel_on_exit__: bool = False,
|
||||
__reraise_on_exit__: bool = True,
|
||||
__next_tick__: bool = False,
|
||||
**kwargs: P.kwargs,
|
||||
) -> concurrent.futures.Future[T]: ...
|
||||
|
||||
|
||||
class BackgroundExecutor(AbstractContextManager):
|
||||
"""A context manager that runs sync tasks in the background.
|
||||
Uses a thread pool executor to delegate tasks to separate threads.
|
||||
On exit,
|
||||
- cancels any (not yet started) tasks with `__cancel_on_exit__=True`
|
||||
- waits for all tasks to finish
|
||||
- re-raises the first exception from tasks with `__reraise_on_exit__=True`"""
|
||||
|
||||
def __init__(self, config: RunnableConfig) -> None:
|
||||
self.stack = ExitStack()
|
||||
self.executor = self.stack.enter_context(get_executor_for_config(config))
|
||||
# mapping of Future to (__cancel_on_exit__, __reraise_on_exit__) flags
|
||||
self.tasks: dict[concurrent.futures.Future, tuple[bool, bool]] = {}
|
||||
|
||||
def submit( # type: ignore[valid-type]
|
||||
self,
|
||||
fn: Callable[P, T],
|
||||
*args: P.args,
|
||||
__name__: str | None = None, # currently not used in sync version
|
||||
__cancel_on_exit__: bool = False, # for sync, can cancel only if not started
|
||||
__reraise_on_exit__: bool = True,
|
||||
__next_tick__: bool = False,
|
||||
**kwargs: P.kwargs,
|
||||
) -> concurrent.futures.Future[T]:
|
||||
ctx = copy_context()
|
||||
if __next_tick__:
|
||||
task = cast(
|
||||
concurrent.futures.Future[T],
|
||||
self.executor.submit(next_tick, ctx.run, fn, *args, **kwargs), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
task = self.executor.submit(ctx.run, fn, *args, **kwargs)
|
||||
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
|
||||
# add a callback to remove the task from the tasks dict when it's done
|
||||
task.add_done_callback(self.done)
|
||||
return task
|
||||
|
||||
def done(self, task: concurrent.futures.Future) -> None:
|
||||
"""Remove the task from the tasks dict when it's done."""
|
||||
try:
|
||||
task.result()
|
||||
except GraphBubbleUp:
|
||||
# This exception is an interruption signal, not an error
|
||||
# so we don't want to re-raise it on exit
|
||||
self.tasks.pop(task)
|
||||
except BaseException:
|
||||
pass
|
||||
else:
|
||||
self.tasks.pop(task)
|
||||
|
||||
def __enter__(self) -> Submit:
|
||||
return self.submit
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
# copy the tasks as done() callback may modify the dict
|
||||
tasks = self.tasks.copy()
|
||||
# cancel all tasks that should be cancelled
|
||||
for task, (cancel, _) in tasks.items():
|
||||
if cancel:
|
||||
task.cancel()
|
||||
# wait for all tasks to finish
|
||||
if pending := {t for t in tasks if not t.done()}:
|
||||
concurrent.futures.wait(pending)
|
||||
# shutdown the executor
|
||||
self.stack.__exit__(exc_type, exc_value, traceback)
|
||||
# if there's already an exception being raised, don't raise another one
|
||||
if exc_type is None:
|
||||
# re-raise the first exception that occurred in a task
|
||||
for task, (_, reraise) in tasks.items():
|
||||
if not reraise:
|
||||
continue
|
||||
try:
|
||||
task.result()
|
||||
except concurrent.futures.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncBackgroundExecutor(AbstractAsyncContextManager):
|
||||
"""A context manager that runs async tasks in the background.
|
||||
Uses the current event loop to delegate tasks to asyncio tasks.
|
||||
On exit,
|
||||
- cancels any tasks with `__cancel_on_exit__=True`
|
||||
- waits for all tasks to finish
|
||||
- re-raises the first exception from tasks with `__reraise_on_exit__=True`
|
||||
ignoring CancelledError"""
|
||||
|
||||
def __init__(self, config: RunnableConfig) -> None:
|
||||
self.tasks: dict[asyncio.Future, tuple[bool, bool]] = {}
|
||||
self.sentinel = object()
|
||||
self.loop = asyncio.get_running_loop()
|
||||
if max_concurrency := config.get("max_concurrency"):
|
||||
self.semaphore: asyncio.Semaphore | None = asyncio.Semaphore(
|
||||
max_concurrency
|
||||
)
|
||||
else:
|
||||
self.semaphore = None
|
||||
|
||||
def submit( # type: ignore[valid-type]
|
||||
self,
|
||||
fn: Callable[P, Awaitable[T]],
|
||||
*args: P.args,
|
||||
__name__: str | None = None,
|
||||
__cancel_on_exit__: bool = False,
|
||||
__reraise_on_exit__: bool = True,
|
||||
__next_tick__: bool = False, # noop in async (always True)
|
||||
**kwargs: P.kwargs,
|
||||
) -> asyncio.Future[T]:
|
||||
coro = cast(Coroutine[None, None, T], fn(*args, **kwargs))
|
||||
if self.semaphore:
|
||||
coro = gated(self.semaphore, coro)
|
||||
if CONTEXT_NOT_SUPPORTED:
|
||||
task = run_coroutine_threadsafe(
|
||||
coro, self.loop, name=__name__, lazy=__next_tick__
|
||||
)
|
||||
else:
|
||||
task = run_coroutine_threadsafe(
|
||||
coro,
|
||||
self.loop,
|
||||
name=__name__,
|
||||
context=copy_context(),
|
||||
lazy=__next_tick__,
|
||||
)
|
||||
self.tasks[task] = (__cancel_on_exit__, __reraise_on_exit__)
|
||||
task.add_done_callback(self.done)
|
||||
return task
|
||||
|
||||
def done(self, task: asyncio.Future) -> None:
|
||||
try:
|
||||
if exc := task.exception():
|
||||
# This exception is an interruption signal, not an error
|
||||
# so we don't want to re-raise it on exit
|
||||
if isinstance(exc, GraphBubbleUp):
|
||||
self.tasks.pop(task)
|
||||
else:
|
||||
self.tasks.pop(task)
|
||||
except asyncio.CancelledError:
|
||||
self.tasks.pop(task)
|
||||
|
||||
async def __aenter__(self) -> Submit:
|
||||
return self.submit
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
# copy the tasks as done() callback may modify the dict
|
||||
tasks = self.tasks.copy()
|
||||
# cancel all tasks that should be cancelled
|
||||
for task, (cancel, _) in tasks.items():
|
||||
if cancel:
|
||||
task.cancel(self.sentinel)
|
||||
# wait for all tasks to finish
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
# if there's already an exception being raised, don't raise another one
|
||||
if exc_type is None:
|
||||
# re-raise the first exception that occurred in a task
|
||||
for task, (_, reraise) in tasks.items():
|
||||
if not reraise:
|
||||
continue
|
||||
try:
|
||||
if exc := task.exception():
|
||||
raise exc
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def gated(semaphore: asyncio.Semaphore, coro: Coroutine[None, None, T]) -> T:
|
||||
"""A coroutine that waits for a semaphore before running another coroutine."""
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
def next_tick(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""A function that yields control to other threads before running another function."""
|
||||
time.sleep(0)
|
||||
return fn(*args, **kwargs)
|
||||
174
venv/Lib/site-packages/langgraph/pregel/_io.py
Normal file
174
venv/Lib/site-packages/langgraph/pregel/_io.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
NULL_TASK_ID,
|
||||
RESUME,
|
||||
RETURN,
|
||||
TASKS,
|
||||
)
|
||||
from langgraph._internal._typing import EMPTY_SEQ, MISSING
|
||||
from langgraph.channels.base import BaseChannel, EmptyChannelError
|
||||
from langgraph.constants import START, TAG_HIDDEN
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.pregel._log import logger
|
||||
from langgraph.types import Command, PregelExecutableTask, Send
|
||||
|
||||
|
||||
def read_channel(
|
||||
channels: Mapping[str, BaseChannel],
|
||||
chan: str,
|
||||
*,
|
||||
catch: bool = True,
|
||||
) -> Any:
|
||||
try:
|
||||
return channels[chan].get()
|
||||
except EmptyChannelError:
|
||||
if catch:
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def read_channels(
|
||||
channels: Mapping[str, BaseChannel],
|
||||
select: Sequence[str] | str,
|
||||
*,
|
||||
skip_empty: bool = True,
|
||||
) -> dict[str, Any] | Any:
|
||||
if isinstance(select, str):
|
||||
return read_channel(channels, select)
|
||||
else:
|
||||
values: dict[str, Any] = {}
|
||||
for k in select:
|
||||
try:
|
||||
values[k] = read_channel(channels, k, catch=not skip_empty)
|
||||
except EmptyChannelError:
|
||||
pass
|
||||
return values
|
||||
|
||||
|
||||
def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
|
||||
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
|
||||
if cmd.graph == Command.PARENT:
|
||||
raise InvalidUpdateError("There is no parent graph")
|
||||
if cmd.goto:
|
||||
if isinstance(cmd.goto, (tuple, list)):
|
||||
sends = cmd.goto
|
||||
else:
|
||||
sends = [cmd.goto]
|
||||
for send in sends:
|
||||
if isinstance(send, Send):
|
||||
yield (NULL_TASK_ID, TASKS, send)
|
||||
elif isinstance(send, str):
|
||||
yield (NULL_TASK_ID, f"branch:to:{send}", START)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"In Command.goto, expected Send/str, got {type(send).__name__}"
|
||||
)
|
||||
if cmd.resume is not None:
|
||||
yield (NULL_TASK_ID, RESUME, cmd.resume)
|
||||
if cmd.update:
|
||||
for k, v in cmd._update_as_tuples():
|
||||
yield (NULL_TASK_ID, k, v)
|
||||
|
||||
|
||||
def map_input(
|
||||
input_channels: str | Sequence[str],
|
||||
chunk: dict[str, Any] | Any | None,
|
||||
) -> Iterator[tuple[str, Any]]:
|
||||
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
|
||||
if chunk is None:
|
||||
return
|
||||
elif isinstance(input_channels, str):
|
||||
yield (input_channels, chunk)
|
||||
else:
|
||||
if not isinstance(chunk, dict):
|
||||
raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
|
||||
for k in chunk:
|
||||
if k in input_channels:
|
||||
yield (k, chunk[k])
|
||||
else:
|
||||
logger.warning(f"Input channel {k} not found in {input_channels}")
|
||||
|
||||
|
||||
def map_output_values(
|
||||
output_channels: str | Sequence[str],
|
||||
pending_writes: Literal[True] | Sequence[tuple[str, Any]],
|
||||
channels: Mapping[str, BaseChannel],
|
||||
) -> Iterator[dict[str, Any] | Any]:
|
||||
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
|
||||
if isinstance(output_channels, str):
|
||||
if pending_writes is True or any(
|
||||
chan == output_channels for chan, _ in pending_writes
|
||||
):
|
||||
yield read_channel(channels, output_channels)
|
||||
else:
|
||||
if pending_writes is True or {
|
||||
c for c, _ in pending_writes if c in output_channels
|
||||
}:
|
||||
yield read_channels(channels, output_channels)
|
||||
|
||||
|
||||
def map_output_updates(
|
||||
output_channels: str | Sequence[str],
|
||||
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
|
||||
cached: bool = False,
|
||||
) -> Iterator[dict[str, Any | dict[str, Any]]]:
|
||||
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
|
||||
output_tasks = [
|
||||
(t, ww)
|
||||
for t, ww in tasks
|
||||
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
|
||||
and ww[0][0] != ERROR
|
||||
and ww[0][0] != INTERRUPT
|
||||
]
|
||||
if not output_tasks:
|
||||
return
|
||||
updated: list[tuple[str, Any]] = []
|
||||
for task, writes in output_tasks:
|
||||
rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
|
||||
if rtn is not MISSING:
|
||||
updated.append((task.name, rtn))
|
||||
elif isinstance(output_channels, str):
|
||||
updated.extend(
|
||||
(task.name, value) for chan, value in writes if chan == output_channels
|
||||
)
|
||||
elif any(chan in output_channels for chan, _ in writes):
|
||||
counts = Counter(chan for chan, _ in writes)
|
||||
if any(counts[chan] > 1 for chan in output_channels):
|
||||
updated.extend(
|
||||
(
|
||||
task.name,
|
||||
{chan: value},
|
||||
)
|
||||
for chan, value in writes
|
||||
if chan in output_channels
|
||||
)
|
||||
else:
|
||||
updated.append(
|
||||
(
|
||||
task.name,
|
||||
{
|
||||
chan: value
|
||||
for chan, value in writes
|
||||
if chan in output_channels
|
||||
},
|
||||
)
|
||||
)
|
||||
grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
|
||||
for node, value in updated:
|
||||
grouped[node].append(value)
|
||||
for node, value in grouped.items():
|
||||
if len(value) == 0:
|
||||
grouped[node] = None
|
||||
if len(value) == 1:
|
||||
grouped[node] = value[0]
|
||||
if cached:
|
||||
grouped["__metadata__"] = {"cached": cached}
|
||||
yield grouped
|
||||
3
venv/Lib/site-packages/langgraph/pregel/_log.py
Normal file
3
venv/Lib/site-packages/langgraph/pregel/_log.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("langgraph")
|
||||
1328
venv/Lib/site-packages/langgraph/pregel/_loop.py
Normal file
1328
venv/Lib/site-packages/langgraph/pregel/_loop.py
Normal file
File diff suppressed because it is too large
Load Diff
250
venv/Lib/site-packages/langgraph/pregel/_messages.py
Normal file
250
venv/Lib/site-packages/langgraph/pregel/_messages.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langgraph._internal._constants import NS_SEP
|
||||
from langgraph.constants import TAG_HIDDEN, TAG_NOSTREAM
|
||||
from langgraph.pregel.protocol import StreamChunk
|
||||
from langgraph.types import Command
|
||||
|
||||
try:
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
except ImportError:
|
||||
_StreamingCallbackHandler = object # type: ignore
|
||||
|
||||
T = TypeVar("T")
|
||||
Meta = tuple[tuple[str, ...], dict[str, Any]]
|
||||
|
||||
|
||||
def _state_values(obj: Any) -> Sequence[Any]:
|
||||
"""Extract top-level field values from a state object (dict, BaseModel, or dataclass)."""
|
||||
if isinstance(obj, dict):
|
||||
return list(obj.values())
|
||||
elif isinstance(obj, BaseModel):
|
||||
return [getattr(obj, k) for k in type(obj).model_fields]
|
||||
elif is_dataclass(obj) and not isinstance(obj, type):
|
||||
return [getattr(obj, f.name) for f in fields(obj)]
|
||||
return ()
|
||||
|
||||
|
||||
class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
|
||||
"""A callback handler that implements stream_mode=messages.
|
||||
|
||||
Collects messages from:
|
||||
(1) chat model stream events; and
|
||||
(2) node outputs.
|
||||
"""
|
||||
|
||||
run_inline = True
|
||||
"""We want this callback to run in the main thread to avoid order/locking issues."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: Callable[[StreamChunk], None],
|
||||
subgraphs: bool,
|
||||
*,
|
||||
parent_ns: tuple[str, ...] | None = None,
|
||||
) -> None:
|
||||
"""Configure the handler to stream messages from LLMs and nodes.
|
||||
|
||||
Args:
|
||||
stream: A callable that takes a StreamChunk and emits it.
|
||||
subgraphs: Whether to emit messages from subgraphs.
|
||||
parent_ns: The namespace where the handler was created.
|
||||
We keep track of this namespace to allow calls to subgraphs that
|
||||
were explicitly requested as a stream with `messages` mode
|
||||
configured.
|
||||
|
||||
Example:
|
||||
parent_ns is used to handle scenarios where the subgraph is explicitly
|
||||
streamed with `stream_mode="messages"`.
|
||||
|
||||
```python
|
||||
def parent_graph_node():
|
||||
# This node is in the parent graph.
|
||||
async for event in some_subgraph(..., stream_mode="messages"):
|
||||
do something with event # <-- these events will be emitted
|
||||
return ...
|
||||
|
||||
parent_graph.invoke(subgraphs=False)
|
||||
```
|
||||
"""
|
||||
self.stream = stream
|
||||
self.subgraphs = subgraphs
|
||||
self.metadata: dict[UUID, Meta] = {}
|
||||
self.seen: set[int | str] = set()
|
||||
self.parent_ns = parent_ns
|
||||
|
||||
def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
|
||||
if dedupe and message.id in self.seen:
|
||||
return
|
||||
else:
|
||||
if message.id is None:
|
||||
message.id = str(uuid4())
|
||||
self.seen.add(message.id)
|
||||
self.stream((meta[0], "messages", (message, meta[1])))
|
||||
|
||||
def _find_and_emit_messages(self, meta: Meta, response: Any) -> None:
|
||||
if isinstance(response, BaseMessage):
|
||||
self._emit(meta, response, dedupe=True)
|
||||
elif isinstance(response, Sequence):
|
||||
for value in response:
|
||||
if isinstance(value, BaseMessage):
|
||||
self._emit(meta, value, dedupe=True)
|
||||
else:
|
||||
for value in _state_values(response):
|
||||
if isinstance(value, BaseMessage):
|
||||
self._emit(meta, value, dedupe=True)
|
||||
elif isinstance(value, Sequence):
|
||||
for item in value:
|
||||
if isinstance(item, BaseMessage):
|
||||
self._emit(meta, item, dedupe=True)
|
||||
|
||||
def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
return output
|
||||
|
||||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
|
||||
return output
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if metadata and (not tags or (TAG_NOSTREAM not in tags)):
|
||||
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
|
||||
:-1
|
||||
]
|
||||
if not self.subgraphs and len(ns) > 0 and ns != self.parent_ns:
|
||||
return
|
||||
if tags:
|
||||
if filtered_tags := [t for t in tags if not t.startswith("seq:step")]:
|
||||
metadata["tags"] = filtered_tags
|
||||
self.metadata[run_id] = (ns, metadata)
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not isinstance(chunk, ChatGenerationChunk):
|
||||
return
|
||||
if meta := self.metadata.get(run_id):
|
||||
self._emit(meta, chunk.message)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if meta := self.metadata.get(run_id):
|
||||
if response.generations and response.generations[0]:
|
||||
gen = response.generations[0][0]
|
||||
if isinstance(gen, ChatGeneration):
|
||||
self._emit(meta, gen.message, dedupe=True)
|
||||
self.metadata.pop(run_id, None)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.metadata.pop(run_id, None)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if (
|
||||
metadata
|
||||
and kwargs.get("name") == metadata.get("langgraph_node")
|
||||
and (not tags or TAG_HIDDEN not in tags)
|
||||
):
|
||||
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
|
||||
:-1
|
||||
]
|
||||
if not self.subgraphs and len(ns) > 0:
|
||||
return
|
||||
self.metadata[run_id] = (ns, metadata)
|
||||
for value in _state_values(inputs):
|
||||
if isinstance(value, BaseMessage):
|
||||
if value.id is not None:
|
||||
self.seen.add(value.id)
|
||||
elif isinstance(value, Sequence) and not isinstance(value, str):
|
||||
for item in value:
|
||||
if isinstance(item, BaseMessage):
|
||||
if item.id is not None:
|
||||
self.seen.add(item.id)
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
response: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if meta := self.metadata.pop(run_id, None):
|
||||
# Handle Command node updates
|
||||
if isinstance(response, Command):
|
||||
self._find_and_emit_messages(meta, response.update)
|
||||
# Handle list of Command updates
|
||||
elif isinstance(response, Sequence) and any(
|
||||
isinstance(value, Command) for value in response
|
||||
):
|
||||
for value in response:
|
||||
if isinstance(value, Command):
|
||||
self._find_and_emit_messages(meta, value.update)
|
||||
else:
|
||||
self._find_and_emit_messages(meta, value)
|
||||
# Handle basic updates / streaming
|
||||
else:
|
||||
self._find_and_emit_messages(meta, response)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.metadata.pop(run_id, None)
|
||||
277
venv/Lib/site-packages/langgraph/pregel/_read.py
Normal file
277
venv/Lib/site-packages/langgraph/pregel/_read.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from langgraph._internal._config import merge_configs
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_READ
|
||||
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
|
||||
from langgraph.pregel._utils import find_subgraph_pregel
|
||||
from langgraph.pregel._write import ChannelWrite
|
||||
from langgraph.pregel.protocol import PregelProtocol
|
||||
from langgraph.types import CachePolicy, RetryPolicy
|
||||
|
||||
READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
|
||||
INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
|
||||
|
||||
|
||||
class ChannelRead(RunnableCallable):
|
||||
"""Implements the logic for reading state from CONFIG_KEY_READ.
|
||||
Usable both as a runnable as well as a static method to call imperatively."""
|
||||
|
||||
channel: str | list[str]
|
||||
|
||||
fresh: bool = False
|
||||
|
||||
mapper: Callable[[Any], Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: str | list[str],
|
||||
*,
|
||||
fresh: bool = False,
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
func=self._read,
|
||||
afunc=self._aread,
|
||||
tags=tags,
|
||||
name=None,
|
||||
trace=False,
|
||||
)
|
||||
self.fresh = fresh
|
||||
self.mapper = mapper
|
||||
self.channel = channel
|
||||
|
||||
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
|
||||
if name:
|
||||
pass
|
||||
elif isinstance(self.channel, str):
|
||||
name = f"ChannelRead<{self.channel}>"
|
||||
else:
|
||||
name = f"ChannelRead<{','.join(self.channel)}>"
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def _read(self, _: Any, config: RunnableConfig) -> Any:
|
||||
return self.do_read(
|
||||
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
||||
)
|
||||
|
||||
async def _aread(self, _: Any, config: RunnableConfig) -> Any:
|
||||
return self.do_read(
|
||||
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def do_read(
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
select: str | list[str],
|
||||
fresh: bool = False,
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Not configured with a read function"
|
||||
"Make sure to call in the context of a Pregel process"
|
||||
)
|
||||
if mapper:
|
||||
return mapper(read(select, fresh))
|
||||
else:
|
||||
return read(select, fresh)
|
||||
|
||||
|
||||
DEFAULT_BOUND = RunnableCallable(lambda input: input)
|
||||
|
||||
|
||||
class PregelNode:
|
||||
"""A node in a Pregel graph. This won't be invoked as a runnable by the graph
|
||||
itself, but instead acts as a container for the components necessary to make
|
||||
a PregelExecutableTask for a node."""
|
||||
|
||||
channels: str | list[str]
|
||||
"""The channels that will be passed as input to `bound`.
|
||||
If a str, the node will be invoked with its value if it isn't empty.
|
||||
If a list, the node will be invoked with a dict of those channels' values."""
|
||||
|
||||
triggers: list[str]
|
||||
"""If any of these channels is written to, this node will be triggered in
|
||||
the next step."""
|
||||
|
||||
mapper: Callable[[Any], Any] | None
|
||||
"""A function to transform the input before passing it to `bound`."""
|
||||
|
||||
writers: list[Runnable]
|
||||
"""A list of writers that will be executed after `bound`, responsible for
|
||||
taking the output of `bound` and writing it to the appropriate channels."""
|
||||
|
||||
bound: Runnable[Any, Any]
|
||||
"""The main logic of the node. This will be invoked with the input from
|
||||
`channels`."""
|
||||
|
||||
retry_policy: Sequence[RetryPolicy] | None
|
||||
"""The retry policies to use when invoking the node."""
|
||||
|
||||
cache_policy: CachePolicy | None
|
||||
"""The cache policy to use when invoking the node."""
|
||||
|
||||
tags: Sequence[str] | None
|
||||
"""Tags to attach to the node for tracing."""
|
||||
|
||||
metadata: Mapping[str, Any] | None
|
||||
"""Metadata to attach to the node for tracing."""
|
||||
|
||||
subgraphs: Sequence[PregelProtocol]
|
||||
"""Subgraphs used by the node."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels: str | list[str],
|
||||
triggers: Sequence[str],
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
writers: list[Runnable] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
bound: Runnable[Any, Any] | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
subgraphs: Sequence[PregelProtocol] | None = None,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.triggers = list(triggers)
|
||||
self.mapper = mapper
|
||||
self.writers = writers or []
|
||||
self.bound = bound if bound is not None else DEFAULT_BOUND
|
||||
self.cache_policy = cache_policy
|
||||
if isinstance(retry_policy, RetryPolicy):
|
||||
self.retry_policy = (retry_policy,)
|
||||
else:
|
||||
self.retry_policy = retry_policy
|
||||
self.tags = tags
|
||||
self.metadata = metadata
|
||||
if subgraphs is not None:
|
||||
self.subgraphs = subgraphs
|
||||
elif self.bound is not DEFAULT_BOUND:
|
||||
try:
|
||||
subgraph = find_subgraph_pregel(self.bound)
|
||||
except Exception:
|
||||
subgraph = None
|
||||
if subgraph:
|
||||
self.subgraphs = [subgraph]
|
||||
else:
|
||||
self.subgraphs = []
|
||||
else:
|
||||
self.subgraphs = []
|
||||
|
||||
def copy(self, update: dict[str, Any]) -> PregelNode:
|
||||
attrs = {**self.__dict__, **update}
|
||||
# Drop the cached properties
|
||||
attrs.pop("flat_writers", None)
|
||||
attrs.pop("node", None)
|
||||
attrs.pop("input_cache_key", None)
|
||||
return PregelNode(**attrs)
|
||||
|
||||
@cached_property
|
||||
def flat_writers(self) -> list[Runnable]:
|
||||
"""Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
|
||||
writers = self.writers.copy()
|
||||
while (
|
||||
len(writers) > 1
|
||||
and isinstance(writers[-1], ChannelWrite)
|
||||
and isinstance(writers[-2], ChannelWrite)
|
||||
):
|
||||
# we can combine writes if they are consecutive
|
||||
# careful to not modify the original writers list or ChannelWrite
|
||||
writers[-2] = ChannelWrite(
|
||||
writes=writers[-2].writes + writers[-1].writes,
|
||||
)
|
||||
writers.pop()
|
||||
return writers
|
||||
|
||||
@cached_property
|
||||
def node(self) -> Runnable[Any, Any] | None:
|
||||
"""Get a runnable that combines `bound` and `writers`."""
|
||||
writers = self.flat_writers
|
||||
if self.bound is DEFAULT_BOUND and not writers:
|
||||
return None
|
||||
elif self.bound is DEFAULT_BOUND and len(writers) == 1:
|
||||
return writers[0]
|
||||
elif self.bound is DEFAULT_BOUND:
|
||||
return RunnableSeq(*writers)
|
||||
elif writers:
|
||||
return RunnableSeq(self.bound, *writers)
|
||||
else:
|
||||
return self.bound
|
||||
|
||||
@cached_property
|
||||
def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
|
||||
"""Get a cache key for the input to the node.
|
||||
This is used to avoid calculating the same input multiple times."""
|
||||
return (
|
||||
self.mapper,
|
||||
tuple(self.channels)
|
||||
if isinstance(self.channels, list)
|
||||
else (self.channels,),
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
return self.bound.invoke(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
return await self.bound.ainvoke(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
yield from self.bound.stream(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
async for item in self.bound.astream(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
):
|
||||
yield item
|
||||
238
venv/Lib/site-packages/langgraph/pregel/_retry.py
Normal file
238
venv/Lib/site-packages/langgraph/pregel/_retry.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from langgraph._internal._config import patch_configurable, recast_checkpoint_ns
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
CONFIG_KEY_RESUMING,
|
||||
NS_SEP,
|
||||
)
|
||||
from langgraph.errors import GraphBubbleUp, ParentCommand
|
||||
from langgraph.types import Command, PregelExecutableTask, RetryPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
|
||||
|
||||
|
||||
def _checkpoint_ns_for_parent_command(ns: str) -> str:
|
||||
"""Return the checkpoint namespace for the parent graph.
|
||||
|
||||
The checkpoint namespace is a `|`-separated path. Each segment is usually
|
||||
of the form `name:task_id` (e.g. `parent_first:<uuid>|node:<uuid>`), but the
|
||||
runtime may also insert a purely-numeric segment (e.g. `|1`) to disambiguate
|
||||
concurrent tasks (e.g. `parent_first:<uuid>|1|node:<uuid>`).
|
||||
|
||||
Numeric segments are not real path levels, so we drop them before computing
|
||||
the parent namespace.
|
||||
"""
|
||||
|
||||
parts = ns.split(NS_SEP)
|
||||
|
||||
# Drop any trailing numeric selectors for the current frame (e.g. `...|node:<id>|1`).
|
||||
while parts and parts[-1].isdigit():
|
||||
parts.pop()
|
||||
|
||||
# Drop the current frame segment itself (e.g. the `node:<id>`).
|
||||
if parts:
|
||||
parts.pop()
|
||||
|
||||
# Drop any trailing numeric selectors for the parent frame (e.g. `...|1|node:<id>`).
|
||||
while parts and parts[-1].isdigit():
|
||||
parts.pop()
|
||||
|
||||
return NS_SEP.join(parts)
|
||||
|
||||
|
||||
def run_with_retry(
|
||||
task: PregelExecutableTask,
|
||||
retry_policy: Sequence[RetryPolicy] | None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Run a task with retries."""
|
||||
retry_policy = task.retry_policy or retry_policy
|
||||
attempts = 0
|
||||
config = task.config
|
||||
if configurable is not None:
|
||||
config = patch_configurable(config, configurable)
|
||||
while True:
|
||||
try:
|
||||
# clear any writes from previous attempts
|
||||
task.writes.clear()
|
||||
# run the task
|
||||
return task.proc.invoke(task.input, config)
|
||||
except ParentCommand as exc:
|
||||
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
|
||||
cmd = exc.args[0]
|
||||
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
|
||||
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
|
||||
# this command is for the current graph, handle it
|
||||
for w in task.writers:
|
||||
w.invoke(cmd, config)
|
||||
break
|
||||
elif cmd.graph == Command.PARENT:
|
||||
# this command is for the parent graph, assign it to the parent.
|
||||
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
|
||||
# bubble up
|
||||
raise
|
||||
except GraphBubbleUp:
|
||||
# if interrupted, end
|
||||
raise
|
||||
except Exception as exc:
|
||||
if SUPPORTS_EXC_NOTES:
|
||||
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
|
||||
if not retry_policy:
|
||||
raise
|
||||
|
||||
# Check which retry policy applies to this exception
|
||||
matching_policy = None
|
||||
for policy in retry_policy:
|
||||
if _should_retry_on(policy, exc):
|
||||
matching_policy = policy
|
||||
break
|
||||
|
||||
if not matching_policy:
|
||||
raise
|
||||
|
||||
# increment attempts
|
||||
attempts += 1
|
||||
# check if we should give up
|
||||
if attempts >= matching_policy.max_attempts:
|
||||
raise
|
||||
# sleep before retrying
|
||||
interval = matching_policy.initial_interval
|
||||
# Apply backoff factor based on attempt count
|
||||
interval = min(
|
||||
matching_policy.max_interval,
|
||||
interval * (matching_policy.backoff_factor ** (attempts - 1)),
|
||||
)
|
||||
|
||||
# Apply jitter if configured
|
||||
sleep_time = (
|
||||
interval + random.uniform(0, 1) if matching_policy.jitter else interval
|
||||
)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# log the retry
|
||||
logger.info(
|
||||
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
# signal subgraphs to resume (if available)
|
||||
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
|
||||
|
||||
|
||||
async def arun_with_retry(
|
||||
task: PregelExecutableTask,
|
||||
retry_policy: Sequence[RetryPolicy] | None,
|
||||
stream: bool = False,
|
||||
match_cached_writes: Callable[[], Awaitable[Sequence[PregelExecutableTask]]]
|
||||
| None = None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Run a task asynchronously with retries."""
|
||||
retry_policy = task.retry_policy or retry_policy
|
||||
attempts = 0
|
||||
config = task.config
|
||||
if configurable is not None:
|
||||
config = patch_configurable(config, configurable)
|
||||
if match_cached_writes is not None and task.cache_key is not None:
|
||||
for t in await match_cached_writes():
|
||||
if t is task:
|
||||
# if the task is already cached, return
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
# clear any writes from previous attempts
|
||||
task.writes.clear()
|
||||
# run the task
|
||||
if stream:
|
||||
async for _ in task.proc.astream(task.input, config):
|
||||
pass
|
||||
# if successful, end
|
||||
break
|
||||
else:
|
||||
return await task.proc.ainvoke(task.input, config)
|
||||
except ParentCommand as exc:
|
||||
ns: str = config[CONF][CONFIG_KEY_CHECKPOINT_NS]
|
||||
cmd = exc.args[0]
|
||||
# strip task_ids from namespace for comparison (ns format: "node1|node2:task_id")
|
||||
if cmd.graph in (ns, recast_checkpoint_ns(ns), task.name):
|
||||
# this command is for the current graph, handle it
|
||||
for w in task.writers:
|
||||
w.invoke(cmd, config)
|
||||
break
|
||||
elif cmd.graph == Command.PARENT:
|
||||
# this command is for the parent graph, assign it to the parent.
|
||||
exc.args = (replace(cmd, graph=_checkpoint_ns_for_parent_command(ns)),)
|
||||
# bubble up
|
||||
raise
|
||||
except GraphBubbleUp:
|
||||
# if interrupted, end
|
||||
raise
|
||||
except Exception as exc:
|
||||
if SUPPORTS_EXC_NOTES:
|
||||
exc.add_note(f"During task with name '{task.name}' and id '{task.id}'")
|
||||
if not retry_policy:
|
||||
raise
|
||||
|
||||
# Check which retry policy applies to this exception
|
||||
matching_policy = None
|
||||
for policy in retry_policy:
|
||||
if _should_retry_on(policy, exc):
|
||||
matching_policy = policy
|
||||
break
|
||||
|
||||
if not matching_policy:
|
||||
raise
|
||||
|
||||
# increment attempts
|
||||
attempts += 1
|
||||
# check if we should give up
|
||||
if attempts >= matching_policy.max_attempts:
|
||||
raise
|
||||
# sleep before retrying
|
||||
interval = matching_policy.initial_interval
|
||||
# Apply backoff factor based on attempt count
|
||||
interval = min(
|
||||
matching_policy.max_interval,
|
||||
interval * (matching_policy.backoff_factor ** (attempts - 1)),
|
||||
)
|
||||
|
||||
# Apply jitter if configured
|
||||
sleep_time = (
|
||||
interval + random.uniform(0, 1) if matching_policy.jitter else interval
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
# log the retry
|
||||
logger.info(
|
||||
f"Retrying task {task.name} after {sleep_time:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}",
|
||||
exc_info=exc,
|
||||
)
|
||||
# signal subgraphs to resume (if available)
|
||||
config = patch_configurable(config, {CONFIG_KEY_RESUMING: True})
|
||||
|
||||
|
||||
def _should_retry_on(retry_policy: RetryPolicy, exc: Exception) -> bool:
|
||||
"""Check if the given exception should be retried based on the retry policy."""
|
||||
if isinstance(retry_policy.retry_on, Sequence):
|
||||
return isinstance(exc, tuple(retry_policy.retry_on))
|
||||
elif isinstance(retry_policy.retry_on, type) and issubclass(
|
||||
retry_policy.retry_on, Exception
|
||||
):
|
||||
return isinstance(exc, retry_policy.retry_on)
|
||||
elif callable(retry_policy.retry_on):
|
||||
return retry_policy.retry_on(exc) # type: ignore[call-arg]
|
||||
else:
|
||||
raise TypeError(
|
||||
"retry_on must be an Exception class, a list or tuple of Exception classes, or a callable"
|
||||
)
|
||||
768
venv/Lib/site-packages/langgraph/pregel/_runner.py
Normal file
768
venv/Lib/site-packages/langgraph/pregel/_runner.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CALL,
|
||||
CONFIG_KEY_SCRATCHPAD,
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
NO_WRITES,
|
||||
RESUME,
|
||||
RETURN,
|
||||
)
|
||||
from langgraph._internal._future import chain_future, run_coroutine_threadsafe
|
||||
from langgraph._internal._scratchpad import PregelScratchpad
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.constants import TAG_HIDDEN
|
||||
from langgraph.errors import GraphBubbleUp, GraphInterrupt
|
||||
from langgraph.pregel._algo import Call
|
||||
from langgraph.pregel._executor import Submit
|
||||
from langgraph.pregel._retry import arun_with_retry, run_with_retry
|
||||
from langgraph.types import (
|
||||
CachePolicy,
|
||||
PregelExecutableTask,
|
||||
RetryPolicy,
|
||||
)
|
||||
|
||||
F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
|
||||
E = TypeVar("E", threading.Event, asyncio.Event)
|
||||
|
||||
# List of filenames to exclude from exception traceback
|
||||
# Note: Frames will be removed if they are the last frame in traceback, recursively
|
||||
EXCLUDED_FRAME_FNAMES = (
|
||||
"langgraph/pregel/retry.py",
|
||||
"langgraph/pregel/runner.py",
|
||||
"langgraph/pregel/executor.py",
|
||||
"langgraph/utils/runnable.py",
|
||||
"langchain_core/runnables/config.py",
|
||||
"concurrent/futures/thread.py",
|
||||
"concurrent/futures/_base.py",
|
||||
)
|
||||
|
||||
SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
|
||||
weakref.WeakSet()
|
||||
)
|
||||
|
||||
|
||||
class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
|
||||
event: E
|
||||
callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
|
||||
counter: int
|
||||
done: set[F]
|
||||
lock: threading.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event: E,
|
||||
callback: weakref.ref[
|
||||
Callable[[PregelExecutableTask, BaseException | None], None]
|
||||
],
|
||||
future_type: type[F],
|
||||
# used for generic typing, newer py supports FutureDict[...](...)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lock = threading.Lock()
|
||||
self.event = event
|
||||
self.callback = callback
|
||||
self.counter = 0
|
||||
self.done: set[F] = set()
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
key: F,
|
||||
value: PregelExecutableTask | None,
|
||||
) -> None:
|
||||
super().__setitem__(key, value) # type: ignore[index]
|
||||
if value is not None:
|
||||
with self.lock:
|
||||
self.event.clear()
|
||||
self.counter += 1
|
||||
key.add_done_callback(partial(self.on_done, value))
|
||||
|
||||
def on_done(
|
||||
self,
|
||||
task: PregelExecutableTask,
|
||||
fut: F,
|
||||
) -> None:
|
||||
try:
|
||||
if cb := self.callback():
|
||||
cb(task, _exception(fut))
|
||||
finally:
|
||||
with self.lock:
|
||||
self.done.add(fut)
|
||||
self.counter -= 1
|
||||
if self.counter == 0 or _should_stop_others(self.done):
|
||||
self.event.set()
|
||||
|
||||
|
||||
class PregelRunner:
|
||||
"""Responsible for executing a set of Pregel tasks concurrently, committing
|
||||
their writes, yielding control to caller when there is output to emit, and
|
||||
interrupting other tasks if appropriate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
submit: weakref.ref[Submit],
|
||||
put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
|
||||
use_astream: bool = False,
|
||||
node_finished: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
self.submit = submit
|
||||
self.put_writes = put_writes
|
||||
self.use_astream = use_astream
|
||||
self.node_finished = node_finished
|
||||
|
||||
def tick(
|
||||
self,
|
||||
tasks: Iterable[PregelExecutableTask],
|
||||
*,
|
||||
reraise: bool = True,
|
||||
timeout: float | None = None,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
PregelExecutableTask | None,
|
||||
],
|
||||
) -> Iterator[None]:
|
||||
tasks = tuple(tasks)
|
||||
futures = FuturesDict(
|
||||
callback=weakref.WeakMethod(self.commit),
|
||||
event=threading.Event(),
|
||||
future_type=concurrent.futures.Future,
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# fast path if single task with no timeout and no waiter
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
elif len(tasks) == 1 and timeout is None and get_waiter is None:
|
||||
t = tasks[0]
|
||||
try:
|
||||
run_with_retry(
|
||||
t,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.commit(t, None)
|
||||
except Exception as exc:
|
||||
self.commit(t, exc)
|
||||
if reraise and futures:
|
||||
# will be re-raised after futures are done
|
||||
fut: concurrent.futures.Future = concurrent.futures.Future()
|
||||
fut.set_exception(exc)
|
||||
futures.done.add(fut)
|
||||
elif reraise:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
if not futures: # maybe `t` scheduled another task
|
||||
return
|
||||
else:
|
||||
tasks = () # don't reschedule this task
|
||||
# add waiter task if requested
|
||||
if get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
# schedule tasks
|
||||
for t in tasks:
|
||||
fut = self.submit()( # type: ignore[misc]
|
||||
run_with_retry,
|
||||
t,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
),
|
||||
},
|
||||
__reraise_on_exit__=reraise,
|
||||
)
|
||||
futures[fut] = t
|
||||
# execute tasks, and wait for one to fail or all to finish.
|
||||
# each task is independent from all other concurrent tasks
|
||||
# yield updates/debug output as each task finishes
|
||||
end_time = timeout + time.monotonic() if timeout else None
|
||||
while len(futures) > (1 if get_waiter is not None else 0):
|
||||
done, inflight = concurrent.futures.wait(
|
||||
futures,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
timeout=(max(0, end_time - time.monotonic()) if end_time else None),
|
||||
)
|
||||
if not done:
|
||||
break # timed out
|
||||
for fut in done:
|
||||
task = futures.pop(fut)
|
||||
if task is None:
|
||||
# waiter task finished, schedule another
|
||||
if inflight and get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
else:
|
||||
# remove references to loop vars
|
||||
del fut, task
|
||||
# maybe stop other tasks
|
||||
if _should_stop_others(done):
|
||||
break
|
||||
# give control back to the caller
|
||||
yield
|
||||
# wait for done callbacks
|
||||
futures.event.wait(
|
||||
timeout=(max(0, end_time - time.monotonic()) if end_time else None)
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# panic on failure or timeout
|
||||
try:
|
||||
_panic_or_proceed(
|
||||
futures.done.union(f for f, t in futures.items() if t is not None),
|
||||
panic=reraise,
|
||||
)
|
||||
except Exception as exc:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
|
||||
async def atick(
|
||||
self,
|
||||
tasks: Iterable[PregelExecutableTask],
|
||||
*,
|
||||
reraise: bool = True,
|
||||
timeout: float | None = None,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
get_waiter: Callable[[], asyncio.Future[None]] | None = None,
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
) -> AsyncIterator[None]:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
tasks = tuple(tasks)
|
||||
futures = FuturesDict(
|
||||
callback=weakref.WeakMethod(self.commit),
|
||||
event=asyncio.Event(),
|
||||
future_type=asyncio.Future,
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# fast path if single task with no waiter and no timeout
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
elif len(tasks) == 1 and get_waiter is None and timeout is None:
|
||||
t = tasks[0]
|
||||
try:
|
||||
await arun_with_retry(
|
||||
t,
|
||||
retry_policy,
|
||||
stream=self.use_astream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(t),
|
||||
stream=self.use_astream,
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.commit(t, None)
|
||||
except Exception as exc:
|
||||
self.commit(t, exc)
|
||||
if reraise and futures:
|
||||
# will be re-raised after futures are done
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
fut.set_exception(exc)
|
||||
futures.done.add(fut)
|
||||
elif reraise:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
if not futures: # maybe `t` scheduled another task
|
||||
return
|
||||
else:
|
||||
tasks = () # don't reschedule this task
|
||||
# add waiter task if requested
|
||||
if get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
# schedule tasks
|
||||
for t in tasks:
|
||||
fut = cast(
|
||||
asyncio.Future,
|
||||
self.submit()( # type: ignore[misc]
|
||||
arun_with_retry,
|
||||
t,
|
||||
retry_policy,
|
||||
stream=self.use_astream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
stream=self.use_astream,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
__name__=t.name,
|
||||
__cancel_on_exit__=True,
|
||||
__reraise_on_exit__=reraise,
|
||||
),
|
||||
)
|
||||
futures[fut] = t
|
||||
# execute tasks, and wait for one to fail or all to finish.
|
||||
# each task is independent from all other concurrent tasks
|
||||
# yield updates/debug output as each task finishes
|
||||
end_time = timeout + loop.time() if timeout else None
|
||||
while len(futures) > (1 if get_waiter is not None else 0):
|
||||
done, inflight = await asyncio.wait(
|
||||
futures,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
||||
)
|
||||
if not done:
|
||||
break # timed out
|
||||
for fut in done:
|
||||
task = futures.pop(fut)
|
||||
if task is None:
|
||||
# waiter task finished, schedule another
|
||||
if inflight and get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
else:
|
||||
# remove references to loop vars
|
||||
del fut, task
|
||||
# maybe stop other tasks
|
||||
if _should_stop_others(done):
|
||||
break
|
||||
# give control back to the caller
|
||||
yield
|
||||
# wait for done callbacks
|
||||
await asyncio.wait_for(
|
||||
futures.event.wait(),
|
||||
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# cancel waiter task
|
||||
for fut in futures:
|
||||
fut.cancel()
|
||||
# panic on failure or timeout
|
||||
try:
|
||||
_panic_or_proceed(
|
||||
futures.done.union(f for f, t in futures.items() if t is not None),
|
||||
timeout_exc_cls=asyncio.TimeoutError,
|
||||
panic=reraise,
|
||||
)
|
||||
except Exception as exc:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
|
||||
def commit(
|
||||
self,
|
||||
task: PregelExecutableTask,
|
||||
exception: BaseException | None,
|
||||
) -> None:
|
||||
if isinstance(exception, asyncio.CancelledError):
|
||||
# for cancelled tasks, also save error in task,
|
||||
# so loop can finish super-step
|
||||
task.writes.append((ERROR, exception))
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
elif exception:
|
||||
if isinstance(exception, GraphInterrupt):
|
||||
# save interrupt to checkpointer
|
||||
if exception.args[0]:
|
||||
writes = [(INTERRUPT, exception.args[0])]
|
||||
if resumes := [w for w in task.writes if w[0] == RESUME]:
|
||||
writes.extend(resumes)
|
||||
self.put_writes()(task.id, writes) # type: ignore[misc]
|
||||
elif isinstance(exception, GraphBubbleUp):
|
||||
# exception will be raised in _panic_or_proceed
|
||||
pass
|
||||
else:
|
||||
# save error to checkpointer
|
||||
task.writes.append((ERROR, exception))
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
else:
|
||||
if self.node_finished and (
|
||||
task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
|
||||
):
|
||||
self.node_finished(task.name)
|
||||
if not task.writes:
|
||||
# add no writes marker
|
||||
task.writes.append((NO_WRITES, None))
|
||||
# save task writes to checkpointer
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
|
||||
|
||||
def _should_stop_others(
|
||||
done: set[F],
|
||||
) -> bool:
|
||||
"""Check if any task failed, if so, cancel all other tasks.
|
||||
GraphInterrupts are not considered failures."""
|
||||
for fut in done:
|
||||
if fut.cancelled():
|
||||
continue
|
||||
elif exc := fut.exception():
|
||||
if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _exception(
|
||||
fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
|
||||
) -> BaseException | None:
|
||||
"""Return the exception from a future, without raising CancelledError."""
|
||||
if fut.cancelled():
|
||||
if isinstance(fut, asyncio.Future):
|
||||
return asyncio.CancelledError()
|
||||
else:
|
||||
return concurrent.futures.CancelledError()
|
||||
else:
|
||||
return fut.exception()
|
||||
|
||||
|
||||
def _panic_or_proceed(
|
||||
futs: set[concurrent.futures.Future] | set[asyncio.Future],
|
||||
*,
|
||||
timeout_exc_cls: type[Exception] = TimeoutError,
|
||||
panic: bool = True,
|
||||
) -> None:
|
||||
"""Cancel remaining tasks if any failed, re-raise exception if panic is True."""
|
||||
done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
||||
inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
||||
for fut in futs:
|
||||
if fut.cancelled():
|
||||
continue
|
||||
elif fut.done():
|
||||
done.add(fut)
|
||||
else:
|
||||
inflight.add(fut)
|
||||
interrupts: list[GraphInterrupt] = []
|
||||
while done:
|
||||
# if any task failed
|
||||
fut = done.pop()
|
||||
if exc := _exception(fut):
|
||||
# cancel all pending tasks
|
||||
while inflight:
|
||||
inflight.pop().cancel()
|
||||
# raise the exception
|
||||
if panic:
|
||||
if isinstance(exc, GraphInterrupt):
|
||||
# collect interrupts
|
||||
interrupts.append(exc)
|
||||
elif fut not in SKIP_RERAISE_SET:
|
||||
raise exc
|
||||
# raise combined interrupts
|
||||
if interrupts:
|
||||
raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
|
||||
if inflight:
|
||||
# if we got here means we timed out
|
||||
while inflight:
|
||||
# cancel all pending tasks
|
||||
inflight.pop().cancel()
|
||||
# raise timeout error
|
||||
raise timeout_exc_cls("Timed out")
|
||||
|
||||
|
||||
def _call(
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
futures: weakref.ref[FuturesDict],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None], PregelExecutableTask | None
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
) -> concurrent.futures.Future[Any]:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise RuntimeError("In an sync context async tasks cannot be called")
|
||||
|
||||
fut: concurrent.futures.Future | None = None
|
||||
# schedule PUSH tasks, collect futures
|
||||
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
||||
# schedule the next task, if the callback returns one
|
||||
if next_task := schedule_task(
|
||||
task(), # type: ignore[arg-type]
|
||||
scratchpad.call_counter(),
|
||||
Call(
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
):
|
||||
if fut := next(
|
||||
(
|
||||
f
|
||||
for f, t in list(futures().items()) # type: ignore[union-attr]
|
||||
if t is not None and t == next_task.id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# if the parent task was retried,
|
||||
# the next task might already be running
|
||||
pass
|
||||
elif next_task.writes:
|
||||
# if it already ran, return the result
|
||||
fut = concurrent.futures.Future()
|
||||
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
||||
if ret is not MISSING:
|
||||
fut.set_result(ret)
|
||||
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
||||
fut.set_exception(
|
||||
exc if isinstance(exc, BaseException) else Exception(exc)
|
||||
)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
else:
|
||||
# schedule the next task
|
||||
fut = submit()( # type: ignore[misc]
|
||||
run_with_retry,
|
||||
next_task,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(next_task),
|
||||
futures=futures,
|
||||
retry_policy=retry_policy,
|
||||
callbacks=callbacks,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
),
|
||||
},
|
||||
__reraise_on_exit__=False,
|
||||
# starting a new task in the next tick ensures
|
||||
# updates from this tick are committed/streamed first
|
||||
__next_tick__=True,
|
||||
)
|
||||
# exceptions for call() tasks are raised into the parent task
|
||||
# so we should not re-raise at the end of the tick
|
||||
SKIP_RERAISE_SET.add(fut)
|
||||
futures()[fut] = next_task # type: ignore[index]
|
||||
fut = cast(asyncio.Future | concurrent.futures.Future, fut)
|
||||
# return a chained future to ensure commit() callback is called
|
||||
# before the returned future is resolved, to ensure stream order etc
|
||||
return chain_future(fut, concurrent.futures.Future())
|
||||
|
||||
|
||||
def _acall(
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
# injected dependencies
|
||||
futures: weakref.ref[FuturesDict],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: bool = False,
|
||||
) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
|
||||
# return a chained future to ensure commit() callback is called
|
||||
# before the returned future is resolved, to ensure stream order etc
|
||||
try:
|
||||
in_async = asyncio.current_task() is not None
|
||||
except RuntimeError:
|
||||
in_async = False
|
||||
# if in async context return an async future, otherwise return a sync future
|
||||
if in_async:
|
||||
fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
|
||||
loop=loop
|
||||
)
|
||||
else:
|
||||
fut = concurrent.futures.Future()
|
||||
# schedule the next task
|
||||
run_coroutine_threadsafe(
|
||||
_acall_impl(
|
||||
fut,
|
||||
task,
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
futures=futures,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
loop=loop,
|
||||
stream=stream,
|
||||
),
|
||||
loop,
|
||||
lazy=False,
|
||||
)
|
||||
return fut
|
||||
|
||||
|
||||
async def _acall_impl(
|
||||
destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
# injected dependencies
|
||||
futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
fut: asyncio.Future | None = None
|
||||
# schedule PUSH tasks, collect futures
|
||||
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
||||
# schedule the next task, if the callback returns one
|
||||
if next_task := await schedule_task(
|
||||
task(), # type: ignore[arg-type]
|
||||
scratchpad.call_counter(),
|
||||
Call(
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
):
|
||||
if fut := next(
|
||||
(
|
||||
f
|
||||
for f, t in list(futures().items()) # type: ignore[union-attr]
|
||||
if t is not None and t == next_task.id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# if the parent task was retried,
|
||||
# the next task might already be running
|
||||
pass
|
||||
elif next_task.writes:
|
||||
# if it already ran, return the result
|
||||
fut = asyncio.Future(loop=loop)
|
||||
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
||||
if ret is not MISSING:
|
||||
fut.set_result(ret)
|
||||
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
||||
fut.set_exception(
|
||||
exc if isinstance(exc, BaseException) else Exception(exc)
|
||||
)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
else:
|
||||
# schedule the next task
|
||||
fut = cast(
|
||||
asyncio.Future,
|
||||
submit()( # type: ignore[misc]
|
||||
arun_with_retry,
|
||||
next_task,
|
||||
retry_policy,
|
||||
stream=stream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(next_task),
|
||||
stream=stream,
|
||||
futures=futures,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
__name__=next_task.name,
|
||||
__cancel_on_exit__=True,
|
||||
__reraise_on_exit__=False,
|
||||
# starting a new task in the next tick ensures
|
||||
# updates from this tick are committed/streamed first
|
||||
__next_tick__=True,
|
||||
),
|
||||
)
|
||||
# exceptions for call() tasks are raised into the parent task
|
||||
# so we should not re-raise at the end of the tick
|
||||
SKIP_RERAISE_SET.add(fut)
|
||||
futures()[fut] = next_task # type: ignore[index]
|
||||
if fut is not None:
|
||||
chain_future(fut, destination)
|
||||
else:
|
||||
destination.set_exception(RuntimeError("Task not scheduled"))
|
||||
except Exception as exc:
|
||||
destination.set_exception(exc)
|
||||
218
venv/Lib/site-packages/langgraph/pregel/_utils.py
Normal file
218
venv/Lib/site-packages/langgraph/pregel/_utils.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import re
|
||||
import textwrap
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnableSequence
|
||||
from langgraph.checkpoint.base import ChannelVersions
|
||||
from typing_extensions import override
|
||||
|
||||
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
|
||||
from langgraph.pregel.protocol import PregelProtocol
|
||||
|
||||
|
||||
def get_new_channel_versions(
|
||||
previous_versions: ChannelVersions, current_versions: ChannelVersions
|
||||
) -> ChannelVersions:
|
||||
"""Get subset of current_versions that are newer than previous_versions."""
|
||||
if previous_versions:
|
||||
version_type = type(next(iter(current_versions.values()), None))
|
||||
null_version = version_type() # type: ignore[misc]
|
||||
new_versions = {
|
||||
k: v
|
||||
for k, v in current_versions.items()
|
||||
if v > previous_versions.get(k, null_version) # type: ignore[operator]
|
||||
}
|
||||
else:
|
||||
new_versions = current_versions
|
||||
|
||||
return new_versions
|
||||
|
||||
|
||||
def find_subgraph_pregel(candidate: Runnable) -> PregelProtocol | None:
|
||||
from langgraph.pregel import Pregel
|
||||
|
||||
candidates: list[Runnable] = [candidate]
|
||||
|
||||
for c in candidates:
|
||||
if (
|
||||
isinstance(c, PregelProtocol)
|
||||
# subgraphs that disabled checkpointing are not considered
|
||||
and (not isinstance(c, Pregel) or c.checkpointer is not False)
|
||||
):
|
||||
return c
|
||||
elif isinstance(c, RunnableSequence) or isinstance(c, RunnableSeq):
|
||||
candidates.extend(c.steps)
|
||||
elif isinstance(c, RunnableLambda):
|
||||
candidates.extend(c.deps)
|
||||
elif isinstance(c, RunnableCallable):
|
||||
if c.func is not None:
|
||||
candidates.extend(
|
||||
nl.__self__ if hasattr(nl, "__self__") else nl
|
||||
for nl in get_function_nonlocals(c.func)
|
||||
)
|
||||
elif c.afunc is not None:
|
||||
candidates.extend(
|
||||
nl.__self__ if hasattr(nl, "__self__") else nl
|
||||
for nl in get_function_nonlocals(c.afunc)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_function_nonlocals(func: Callable) -> list[Any]:
|
||||
"""Get the nonlocal variables accessed by a function.
|
||||
|
||||
Args:
|
||||
func: The function to check.
|
||||
|
||||
Returns:
|
||||
List[Any]: The nonlocal variables accessed by the function.
|
||||
"""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
visitor = FunctionNonLocals()
|
||||
visitor.visit(tree)
|
||||
values: list[Any] = []
|
||||
closure = (
|
||||
inspect.getclosurevars(func.__wrapped__)
|
||||
if hasattr(func, "__wrapped__") and callable(func.__wrapped__)
|
||||
else inspect.getclosurevars(func)
|
||||
)
|
||||
candidates = {**closure.globals, **closure.nonlocals}
|
||||
for k, v in candidates.items():
|
||||
if k in visitor.nonlocals:
|
||||
values.append(v)
|
||||
for kk in visitor.nonlocals:
|
||||
if "." in kk and kk.startswith(k):
|
||||
vv = v
|
||||
for part in kk.split(".")[1:]:
|
||||
if vv is None:
|
||||
break
|
||||
else:
|
||||
try:
|
||||
vv = getattr(vv, part)
|
||||
except AttributeError:
|
||||
break
|
||||
else:
|
||||
values.append(vv)
|
||||
except (SyntaxError, TypeError, OSError, SystemError):
|
||||
return []
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class FunctionNonLocals(ast.NodeVisitor):
|
||||
"""Get the nonlocal variables accessed of a function."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.nonlocals: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
@override
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||
"""Visit an async function definition.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
@override
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
visitor = NonLocals()
|
||||
visitor.visit(node)
|
||||
self.nonlocals.update(visitor.loads - visitor.stores)
|
||||
|
||||
|
||||
class NonLocals(ast.NodeVisitor):
|
||||
"""Get nonlocal variables accessed."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.loads: set[str] = set()
|
||||
self.stores: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_Name(self, node: ast.Name) -> Any:
|
||||
"""Visit a name node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
self.loads.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Store):
|
||||
self.stores.add(node.id)
|
||||
|
||||
@override
|
||||
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
||||
"""Visit an attribute node.
|
||||
|
||||
Args:
|
||||
node: The node to visit.
|
||||
|
||||
Returns:
|
||||
Any: The result of the visit.
|
||||
"""
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
parent = node.value
|
||||
attr_expr = node.attr
|
||||
while isinstance(parent, ast.Attribute):
|
||||
attr_expr = parent.attr + "." + attr_expr
|
||||
parent = parent.value
|
||||
if isinstance(parent, ast.Name):
|
||||
self.loads.add(parent.id + "." + attr_expr)
|
||||
self.loads.discard(parent.id)
|
||||
elif isinstance(parent, ast.Call):
|
||||
if isinstance(parent.func, ast.Name):
|
||||
self.loads.add(parent.func.id)
|
||||
else:
|
||||
parent = parent.func
|
||||
attr_expr = ""
|
||||
while isinstance(parent, ast.Attribute):
|
||||
if attr_expr:
|
||||
attr_expr = parent.attr + "." + attr_expr
|
||||
else:
|
||||
attr_expr = parent.attr
|
||||
parent = parent.value
|
||||
if isinstance(parent, ast.Name):
|
||||
self.loads.add(parent.id + "." + attr_expr)
|
||||
|
||||
|
||||
def is_xxh3_128_hexdigest(value: str) -> bool:
|
||||
"""Check if the given string matches the format of xxh3_128_hexdigest."""
|
||||
return bool(re.fullmatch(r"[0-9a-f]{32}", value))
|
||||
120
venv/Lib/site-packages/langgraph/pregel/_validate.py
Normal file
120
venv/Lib/site-packages/langgraph/pregel/_validate.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langgraph._internal._constants import RESERVED
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.managed.base import ManagedValueMapping
|
||||
from langgraph.pregel._read import PregelNode
|
||||
from langgraph.types import All
|
||||
|
||||
|
||||
def validate_graph(
|
||||
nodes: Mapping[str, PregelNode],
|
||||
channels: dict[str, BaseChannel],
|
||||
managed: ManagedValueMapping,
|
||||
input_channels: str | Sequence[str],
|
||||
output_channels: str | Sequence[str],
|
||||
stream_channels: str | Sequence[str] | None,
|
||||
interrupt_after_nodes: All | Sequence[str],
|
||||
interrupt_before_nodes: All | Sequence[str],
|
||||
) -> None:
|
||||
for chan in channels:
|
||||
if chan in RESERVED:
|
||||
raise ValueError(f"Channel name '{chan}' is reserved")
|
||||
for name in managed:
|
||||
if name in RESERVED:
|
||||
raise ValueError(f"Managed name '{name}' is reserved")
|
||||
|
||||
subscribed_channels = set[str]()
|
||||
for name, node in nodes.items():
|
||||
if name in RESERVED:
|
||||
raise ValueError(f"Node name '{name}' is reserved")
|
||||
if isinstance(node, PregelNode):
|
||||
subscribed_channels.update(node.triggers)
|
||||
if isinstance(node.channels, str):
|
||||
if node.channels not in channels:
|
||||
raise ValueError(
|
||||
f"Node {name} reads channel '{node.channels}' "
|
||||
f"not in known channels: '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
else:
|
||||
for chan in node.channels:
|
||||
if chan not in channels and chan not in managed:
|
||||
raise ValueError(
|
||||
f"Node {name} reads channel '{chan}' "
|
||||
f"not in known channels: '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid node type {type(node)}, expected PregelNode or NodeBuilder"
|
||||
)
|
||||
|
||||
for chan in subscribed_channels:
|
||||
if chan not in channels:
|
||||
raise ValueError(
|
||||
f"Subscribed channel '{chan}' not "
|
||||
f"in known channels: '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
|
||||
if isinstance(input_channels, str):
|
||||
if input_channels not in channels:
|
||||
raise ValueError(
|
||||
f"Input channel '{input_channels}' not "
|
||||
f"in known channels: '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
if input_channels not in subscribed_channels:
|
||||
raise ValueError(
|
||||
f"Input channel {input_channels} is not subscribed to by any node"
|
||||
)
|
||||
else:
|
||||
for chan in input_channels:
|
||||
if chan not in channels:
|
||||
raise ValueError(
|
||||
f"Input channel '{chan}' not in '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
if all(chan not in subscribed_channels for chan in input_channels):
|
||||
raise ValueError(
|
||||
f"None of the input channels {input_channels} "
|
||||
f"are subscribed to by any node"
|
||||
)
|
||||
|
||||
all_output_channels = set[str]()
|
||||
if isinstance(output_channels, str):
|
||||
all_output_channels.add(output_channels)
|
||||
else:
|
||||
all_output_channels.update(output_channels)
|
||||
if isinstance(stream_channels, str):
|
||||
all_output_channels.add(stream_channels)
|
||||
elif stream_channels is not None:
|
||||
all_output_channels.update(stream_channels)
|
||||
|
||||
for chan in all_output_channels:
|
||||
if chan not in channels:
|
||||
raise ValueError(
|
||||
f"Output channel '{chan}' not "
|
||||
f"in known channels: '{repr(sorted(channels))[:100]}'"
|
||||
)
|
||||
|
||||
if interrupt_after_nodes != "*":
|
||||
for n in interrupt_after_nodes:
|
||||
if n not in nodes:
|
||||
raise ValueError(f"Node {n} not in nodes")
|
||||
if interrupt_before_nodes != "*":
|
||||
for n in interrupt_before_nodes:
|
||||
if n not in nodes:
|
||||
raise ValueError(f"Node {n} not in nodes")
|
||||
|
||||
|
||||
def validate_keys(
|
||||
keys: str | Sequence[str] | None,
|
||||
channels: Mapping[str, Any],
|
||||
) -> None:
|
||||
if isinstance(keys, str):
|
||||
if keys not in channels:
|
||||
raise ValueError(f"Key {keys} not in channels")
|
||||
elif keys is not None:
|
||||
for chan in keys:
|
||||
if chan not in channels:
|
||||
raise ValueError(f"Key {chan} not in channels")
|
||||
192
venv/Lib/site-packages/langgraph/pregel/_write.py
Normal file
192
venv/Lib/site-packages/langgraph/pregel/_write.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
NamedTuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, TASKS
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.types import Send
|
||||
|
||||
TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None]
|
||||
R = TypeVar("R", bound=Runnable)
|
||||
|
||||
SKIP_WRITE = object()
|
||||
PASSTHROUGH = object()
|
||||
|
||||
|
||||
class ChannelWriteEntry(NamedTuple):
|
||||
channel: str
|
||||
"""Channel name to write to."""
|
||||
value: Any = PASSTHROUGH
|
||||
"""Value to write, or PASSTHROUGH to use the input."""
|
||||
skip_none: bool = False
|
||||
"""Whether to skip writing if the value is None."""
|
||||
mapper: Callable | None = None
|
||||
"""Function to transform the value before writing."""
|
||||
|
||||
|
||||
class ChannelWriteTupleEntry(NamedTuple):
|
||||
mapper: Callable[[Any], Sequence[tuple[str, Any]] | None]
|
||||
"""Function to extract tuples from value."""
|
||||
value: Any = PASSTHROUGH
|
||||
"""Value to write, or PASSTHROUGH to use the input."""
|
||||
static: Sequence[tuple[str, Any, str | None]] | None = None
|
||||
"""Optional, declared writes for static analysis."""
|
||||
|
||||
|
||||
class ChannelWrite(RunnableCallable):
|
||||
"""Implements the logic for sending writes to CONFIG_KEY_SEND.
|
||||
Can be used as a runnable or as a static method to call imperatively."""
|
||||
|
||||
writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]
|
||||
"""Sequence of write entries or Send objects to write."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
|
||||
*,
|
||||
tags: Sequence[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
func=self._write,
|
||||
afunc=self._awrite,
|
||||
name=None,
|
||||
tags=tags,
|
||||
trace=False,
|
||||
)
|
||||
self.writes = cast(
|
||||
list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
|
||||
)
|
||||
|
||||
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
|
||||
if not name:
|
||||
name = f"ChannelWrite<{','.join(w.channel if isinstance(w, ChannelWriteEntry) else '...' if isinstance(w, ChannelWriteTupleEntry) else w.node for w in self.writes)}>"
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def _write(self, input: Any, config: RunnableConfig) -> None:
|
||||
writes = [
|
||||
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
|
||||
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
|
||||
else ChannelWriteTupleEntry(write.mapper, input)
|
||||
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
|
||||
else write
|
||||
for write in self.writes
|
||||
]
|
||||
self.do_write(
|
||||
config,
|
||||
writes,
|
||||
)
|
||||
return input
|
||||
|
||||
async def _awrite(self, input: Any, config: RunnableConfig) -> None:
|
||||
writes = [
|
||||
ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
|
||||
if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
|
||||
else ChannelWriteTupleEntry(write.mapper, input)
|
||||
if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
|
||||
else write
|
||||
for write in self.writes
|
||||
]
|
||||
self.do_write(
|
||||
config,
|
||||
writes,
|
||||
)
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def do_write(
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
|
||||
allow_passthrough: bool = True,
|
||||
) -> None:
|
||||
# validate
|
||||
for w in writes:
|
||||
if isinstance(w, ChannelWriteEntry):
|
||||
if w.channel == TASKS:
|
||||
raise InvalidUpdateError(
|
||||
"Cannot write to the reserved channel TASKS"
|
||||
)
|
||||
if w.value is PASSTHROUGH and not allow_passthrough:
|
||||
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
|
||||
if isinstance(w, ChannelWriteTupleEntry):
|
||||
if w.value is PASSTHROUGH and not allow_passthrough:
|
||||
raise InvalidUpdateError("PASSTHROUGH value must be replaced")
|
||||
# if we want to persist writes found before hitting a ParentCommand
|
||||
# can move this to a finally block
|
||||
write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
|
||||
write(_assemble_writes(writes))
|
||||
|
||||
@staticmethod
|
||||
def is_writer(runnable: Runnable) -> bool:
|
||||
"""Used by PregelNode to distinguish between writers and other runnables."""
|
||||
return (
|
||||
isinstance(runnable, ChannelWrite)
|
||||
or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_static_writes(
|
||||
runnable: Runnable,
|
||||
) -> Sequence[tuple[str, Any, str | None]] | None:
|
||||
"""Used to get conditional writes a writer declares for static analysis."""
|
||||
if isinstance(runnable, ChannelWrite):
|
||||
return [
|
||||
w
|
||||
for entry in runnable.writes
|
||||
if isinstance(entry, ChannelWriteTupleEntry) and entry.static
|
||||
for w in entry.static
|
||||
] or None
|
||||
elif writes := getattr(runnable, "_is_channel_writer", MISSING):
|
||||
if writes is not MISSING:
|
||||
writes = cast(
|
||||
Sequence[tuple[ChannelWriteEntry | Send, str | None]],
|
||||
writes,
|
||||
)
|
||||
entries = [e for e, _ in writes]
|
||||
labels = [la for _, la in writes]
|
||||
return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]
|
||||
|
||||
@staticmethod
|
||||
def register_writer(
|
||||
runnable: R,
|
||||
static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
|
||||
) -> R:
|
||||
"""Used to mark a runnable as a writer, so that it can be detected by is_writer.
|
||||
Instances of ChannelWrite are automatically marked as writers.
|
||||
Optionally, a list of declared writes can be passed for static analysis."""
|
||||
# using object.__setattr__ to work around objects that override __setattr__
|
||||
# eg. pydantic models and dataclasses
|
||||
object.__setattr__(runnable, "_is_channel_writer", static)
|
||||
return runnable
|
||||
|
||||
|
||||
def _assemble_writes(
|
||||
writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
|
||||
) -> list[tuple[str, Any]]:
|
||||
"""Assembles the writes into a list of tuples."""
|
||||
tuples: list[tuple[str, Any]] = []
|
||||
for w in writes:
|
||||
if isinstance(w, Send):
|
||||
tuples.append((TASKS, w))
|
||||
elif isinstance(w, ChannelWriteTupleEntry):
|
||||
if ww := w.mapper(w.value):
|
||||
tuples.extend(ww)
|
||||
elif isinstance(w, ChannelWriteEntry):
|
||||
value = w.mapper(w.value) if w.mapper is not None else w.value
|
||||
if value is SKIP_WRITE:
|
||||
continue
|
||||
if w.skip_none and value is None:
|
||||
continue
|
||||
tuples.append((w.channel, value))
|
||||
else:
|
||||
raise ValueError(f"Invalid write entry: {w}")
|
||||
return tuples
|
||||
308
venv/Lib/site-packages/langgraph/pregel/debug.py
Normal file
308
venv/Lib/site-packages/langgraph/pregel/debug.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.base import CheckpointMetadata, PendingWrite
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph._internal._config import patch_checkpoint_map
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
NS_END,
|
||||
NS_SEP,
|
||||
RETURN,
|
||||
)
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.constants import TAG_HIDDEN
|
||||
from langgraph.pregel._io import read_channels
|
||||
from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot
|
||||
|
||||
__all__ = ("TaskPayload", "TaskResultPayload", "CheckpointTask", "CheckpointPayload")
|
||||
|
||||
|
||||
class TaskPayload(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
input: Any
|
||||
triggers: list[str]
|
||||
|
||||
|
||||
class TaskResultPayload(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
error: str | None
|
||||
interrupts: list[dict]
|
||||
result: dict[str, Any]
|
||||
|
||||
|
||||
class CheckpointTask(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
error: str | None
|
||||
interrupts: list[dict]
|
||||
state: StateSnapshot | RunnableConfig | None
|
||||
|
||||
|
||||
class CheckpointPayload(TypedDict):
|
||||
config: RunnableConfig | None
|
||||
metadata: CheckpointMetadata
|
||||
values: dict[str, Any]
|
||||
next: list[str]
|
||||
parent_config: RunnableConfig | None
|
||||
tasks: list[CheckpointTask]
|
||||
|
||||
|
||||
TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8")
|
||||
|
||||
|
||||
def map_debug_tasks(tasks: Iterable[PregelExecutableTask]) -> Iterator[TaskPayload]:
|
||||
"""Produce "task" events for stream_mode=debug."""
|
||||
for task in tasks:
|
||||
if task.config is not None and TAG_HIDDEN in task.config.get("tags", []):
|
||||
continue
|
||||
|
||||
yield {
|
||||
"id": task.id,
|
||||
"name": task.name,
|
||||
"input": task.input,
|
||||
"triggers": task.triggers,
|
||||
}
|
||||
|
||||
|
||||
def is_multiple_channel_write(value: Any) -> bool:
|
||||
"""Return True if the payload already wraps multiple writes from the same channel."""
|
||||
return (
|
||||
isinstance(value, dict)
|
||||
and "$writes" in value
|
||||
and isinstance(value["$writes"], list)
|
||||
)
|
||||
|
||||
|
||||
def map_task_result_writes(writes: Sequence[tuple[str, Any]]) -> dict[str, Any]:
|
||||
"""Folds task writes into a result dict and aggregates multiple writes to the same channel.
|
||||
|
||||
If the channel contains a single write, we record the write in the result dict as `{channel: write}`
|
||||
If the channel contains multiple writes, we record the writes in the result dict as `{channel: {'$writes': [write1, write2, ...]}}`"""
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for channel, value in writes:
|
||||
existing = result.get(channel)
|
||||
|
||||
if existing is not None:
|
||||
channel_writes = (
|
||||
existing["$writes"]
|
||||
if is_multiple_channel_write(existing)
|
||||
else [existing]
|
||||
)
|
||||
channel_writes.append(value)
|
||||
result[channel] = {"$writes": channel_writes}
|
||||
else:
|
||||
result[channel] = value
|
||||
return result
|
||||
|
||||
|
||||
def map_debug_task_results(
|
||||
task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]],
|
||||
stream_keys: str | Sequence[str],
|
||||
) -> Iterator[TaskResultPayload]:
|
||||
"""Produce "task_result" events for stream_mode=debug."""
|
||||
stream_channels_list = (
|
||||
[stream_keys] if isinstance(stream_keys, str) else stream_keys
|
||||
)
|
||||
task, writes = task_tup
|
||||
yield {
|
||||
"id": task.id,
|
||||
"name": task.name,
|
||||
"error": next((w[1] for w in writes if w[0] == ERROR), None),
|
||||
"result": map_task_result_writes(
|
||||
[w for w in writes if w[0] in stream_channels_list or w[0] == RETURN]
|
||||
),
|
||||
"interrupts": [
|
||||
asdict(v)
|
||||
for w in writes
|
||||
if w[0] == INTERRUPT
|
||||
for v in (w[1] if isinstance(w[1], Sequence) else [w[1]])
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def rm_pregel_keys(config: RunnableConfig | None) -> RunnableConfig | None:
|
||||
"""Remove pregel-specific keys from the config."""
|
||||
if config is None:
|
||||
return config
|
||||
return {
|
||||
"configurable": {
|
||||
k: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if not k.startswith("__pregel_")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def map_debug_checkpoint(
|
||||
config: RunnableConfig,
|
||||
channels: Mapping[str, BaseChannel],
|
||||
stream_channels: str | Sequence[str],
|
||||
metadata: CheckpointMetadata,
|
||||
tasks: Iterable[PregelExecutableTask],
|
||||
pending_writes: list[PendingWrite],
|
||||
parent_config: RunnableConfig | None,
|
||||
output_keys: str | Sequence[str],
|
||||
) -> Iterator[CheckpointPayload]:
|
||||
"""Produce "checkpoint" events for stream_mode=debug."""
|
||||
|
||||
parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
|
||||
task_states: dict[str, RunnableConfig | StateSnapshot] = {}
|
||||
|
||||
for task in tasks:
|
||||
if not task.subgraphs:
|
||||
continue
|
||||
|
||||
# assemble checkpoint_ns for this task
|
||||
task_ns = f"{task.name}{NS_END}{task.id}"
|
||||
if parent_ns:
|
||||
task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
|
||||
|
||||
# set config as signal that subgraph checkpoints exist
|
||||
task_states[task.id] = {
|
||||
CONF: {
|
||||
"thread_id": config[CONF]["thread_id"],
|
||||
CONFIG_KEY_CHECKPOINT_NS: task_ns,
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
"config": rm_pregel_keys(patch_checkpoint_map(config, metadata)),
|
||||
"parent_config": rm_pregel_keys(patch_checkpoint_map(parent_config, metadata)),
|
||||
"values": read_channels(channels, stream_channels),
|
||||
"metadata": metadata,
|
||||
"next": [t.name for t in tasks],
|
||||
"tasks": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"error": t.error,
|
||||
"state": t.state,
|
||||
}
|
||||
if t.error
|
||||
else {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"result": t.result,
|
||||
"interrupts": tuple(asdict(i) for i in t.interrupts),
|
||||
"state": t.state,
|
||||
}
|
||||
if t.result
|
||||
else {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"interrupts": tuple(asdict(i) for i in t.interrupts),
|
||||
"state": t.state,
|
||||
}
|
||||
for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def tasks_w_writes(
|
||||
tasks: Iterable[PregelTask | PregelExecutableTask],
|
||||
pending_writes: list[PendingWrite] | None,
|
||||
states: dict[str, RunnableConfig | StateSnapshot] | None,
|
||||
output_keys: str | Sequence[str],
|
||||
) -> tuple[PregelTask, ...]:
|
||||
"""Apply writes / subgraph states to tasks to be returned in a StateSnapshot."""
|
||||
pending_writes = pending_writes or []
|
||||
out: list[PregelTask] = []
|
||||
for task in tasks:
|
||||
rtn = next(
|
||||
(
|
||||
val
|
||||
for tid, chan, val in pending_writes
|
||||
if tid == task.id and chan == RETURN
|
||||
),
|
||||
MISSING,
|
||||
)
|
||||
task_error = next(
|
||||
(exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR),
|
||||
None,
|
||||
)
|
||||
task_interrupts = tuple(
|
||||
v
|
||||
for tid, n, vv in pending_writes
|
||||
if tid == task.id and n == INTERRUPT
|
||||
for v in (vv if isinstance(vv, Sequence) else [vv])
|
||||
)
|
||||
|
||||
task_writes = [
|
||||
(chan, val)
|
||||
for tid, chan, val in pending_writes
|
||||
if tid == task.id and chan not in (ERROR, INTERRUPT, RETURN)
|
||||
]
|
||||
|
||||
if rtn is not MISSING:
|
||||
task_result = rtn
|
||||
elif isinstance(output_keys, str):
|
||||
# unwrap single channel writes to just the write value
|
||||
filtered_writes = [
|
||||
(chan, val) for chan, val in task_writes if chan == output_keys
|
||||
]
|
||||
mapped_writes = map_task_result_writes(filtered_writes)
|
||||
task_result = mapped_writes.get(str(output_keys)) if mapped_writes else None
|
||||
else:
|
||||
if isinstance(output_keys, str):
|
||||
output_keys = [output_keys]
|
||||
# map task result writes to the desired output channels
|
||||
# repeateed writes to the same channel are aggregated into: {'$writes': [write1, write2, ...]}
|
||||
filtered_writes = [
|
||||
(chan, val) for chan, val in task_writes if chan in output_keys
|
||||
]
|
||||
mapped_writes = map_task_result_writes(filtered_writes)
|
||||
task_result = mapped_writes if filtered_writes else {}
|
||||
|
||||
has_writes = rtn is not MISSING or any(
|
||||
w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes
|
||||
)
|
||||
|
||||
out.append(
|
||||
PregelTask(
|
||||
task.id,
|
||||
task.name,
|
||||
task.path,
|
||||
task_error,
|
||||
task_interrupts,
|
||||
states.get(task.id) if states else None,
|
||||
task_result if has_writes else None,
|
||||
)
|
||||
)
|
||||
return tuple(out)
|
||||
|
||||
|
||||
COLOR_MAPPING = {
|
||||
"black": "0;30",
|
||||
"red": "0;31",
|
||||
"green": "0;32",
|
||||
"yellow": "0;33",
|
||||
"blue": "0;34",
|
||||
"magenta": "0;35",
|
||||
"cyan": "0;36",
|
||||
"white": "0;37",
|
||||
"gray": "1;30",
|
||||
}
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
return f"\033[1;3{COLOR_MAPPING[color]}m{text}\033[0m"
|
||||
|
||||
|
||||
def get_bolded_text(text: str) -> str:
|
||||
"""Get bolded text."""
|
||||
return f"\033[1m{text}\033[0m"
|
||||
3345
venv/Lib/site-packages/langgraph/pregel/main.py
Normal file
3345
venv/Lib/site-packages/langgraph/pregel/main.py
Normal file
File diff suppressed because it is too large
Load Diff
164
venv/Lib/site-packages/langgraph/pregel/protocol.py
Normal file
164
venv/Lib/site-packages/langgraph/pregel/protocol.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from typing import Any, Generic, cast
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.runnables.graph import Graph as DrawableGraph
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph.types import All, Command, StateSnapshot, StateUpdate, StreamMode
|
||||
from langgraph.typing import ContextT, InputT, OutputT, StateT
|
||||
|
||||
__all__ = ("PregelProtocol", "StreamProtocol")
|
||||
|
||||
|
||||
class PregelProtocol(Runnable[InputT, Any], Generic[StateT, ContextT, InputT, OutputT]):
|
||||
@abstractmethod
|
||||
def with_config(
|
||||
self, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Self: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_graph(
|
||||
self,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
xray: int | bool = False,
|
||||
) -> DrawableGraph: ...
|
||||
|
||||
@abstractmethod
|
||||
async def aget_graph(
|
||||
self,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
xray: int | bool = False,
|
||||
) -> DrawableGraph: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_state(
|
||||
self, config: RunnableConfig, *, subgraphs: bool = False
|
||||
) -> StateSnapshot: ...
|
||||
|
||||
@abstractmethod
|
||||
async def aget_state(
|
||||
self, config: RunnableConfig, *, subgraphs: bool = False
|
||||
) -> StateSnapshot: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_state_history(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[StateSnapshot]: ...
|
||||
|
||||
@abstractmethod
|
||||
def aget_state_history(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> AsyncIterator[StateSnapshot]: ...
|
||||
|
||||
@abstractmethod
|
||||
def bulk_update_state(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
updates: Sequence[Sequence[StateUpdate]],
|
||||
) -> RunnableConfig: ...
|
||||
|
||||
@abstractmethod
|
||||
async def abulk_update_state(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
updates: Sequence[Sequence[StateUpdate]],
|
||||
) -> RunnableConfig: ...
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
values: dict[str, Any] | Any | None,
|
||||
as_node: str | None = None,
|
||||
) -> RunnableConfig: ...
|
||||
|
||||
@abstractmethod
|
||||
async def aupdate_state(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
values: dict[str, Any] | Any | None,
|
||||
as_node: str | None = None,
|
||||
) -> RunnableConfig: ...
|
||||
|
||||
@abstractmethod
|
||||
def stream(
|
||||
self,
|
||||
input: InputT | Command | None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
context: ContextT | None = None,
|
||||
stream_mode: StreamMode | list[StreamMode] | None = None,
|
||||
interrupt_before: All | Sequence[str] | None = None,
|
||||
interrupt_after: All | Sequence[str] | None = None,
|
||||
subgraphs: bool = False,
|
||||
) -> Iterator[dict[str, Any] | Any]: ...
|
||||
|
||||
@abstractmethod
|
||||
def astream(
|
||||
self,
|
||||
input: InputT | Command | None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
context: ContextT | None = None,
|
||||
stream_mode: StreamMode | list[StreamMode] | None = None,
|
||||
interrupt_before: All | Sequence[str] | None = None,
|
||||
interrupt_after: All | Sequence[str] | None = None,
|
||||
subgraphs: bool = False,
|
||||
) -> AsyncIterator[dict[str, Any] | Any]: ...
|
||||
|
||||
@abstractmethod
|
||||
def invoke(
|
||||
self,
|
||||
input: InputT | Command | None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
context: ContextT | None = None,
|
||||
interrupt_before: All | Sequence[str] | None = None,
|
||||
interrupt_after: All | Sequence[str] | None = None,
|
||||
) -> dict[str, Any] | Any: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: InputT | Command | None,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
context: ContextT | None = None,
|
||||
interrupt_before: All | Sequence[str] | None = None,
|
||||
interrupt_after: All | Sequence[str] | None = None,
|
||||
) -> dict[str, Any] | Any: ...
|
||||
|
||||
|
||||
StreamChunk = tuple[tuple[str, ...], str, Any]
|
||||
|
||||
|
||||
class StreamProtocol:
|
||||
__slots__ = ("modes", "__call__")
|
||||
|
||||
modes: set[StreamMode]
|
||||
|
||||
__call__: Callable[[Self, StreamChunk], None]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
__call__: Callable[[StreamChunk], None],
|
||||
modes: set[StreamMode],
|
||||
) -> None:
|
||||
self.__call__ = cast(Callable[[Self, StreamChunk], None], __call__)
|
||||
self.modes = modes
|
||||
1015
venv/Lib/site-packages/langgraph/pregel/remote.py
Normal file
1015
venv/Lib/site-packages/langgraph/pregel/remote.py
Normal file
File diff suppressed because it is too large
Load Diff
38
venv/Lib/site-packages/langgraph/pregel/types.py
Normal file
38
venv/Lib/site-packages/langgraph/pregel/types.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Re-export types moved to langgraph.types"""
|
||||
|
||||
from langgraph.types import (
|
||||
All,
|
||||
CachePolicy,
|
||||
PregelExecutableTask,
|
||||
PregelTask,
|
||||
RetryPolicy,
|
||||
StateSnapshot,
|
||||
StateUpdate,
|
||||
StreamMode,
|
||||
StreamWriter,
|
||||
default_retry_on,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"All",
|
||||
"StateUpdate",
|
||||
"CachePolicy",
|
||||
"PregelExecutableTask",
|
||||
"PregelTask",
|
||||
"RetryPolicy",
|
||||
"StateSnapshot",
|
||||
"StreamMode",
|
||||
"StreamWriter",
|
||||
"default_retry_on",
|
||||
]
|
||||
|
||||
from warnings import warn
|
||||
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV10
|
||||
|
||||
warn(
|
||||
"Importing from langgraph.pregel.types is deprecated. "
|
||||
"Please use 'from langgraph.types import ...' instead.",
|
||||
LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
Reference in New Issue
Block a user