initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""Internal modules for LangGraph.
This module is not part of the public API, and thus stability is not guaranteed.
"""

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
def _freeze(obj: Any, depth: int = 10) -> Hashable:
if isinstance(obj, Hashable) or depth <= 0:
# already hashable, no need to freeze
return obj
elif isinstance(obj, Mapping):
# sort keys so {"a":1,"b":2} == {"b":2,"a":1}
return tuple(sorted((k, _freeze(v, depth - 1)) for k, v in obj.items()))
elif isinstance(obj, Sequence):
return tuple(_freeze(x, depth - 1) for x in obj)
# numpy / pandas etc. can provide their own .tobytes()
elif hasattr(obj, "tobytes"):
return (
type(obj).__name__,
obj.tobytes(),
obj.shape if hasattr(obj, "shape") else None,
)
return obj # strings, ints, dataclasses with frozen=True, etc.
def default_cache_key(*args: Any, **kwargs: Any) -> str | bytes:
"""Default cache key function that uses the arguments and keyword arguments to generate a hashable key."""
import pickle
# protocol 5 strikes a good balance between speed and size
return pickle.dumps((_freeze(args), _freeze(kwargs)), protocol=5, fix_imports=False)

View File

@@ -0,0 +1,329 @@
from __future__ import annotations
from collections import ChainMap
from collections.abc import Mapping, Sequence
from os import getenv
from typing import Any, cast
from langchain_core.callbacks import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
Callbacks,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
CONFIG_KEYS,
COPIABLE_KEYS,
var_child_runnable_config,
)
from langgraph.checkpoint.base import CheckpointMetadata
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
NS_END,
NS_SEP,
)
DEFAULT_RECURSION_LIMIT = int(getenv("LANGGRAPH_DEFAULT_RECURSION_LIMIT", "10000"))
def recast_checkpoint_ns(ns: str) -> str:
"""Remove task IDs from checkpoint namespace.
Args:
ns: The checkpoint namespace with task IDs.
Returns:
str: The checkpoint namespace without task IDs.
"""
return NS_SEP.join(
part.split(NS_END)[0] for part in ns.split(NS_SEP) if not part.isdigit()
)
def patch_configurable(
config: RunnableConfig | None, patch: dict[str, Any]
) -> RunnableConfig:
if config is None:
return {CONF: patch}
elif CONF not in config:
return {**config, CONF: patch}
else:
return {**config, CONF: {**config[CONF], **patch}}
def patch_checkpoint_map(
config: RunnableConfig | None, metadata: CheckpointMetadata | None
) -> RunnableConfig:
if config is None:
return config
elif parents := (metadata.get("parents") if metadata else None):
conf = config[CONF]
return patch_configurable(
config,
{
CONFIG_KEY_CHECKPOINT_MAP: {
**parents,
conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID],
},
},
)
else:
return config
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs: The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in configs:
if config is None:
continue
for key, value in config.items():
if not value:
continue
if key == "metadata":
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == "tags":
if base_value := base.get(key):
base[key] = [*base_value, *value] # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == CONF:
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore[dict-item]
else:
base[key] = value
elif key == "callbacks":
base_callbacks = base.get("callbacks")
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(value, list):
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + value
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in value:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif isinstance(value, BaseCallbackManager):
# value is a manager
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
mngr = value.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.merge(value)
else:
raise NotImplementedError
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]
else:
base[key] = config[key] # type: ignore[literal-required]
if CONF not in base:
base[CONF] = {}
return base
def patch_config(
config: RunnableConfig | None,
*,
callbacks: Callbacks = None,
recursion_limit: int | None = None,
max_concurrency: int | None = None,
run_name: str | None = None,
configurable: dict[str, Any] | None = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config: The config to patch.
callbacks: The callbacks to set.
recursion_limit: The recursion limit to set.
max_concurrency: The max number of concurrent steps to run, which also applies to parallelized steps.
run_name: The run name to set.
configurable: The configurable to set.
Returns:
RunnableConfig: The patched config.
"""
config = config.copy() if config is not None else {}
if callbacks is not None:
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if configurable is not None:
config[CONF] = {**config.get(CONF, {}), **configurable}
return config
def get_callback_manager_for_config(
config: RunnableConfig, tags: Sequence[str] | None = None
) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config: The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, CallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=all_tags,
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
tags: Sequence[str] | None = None,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config: The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, AsyncCallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=all_tags,
inheritable_metadata=config.get("metadata"),
)
def _is_not_empty(value: Any) -> bool:
if isinstance(value, (list, tuple, dict)):
return len(value) > 0
else:
return value is not None
def ensure_config(*configs: RunnableConfig | None) -> RunnableConfig:
"""Return a config with all keys, merging any provided configs.
Args:
*configs: Configs to merge before ensuring defaults.
Returns:
RunnableConfig: The merged and ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata=ChainMap(),
callbacks=None,
recursion_limit=DEFAULT_RECURSION_LIMIT,
configurable={},
)
if var_config := var_child_runnable_config.get():
empty.update(
{
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
for k, v in var_config.items()
if _is_not_empty(v)
},
)
for config in configs:
if config is None:
continue
for k, v in config.items():
if _is_not_empty(v) and k in CONFIG_KEYS:
if k == CONF:
empty[k] = cast(dict, v).copy()
else:
empty[k] = v # type: ignore[literal-required]
for k, v in config.items():
if _is_not_empty(v) and k not in CONFIG_KEYS:
empty[CONF][k] = v
_empty_metadata = empty["metadata"]
for key, value in empty[CONF].items():
if _exclude_as_metadata(key, value, _empty_metadata):
continue
_empty_metadata[key] = value
return empty
_OMIT = ("key", "token", "secret", "password", "auth")
def _exclude_as_metadata(key: str, value: Any, metadata: Mapping[str, Any]) -> bool:
key_lower = key.casefold()
return (
key.startswith("__")
or not isinstance(value, (str, int, float, bool))
or key in metadata
or any(substr in key_lower for substr in _OMIT)
)

View File

@@ -0,0 +1,112 @@
"""Constants used for Pregel operations."""
import sys
from typing import Literal, cast
# --- Reserved write keys ---
INPUT = sys.intern("__input__")
# for values passed as input to the graph
INTERRUPT = sys.intern("__interrupt__")
# for dynamic interrupts raised by nodes
RESUME = sys.intern("__resume__")
# for values passed to resume a node after an interrupt
ERROR = sys.intern("__error__")
# for errors raised by nodes
NO_WRITES = sys.intern("__no_writes__")
# marker to signal node didn't write anything
TASKS = sys.intern("__pregel_tasks")
# for Send objects returned by nodes/edges, corresponds to PUSH below
RETURN = sys.intern("__return__")
# for writes of a task where we simply record the return value
PREVIOUS = sys.intern("__previous__")
# the implicit branch that handles each node's Control values
# --- Reserved cache namespaces ---
CACHE_NS_WRITES = sys.intern("__pregel_ns_writes")
# cache namespace for node writes
# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = sys.intern("__pregel_send")
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = sys.intern("__pregel_read")
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CALL = sys.intern("__pregel_call")
# holds the `call` function that accepts a node/func, args and returns a future
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
# holds a `StreamProtocol` passed from parent graph to child graphs
CONFIG_KEY_CACHE = sys.intern("__pregel_cache")
# holds a `BaseCache` made available to subgraphs
CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming")
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id")
# holds the task ID for the current task
CONFIG_KEY_THREAD_ID = sys.intern("thread_id")
# holds the thread ID for the current invocation
CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map")
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id")
# holds the current checkpoint_id, if any
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
# holds the current checkpoint_ns, "" for root graph
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
# holds a callback to be called when a node is finished
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_RUNNER_SUBMIT = sys.intern("__pregel_runner_submit")
# holds a function that receives tasks from runner, executes them and returns results
CONFIG_KEY_DURABILITY = sys.intern("__pregel_durability")
# holds the durability mode, one of "sync", "async", or "exit"
CONFIG_KEY_RUNTIME = sys.intern("__pregel_runtime")
# holds a `Runtime` instance with context, store, stream writer, etc.
CONFIG_KEY_RESUME_MAP = sys.intern("__pregel_resume_map")
# holds a mapping of task ns -> resume value for resuming tasks
# --- Other constants ---
PUSH = sys.intern("__pregel_push")
# denotes push-style tasks, ie. those created by Send objects
PULL = sys.intern("__pregel_pull")
# denotes pull-style tasks, ie. those triggered by edges
NS_SEP = sys.intern("|")
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = sys.intern(":")
# for checkpoint_ns, for each level, separates the namespace from the task_id
CONF = cast(Literal["configurable"], sys.intern("configurable"))
# key for the configurable dict in RunnableConfig
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
# the task_id to use for writes that are not associated with a task
OVERWRITE = sys.intern("__overwrite__")
# dict key for the overwrite value, used as `{'__overwrite__': value}`
# redefined to avoid circular import with langgraph.constants
_TAG_HIDDEN = sys.intern("langsmith:hidden")
RESERVED = {
_TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
RESUME,
ERROR,
NO_WRITES,
# reserved config.configurable keys
CONFIG_KEY_SEND,
CONFIG_KEY_READ,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_STREAM,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUME_MAP,
# other constants
PUSH,
PULL,
NS_SEP,
NS_END,
CONF,
}

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
import dataclasses
import types
import weakref
from collections.abc import Generator, Sequence
from typing import Annotated, Any, Optional, Union, get_origin, get_type_hints
from pydantic import BaseModel
from typing_extensions import NotRequired, ReadOnly, Required
from langgraph._internal._typing import MISSING
def _is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
# Handle new union syntax (PEP 604): str | None
if isinstance(type_, types.UnionType):
return any(
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
)
if hasattr(type_, "__origin__") and hasattr(type_, "__args__"):
origin = get_origin(type_)
if origin is Optional:
return True
if origin is Union:
return any(
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
)
if origin is Annotated:
return _is_optional_type(type_.__args__[0])
return origin is None
if hasattr(type_, "__bound__") and type_.__bound__ is not None:
return _is_optional_type(type_.__bound__)
return type_ is None
def _is_required_type(type_: Any) -> bool | None:
"""Check if an annotation is marked as Required/NotRequired.
Returns:
- True if required
- False if not required
- None if not annotated with either
"""
origin = get_origin(type_)
if origin is Required:
return True
if origin is NotRequired:
return False
if origin is Annotated or getattr(origin, "__args__", None):
# See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated
return _is_required_type(type_.__args__[0])
return None
def _is_readonly_type(type_: Any) -> bool:
"""Check if an annotation is marked as ReadOnly.
Returns:
- True if is read only
- False if not read only
"""
# See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier
origin = get_origin(type_)
if origin is Annotated:
return _is_readonly_type(type_.__args__[0])
if origin is ReadOnly:
return True
return False
_DEFAULT_KEYS: frozenset[str] = frozenset()
def get_field_default(name: str, type_: Any, schema: type[Any]) -> Any:
"""Determine the default value for a field in a state schema.
This is based on:
If TypedDict:
- Required/NotRequired
- total=False -> everything optional
- Type annotation (Optional/Union[None])
"""
optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS)
irq = _is_required_type(type_)
if name in optional_keys:
# Either total=False or explicit NotRequired.
# No type annotation trumps this.
if irq:
# Unless it's earlier versions of python & explicit Required
return ...
return None
if irq is not None:
if irq:
# Handle Required[<type>]
# (we already handled NotRequired and total=False)
return ...
# Handle NotRequired[<type>] for earlier versions of python
return None
if dataclasses.is_dataclass(schema):
field_info = next(
(f for f in dataclasses.fields(schema) if f.name == name), None
)
if field_info:
if (
field_info.default is not dataclasses.MISSING
and field_info.default is not ...
):
return field_info.default
elif field_info.default_factory is not dataclasses.MISSING:
return field_info.default_factory()
# Note, we ignore ReadOnly attributes,
# as they don't make much sense. (we don't care if you mutate the state in your node)
# and mutating state in your node has no effect on our graph state.
# Base case is the annotation
if _is_optional_type(type_):
return None
return ...
def get_enhanced_type_hints(
type: type[Any],
) -> Generator[tuple[str, Any, Any, str | None], None, None]:
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
for name, typ in get_type_hints(type).items():
default = None
description = None
# Pydantic models
try:
if hasattr(type, "model_fields") and name in type.model_fields:
field = type.model_fields[name]
if hasattr(field, "description") and field.description is not None:
description = field.description
if hasattr(field, "default") and field.default is not None:
default = field.default
if (
hasattr(default, "__class__")
and getattr(default.__class__, "__name__", "")
== "PydanticUndefinedType"
):
default = None
except (AttributeError, KeyError, TypeError):
pass
# TypedDict, dataclass
try:
if hasattr(type, "__dict__"):
type_dict = getattr(type, "__dict__")
if name in type_dict:
default = type_dict[name]
except (AttributeError, KeyError, TypeError):
pass
yield name, typ, default, description
def get_update_as_tuples(input: Any, keys: Sequence[str]) -> list[tuple[str, Any]]:
"""Get Pydantic state update as a list of (key, value) tuples."""
if isinstance(input, BaseModel):
keep = input.model_fields_set
defaults = {k: v.default for k, v in type(input).model_fields.items()}
else:
keep = None
defaults = {}
# NOTE: This behavior for Pydantic is somewhat inelegant,
# but we keep around for backwards compatibility
# if input is a Pydantic model, only update values
# that are different from the default values or in the keep set
return [
(k, value)
for k in keys
if (value := getattr(input, k, MISSING)) is not MISSING
and (
value is not None
or defaults.get(k, MISSING) is not None
or (keep is not None and k in keep)
)
]
ANNOTATED_KEYS_CACHE: weakref.WeakKeyDictionary[type[Any], tuple[str, ...]] = (
weakref.WeakKeyDictionary()
)
def get_cached_annotated_keys(obj: type[Any]) -> tuple[str, ...]:
"""Return cached annotated keys for a Python class."""
if obj in ANNOTATED_KEYS_CACHE:
return ANNOTATED_KEYS_CACHE[obj]
if isinstance(obj, type):
keys: list[str] = []
for base in reversed(obj.__mro__):
ann = base.__dict__.get("__annotations__")
# In Python 3.14+, Pydantic models use descriptors for __annotations__
# so we need to fall back to getattr if __dict__.get returns None
if ann is None:
ann = getattr(base, "__annotations__", None)
if ann is None or isinstance(ann, types.GetSetDescriptorType):
continue
keys.extend(ann.keys())
return ANNOTATED_KEYS_CACHE.setdefault(obj, tuple(keys))
else:
raise TypeError(f"Expected a type, got {type(obj)}. ")

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import contextvars
import inspect
import sys
import types
from collections.abc import Awaitable, Coroutine, Generator
from typing import TypeVar, cast
T = TypeVar("T")
AnyFuture = asyncio.Future | concurrent.futures.Future
CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)
EAGER_NOT_SUPPORTED = sys.version_info < (3, 12)
def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
# Tries to call Future.get_loop() if it's available.
# Otherwise fallbacks to using the old '_loop' property.
try:
get_loop = fut.get_loop
except AttributeError:
pass
else:
return get_loop()
return fut._loop
def _convert_future_exc(exc: BaseException) -> BaseException:
exc_class = type(exc)
if exc_class is concurrent.futures.CancelledError:
return asyncio.CancelledError(*exc.args)
elif exc_class is concurrent.futures.TimeoutError:
return asyncio.TimeoutError(*exc.args)
elif exc_class is concurrent.futures.InvalidStateError:
return asyncio.InvalidStateError(*exc.args)
else:
return exc
def _set_concurrent_future_state(
concurrent: concurrent.futures.Future,
source: AnyFuture,
) -> None:
"""Copy state from a future to a concurrent.futures.Future."""
assert source.done()
if source.cancelled():
concurrent.cancel()
if not concurrent.set_running_or_notify_cancel():
return
exception = source.exception()
if exception is not None:
concurrent.set_exception(_convert_future_exc(exception))
else:
result = source.result()
concurrent.set_result(result)
def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
"""Internal helper to copy state from another Future.
The other Future may be a concurrent.futures.Future.
"""
if dest.done():
return
assert source.done()
if dest.cancelled():
return
if source.cancelled():
dest.cancel()
else:
exception = source.exception()
if exception is not None:
dest.set_exception(_convert_future_exc(exception))
else:
result = source.result()
dest.set_result(result)
def _chain_future(source: AnyFuture, destination: AnyFuture) -> None:
"""Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination.
If destination is cancelled, source gets cancelled too.
Compatible with both asyncio.Future and concurrent.futures.Future.
"""
if not asyncio.isfuture(source) and not isinstance(
source, concurrent.futures.Future
):
raise TypeError("A future is required for source argument")
if not asyncio.isfuture(destination) and not isinstance(
destination, concurrent.futures.Future
):
raise TypeError("A future is required for destination argument")
source_loop = _get_loop(source) if asyncio.isfuture(source) else None
dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None
def _set_state(future: AnyFuture, other: AnyFuture) -> None:
if asyncio.isfuture(future):
_copy_future_state(other, future)
else:
_set_concurrent_future_state(future, other)
def _call_check_cancel(destination: AnyFuture) -> None:
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)
def _call_set_state(source: AnyFuture) -> None:
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
return
if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source)
else:
if dest_loop.is_closed():
return
dest_loop.call_soon_threadsafe(_set_state, destination, source)
destination.add_done_callback(_call_check_cancel)
source.add_done_callback(_call_set_state)
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
# adapted from asyncio.run_coroutine_threadsafe
try:
_chain_future(source, destination)
return destination
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
if isinstance(destination, concurrent.futures.Future):
if destination.set_running_or_notify_cancel():
destination.set_exception(exc)
else:
destination.set_exception(exc)
raise
def _ensure_future(
coro_or_future: Coroutine[None, None, T] | Awaitable[T],
*,
loop: asyncio.AbstractEventLoop,
name: str | None = None,
context: contextvars.Context | None = None,
lazy: bool = True,
) -> asyncio.Task[T]:
called_wrap_awaitable = False
if not asyncio.iscoroutine(coro_or_future):
if inspect.isawaitable(coro_or_future):
coro_or_future = cast(
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
)
called_wrap_awaitable = True
else:
raise TypeError(
"An asyncio.Future, a coroutine or an awaitable is required."
f" Got {type(coro_or_future).__name__} instead."
)
try:
if CONTEXT_NOT_SUPPORTED:
return loop.create_task(coro_or_future, name=name)
elif EAGER_NOT_SUPPORTED or lazy:
return loop.create_task(coro_or_future, name=name, context=context)
else:
return asyncio.eager_task_factory(
loop, coro_or_future, name=name, context=context
)
except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise
@types.coroutine
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
"""Helper for asyncio.ensure_future().
Wraps awaitable (an object with __await__) into a coroutine
that will later be wrapped in a Task by ensure_future().
"""
return (yield from awaitable.__await__())
def run_coroutine_threadsafe(
coro: Coroutine[None, None, T],
loop: asyncio.AbstractEventLoop,
*,
lazy: bool,
name: str | None = None,
context: contextvars.Context | None = None,
) -> asyncio.Future[T]:
"""Submit a coroutine object to a given event loop.
Return an asyncio.Future to access the result.
"""
if asyncio._get_running_loop() is loop:
return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy)
else:
future: asyncio.Future[T] = asyncio.Future(loop=loop)
def callback() -> None:
try:
chain_future(
_ensure_future(coro, loop=loop, name=name, context=context),
future,
)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
future.set_exception(exc)
raise
loop.call_soon_threadsafe(callback, context=context)
return future

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import sys
import typing
import warnings
from contextlib import nullcontext
from dataclasses import is_dataclass
from functools import lru_cache
from typing import (
Any,
cast,
overload,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
)
from pydantic import (
create_model as _create_model_base,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from typing_extensions import TypedDict
@overload
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: ...
@overload
def get_fields(model: BaseModel) -> dict[str, FieldInfo]: ...
def get_fields(
model: type[BaseModel] | BaseModel,
) -> dict[str, FieldInfo]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
_SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
)
NO_DEFAULT = object()
def _create_root_model(
name: str,
type_: Any,
module_name: str | None = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
"""Create a base class."""
def schema(
cls: type[BaseModel],
by_alias: bool = True, # noqa: FBT001,FBT002
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: type[BaseModel],
by_alias: bool = True, # noqa: FBT001,FBT002
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": module_name or "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
with warnings.catch_warnings():
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast("type[BaseModel]", custom_root_type)
@lru_cache(maxsize=256)
def _create_root_model_cached(
model_name: str,
type_: Any,
*,
module_name: str | None = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
return _create_root_model(
model_name, type_, default_=default_, module_name=module_name
)
@lru_cache(maxsize=256)
def _create_model_cached(
model_name: str,
/,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
# Reserved names should capture all the `public` names / methods that are
# used by BaseModel internally. This will keep the reserved names up-to-date.
# For reference, the reserved names are:
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
# "model_computed_fields", "model_config", "model_construct", "model_copy",
# "model_dump", "model_dump_json", "model_extra", "model_fields",
# "model_fields_set", "model_json_schema", "model_parametrized_name",
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
# "model_validate_strings"
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields."""
remapped = {}
for key, value in field_definitions.items():
if key.startswith("_") or key in _RESERVED_NAMES:
# Let's add a prefix to avoid colliding with internal pydantic fields
if isinstance(value, FieldInfo):
msg = (
f"Remapping for fields starting with '_' or fields with a name "
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
f" the field is a pydantic Field instance. Got {key}."
)
raise NotImplementedError(msg)
type_, default_ = value
remapped[f"private_{key}"] = (
type_,
Field(
default=default_,
alias=key,
serialization_alias=key,
title=key.lstrip("_").replace("_", " ").title(),
),
)
else:
remapped[key] = value
return remapped
def create_model(
model_name: str,
*,
field_definitions: dict[str, Any] | None = None,
root: Any | None = None,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Attention:
Please do not use outside of langchain packages. This API
is subject to change at any time.
Args:
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
Returns:
Type[BaseModel]: The created model.
"""
field_definitions = field_definitions or {}
if root:
if field_definitions:
msg = (
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
raise NotImplementedError(msg)
if isinstance(root, tuple):
kwargs = {"type_": root[0], "default_": root[1]}
else:
kwargs = {"type_": root}
try:
named_root_model = _create_root_model_cached(model_name, **kwargs)
except TypeError:
# something in the arguments into _create_root_model_cached is not hashable
named_root_model = _create_root_model(
model_name,
**kwargs,
)
return named_root_model
# No root, just field definitions
names = set(field_definitions.keys())
capture_warnings = False
for name in names:
# Also if any non-reserved name is used (e.g., model_id or model_name)
if name.startswith("model"):
capture_warnings = True
with warnings.catch_warnings() if capture_warnings else nullcontext():
if capture_warnings:
warnings.filterwarnings(action="ignore")
try:
return _create_model_cached(model_name, **field_definitions)
except TypeError:
# something in field definitions is not hashable
return _create_model_base(
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
def is_supported_by_pydantic(type_: Any) -> bool:
"""Check if a given "complex" type is supported by pydantic.
This will return False for primitive types like int, str, etc.
The check is meant for container types like dataclasses, TypedDicts, etc.
"""
if is_dataclass(type_):
return True
if isinstance(type_, type) and issubclass(type_, BaseModel):
return True
if hasattr(type_, "__orig_bases__"):
for base in type_.__orig_bases__:
if base is TypedDict:
return True
elif base is typing.TypedDict: # noqa: TID251
# ignoring TID251 since it's OK to use typing.TypedDict in this case.
# Pydantic supports typing.TypedDict from Python 3.12
# For older versions, only typing_extensions.TypedDict is supported.
if sys.version_info >= (3, 12):
return True
return False

View File

@@ -0,0 +1,124 @@
# type: ignore
from __future__ import annotations
import asyncio
import queue
import threading
import types
from collections import deque
from time import monotonic
class AsyncQueue(asyncio.Queue):
"""Async unbounded FIFO queue with a wait() method.
Subclassed from asyncio.Queue, adding a wait() method."""
async def wait(self) -> None:
"""If queue is empty, wait until an item is available.
Copied from Queue.get(), removing the call to .get_nowait(),
ie. this doesn't consume the item, just waits for it.
"""
while self.empty():
getter = self._get_loop().create_future()
self._getters.append(getter)
try:
await getter
except BaseException:
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
class Semaphore(threading.Semaphore):
"""Semaphore subclass with a wait() method."""
def wait(self, blocking: bool = True, timeout: float | None = None):
"""Block until the semaphore can be acquired, but don't acquire it."""
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
rc = False
endtime = None
with self._cond:
while self._value == 0:
if not blocking:
break
if timeout is not None:
if endtime is None:
endtime = monotonic() + timeout
else:
timeout = endtime - monotonic()
if timeout <= 0:
break
self._cond.wait(timeout)
else:
rc = True
return rc
class SyncQueue:
"""Unbounded FIFO queue with a wait() method.
Adapted from pure Python implementation of queue.SimpleQueue.
"""
def __init__(self):
self._queue = deque()
self._count = Semaphore(0)
def put(self, item, block=True, timeout=None):
"""Put the item on the queue.
The optional 'block' and 'timeout' arguments are ignored, as this method
never blocks. They are provided for compatibility with the Queue class.
"""
self._queue.append(item)
self._count.release()
def get(self, block=False, timeout=None):
"""Remove and return an item from the queue.
If optional args 'block' is true and 'timeout' is None (the default),
block if necessary until an item is available. If 'timeout' is
a non-negative number, it blocks at most 'timeout' seconds and raises
the Empty exception if no item was available within that time.
Otherwise ('block' is false), return an item if one is immediately
available, else raise the Empty exception ('timeout' is ignored
in that case).
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
if not self._count.acquire(block, timeout):
raise queue.Empty
try:
return self._queue.popleft()
except IndexError:
raise queue.Empty
def wait(self, block=True, timeout=None):
"""If queue is empty, wait until an item maybe is available,
but don't consume it.
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
self._count.wait(block, timeout)
def empty(self):
"""Return True if the queue is empty, False otherwise (not reliable!)."""
return len(self._queue) == 0
def qsize(self):
"""Return the approximate size of the queue (not reliable!)."""
return len(self._queue)
__class_getitem__ = classmethod(types.GenericAlias)

View File

@@ -0,0 +1,29 @@
def default_retry_on(exc: Exception) -> bool:
import httpx
import requests
if isinstance(exc, ConnectionError):
return True
if isinstance(exc, httpx.HTTPStatusError):
return 500 <= exc.response.status_code < 600
if isinstance(exc, requests.HTTPError):
return 500 <= exc.response.status_code < 600 if exc.response else True
if isinstance(
exc,
(
ValueError,
TypeError,
ArithmeticError,
ImportError,
LookupError,
NameError,
SyntaxError,
RuntimeError,
ReferenceError,
StopIteration,
StopAsyncIteration,
OSError,
),
):
return False
return True

View File

@@ -0,0 +1,914 @@
from __future__ import annotations
import asyncio
import enum
import inspect
import sys
import warnings
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterator,
Sequence,
)
from contextlib import AsyncExitStack, contextmanager
from contextvars import Context, Token, copy_context
from functools import partial, wraps
from typing import (
Any,
Optional,
Protocol,
TypeGuard,
cast,
)
from langchain_core.runnables.base import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableParallel,
RunnableSequence,
)
from langchain_core.runnables.base import (
RunnableLike as LCRunnableLike,
)
from langchain_core.runnables.config import (
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.langchain import LangChainTracer
from langgraph.store.base import BaseStore
from langgraph._internal._config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_RUNTIME,
)
from langgraph._internal._typing import MISSING
from langgraph.types import StreamWriter
try:
from langchain_core.tracers._streaming import _StreamingCallbackHandler
except ImportError:
_StreamingCallbackHandler = None # type: ignore
def _set_config_context(
config: RunnableConfig, run: Any = None
) -> Token[RunnableConfig | None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
config_token = var_child_runnable_config.set(config)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run})
return config_token
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
"""Set the child Runnable config + tracing context.
Args:
token: The config token to reset.
"""
var_child_runnable_config.reset(token)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context(
{
"parent": None,
"project_name": None,
"tags": None,
"metadata": None,
"enabled": None,
"client": None,
}
)
@contextmanager
def set_config_context(
config: RunnableConfig, run: Any = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
ctx = copy_context()
config_token = ctx.run(_set_config_context, config, run)
try:
yield ctx
finally:
ctx.run(_unset_config_context, config_token, run)
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
# Special type to denote any type is accepted
ANY_TYPE = object()
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
# A named argument may appear multiple times if it appears with distinct types.
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
(
"config",
(
RunnableConfig,
"RunnableConfig",
Optional[RunnableConfig], # noqa: UP045
"Optional[RunnableConfig]",
inspect.Parameter.empty,
),
# for now, use config directly, eventually, will pop off of Runtime
"N/A",
inspect.Parameter.empty,
),
(
"writer",
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
"stream_writer",
lambda _: None,
),
(
"store",
(
BaseStore,
"BaseStore",
inspect.Parameter.empty,
),
"store",
inspect.Parameter.empty,
),
(
"store",
(
Optional[BaseStore], # noqa: UP045
"Optional[BaseStore]",
),
"store",
None,
),
(
"previous",
(ANY_TYPE,),
"previous",
inspect.Parameter.empty,
),
(
"runtime",
(ANY_TYPE,),
# we never hit this block, we just inject runtime directly
"N/A",
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Used to configure keyword arguments that can be injected at runtime
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
For a keyword to be injected from the config object, the function signature
must contain a kwarg with the same name and a matching type annotation.
Each tuple contains:
- the name of the kwarg in the function signature
- the type annotation(s) for the kwarg
- the `Runtime` attribute for fetching the value (N/A if not applicable)
This is fully internal and should be further refactored to use `get_type_hints`
to resolve forward references and optional types formatted like BaseStore | None.
"""
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
class _RunnableWithWriter(Protocol[Input, Output]):
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
class _RunnableWithStore(Protocol[Input, Output]):
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
class _RunnableWithWriterStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, writer: StreamWriter, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriter(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
) -> Output: ...
class _RunnableWithConfigStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
def __call__(
self,
state: Input,
*,
config: RunnableConfig,
writer: StreamWriter,
store: BaseStore,
) -> Output: ...
RunnableLike = (
LCRunnableLike
| _RunnableWithWriter[Input, Output]
| _RunnableWithStore[Input, Output]
| _RunnableWithWriterStore[Input, Output]
| _RunnableWithConfigWriter[Input, Output]
| _RunnableWithConfigStore[Input, Output]
| _RunnableWithConfigWriterStore[Input, Output]
)
class RunnableCallable(Runnable):
"""A much simpler version of RunnableLambda that requires sync and async functions."""
def __init__(
self,
func: Callable[..., Any | Runnable] | None,
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
*,
name: str | None = None,
tags: Sequence[str] | None = None,
trace: bool = True,
recurse: bool = True,
explode_args: bool = False,
**kwargs: Any,
) -> None:
self.name = name
if self.name is None:
if func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.func = func
self.afunc = afunc
self.tags = tags
self.kwargs = kwargs
self.trace = trace
self.recurse = recurse
self.explode_args = explode_args
# check signature
if func is None and afunc is None:
raise ValueError("At least one of func or afunc must be provided.")
self.func_accepts: dict[str, tuple[str, Any]] = {}
params = inspect.signature(cast(Callable, func or afunc)).parameters
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
p = params.get(kw)
if p is None or p.kind not in VALID_KINDS:
# If parameter is not found or is not a valid kind, skip
continue
if typ != (ANY_TYPE,) and p.annotation not in typ:
# A specific type is required, but the function annotation does
# not match the expected type.
# If this is a config parameter with incorrect typing, emit a warning
# because we used to support any type but are moving towards more correct typing
if kw == "config" and p.annotation != inspect.Parameter.empty:
warnings.warn(
f"The 'config' parameter should be typed as 'RunnableConfig' or "
f"'RunnableConfig | None', not '{p.annotation}'. ",
UserWarning,
stacklevel=4,
)
continue
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
self.func_accepts[kw] = (runtime_key, default)
def __repr__(self) -> str:
repr_args = {
k: v
for k, v in self.__dict__.items()
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
}
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
def invoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if self.func is None:
raise TypeError(
f'No synchronous function provided to "{self.name}".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg is already set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_callback_manager_for_config(config, self.tags)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
# get the run
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(child_config, run) as context:
ret = context.run(self.func, *args, **kwargs)
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(ret)
else:
ret = self.func(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return ret.invoke(input, config)
return ret
async def ainvoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if not self.afunc:
return self.invoke(input, config)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg has already been set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.name,
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
if ASYNCIO_ACCEPTS_CONTEXT:
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
with set_config_context(child_config, run) as context:
ret = await asyncio.create_task(coro, context=context)
else:
ret = await coro
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(ret)
else:
ret = await self.afunc(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return await ret.ainvoke(input, config)
return ret
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
inspect.iscoroutinefunction(func)
or hasattr(func, "__call__")
and inspect.iscoroutinefunction(func.__call__)
)
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)
def coerce_to_runnable(
thing: RunnableLike, *, name: str | None, trace: bool
) -> Runnable:
"""Coerce a runnable-like object into a Runnable.
Args:
thing: A runnable-like object.
Returns:
A Runnable.
"""
if isinstance(thing, Runnable):
return thing
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableLambda(thing, name=name)
elif callable(thing):
if is_async_callable(thing):
return RunnableCallable(None, thing, name=name, trace=trace)
else:
return RunnableCallable(
thing,
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
name=name,
trace=trace,
)
elif isinstance(thing, dict):
return RunnableParallel(thing)
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
class RunnableSeq(Runnable):
"""Sequence of `Runnable`, where the output of each is the input of the next.
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
LangGraph.
"""
def __init__(
self,
*steps: RunnableLike,
name: str | None = None,
trace_inputs: Callable[[Any], Any] | None = None,
) -> None:
"""Create a new RunnableSeq.
Args:
steps: The steps to include in the sequence.
name: The name of the `Runnable`.
Raises:
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: list[Runnable] = []
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
elif isinstance(step, RunnableSeq):
steps_flat.extend(step.steps)
else:
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
if len(steps_flat) < 2:
raise ValueError(
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
)
self.steps = steps_flat
self.name = name
self.trace_inputs = trace_inputs
def __or__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSeq(
*self.steps,
other.first,
*other.middle,
other.last,
name=self.name or other.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*self.steps,
*other.steps,
name=self.name or other.name,
)
else:
return RunnableSeq(
*self.steps,
coerce_to_runnable(other, name=None, trace=True),
name=self.name,
)
def __ror__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSequence(
other.first,
*other.middle,
other.last,
*self.steps,
name=other.name or self.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*other.steps,
*self.steps,
name=other.name or self.name,
)
else:
return RunnableSequence(
coerce_to_runnable(other, name=None, trace=True),
*self.steps,
name=self.name,
)
def invoke(
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks and context
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(input)
return input
async def ainvoke(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = await asyncio.create_task(
step.ainvoke(input, config, **kwargs), context=context
)
else:
input = await step.ainvoke(input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(input)
return input
def stream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
for idx, step in enumerate(self.steps):
if idx == 0:
iterator = step.stream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
iterator = step.transform(iterator, config)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
iterator = h.tap_output_iter(run_manager.run_id, iterator)
# consume into final output
output = context.run(_consume_iter, iterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
async def astream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{idx + 1}"
),
)
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await asyncio.create_task(
_consume_aiter(aiterator), context=context
)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
else:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await _consume_aiter(aiterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
def _consume_iter(it: Iterator[Any]) -> Any:
"""Consume an iterator."""
output: Any = None
add_supported = False
for chunk in it:
# collect final output
if output is None:
output = chunk
elif add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
"""Consume an async iterator."""
output: Any = None
add_supported = False
async for chunk in it:
# collect final output
if add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output

View File

@@ -0,0 +1,19 @@
import dataclasses
from collections.abc import Callable
from typing import Any
from langgraph.types import _DC_KWARGS
@dataclasses.dataclass(**_DC_KWARGS)
class PregelScratchpad:
step: int
stop: int
# call
call_counter: Callable[[], int]
# interrupt
interrupt_counter: Callable[[], int]
get_null_resume: Callable[[bool], Any]
resume: list[Any]
# subgraph
subgraph_counter: Callable[[], int]

View File

@@ -0,0 +1,253 @@
from __future__ import annotations
import dataclasses
import logging
import sys
import types
from collections import deque
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
from langchain_core import messages as lc_messages
from langgraph.checkpoint.base import BaseCheckpointSaver
from pydantic import BaseModel
from typing_extensions import NotRequired, Required, is_typeddict
try:
from langgraph.checkpoint.serde._msgpack import ( # noqa: F401
STRICT_MSGPACK_ENABLED,
)
except ImportError:
STRICT_MSGPACK_ENABLED = False
_warned_allowlist_unsupported = False
logger = logging.getLogger(__name__)
def _supports_checkpointer_allowlist() -> bool:
return hasattr(BaseCheckpointSaver, "with_allowlist")
_SUPPORTS_ALLOWLIST = _supports_checkpointer_allowlist()
def apply_checkpointer_allowlist(
checkpointer: Any, allowlist: set[tuple[str, ...]] | None
) -> Any:
if not checkpointer or allowlist is None or checkpointer in (True, False):
return checkpointer
if not _SUPPORTS_ALLOWLIST:
global _warned_allowlist_unsupported
if not _warned_allowlist_unsupported:
logger.warning(
"Checkpointer does not support with_allowlist; strict msgpack "
"allowlist will be skipped."
)
_warned_allowlist_unsupported = True
return checkpointer
return checkpointer.with_allowlist(allowlist)
def curated_core_allowlist() -> set[tuple[str, ...]]:
allowlist: set[tuple[str, ...]] = set()
for name in (
"BaseMessage",
"BaseMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"AIMessage",
"AIMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"RemoveMessage",
):
cls = getattr(lc_messages, name, None)
if cls is None:
continue
allowlist.add((cls.__module__, cls.__name__))
return allowlist
def build_serde_allowlist(
*,
schemas: list[type[Any]] | None = None,
channels: dict[str, Any] | None = None,
) -> set[tuple[str, ...]]:
allowlist = curated_core_allowlist()
if schemas:
schemas = [schema for schema in schemas if schema is not None]
return allowlist | collect_allowlist_from_schemas(
schemas=schemas,
channels=channels,
)
def collect_allowlist_from_schemas(
*,
schemas: list[type[Any]] | None = None,
channels: dict[str, Any] | None = None,
) -> set[tuple[str, ...]]:
allowlist: set[tuple[str, ...]] = set()
seen: set[Any] = set()
seen_ids: set[int] = set()
if schemas:
for schema in schemas:
_collect_from_type(schema, allowlist, seen, seen_ids)
if channels:
for channel in channels.values():
value_type = getattr(channel, "ValueType", None)
if value_type is not None:
_collect_from_type(value_type, allowlist, seen, seen_ids)
update_type = getattr(channel, "UpdateType", None)
if update_type is not None:
_collect_from_type(update_type, allowlist, seen, seen_ids)
return allowlist
def _collect_from_type(
typ: Any,
allowlist: set[tuple[str, ...]],
seen: set[Any],
seen_ids: set[int],
) -> None:
if _already_seen(typ, seen, seen_ids):
return
if typ is Any or typ is None:
return
if typ is Literal:
return
if isinstance(typ, types.UnionType):
for arg in typ.__args__:
_collect_from_type(arg, allowlist, seen, seen_ids)
return
origin = get_origin(typ)
if origin is Union:
for arg in get_args(typ):
_collect_from_type(arg, allowlist, seen, seen_ids)
return
if origin is Annotated or origin in (Required, NotRequired):
args = get_args(typ)
if args:
_collect_from_type(args[0], allowlist, seen, seen_ids)
return
if origin is Literal:
return
if origin in (list, set, tuple, dict, deque, frozenset):
for arg in get_args(typ):
_collect_from_type(arg, allowlist, seen, seen_ids)
return
if hasattr(typ, "__supertype__"):
_collect_from_type(typ.__supertype__, allowlist, seen, seen_ids)
return
if is_typeddict(typ):
for field_type in _safe_get_type_hints(typ).values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
return
if _is_pydantic_model(typ):
allowlist.add((typ.__module__, typ.__name__))
field_types = _safe_get_type_hints(typ)
if field_types:
for field_type in field_types.values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
else:
for field_type in _pydantic_field_types(typ):
_collect_from_type(field_type, allowlist, seen, seen_ids)
return
if dataclasses.is_dataclass(typ):
if typ_name := getattr(typ, "__name__", None):
allowlist.add((typ.__module__, typ_name))
field_types = _safe_get_type_hints(typ)
if field_types:
for field_type in field_types.values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
else:
for field in dataclasses.fields(typ):
_collect_from_type(field.type, allowlist, seen, seen_ids)
return
if isinstance(typ, type) and issubclass(typ, Enum):
allowlist.add((typ.__module__, typ.__name__))
return
def _already_seen(typ: Any, seen: set[Any], seen_ids: set[int]) -> bool:
try:
if typ in seen:
return True
seen.add(typ)
return False
except TypeError:
typ_id = id(typ)
if typ_id in seen_ids:
return True
seen_ids.add(typ_id)
return False
def _safe_get_type_hints(typ: Any) -> dict[str, Any]:
try:
module = sys.modules.get(getattr(typ, "__module__", ""))
globalns = module.__dict__ if module else None
localns = dict(vars(typ)) if hasattr(typ, "__dict__") else None
return get_type_hints(
typ, globalns=globalns, localns=localns, include_extras=True
)
except Exception:
return {}
def _is_pydantic_model(typ: Any) -> bool:
if not isinstance(typ, type):
return False
if issubclass(typ, BaseModel):
return True
try:
from pydantic.v1 import BaseModel as BaseModelV1
except Exception:
return False
return issubclass(typ, BaseModelV1)
def _pydantic_field_types(typ: type[Any]) -> list[Any]:
if hasattr(typ, "model_fields"):
return [
field.annotation
for field in typ.model_fields.values()
if getattr(field, "annotation", None) is not None
]
if hasattr(typ, "__fields__"):
return [
field.outer_type_
for field in typ.__fields__.values()
if getattr(field, "outer_type_", None) is not None
]
return []

View File

@@ -0,0 +1,54 @@
"""Private typing utilities for LangGraph."""
from __future__ import annotations
from dataclasses import Field
from typing import Any, ClassVar, Protocol, TypeAlias
from pydantic import BaseModel
from typing_extensions import TypedDict
class TypedDictLikeV1(Protocol):
"""Protocol to represent types that behave like TypedDicts
Version 1: using `ClassVar` for keys."""
__required_keys__: ClassVar[frozenset[str]]
__optional_keys__: ClassVar[frozenset[str]]
class TypedDictLikeV2(Protocol):
"""Protocol to represent types that behave like TypedDicts
Version 2: not using `ClassVar` for keys."""
__required_keys__: frozenset[str]
__optional_keys__: frozenset[str]
class DataclassLike(Protocol):
"""Protocol to represent types that behave like dataclasses.
Inspired by the private _DataclassT from dataclasses that uses a similar protocol as a bound."""
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
StateLike: TypeAlias = TypedDictLikeV1 | TypedDictLikeV2 | DataclassLike | BaseModel
"""Type alias for state-like types.
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in type checking.
"""
MISSING = object()
"""Unset sentinel value."""
class DeprecatedKwargs(TypedDict):
"""TypedDict to use for extra keyword arguments, enabling type checking warnings for deprecated arguments."""
EMPTY_SEQ: tuple[str, ...] = tuple()
"""An empty sequence of strings."""

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import Generic, TypeVar
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
ValueT = TypeVar("ValueT")
Namespace = tuple[str, ...]
FullKey = tuple[Namespace, str]
class BaseCache(ABC, Generic[ValueT]):
"""Base class for a cache."""
serde: SerializerProtocol = JsonPlusSerializer(pickle_fallback=False)
def __init__(self, *, serde: SerializerProtocol | None = None) -> None:
"""Initialize the cache with a serializer."""
self.serde = serde or self.serde
@abstractmethod
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Get the cached values for the given keys."""
@abstractmethod
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Asynchronously get the cached values for the given keys."""
@abstractmethod
def set(self, pairs: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Set the cached values for the given keys and TTLs."""
@abstractmethod
async def aset(self, pairs: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Asynchronously set the cached values for the given keys and TTLs."""
@abstractmethod
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""
@abstractmethod
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Asynchronously delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""

View File

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import datetime
import threading
from collections.abc import Mapping, Sequence
from langgraph.cache.base import BaseCache, FullKey, Namespace, ValueT
from langgraph.checkpoint.serde.base import SerializerProtocol
class InMemoryCache(BaseCache[ValueT]):
def __init__(self, *, serde: SerializerProtocol | None = None):
super().__init__(serde=serde)
self._cache: dict[Namespace, dict[str, tuple[str, bytes, float | None]]] = {}
self._lock = threading.RLock()
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Get the cached values for the given keys."""
with self._lock:
if not keys:
return {}
now = datetime.datetime.now(datetime.timezone.utc).timestamp()
values: dict[FullKey, ValueT] = {}
for ns_tuple, key in keys:
ns = Namespace(ns_tuple)
if ns in self._cache and key in self._cache[ns]:
enc, val, expiry = self._cache[ns][key]
if expiry is None or now < expiry:
values[(ns, key)] = self.serde.loads_typed((enc, val))
else:
del self._cache[ns][key]
return values
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Asynchronously get the cached values for the given keys."""
return self.get(keys)
def set(self, keys: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Set the cached values for the given keys."""
with self._lock:
now = datetime.datetime.now(datetime.timezone.utc)
for (ns, key), (value, ttl) in keys.items():
if ttl is not None:
delta = datetime.timedelta(seconds=ttl)
expiry: float | None = (now + delta).timestamp()
else:
expiry = None
if ns not in self._cache:
self._cache[ns] = {}
self._cache[ns][key] = (
*self.serde.dumps_typed(value),
expiry,
)
async def aset(self, keys: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Asynchronously set the cached values for the given keys."""
self.set(keys)
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""
with self._lock:
if namespaces is None:
self._cache.clear()
else:
for ns in namespaces:
if ns in self._cache:
del self._cache[ns]
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Asynchronously delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""
self.clear(namespaces)

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from langgraph.cache.base import BaseCache, FullKey, Namespace, ValueT
from langgraph.checkpoint.serde.base import SerializerProtocol
class RedisCache(BaseCache[ValueT]):
"""Redis-based cache implementation with TTL support."""
def __init__(
self,
redis: Any,
*,
serde: SerializerProtocol | None = None,
prefix: str = "langgraph:cache:",
) -> None:
"""Initialize the cache with a Redis client.
Args:
redis: Redis client instance (sync or async)
serde: Serializer to use for values
prefix: Key prefix for all cached values
"""
super().__init__(serde=serde)
self.redis = redis
self.prefix = prefix
def _make_key(self, ns: Namespace, key: str) -> str:
"""Create a Redis key from namespace and key."""
ns_str = ":".join(ns) if ns else ""
return f"{self.prefix}{ns_str}:{key}" if ns_str else f"{self.prefix}{key}"
def _parse_key(self, redis_key: str) -> tuple[Namespace, str]:
"""Parse a Redis key back to namespace and key."""
if not redis_key.startswith(self.prefix):
raise ValueError(
f"Key {redis_key} does not start with prefix {self.prefix}"
)
remaining = redis_key[len(self.prefix) :]
if ":" in remaining:
parts = remaining.split(":")
key = parts[-1]
ns_parts = parts[:-1]
return (tuple(ns_parts), key)
else:
return (tuple(), remaining)
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Get the cached values for the given keys."""
if not keys:
return {}
# Build Redis keys
redis_keys = [self._make_key(ns, key) for ns, key in keys]
# Get values from Redis using MGET
try:
raw_values = self.redis.mget(redis_keys)
except Exception:
# If Redis is unavailable, return empty dict
return {}
values: dict[FullKey, ValueT] = {}
for i, raw_value in enumerate(raw_values):
if raw_value is not None:
try:
# Deserialize the value
encoding, data = raw_value.split(b":", 1)
values[keys[i]] = self.serde.loads_typed((encoding.decode(), data))
except Exception:
# Skip corrupted entries
continue
return values
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
"""Asynchronously get the cached values for the given keys."""
return self.get(keys)
def set(self, mapping: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Set the cached values for the given keys and TTLs."""
if not mapping:
return
# Use pipeline for efficient batch operations
pipe = self.redis.pipeline()
for (ns, key), (value, ttl) in mapping.items():
redis_key = self._make_key(ns, key)
encoding, data = self.serde.dumps_typed(value)
# Store as "encoding:data" format
serialized_value = f"{encoding}:".encode() + data
if ttl is not None:
pipe.setex(redis_key, ttl, serialized_value)
else:
pipe.set(redis_key, serialized_value)
try:
pipe.execute()
except Exception:
# Silently fail if Redis is unavailable
pass
async def aset(self, mapping: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
"""Asynchronously set the cached values for the given keys and TTLs."""
self.set(mapping)
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""
try:
if namespaces is None:
# Clear all keys with our prefix
pattern = f"{self.prefix}*"
keys = self.redis.keys(pattern)
if keys:
self.redis.delete(*keys)
else:
# Clear specific namespaces
keys_to_delete = []
for ns in namespaces:
ns_str = ":".join(ns) if ns else ""
pattern = (
f"{self.prefix}{ns_str}:*" if ns_str else f"{self.prefix}*"
)
keys = self.redis.keys(pattern)
keys_to_delete.extend(keys)
if keys_to_delete:
self.redis.delete(*keys_to_delete)
except Exception:
# Silently fail if Redis is unavailable
pass
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
"""Asynchronously delete the cached values for the given namespaces.
If no namespaces are provided, clear all cached values."""
self.clear(namespaces)

View File

@@ -0,0 +1,27 @@
from langgraph.channels.any_value import AnyValue
from langgraph.channels.base import BaseChannel
from langgraph.channels.binop import BinaryOperatorAggregate
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue, LastValueAfterFinish
from langgraph.channels.named_barrier_value import (
NamedBarrierValue,
NamedBarrierValueAfterFinish,
)
from langgraph.channels.topic import Topic
from langgraph.channels.untracked_value import UntrackedValue
__all__ = (
# base
"BaseChannel",
# value types
"AnyValue",
"LastValue",
"LastValueAfterFinish",
"UntrackedValue",
"EphemeralValue",
"BinaryOperatorAggregate",
"NamedBarrierValue",
"NamedBarrierValueAfterFinish",
# topics
"Topic",
)

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError
__all__ = ("AnyValue",)
class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, assumes that if multiple values are
received, they are all equal."""
__slots__ = ("typ", "value")
value: Value | Any
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, AnyValue)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.key)
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.key)
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
if self.value is MISSING:
return False
else:
self.value = MISSING
return True
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING
def checkpoint(self) -> Value:
return self.value

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Generic, TypeVar
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.errors import EmptyChannelError
Value = TypeVar("Value")
Update = TypeVar("Update")
Checkpoint = TypeVar("Checkpoint")
__all__ = ("BaseChannel",)
class BaseChannel(Generic[Value, Update, Checkpoint], ABC):
"""Base class for all channels."""
__slots__ = ("key", "typ")
def __init__(self, typ: Any, key: str = "") -> None:
self.typ = typ
self.key = key
@property
@abstractmethod
def ValueType(self) -> Any:
"""The type of the value stored in the channel."""
@property
@abstractmethod
def UpdateType(self) -> Any:
"""The type of the update received by the channel."""
# serialize/deserialize methods
def copy(self) -> Self:
"""Return a copy of the channel.
By default, delegates to `checkpoint()` and `from_checkpoint()`.
Subclasses can override this method with a more efficient implementation.
"""
return self.from_checkpoint(self.checkpoint())
def checkpoint(self) -> Checkpoint | Any:
"""Return a serializable representation of the channel's current state.
Raises `EmptyChannelError` if the channel is empty (never updated yet),
or doesn't support checkpoints.
"""
try:
return self.get()
except EmptyChannelError:
return MISSING
@abstractmethod
def from_checkpoint(self, checkpoint: Checkpoint | Any) -> Self:
"""Return a new identical channel, optionally initialized from a checkpoint.
If the checkpoint contains complex data structures, they should be copied.
"""
# read methods
@abstractmethod
def get(self) -> Value:
"""Return the current value of the channel.
Raises `EmptyChannelError` if the channel is empty (never updated yet)."""
def is_available(self) -> bool:
"""Return `True` if the channel is available (not empty), `False` otherwise.
Subclasses should override this method to provide a more efficient
implementation than calling `get()` and catching `EmptyChannelError`.
"""
try:
self.get()
return True
except EmptyChannelError:
return False
# write methods
@abstractmethod
def update(self, values: Sequence[Update]) -> bool:
"""Update the channel's value with the given sequence of updates.
The order of the updates in the sequence is arbitrary.
This method is called by Pregel for all channels at the end of each step.
If there are no updates, it is called with an empty sequence.
Raises `InvalidUpdateError` if the sequence of updates is invalid.
Returns `True` if the channel was updated, `False` otherwise."""
def consume(self) -> bool:
"""Notify the channel that a subscribed task ran.
By default, no-op.
A channel can use this method to modify its state, preventing the value from being consumed again.
Returns `True` if the channel was updated, `False` otherwise.
"""
return False
def finish(self) -> bool:
"""Notify the channel that the Pregel run is finishing.
By default, no-op.
A channel can use this method to modify its state, preventing finish.
Returns `True` if the channel was updated, `False` otherwise.
"""
return False

View File

@@ -0,0 +1,134 @@
import collections.abc
from collections.abc import Callable, Sequence
from typing import Any, Generic
from typing_extensions import NotRequired, Required, Self
from langgraph._internal._constants import OVERWRITE
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import (
EmptyChannelError,
ErrorCode,
InvalidUpdateError,
create_error_message,
)
from langgraph.types import Overwrite
__all__ = ("BinaryOperatorAggregate",)
# Adapted from typing_extensions
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
return _strip_extras(t.__args__[0])
return t
def _get_overwrite(value: Any) -> tuple[bool, Any]:
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
if isinstance(value, Overwrite):
return True, value.value
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
return True, value[OVERWRITE]
return False, None
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the result of applying a binary operator to the current value and each new value.
```python
import operator
total = Channels.BinaryOperatorAggregate(int, operator.add)
```
"""
__slots__ = ("value", "operator")
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
typ = _strip_extras(typ)
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
typ = list
if typ in (collections.abc.Set, collections.abc.MutableSet):
typ = set
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
typ = dict
try:
self.value = typ()
except Exception:
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, BinaryOperatorAggregate) and (
value.operator is self.operator
if value.operator.__name__ != "<lambda>"
and self.operator.__name__ != "<lambda>"
else True
)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if not values:
return False
if self.value is MISSING:
self.value = values[0]
values = values[1:]
seen_overwrite: bool = False
for value in values:
is_overwrite, overwrite_value = _get_overwrite(value)
if is_overwrite:
if seen_overwrite:
msg = create_error_message(
message="Can receive only one Overwrite value per super-step.",
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
)
raise InvalidUpdateError(msg)
self.value = overwrite_value
seen_overwrite = True
continue
if not seen_overwrite:
self.value = self.operator(self.value, value)
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING
def checkpoint(self) -> Value:
return self.value

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
__all__ = ("EphemeralValue",)
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the value received in the step immediately preceding, clears after."""
__slots__ = ("value", "guard")
value: Value | Any
guard: bool
def __init__(self, typ: Any, guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, EphemeralValue) and value.guard == self.guard
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
if self.value is not MISSING:
self.value = MISSING
return True
else:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING
def checkpoint(self) -> Value:
return self.value

View File

@@ -0,0 +1,151 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import (
EmptyChannelError,
ErrorCode,
InvalidUpdateError,
create_error_message,
)
__all__ = ("LastValue", "LastValueAfterFinish")
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, can receive at most one value per step."""
__slots__ = ("value",)
value: Value | Any
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, LastValue)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.key)
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.key)
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1:
msg = create_error_message(
message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.",
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
)
raise InvalidUpdateError(msg)
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING
def checkpoint(self) -> Value:
return self.value
class LastValueAfterFinish(
Generic[Value], BaseChannel[Value, Value, tuple[Value, bool]]
):
"""Stores the last value received, but only made available after finish().
Once made available, clears the value."""
__slots__ = ("value", "finished")
value: Value | Any
finished: bool
def __init__(self, typ: Any, key: str = "") -> None:
super().__init__(typ, key)
self.value = MISSING
self.finished = False
def __eq__(self, value: object) -> bool:
return isinstance(value, LastValueAfterFinish)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def checkpoint(self) -> tuple[Value | Any, bool] | Any:
if self.value is MISSING:
return MISSING
return (self.value, self.finished)
def from_checkpoint(self, checkpoint: tuple[Value | Any, bool] | Any) -> Self:
empty = self.__class__(self.typ)
empty.key = self.key
if checkpoint is not MISSING:
empty.value, empty.finished = checkpoint
return empty
def update(self, values: Sequence[Value | Any]) -> bool:
if len(values) == 0:
return False
self.finished = False
self.value = values[-1]
return True
def consume(self) -> bool:
if self.finished:
self.finished = False
self.value = MISSING
return True
return False
def finish(self) -> bool:
if not self.finished and self.value is not MISSING:
self.finished = True
return True
else:
return False
def get(self) -> Value:
if self.value is MISSING or not self.finished:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING and self.finished

View File

@@ -0,0 +1,167 @@
from collections.abc import Sequence
from typing import Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
__all__ = ("NamedBarrierValue", "NamedBarrierValueAfterFinish")
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
"""A channel that waits until all named values are received before making the value available."""
__slots__ = ("names", "seen")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
def __eq__(self, value: object) -> bool:
return isinstance(value, NamedBarrierValue) and value.names == self.names
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.names)
empty.key = self.key
empty.seen = self.seen.copy()
return empty
def checkpoint(self) -> set[Value]:
return self.seen
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not MISSING:
empty.seen = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
def get(self) -> Value:
if self.seen != self.names:
raise EmptyChannelError()
return None
def is_available(self) -> bool:
return self.seen == self.names
def consume(self) -> bool:
if self.seen == self.names:
self.seen = set()
return True
return False
class NamedBarrierValueAfterFinish(
Generic[Value], BaseChannel[Value, Value, set[Value]]
):
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
__slots__ = ("names", "seen", "finished")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
self.finished = False
def __eq__(self, value: object) -> bool:
return (
isinstance(value, NamedBarrierValueAfterFinish)
and value.names == self.names
)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.names)
empty.key = self.key
empty.seen = self.seen.copy()
empty.finished = self.finished
return empty
def checkpoint(self) -> tuple[set[Value], bool]:
return (self.seen, self.finished)
def from_checkpoint(self, checkpoint: tuple[set[Value], bool]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not MISSING:
empty.seen, empty.finished = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
def get(self) -> Value:
if not self.finished or self.seen != self.names:
raise EmptyChannelError()
return None
def is_available(self) -> bool:
return self.finished and self.seen == self.names
def consume(self) -> bool:
if self.finished and self.seen == self.names:
self.finished = False
self.seen = set()
return True
return False
def finish(self) -> bool:
if not self.finished and self.seen == self.names:
self.finished = True
return True
else:
return False

View File

@@ -0,0 +1,94 @@
from __future__ import annotations
from collections.abc import Iterator, Sequence
from typing import Any, Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError
__all__ = ("Topic",)
def _flatten(values: Sequence[Value | list[Value]]) -> Iterator[Value]:
for value in values:
if isinstance(value, list):
yield from value
else:
yield value
class Topic(
Generic[Value],
BaseChannel[Sequence[Value], Value | list[Value], list[Value]],
):
"""A configurable PubSub Topic.
Args:
typ: The type of the value stored in the channel.
accumulate: Whether to accumulate values across steps. If `False`, the channel will be emptied after each step.
"""
__slots__ = ("values", "accumulate")
def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
super().__init__(typ)
# attrs
self.accumulate = accumulate
# state
self.values = list[Value]()
def __eq__(self, value: object) -> bool:
return isinstance(value, Topic) and value.accumulate == self.accumulate
@property
def ValueType(self) -> Any:
"""The type of the value stored in the channel."""
return Sequence[self.typ] # type: ignore[name-defined]
@property
def UpdateType(self) -> Any:
"""The type of the update received by the channel."""
return self.typ | list[self.typ] # type: ignore[name-defined]
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.accumulate)
empty.key = self.key
empty.values = self.values.copy()
return empty
def checkpoint(self) -> list[Value]:
return self.values
def from_checkpoint(self, checkpoint: list[Value]) -> Self:
empty = self.__class__(self.typ, self.accumulate)
empty.key = self.key
if checkpoint is not MISSING:
if isinstance(checkpoint, tuple):
# backwards compatibility
empty.values = checkpoint[1]
else:
empty.values = checkpoint
return empty
def update(self, values: Sequence[Value | list[Value]]) -> bool:
updated = False
if not self.accumulate:
updated = bool(self.values)
self.values = list[Value]()
if flat_values := tuple(_flatten(values)):
updated = True
self.values.extend(flat_values)
return updated
def get(self) -> Sequence[Value]:
if self.values:
return list(self.values)
else:
raise EmptyChannelError
def is_available(self) -> bool:
return bool(self.values)

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
__all__ = ("UntrackedValue",)
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the last value received, never checkpointed."""
__slots__ = ("value", "guard")
guard: bool
value: Value | Any
def __init__(self, typ: type[Value], guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, UntrackedValue) and value.guard == self.guard
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
empty.value = self.value
return empty
def checkpoint(self) -> Value | Any:
return MISSING
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.guard)
empty.key = self.key
return empty
def update(self, values: Sequence[Value]) -> bool:
if len(values) == 0:
return False
if len(values) != 1 and self.guard:
raise InvalidUpdateError(
f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
)
self.value = values[-1]
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING

View File

@@ -0,0 +1,628 @@
from __future__ import annotations
import copy
import logging
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, Sequence
from typing import ( # noqa: UP035
Any,
Generic,
Literal,
NamedTuple,
TypedDict,
TypeVar,
)
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base.id import uuid6
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import (
ERROR,
INTERRUPT,
RESUME,
SCHEDULED,
ChannelProtocol,
)
V = TypeVar("V", int, float, str)
PendingWrite = tuple[str, str, Any]
logger = logging.getLogger(__name__)
# Marked as total=False to allow for future expansion.
class CheckpointMetadata(TypedDict, total=False):
"""Metadata associated with a checkpoint."""
source: Literal["input", "loop", "update", "fork"]
"""The source of the checkpoint.
- `"input"`: The checkpoint was created from an input to invoke/stream/batch.
- `"loop"`: The checkpoint was created from inside the pregel loop.
- `"update"`: The checkpoint was created from a manual state update.
- `"fork"`: The checkpoint was created as a copy of another checkpoint.
"""
step: int
"""The step number of the checkpoint.
`-1` for the first `"input"` checkpoint.
`0` for the first `"loop"` checkpoint.
`...` for the `nth` checkpoint afterwards.
"""
parents: dict[str, str]
"""The IDs of the parent checkpoints.
Mapping from checkpoint namespace to checkpoint ID.
"""
run_id: str
"""The ID of the run that created this checkpoint."""
ChannelVersions = dict[str, str | int | float]
class Checkpoint(TypedDict):
"""State snapshot at a given point in time."""
v: int
"""The version of the checkpoint format. Currently `1`."""
id: str
"""The ID of the checkpoint.
This is both unique and monotonically increasing, so can be used for sorting
checkpoints from first to last."""
ts: str
"""The timestamp of the checkpoint in ISO 8601 format."""
channel_values: dict[str, Any]
"""The values of the channels at the time of the checkpoint.
Mapping from channel name to deserialized channel snapshot value.
"""
channel_versions: ChannelVersions
"""The versions of the channels at the time of the checkpoint.
The keys are channel names and the values are monotonically increasing
version strings for each channel.
"""
versions_seen: dict[str, ChannelVersions]
"""Map from node ID to map from channel name to version seen.
This keeps track of the versions of the channels that each node has seen.
Used to determine which nodes to execute next.
"""
updated_channels: list[str] | None
"""The channels that were updated in this checkpoint.
"""
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
return Checkpoint(
v=checkpoint["v"],
ts=checkpoint["ts"],
id=checkpoint["id"],
channel_values=checkpoint["channel_values"].copy(),
channel_versions=checkpoint["channel_versions"].copy(),
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
pending_sends=checkpoint.get("pending_sends", []).copy(),
updated_channels=checkpoint.get("updated_channels", None),
)
class CheckpointTuple(NamedTuple):
"""A tuple containing a checkpoint and its associated data."""
config: RunnableConfig
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: RunnableConfig | None = None
pending_writes: list[PendingWrite] | None = None
class BaseCheckpointSaver(Generic[V]):
"""Base class for creating a graph checkpointer.
Checkpointers allow LangGraph agents to persist their state
within and across multiple interactions.
When a checkpointer is configured, you should pass a `thread_id` in the config when
invoking the graph:
```python
config = {"configurable": {"thread_id": "my-thread"}}
graph.invoke(inputs, config)
```
The `thread_id` is the primary key used to store and retrieve checkpoints. Without
it, the checkpointer cannot save state, resume from interrupts, or enable
time-travel debugging.
How you choose ``thread_id`` depends on your use case:
- **Single-shot workflows**: Use a unique ID (e.g., uuid4) for each run when
executions are independent.
- **Conversational memory**: Reuse the same `thread_id` across invocations
to accumulate state (e.g., chat history) within a conversation.
Attributes:
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
Note:
When creating a custom checkpoint saver, consider implementing async
versions to avoid blocking the main thread.
"""
serde: SerializerProtocol = JsonPlusSerializer()
def __init__(
self,
*,
serde: SerializerProtocol | None = None,
) -> None:
self.serde = maybe_add_typed_methods(serde or self.serde)
@property
def config_specs(self) -> list:
"""Define the configuration options for the checkpoint saver.
Returns:
list: List of configuration field specs.
"""
return []
def get(self, config: RunnableConfig) -> Checkpoint | None:
"""Fetch a checkpoint using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint, or `None` if not found.
"""
if value := self.get_tuple(config):
return value.checkpoint
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Fetch a checkpoint tuple using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint tuple, or `None` if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints that match the given criteria.
Args:
config: Base configuration for filtering checkpoints.
filter: Additional filtering criteria.
before: List checkpoints created before this configuration.
limit: Maximum number of checkpoints to return.
Returns:
Iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store a checkpoint with its configuration and metadata.
Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def delete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
raise NotImplementedError
def delete_for_runs(
self,
run_ids: Sequence[str],
) -> None:
"""Delete all checkpoints and writes associated with the given run IDs.
Args:
run_ids: The run IDs whose checkpoints should be deleted.
"""
raise NotImplementedError
def copy_thread(
self,
source_thread_id: str,
target_thread_id: str,
) -> None:
"""Copy all checkpoints and writes from one thread to another.
Args:
source_thread_id: The thread ID to copy from.
target_thread_id: The thread ID to copy to.
"""
raise NotImplementedError
def prune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
) -> None:
"""Prune checkpoints for the given threads.
Args:
thread_ids: The thread IDs to prune.
strategy: The pruning strategy. `"keep_latest"` retains only the most
recent checkpoint per namespace. `"delete"` removes all checkpoints.
"""
raise NotImplementedError
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
"""Asynchronously fetch a checkpoint using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint, or `None` if not found.
"""
if value := await self.aget_tuple(config):
return value.checkpoint
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Asynchronously fetch a checkpoint tuple using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint tuple, or `None` if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def alist(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.
Args:
config: Base configuration for filtering checkpoints.
filter: Additional filtering criteria for metadata.
before: List checkpoints created before this configuration.
limit: Maximum number of checkpoints to return.
Returns:
Async iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
yield
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.
Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def adelete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
raise NotImplementedError
async def adelete_for_runs(
self,
run_ids: Sequence[str],
) -> None:
"""Asynchronously delete all checkpoints and writes for the given run IDs.
Args:
run_ids: The run IDs whose checkpoints should be deleted.
"""
raise NotImplementedError
async def acopy_thread(
self,
source_thread_id: str,
target_thread_id: str,
) -> None:
"""Asynchronously copy all checkpoints and writes from one thread to another.
Args:
source_thread_id: The thread ID to copy from.
target_thread_id: The thread ID to copy to.
"""
raise NotImplementedError
async def aprune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
) -> None:
"""Asynchronously prune checkpoints for the given threads.
Args:
thread_ids: The thread IDs to prune.
strategy: The pruning strategy. `"keep_latest"` retains only the most
recent checkpoint per namespace. `"delete"` removes all checkpoints.
"""
raise NotImplementedError
def get_next_version(self, current: V | None, channel: None) -> V:
"""Generate the next version ID for a channel.
Default is to use integer versions, incrementing by `1`.
If you override, you can use `str`/`int`/`float` versions, as long as they are monotonically increasing.
Args:
current: The current version identifier (`int`, `float`, or `str`).
channel: Deprecated argument, kept for backwards compatibility.
Returns:
V: The next version identifier, which must be increasing.
"""
if isinstance(current, str):
raise NotImplementedError
elif current is None:
return 1
else:
return current + 1
def with_allowlist(
self, extra_allowlist: Collection[tuple[str, ...]]
) -> BaseCheckpointSaver[V]:
"""Return a shallow clone with a derived msgpack allowlist."""
serde = _with_msgpack_allowlist(self.serde, extra_allowlist)
if serde is self.serde:
return self
clone = copy.copy(self)
clone.serde = maybe_add_typed_methods(serde)
return clone
def _with_msgpack_allowlist(
serde: SerializerProtocol, extra_allowlist: Collection[tuple[str, ...]]
) -> SerializerProtocol:
if isinstance(serde, JsonPlusSerializer):
return serde.with_msgpack_allowlist(extra_allowlist)
if isinstance(serde, EncryptedSerializer):
inner = serde.serde
if isinstance(inner, JsonPlusSerializer):
updated_inner = inner.with_msgpack_allowlist(extra_allowlist)
if updated_inner is inner:
return serde
return EncryptedSerializer(serde.cipher, updated_inner)
logger.warning(
"Serializer %s does not support msgpack allowlist. "
"Strict msgpack deserialization will not be enforced.",
type(serde).__name__,
)
return serde
class EmptyChannelError(Exception):
"""Raised when attempting to get the value of a channel that hasn't been updated
for the first time yet."""
pass
def get_checkpoint_id(config: RunnableConfig) -> str | None:
"""Get checkpoint ID."""
return config["configurable"].get("checkpoint_id")
def get_checkpoint_metadata(
config: RunnableConfig, metadata: CheckpointMetadata
) -> CheckpointMetadata:
"""Get checkpoint metadata in a backwards-compatible manner."""
metadata = {
k: v.replace("\u0000", "") if isinstance(v, str) else v
for k, v in metadata.items()
}
for obj in (config.get("metadata"), config.get("configurable")):
if not obj:
continue
for key, v in obj.items():
if key in metadata or key in EXCLUDED_METADATA_KEYS or key.startswith("__"):
continue
elif isinstance(v, str):
metadata[key] = v.replace("\u0000", "")
elif isinstance(v, (int, bool, float)):
metadata[key] = v
return metadata
def get_serializable_checkpoint_metadata(
config: RunnableConfig, metadata: CheckpointMetadata
) -> CheckpointMetadata:
"""Get checkpoint metadata in a backwards-compatible manner."""
checkpoint_metadata = get_checkpoint_metadata(config, metadata)
if "writes" in checkpoint_metadata:
checkpoint_metadata.pop("writes")
return checkpoint_metadata
"""
Mapping from error type to error index.
Regular writes just map to their index in the list of writes being saved.
Special writes (e.g. errors) map to negative indices, to avoid those writes from
conflicting with regular writes.
Each Checkpointer implementation should use this mapping in put_writes.
"""
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
EXCLUDED_METADATA_KEYS = {
"thread_id",
"checkpoint_id",
"checkpoint_ns",
"checkpoint_map",
"langgraph_step",
"langgraph_node",
"langgraph_triggers",
"langgraph_path",
"langgraph_checkpoint_ns",
}
# --- below are deprecated utilities used by past versions of LangGraph ---
LATEST_VERSION = 2
def empty_checkpoint() -> Checkpoint:
from datetime import datetime, timezone
return Checkpoint(
v=LATEST_VERSION,
id=str(uuid6(clock_seq=-2)),
ts=datetime.now(timezone.utc).isoformat(),
channel_values={},
channel_versions={},
versions_seen={},
pending_sends=[],
updated_channels=None,
)
def create_checkpoint(
checkpoint: Checkpoint,
channels: Mapping[str, ChannelProtocol] | None,
step: int,
*,
id: str | None = None,
) -> Checkpoint:
"""Create a checkpoint for the given channels."""
from datetime import datetime, timezone
ts = datetime.now(timezone.utc).isoformat()
if channels is None:
values = checkpoint["channel_values"]
else:
values = {}
for k, v in channels.items():
if k not in checkpoint["channel_versions"]:
continue
try:
values[k] = v.checkpoint()
except EmptyChannelError:
pass
return Checkpoint(
v=LATEST_VERSION,
ts=ts,
id=id or str(uuid6(clock_seq=step)),
channel_values=values,
channel_versions=checkpoint["channel_versions"],
versions_seen=checkpoint["versions_seen"],
pending_sends=checkpoint.get("pending_sends", []),
updated_channels=None,
)

View File

@@ -0,0 +1,109 @@
"""Adapted from
https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95
Bundled in to avoid install issues with uuid6 package
"""
from __future__ import annotations
import random
import time
import uuid
_last_v6_timestamp = None
class UUID(uuid.UUID):
r"""UUID draft version objects"""
__slots__ = ()
def __init__(
self,
hex: str | None = None,
bytes: bytes | None = None,
bytes_le: bytes | None = None,
fields: tuple[int, int, int, int, int, int] | None = None,
int: int | None = None,
version: int | None = None,
*,
is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown,
) -> None:
r"""Create a UUID."""
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
return super().__init__(
hex=hex,
bytes=bytes,
bytes_le=bytes_le,
fields=fields,
int=int,
version=version,
is_safe=is_safe,
)
if not 0 <= int < 1 << 128:
raise ValueError("int is out of range (need a 128-bit value)")
if version is not None:
if not 6 <= version <= 8:
raise ValueError("illegal version number")
# Set the variant to RFC 4122.
int &= ~(0xC000 << 48)
int |= 0x8000 << 48
# Set the version number.
int &= ~(0xF000 << 64)
int |= version << 76
super().__init__(int=int, is_safe=is_safe)
@property
def subsec(self) -> int:
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
@property
def time(self) -> int:
if self.version == 6:
return (
(self.time_low << 28)
| (self.time_mid << 12)
| (self.time_hi_version & 0x0FFF)
)
if self.version == 7:
return self.int >> 80
if self.version == 8:
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
return super().time
def _subsec_decode(value: int) -> int:
return -(-value * 10**6 // 2**20)
def uuid6(node: int | None = None, clock_seq: int | None = None) -> UUID:
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
improved DB locality. It is expected that UUIDv6 will primarily be
used in contexts where there are existing v1 UUIDs. Systems that do
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
If 'node' is not given, a random 48-bit number is chosen.
If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
global _last_v6_timestamp
nanoseconds = time.time_ns()
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = nanoseconds // 100 + 0x01B21DD213814000
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
timestamp = _last_v6_timestamp + 1
_last_v6_timestamp = timestamp
if clock_seq is None:
clock_seq = random.getrandbits(14) # instead of stable storage
if node is None:
node = random.getrandbits(48)
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
time_low_and_version = timestamp & 0x0FFF
uuid_int = time_high_and_time_mid << 80
uuid_int |= time_low_and_version << 64
uuid_int |= (clock_seq & 0x3FFF) << 48
uuid_int |= node & 0xFFFFFFFFFFFF
return UUID(int=uuid_int, version=6)

View File

@@ -0,0 +1,603 @@
from __future__ import annotations
import logging
import os
import pickle
import random
import shutil
from collections import defaultdict
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
SerializerProtocol,
get_checkpoint_id,
get_checkpoint_metadata,
)
logger = logging.getLogger(__name__)
class InMemorySaver(
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
):
"""An in-memory checkpoint saver.
This checkpoint saver stores checkpoints in memory using a `defaultdict`.
Note:
Only use `InMemorySaver` for debugging or testing purposes.
For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`.
If you are using LangSmith Deployment, no checkpointer needs to be specified. The correct managed checkpointer will be used automatically.
Args:
serde: The serializer to use for serializing and deserializing checkpoints.
Example:
```python
import asyncio
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph
builder = StateGraph(int)
builder.add_node("add_one", lambda x: x + 1)
builder.set_entry_point("add_one")
builder.set_finish_point("add_one")
memory = InMemorySaver()
graph = builder.compile(checkpointer=memory)
coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
asyncio.run(coro) # Output: 2
```
"""
# thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping
storage: defaultdict[
str,
dict[str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], str | None]]],
]
# (thread ID, checkpoint NS, checkpoint ID) -> (task ID, write idx)
writes: defaultdict[
tuple[str, str, str],
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
]
blobs: dict[
tuple[
str, str, str, str | int | float
], # thread id, checkpoint ns, channel, version
tuple[str, bytes],
]
def __init__(
self,
*,
serde: SerializerProtocol | None = None,
factory: type[defaultdict] = defaultdict,
) -> None:
super().__init__(serde=serde)
self.storage = factory(lambda: defaultdict(dict))
self.writes = factory(dict)
self.blobs = factory()
self.stack = ExitStack()
if factory is not defaultdict:
self.stack.enter_context(self.storage) # type: ignore[arg-type]
self.stack.enter_context(self.writes) # type: ignore[arg-type]
self.stack.enter_context(self.blobs) # type: ignore[arg-type]
def __enter__(self) -> InMemorySaver:
self.stack.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
return self.stack.__exit__(exc_type, exc_value, traceback)
async def __aenter__(self) -> InMemorySaver:
self.stack.__enter__()
return self
async def __aexit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> bool | None:
return self.stack.__exit__(__exc_type, __exc_value, __traceback)
def _load_blobs(
self, thread_id: str, checkpoint_ns: str, versions: ChannelVersions
) -> dict[str, Any]:
channel_values: dict[str, Any] = {}
for k, v in versions.items():
kk = (thread_id, checkpoint_ns, k, v)
if kk in self.blobs:
vv = self.blobs[kk]
if vv[0] != "empty":
channel_values[k] = self.serde.loads_typed(vv)
return channel_values
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Get a checkpoint tuple from the in-memory storage.
This method retrieves a checkpoint tuple from the in-memory storage based on the
provided config. If the config contains a `checkpoint_id` key, the checkpoint with
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config: The config to use for retrieving the checkpoint.
Returns:
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id: str = config["configurable"]["thread_id"]
checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id := get_checkpoint_id(config):
if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
checkpoint, metadata, parent_checkpoint_id = saved
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
return CheckpointTuple(
config=config,
checkpoint={
**checkpoint_,
"channel_values": self._load_blobs(
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
),
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
],
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
)
else:
if checkpoints := self.storage[thread_id][checkpoint_ns]:
checkpoint_id = max(checkpoints.keys())
checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id]
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
checkpoint_ = self.serde.loads_typed(checkpoint)
return CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint={
**checkpoint_,
"channel_values": self._load_blobs(
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
),
},
metadata=self.serde.loads_typed(metadata),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
],
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
)
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the in-memory storage.
This method retrieves a list of checkpoint tuples from the in-memory storage based
on the provided criteria.
Args:
config: Base configuration for filtering checkpoints.
filter: Additional filtering criteria for metadata.
before: List checkpoints created before this configuration.
limit: Maximum number of checkpoints to return.
Yields:
An iterator of matching checkpoint tuples.
"""
thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
config_checkpoint_ns = (
config["configurable"].get("checkpoint_ns") if config else None
)
config_checkpoint_id = get_checkpoint_id(config) if config else None
for thread_id in thread_ids:
for checkpoint_ns in self.storage[thread_id].keys():
if (
config_checkpoint_ns is not None
and checkpoint_ns != config_checkpoint_ns
):
continue
for checkpoint_id, (
checkpoint,
metadata_b,
parent_checkpoint_id,
) in sorted(
self.storage[thread_id][checkpoint_ns].items(),
key=lambda x: x[0],
reverse=True,
):
# filter by checkpoint ID from config
if config_checkpoint_id and checkpoint_id != config_checkpoint_id:
continue
# filter by checkpoint ID from `before` config
if (
before
and (before_checkpoint_id := get_checkpoint_id(before))
and checkpoint_id >= before_checkpoint_id
):
continue
# filter by metadata
metadata = self.serde.loads_typed(metadata_b)
if filter and not all(
query_value == metadata.get(query_key)
for query_key, query_value in filter.items()
):
continue
# limit search results
if limit is not None and limit <= 0:
break
elif limit is not None:
limit -= 1
writes = self.writes[
(thread_id, checkpoint_ns, checkpoint_id)
].values()
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
},
checkpoint={
**checkpoint_,
"channel_values": self._load_blobs(
thread_id,
checkpoint_ns,
checkpoint_["channel_versions"],
),
},
metadata=metadata,
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": parent_checkpoint_id,
}
}
if parent_checkpoint_id
else None
),
pending_writes=[
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
],
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the in-memory storage.
This method saves a checkpoint to the in-memory storage. The checkpoint is associated
with the provided config.
Args:
config: The config to associate with the checkpoint.
checkpoint: The checkpoint to save.
metadata: Additional metadata to save with the checkpoint.
new_versions: New versions as of this write
Returns:
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
c = checkpoint.copy()
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
values: dict[str, Any] = c.pop("channel_values") # type: ignore[misc]
for k, v in new_versions.items():
self.blobs[(thread_id, checkpoint_ns, k, v)] = (
self.serde.dumps_typed(values[k]) if k in values else ("empty", b"")
)
self.storage[thread_id][checkpoint_ns].update(
{
checkpoint["id"]: (
self.serde.dumps_typed(c),
self.serde.dumps_typed(get_checkpoint_metadata(config, metadata)),
config["configurable"].get("checkpoint_id"), # parent
)
}
)
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Save a list of writes to the in-memory storage.
This method saves a list of writes to the in-memory storage. The writes are associated
with the provided config.
Args:
config: The config to associate with the writes.
writes: The writes to save.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Returns:
RunnableConfig: The updated config containing the saved writes' timestamp.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = config["configurable"]["checkpoint_id"]
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
outer_writes_ = self.writes.get(outer_key)
for idx, (c, v) in enumerate(writes):
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_:
continue
self.writes[outer_key][inner_key] = (
task_id,
c,
self.serde.dumps_typed(v),
task_path,
)
def delete_thread(self, thread_id: str) -> None:
"""Delete all checkpoints and writes associated with a thread ID.
Args:
thread_id: The thread ID to delete.
Returns:
None
"""
if thread_id in self.storage:
del self.storage[thread_id]
for k in list(self.writes.keys()):
if k[0] == thread_id:
del self.writes[k]
for k in list(self.blobs.keys()):
if k[0] == thread_id:
del self.blobs[k]
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Asynchronous version of `get_tuple`.
This method is an asynchronous wrapper around `get_tuple` that runs the synchronous
method in a separate thread using asyncio.
Args:
config: The config to use for retrieving the checkpoint.
Returns:
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
return self.get_tuple(config)
async def alist(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronous version of `list`.
This method is an asynchronous wrapper around `list` that runs the synchronous
method in a separate thread using asyncio.
Args:
config: The config to use for listing the checkpoints.
Yields:
An asynchronous iterator of checkpoint tuples.
"""
for item in self.list(config, filter=filter, before=before, limit=limit):
yield item
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronous version of `put`.
Args:
config: The config to associate with the checkpoint.
checkpoint: The checkpoint to save.
metadata: Additional metadata to save with the checkpoint.
new_versions: New versions as of this write
Returns:
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
return self.put(config, checkpoint, metadata, new_versions)
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronous version of `put_writes`.
This method is an asynchronous wrapper around `put_writes` that runs the synchronous
method in a separate thread using asyncio.
Args:
config: The config to associate with the writes.
writes: The writes to save, each as a (channel, value) pair.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Returns:
None
"""
return self.put_writes(config, writes, task_id, task_path)
async def adelete_thread(self, thread_id: str) -> None:
"""Delete all checkpoints and writes associated with a thread ID.
Args:
thread_id: The thread ID to delete.
Returns:
None
"""
return self.delete_thread(thread_id)
def get_next_version(self, current: str | None, channel: None) -> str:
if current is None:
current_v = 0
elif isinstance(current, int):
current_v = current
else:
current_v = int(current.split(".")[0])
next_v = current_v + 1
next_h = random.random()
return f"{next_v:032}.{next_h:016}"
MemorySaver = InMemorySaver # Kept for backwards compatibility
class PersistentDict(defaultdict):
"""Persistent dictionary with an API compatible with shelve and anydbm.
The dict is kept in memory, so the dictionary operations run as fast as
a regular dictionary.
Write to disk is delayed until close or sync (similar to gdbm's fast mode).
Input file format is automatically discovered.
Output file format is selectable between pickle, json, and csv.
All three serialization formats are backed by fast C implementations.
Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/
"""
def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
self.flag = "c" # r=readonly, c=create, or n=new
self.mode = None # None or an octal triple like 0644
self.format = "pickle" # 'csv', 'json', or 'pickle'
self.filename = filename
super().__init__(*args, **kwds)
def sync(self) -> None:
"Write dict to disk"
if self.flag == "r":
return
tempname = self.filename + ".tmp"
fileobj = open(tempname, "wb" if self.format == "pickle" else "w")
try:
self.dump(fileobj)
except Exception:
os.remove(tempname)
raise
finally:
fileobj.close()
shutil.move(tempname, self.filename) # atomic commit
if self.mode is not None:
os.chmod(self.filename, self.mode)
def close(self) -> None:
self.sync()
self.clear()
def __enter__(self) -> PersistentDict:
return self
def __exit__(self, *exc_info: Any) -> None:
self.close()
def dump(self, fileobj: Any) -> None:
if self.format == "pickle":
pickle.dump(dict(self), fileobj, 2)
else:
raise NotImplementedError("Unknown format: " + repr(self.format))
def load(self) -> None:
# try formats from most restrictive to least restrictive
if self.flag == "n":
return
with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj:
for loader in (pickle.load,):
fileobj.seek(0)
try:
return self.update(loader(fileobj))
except EOFError:
return
except Exception:
logger.error(f"Failed to load file: {fileobj.name}")
raise
raise ValueError("File not in a supported format")

View File

@@ -0,0 +1,89 @@
import os
from collections.abc import Iterable
from typing import cast
STRICT_MSGPACK_ENABLED = os.getenv("LANGGRAPH_STRICT_MSGPACK", "false").lower() in (
"1",
"true",
"yes",
)
_SENTINEL = cast(None, object())
SAFE_MSGPACK_TYPES: frozenset[tuple[str, ...]] = frozenset(
{
# datetime types
("datetime", "datetime"),
("datetime", "date"),
("datetime", "time"),
("datetime", "timedelta"),
("datetime", "timezone"),
# uuid
("uuid", "UUID"),
# numeric
("decimal", "Decimal"),
# collections
("builtins", "set"),
("builtins", "frozenset"),
("collections", "deque"),
# ip addresses
("ipaddress", "IPv4Address"),
("ipaddress", "IPv4Interface"),
("ipaddress", "IPv4Network"),
("ipaddress", "IPv6Address"),
("ipaddress", "IPv6Interface"),
("ipaddress", "IPv6Network"),
# pathlib
("pathlib", "Path"),
("pathlib", "PosixPath"),
("pathlib", "WindowsPath"),
# pathlib in Python 3.13+
("pathlib._local", "Path"),
("pathlib._local", "PosixPath"),
("pathlib._local", "WindowsPath"),
# zoneinfo
("zoneinfo", "ZoneInfo"),
# regex
("re", "compile"),
# langchain-core messages (safe container types used by graph state)
("langchain_core.messages.base", "BaseMessage"),
("langchain_core.messages.base", "BaseMessageChunk"),
("langchain_core.messages.human", "HumanMessage"),
("langchain_core.messages.human", "HumanMessageChunk"),
("langchain_core.messages.ai", "AIMessage"),
("langchain_core.messages.ai", "AIMessageChunk"),
("langchain_core.messages.system", "SystemMessage"),
("langchain_core.messages.system", "SystemMessageChunk"),
("langchain_core.messages.chat", "ChatMessage"),
("langchain_core.messages.chat", "ChatMessageChunk"),
("langchain_core.messages.tool", "ToolMessage"),
("langchain_core.messages.tool", "ToolMessageChunk"),
("langchain_core.messages.function", "FunctionMessage"),
("langchain_core.messages.function", "FunctionMessageChunk"),
("langchain_core.messages.modifier", "RemoveMessage"),
# langchain-core document model
("langchain_core.documents.base", "Document"),
# langgraph
("langgraph.types", "Send"),
("langgraph.types", "Interrupt"),
("langgraph.types", "Command"),
("langgraph.types", "StateSnapshot"),
("langgraph.types", "PregelTask"),
("langgraph.types", "Overwrite"),
("langgraph.store.base", "Item"),
("langgraph.store.base", "GetOp"),
}
)
# Allowed (module, name, method) triples for EXT_METHOD_SINGLE_ARG.
# Only these specific method invocations are permitted during deserialization.
# This is separate from SAFE_MSGPACK_TYPES which only governs construction.
SAFE_MSGPACK_METHODS: frozenset[tuple[str, str, str]] = frozenset(
{
("datetime", "datetime", "fromisoformat"),
}
)
AllowedMsgpackModules = Iterable[tuple[str, ...] | type]

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
class UntypedSerializerProtocol(Protocol):
"""Protocol for serialization and deserialization of objects."""
def dumps(self, obj: Any) -> bytes: ...
def loads(self, data: bytes) -> Any: ...
@runtime_checkable
class SerializerProtocol(Protocol):
"""Protocol for serialization and deserialization of objects.
- `dumps_typed`: Serialize an object to a tuple `(type, bytes)`.
- `loads_typed`: Deserialize an object from a tuple `(type, bytes)`.
Valid implementations include the `pickle`, `json` and `orjson` modules.
"""
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...
class SerializerCompat(SerializerProtocol):
def __init__(self, serde: UntypedSerializerProtocol) -> None:
self.serde = serde
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return type(obj).__name__, self.serde.dumps(obj)
def loads_typed(self, data: tuple[str, bytes]) -> Any:
return self.serde.loads(data[1])
def maybe_add_typed_methods(
serde: SerializerProtocol | UntypedSerializerProtocol,
) -> SerializerProtocol:
"""Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility."""
if not isinstance(serde, SerializerProtocol):
return SerializerCompat(serde)
return serde
class CipherProtocol(Protocol):
"""Protocol for encryption and decryption of data.
- `encrypt`: Encrypt plaintext.
- `decrypt`: Decrypt ciphertext.
"""
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
"""Encrypt plaintext. Returns a tuple `(cipher name, ciphertext)`."""
...
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
"""Decrypt ciphertext. Returns the plaintext."""
...

View File

@@ -0,0 +1,80 @@
import os
from typing import Any
from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
class EncryptedSerializer(SerializerProtocol):
"""Serializer that encrypts and decrypts data using an encryption protocol."""
def __init__(
self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
) -> None:
self.cipher = cipher
self.serde = serde
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
"""Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
# serialize data
typ, data = self.serde.dumps_typed(obj)
# encrypt data
ciphername, ciphertext = self.cipher.encrypt(data)
# add cipher name to type
return f"{typ}+{ciphername}", ciphertext
def loads_typed(self, data: tuple[str, bytes]) -> Any:
enc_cipher, ciphertext = data
# unencrypted data
if "+" not in enc_cipher:
return self.serde.loads_typed(data)
# extract cipher name
typ, ciphername = enc_cipher.split("+", 1)
# decrypt data
decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
# deserialize data
return self.serde.loads_typed((typ, decrypted_data))
@classmethod
def from_pycryptodome_aes(
cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
) -> "EncryptedSerializer":
"""Create an `EncryptedSerializer` using AES encryption."""
try:
from Crypto.Cipher import AES
except ImportError:
raise ImportError(
"Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
) from None
# check if AES key is provided
if "key" in kwargs:
key: bytes = kwargs.pop("key")
else:
key_str = os.getenv("LANGGRAPH_AES_KEY")
if key_str is None:
raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
key = key_str.encode()
if len(key) not in (16, 24, 32):
raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
# set default mode to EAX if not provided
if kwargs.get("mode") is None:
kwargs["mode"] = AES.MODE_EAX
class PycryptodomeAesCipher(CipherProtocol):
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
cipher = AES.new(key, **kwargs)
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
return "aes", cipher.nonce + tag + ciphertext
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
nonce = ciphertext[:16]
tag = ciphertext[16:32]
actual_ciphertext = ciphertext[32:]
cipher = AES.new(key, **kwargs, nonce=nonce)
return cipher.decrypt_and_verify(actual_ciphertext, tag)
return cls(PycryptodomeAesCipher(), serde)

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from threading import Lock
from typing import TypedDict
from typing_extensions import NotRequired
logger = logging.getLogger(__name__)
class SerdeEvent(TypedDict):
kind: str
module: str
name: str
method: NotRequired[str]
SerdeEventListener = Callable[[SerdeEvent], None]
_listeners: list[SerdeEventListener] = []
_listeners_lock = Lock()
def register_serde_event_listener(listener: SerdeEventListener) -> Callable[[], None]:
"""Register a listener for serde allowlist events."""
with _listeners_lock:
_listeners.append(listener)
def unregister() -> None:
with _listeners_lock:
try:
_listeners.remove(listener)
except ValueError:
pass
return unregister
def emit_serde_event(event: SerdeEvent) -> None:
"""Emit a serde event to all listeners.
Listener failures are isolated and logged.
"""
with _listeners_lock:
listeners = tuple(_listeners)
for listener in listeners:
try:
listener(event)
except Exception:
logger.warning("Serde listener failed", exc_info=True)

View File

@@ -0,0 +1,827 @@
from __future__ import annotations
import copy
import dataclasses
import decimal
import importlib
import json
import logging
import pathlib
import pickle
import re
import sys
from collections import deque
from collections.abc import Callable, Iterable, Sequence
from datetime import date, datetime, time, timedelta, timezone
from enum import Enum
from inspect import isclass
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import UUID
from zoneinfo import ZoneInfo
import ormsgpack
from langchain_core.load.load import Reviver
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.event_hooks import emit_serde_event
from langgraph.checkpoint.serde.types import SendProtocol
from langgraph.store.base import Item
if TYPE_CHECKING:
from langgraph.checkpoint.serde._msgpack import (
AllowedMsgpackModules,
)
from langgraph.checkpoint.serde.types import SendProtocol
LC_REVIVER = Reviver()
EMPTY_BYTES = b""
logger = logging.getLogger(__name__)
class JsonPlusSerializer(SerializerProtocol):
"""Serializer that uses ormsgpack, with optional fallbacks.
!!! warning
Security note: This serializer is intended for use within the `BaseCheckpointSaver`
class and called within the Pregel loop. It should not be used on untrusted
python objects. If an attacker can write directly to your checkpoint database,
they may be able to trigger code execution when data is deserialized.
"""
def __init__(
self,
*,
pickle_fallback: bool = False,
allowed_json_modules: Iterable[tuple[str, ...]] | Literal[True] | None = None,
allowed_msgpack_modules: (
AllowedMsgpackModules | Literal[True] | None
) = _lg_msgpack._SENTINEL,
__unpack_ext_hook__: Callable[[int, bytes], Any] | None = None,
) -> None:
if allowed_msgpack_modules is _lg_msgpack._SENTINEL:
if _lg_msgpack.STRICT_MSGPACK_ENABLED:
allowed_msgpack_modules = None
else:
allowed_msgpack_modules = True
self.pickle_fallback = pickle_fallback
self._allowed_json_modules: set[tuple[str, ...]] | Literal[True] | None = (
_normalize_allowlist(allowed_json_modules)
)
self._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
self._custom_unpack_ext_hook = __unpack_ext_hook__ is not None
self._unpack_ext_hook = (
__unpack_ext_hook__
if __unpack_ext_hook__ is not None
else _create_msgpack_ext_hook(self._allowed_msgpack_modules)
)
def with_msgpack_allowlist(
self, extra_allowlist: Iterable[tuple[str, ...] | type]
) -> JsonPlusSerializer:
"""Return a new serializer with a merged msgpack allowlist."""
base_allowlist = self._allowed_msgpack_modules
if base_allowlist is True or base_allowlist is False:
return self
elif base_allowlist:
base_allowlist = set(base_allowlist)
else:
base_allowlist = set()
extra = _normalize_module_keys(tuple(extra_allowlist))
merged = base_allowlist | extra
if merged == base_allowlist:
return self
allowed_msgpack_modules: AllowedMsgpackModules | Literal[True] | None
if merged:
allowed_msgpack_modules = tuple(merged)
elif isinstance(self._allowed_msgpack_modules, set):
allowed_msgpack_modules = tuple(self._allowed_msgpack_modules)
else:
allowed_msgpack_modules = self._allowed_msgpack_modules
clone = copy.copy(self)
clone._allowed_json_modules = _normalize_allowlist(self._allowed_json_modules)
clone._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
if not clone._custom_unpack_ext_hook:
clone._unpack_ext_hook = _create_msgpack_ext_hook(
clone._allowed_msgpack_modules
)
return clone
def _encode_constructor_args(
self,
constructor: Callable | type[Any],
*,
method: None | str | Sequence[None | str] = None,
args: Sequence[Any] | None = None,
kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
out = {
"lc": 2,
"type": "constructor",
"id": (*constructor.__module__.split("."), constructor.__name__),
}
if method is not None:
out["method"] = method
if args is not None:
out["args"] = args
if kwargs is not None:
out["kwargs"] = kwargs
return out
def _reviver(self, value: dict[str, Any]) -> Any:
if self._allowed_json_modules and (
value.get("lc", None) == 2
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
try:
return self._revive_lc2(value)
except InvalidModuleError as e:
logger.warning(
"Object %s is not in the deserialization allowlist.\n%s",
value["id"],
e.message,
)
return LC_REVIVER(value)
def _revive_lc2(self, value: dict[str, Any]) -> Any:
self._check_allowed_json_modules(value)
[*module, name] = value["id"]
try:
mod = importlib.import_module(".".join(module))
cls = getattr(mod, name)
method = value.get("method")
if isinstance(method, str):
methods = [getattr(cls, method)]
elif isinstance(method, list):
methods = [cls if m is None else getattr(cls, m) for m in method]
else:
methods = [cls]
args = value.get("args")
kwargs = value.get("kwargs")
for method in methods:
try:
if isclass(method) and issubclass(method, BaseException):
return None
if args and kwargs:
return method(*args, **kwargs)
elif args:
return method(*args)
elif kwargs:
return method(**kwargs)
else:
return method()
except Exception:
continue
except Exception:
return None
def _check_allowed_json_modules(self, value: dict[str, Any]) -> None:
needed = tuple(value["id"])
method = value.get("method")
if isinstance(method, list):
method_display = ",".join(m or "<init>" for m in method)
elif isinstance(method, str):
method_display = method
else:
method_display = "<init>"
dotted = ".".join(needed)
if not self._allowed_json_modules:
raise InvalidModuleError(
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
"No allowed_json_modules configured.\n\n"
"Unblock with ONE of:\n"
f" • JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
" • (DANGEROUS) JsonPlusSerializer(allowed_json_modules=True)\n\n"
"Note: Prefix allowlists are intentionally unsupported; prefer exact symbols "
"or plain-JSON representations revived without import-time side effects."
)
if self._allowed_json_modules is True:
return
if needed in self._allowed_json_modules:
return
raise InvalidModuleError(
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
"Symbol is not in the deserialization allowlist.\n\n"
"Add exactly this symbol to unblock:\n"
f" JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
"Or, as a last resort (DANGEROUS):\n"
" JsonPlusSerializer(allowed_json_modules=True)"
)
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
if obj is None:
return "null", EMPTY_BYTES
elif isinstance(obj, bytes):
return "bytes", obj
elif isinstance(obj, bytearray):
return "bytearray", obj
else:
try:
return "msgpack", _msgpack_enc(obj)
except ormsgpack.MsgpackEncodeError as exc:
if self.pickle_fallback:
return "pickle", pickle.dumps(obj)
raise exc
def loads_typed(self, data: tuple[str, bytes]) -> Any:
type_, data_ = data
if type_ == "null":
return None
elif type_ == "bytes":
return data_
elif type_ == "bytearray":
return bytearray(data_)
elif type_ == "json":
return json.loads(data_, object_hook=self._reviver)
elif type_ == "msgpack":
return ormsgpack.unpackb(
data_, ext_hook=self._unpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
elif self.pickle_fallback and type_ == "pickle":
return pickle.loads(data_)
else:
raise NotImplementedError(f"Unknown serialization type: {type_}")
# --- msgpack ---
EXT_CONSTRUCTOR_SINGLE_ARG = 0
EXT_CONSTRUCTOR_POS_ARGS = 1
EXT_CONSTRUCTOR_KW_ARGS = 2
EXT_METHOD_SINGLE_ARG = 3
EXT_PYDANTIC_V1 = 4
EXT_PYDANTIC_V2 = 5
EXT_NUMPY_ARRAY = 6
def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
return ormsgpack.Ext(
EXT_PYDANTIC_V2,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.model_dump(),
"model_validate_json",
),
),
)
elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.get_secret_value(),
),
),
)
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
return ormsgpack.Ext(
EXT_PYDANTIC_V1,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.dict(),
),
),
)
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj._asdict(),
),
),
)
elif isinstance(obj, pathlib.Path):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
),
)
elif isinstance(obj, re.Pattern):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
("re", "compile", (obj.pattern, obj.flags)),
),
)
elif isinstance(obj, UUID):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
),
)
elif isinstance(obj, decimal.Decimal):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (set, frozenset, deque)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
),
)
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, datetime):
return ormsgpack.Ext(
EXT_METHOD_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.isoformat(),
"fromisoformat",
),
),
)
elif isinstance(obj, timedelta):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.days, obj.seconds, obj.microseconds),
),
),
)
elif isinstance(obj, date):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.year, obj.month, obj.day),
),
),
)
elif isinstance(obj, time):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
"hour": obj.hour,
"minute": obj.minute,
"second": obj.second,
"microsecond": obj.microsecond,
"tzinfo": obj.tzinfo,
"fold": obj.fold,
},
),
),
)
elif isinstance(obj, timezone):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.__getinitargs__(), # type: ignore[attr-defined]
),
),
)
elif isinstance(obj, ZoneInfo):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
),
)
elif isinstance(obj, Enum):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
),
)
elif isinstance(obj, SendProtocol):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
field.name: getattr(obj, field.name)
for field in dataclasses.fields(obj)
},
),
),
)
elif isinstance(obj, Item):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{k: getattr(obj, k) for k in obj.__slots__},
),
),
)
elif (np_mod := sys.modules.get("numpy")) is not None and isinstance(
obj, np_mod.ndarray
):
order = "F" if obj.flags.f_contiguous and not obj.flags.c_contiguous else "C"
if obj.flags.c_contiguous:
mv = memoryview(obj)
try:
meta = (obj.dtype.str, obj.shape, order, mv)
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
finally:
mv.release()
else:
buf = obj.tobytes(order="A")
meta = (obj.dtype.str, obj.shape, order, buf)
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
elif isinstance(obj, BaseException):
return repr(obj)
else:
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
def _create_msgpack_ext_hook(
allowed_modules: set[tuple[str, ...]] | Literal[True] | None,
) -> Callable[[int, bytes], Any]:
"""Create msgpack ext hook with allowlist.
Args:
allowed_modules: Set of (module, name) tuples that are allowed to be
deserialized, or True to allow all with warnings for unregistered types, or None to only allow safe types.
Returns:
An ext_hook function for use with ormsgpack.unpackb.
"""
def _check_allowed(module: str, name: str) -> bool:
"""Check if type is allowed. Returns True if allowed, False if blocked."""
key = (module, name)
if key in _lg_msgpack.SAFE_MSGPACK_TYPES:
return True
if allowed_modules is True:
# default is to warn but allow unregistered types
emit_serde_event(
{
"kind": "msgpack_unregistered_allowed",
"module": module,
"name": name,
}
)
logger.warning(
"Deserializing unregistered type %s.%s from checkpoint. "
"This will be blocked in a future version. "
"Add to allowed_msgpack_modules to silence: [(%r, %r)]",
module,
name,
module,
name,
)
return True
if allowed_modules is not None:
if key in allowed_modules:
return True
# strict mode blocks unregistered types
emit_serde_event(
{
"kind": "msgpack_blocked",
"module": module,
"name": name,
}
)
logger.warning(
"Blocked deserialization of %s.%s - not in allowed_msgpack_modules. "
"Add to allowed_msgpack_modules to allow: [(%r, %r)]",
module,
name,
module,
name,
)
return False
def _check_allowed_method(module: str, name: str, method: str) -> bool:
"""Check if a method invocation is allowed."""
key = (module, name, method)
if key in _lg_msgpack.SAFE_MSGPACK_METHODS:
return True
emit_serde_event(
{
"kind": "msgpack_method_blocked",
"module": module,
"name": name,
"method": method,
}
)
logger.warning(
"Blocked deserialization of method call %s.%s.%s - "
"not in allowed methods set.",
module,
name,
method,
)
return False
def ext_hook(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
# We default to returning the raw data. If the user
# is using this in the context of a pydantic state, etc., then
# it would be validated upon construction.
return tup[2]
# module, name, arg
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
except Exception:
return None
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
except Exception:
return None
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
except Exception:
return None
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed_method(tup[0], tup[1], tup[3]):
return tup[2]
# module, name, arg, method
return getattr(
getattr(importlib.import_module(tup[0]), tup[1]), tup[3]
)(tup[2])
except Exception:
return None
elif code == EXT_PYDANTIC_V1:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.construct(**tup[2])
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
try:
return tup[2]
except NameError:
return None
elif code == EXT_PYDANTIC_V2:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs, method
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.model_construct(**tup[2])
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
try:
return tup[2]
except NameError:
return None
elif code == EXT_NUMPY_ARRAY:
try:
import numpy as _np
dtype_str, shape, order, buf = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
return arr.reshape(shape, order=order)
except Exception:
return None
return None
return ext_hook
# Aliasing in case anyone imported it directly
_msgpack_ext_hook = _create_msgpack_ext_hook(allowed_modules=None)
def _msgpack_ext_hook_to_json(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
if tup[0] == "uuid" and tup[1] == "UUID":
hex_ = tup[2]
return (
f"{hex_[:8]}-{hex_[8:12]}-{hex_[12:16]}-{hex_[16:20]}-{hex_[20:]}"
)
# module, name, arg
return tup[2]
except Exception:
return
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
if tup[0] == "langgraph.types" and tup[1] == "Send":
from langgraph.types import Send # type: ignore
return Send(*tup[2])
# module, name, args
return tup[2]
except Exception:
return
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, args
return tup[2]
except Exception:
return
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, arg, method
return tup[2]
except Exception:
return
elif code == EXT_PYDANTIC_V1:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, kwargs
return tup[2]
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
return
elif code == EXT_PYDANTIC_V2:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, kwargs, method
return tup[2]
except Exception:
return
elif code == EXT_NUMPY_ARRAY:
try:
import numpy as _np
dtype_str, shape, order, buf = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
return arr.reshape(shape, order=order).tolist()
except Exception:
return
class InvalidModuleError(Exception):
"""Exception raised when a module is not in the allowlist."""
def __init__(self, message: str):
self.message = message
_option = (
ormsgpack.OPT_NON_STR_KEYS
| ormsgpack.OPT_PASSTHROUGH_DATACLASS
| ormsgpack.OPT_PASSTHROUGH_DATETIME
| ormsgpack.OPT_PASSTHROUGH_ENUM
| ormsgpack.OPT_PASSTHROUGH_UUID
| ormsgpack.OPT_REPLACE_SURROGATES
)
def _msgpack_enc(data: Any) -> bytes:
return ormsgpack.packb(data, default=_msgpack_default, option=_option)
def _normalize_allowlist(
allowlist: AllowedMsgpackModules | Literal[True] | None,
) -> set[tuple[str, ...]] | Literal[True] | None:
if allowlist is True:
return allowlist
elif allowlist:
return _normalize_module_keys(allowlist)
else:
return None
def _normalize_module_keys(
modules: AllowedMsgpackModules,
) -> set[tuple[str, ...]]:
normalized: set[tuple[str, ...]] = set()
for module in modules:
if isclass(module):
normalized.add((module.__module__, module.__name__))
else:
normalized.add(cast(tuple[str, ...], module))
return normalized

View File

@@ -0,0 +1,51 @@
from collections.abc import Sequence
from typing import (
Any,
Protocol,
TypeVar,
runtime_checkable,
)
from typing_extensions import Self
ERROR = "__error__"
SCHEDULED = "__scheduled__"
INTERRUPT = "__interrupt__"
RESUME = "__resume__"
TASKS = "__pregel_tasks"
Value = TypeVar("Value", covariant=True)
Update = TypeVar("Update", contravariant=True)
C = TypeVar("C")
class ChannelProtocol(Protocol[Value, Update, C]):
# Mirrors langgraph.channels.base.BaseChannel
@property
def ValueType(self) -> Any: ...
@property
def UpdateType(self) -> Any: ...
def checkpoint(self) -> C | None: ...
def from_checkpoint(self, checkpoint: C | None) -> Self: ...
def update(self, values: Sequence[Update]) -> bool: ...
def get(self) -> Value: ...
def consume(self) -> bool: ...
@runtime_checkable
class SendProtocol(Protocol):
# Mirrors langgraph.constants.Send
node: str
arg: Any
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def __eq__(self, value: object) -> bool: ...

View File

@@ -0,0 +1,196 @@
import asyncio
import sys
from typing import Any
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import var_child_runnable_config
from langgraph.store.base import BaseStore
from langgraph._internal._constants import CONF, CONFIG_KEY_RUNTIME
from langgraph.types import StreamWriter
def _no_op_stream_writer(c: Any) -> None:
pass
def get_config() -> RunnableConfig:
if sys.version_info < (3, 11):
try:
if asyncio.current_task():
raise RuntimeError(
"Python 3.11 or later required to use this in an async context"
)
except RuntimeError:
pass
if var_config := var_child_runnable_config.get():
return var_config
else:
raise RuntimeError("Called get_config outside of a runnable context")
def get_store() -> BaseStore:
"""Access LangGraph store from inside a graph node or entrypoint task at runtime.
Can be called from inside any [`StateGraph`][langgraph.graph.StateGraph] node or
functional API [`task`][langgraph.func.task], as long as the `StateGraph` or the [`entrypoint`][langgraph.func.entrypoint]
was initialized with a store, e.g.:
```python
# with StateGraph
graph = (
StateGraph(...)
...
.compile(store=store)
)
# or with entrypoint
@entrypoint(store=store)
def workflow(inputs):
...
```
!!! warning "Async with Python < 3.11"
If you are using Python < 3.11 and are running LangGraph asynchronously,
`get_store()` won't work since it uses [`contextvar`](https://docs.python.org/3/library/contextvars.html) propagation (only available in [Python >= 3.11](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task)).
Example: Using with `StateGraph`
```python
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START
from langgraph.store.memory import InMemoryStore
from langgraph.config import get_store
store = InMemoryStore()
store.put(("values",), "foo", {"bar": 2})
class State(TypedDict):
foo: int
def my_node(state: State):
my_store = get_store()
stored_value = my_store.get(("values",), "foo").value["bar"]
return {"foo": stored_value + 1}
graph = (
StateGraph(State)
.add_node(my_node)
.add_edge(START, "my_node")
.compile(store=store)
)
graph.invoke({"foo": 1})
```
```pycon
{"foo": 3}
```
Example: Using with functional API
```python
from langgraph.func import entrypoint, task
from langgraph.store.memory import InMemoryStore
from langgraph.config import get_store
store = InMemoryStore()
store.put(("values",), "foo", {"bar": 2})
@task
def my_task(value: int):
my_store = get_store()
stored_value = my_store.get(("values",), "foo").value["bar"]
return stored_value + 1
@entrypoint(store=store)
def workflow(value: int):
return my_task(value).result()
workflow.invoke(1)
```
```pycon
3
```
"""
return get_config()[CONF][CONFIG_KEY_RUNTIME].store
def get_stream_writer() -> StreamWriter:
"""Access LangGraph [`StreamWriter`][langgraph.types.StreamWriter] from inside a graph node or entrypoint task at runtime.
Can be called from inside any [`StateGraph`][langgraph.graph.StateGraph] node or
functional API [`task`][langgraph.func.task].
!!! warning "Async with Python < 3.11"
If you are using Python < 3.11 and are running LangGraph asynchronously,
`get_stream_writer()` won't work since it uses [`contextvar`](https://docs.python.org/3/library/contextvars.html) propagation (only available in [Python >= 3.11](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task)).
Example: Using with `StateGraph`
```python
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START
from langgraph.config import get_stream_writer
class State(TypedDict):
foo: int
def my_node(state: State):
my_stream_writer = get_stream_writer()
my_stream_writer({"custom_data": "Hello!"})
return {"foo": state["foo"] + 1}
graph = (
StateGraph(State)
.add_node(my_node)
.add_edge(START, "my_node")
.compile(store=store)
)
for chunk in graph.stream({"foo": 1}, stream_mode="custom"):
print(chunk)
```
```pycon
{"custom_data": "Hello!"}
```
Example: Using with functional API
```python
from langgraph.func import entrypoint, task
from langgraph.config import get_stream_writer
@task
def my_task(value: int):
my_stream_writer = get_stream_writer()
my_stream_writer({"custom_data": "Hello!"})
return value + 1
@entrypoint(store=store)
def workflow(value: int):
return my_task(value).result()
for chunk in workflow.stream(1, stream_mode="custom"):
print(chunk)
```
```pycon
{"custom_data": "Hello!"}
```
"""
runtime = get_config()[CONF][CONFIG_KEY_RUNTIME]
return runtime.stream_writer

View File

@@ -0,0 +1,64 @@
import sys
from typing import Any
from warnings import warn
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINTER,
TASKS,
)
from langgraph.warnings import LangGraphDeprecatedSinceV10
__all__ = (
"TAG_NOSTREAM",
"TAG_HIDDEN",
"START",
"END",
# retained for backwards compatibility (mostly langgraph-api), should be removed in v2 (or earlier)
"CONF",
"TASKS",
"CONFIG_KEY_CHECKPOINTER",
)
# --- Public constants ---
TAG_NOSTREAM = sys.intern("nostream")
"""Tag to disable streaming for a chat model."""
TAG_HIDDEN = sys.intern("langsmith:hidden")
"""Tag to hide a node/edge from certain tracing/streaming environments."""
END = sys.intern("__end__")
"""The last (maybe virtual) node in graph-style Pregel."""
START = sys.intern("__start__")
"""The first (maybe virtual) node in graph-style Pregel."""
def __getattr__(name: str) -> Any:
if name in ["Send", "Interrupt"]:
warn(
f"Importing {name} from langgraph.constants is deprecated. "
f"Please use 'from langgraph.types import {name}' instead.",
LangGraphDeprecatedSinceV10,
stacklevel=2,
)
from importlib import import_module
module = import_module("langgraph.types")
return getattr(module, name)
try:
from importlib import import_module
private_constants = import_module("langgraph._internal._constants")
attr = getattr(private_constants, name)
warn(
f"Importing {name} from langgraph.constants is deprecated. "
f"This constant is now private and should not be used directly. "
"Please let the LangGraph team know if you need this value.",
LangGraphDeprecatedSinceV10,
stacklevel=2,
)
return attr
except AttributeError:
pass
raise AttributeError(f"module has no attribute '{name}'")

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
from collections.abc import Sequence
from enum import Enum
from typing import Any
from warnings import warn
# EmptyChannelError is re-exported from langgraph.channels.base
from langgraph.checkpoint.base import EmptyChannelError # noqa: F401
from typing_extensions import deprecated
from langgraph.types import Command, Interrupt
from langgraph.warnings import LangGraphDeprecatedSinceV10
__all__ = (
"EmptyChannelError",
"ErrorCode",
"GraphRecursionError",
"InvalidUpdateError",
"GraphBubbleUp",
"GraphInterrupt",
"NodeInterrupt",
"ParentCommand",
"EmptyInputError",
"TaskNotFound",
)
class ErrorCode(Enum):
GRAPH_RECURSION_LIMIT = "GRAPH_RECURSION_LIMIT"
INVALID_CONCURRENT_GRAPH_UPDATE = "INVALID_CONCURRENT_GRAPH_UPDATE"
INVALID_GRAPH_NODE_RETURN_VALUE = "INVALID_GRAPH_NODE_RETURN_VALUE"
MULTIPLE_SUBGRAPHS = "MULTIPLE_SUBGRAPHS"
INVALID_CHAT_HISTORY = "INVALID_CHAT_HISTORY"
def create_error_message(*, message: str, error_code: ErrorCode) -> str:
return (
f"{message}\n"
"For troubleshooting, visit: https://docs.langchain.com/oss/python/langgraph/"
f"errors/{error_code.value}"
)
class GraphRecursionError(RecursionError):
"""Raised when the graph has exhausted the maximum number of steps.
This prevents infinite loops. To increase the maximum number of steps,
run your graph with a config specifying a higher `recursion_limit`.
Troubleshooting guides:
- [`GRAPH_RECURSION_LIMIT`](https://docs.langchain.com/oss/python/langgraph/GRAPH_RECURSION_LIMIT)
Examples:
graph = builder.compile()
graph.invoke(
{"messages": [("user", "Hello, world!")]},
# The config is the second positional argument
{"recursion_limit": 1000},
)
"""
pass
class InvalidUpdateError(Exception):
"""Raised when attempting to update a channel with an invalid set of updates.
Troubleshooting guides:
- [`INVALID_CONCURRENT_GRAPH_UPDATE`](https://docs.langchain.com/oss/python/langgraph/INVALID_CONCURRENT_GRAPH_UPDATE)
- [`INVALID_GRAPH_NODE_RETURN_VALUE`](https://docs.langchain.com/oss/python/langgraph/INVALID_GRAPH_NODE_RETURN_VALUE)
"""
pass
class GraphBubbleUp(Exception):
pass
class GraphInterrupt(GraphBubbleUp):
"""Raised when a subgraph is interrupted, suppressed by the root graph.
Never raised directly, or surfaced to the user."""
def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None:
super().__init__(interrupts)
@deprecated(
"NodeInterrupt is deprecated. Please use [`interrupt`][langgraph.types.interrupt] instead.",
category=None,
)
class NodeInterrupt(GraphInterrupt):
"""Raised by a node to interrupt execution."""
def __init__(self, value: Any, id: str | None = None) -> None:
warn(
"NodeInterrupt is deprecated. Please use `langgraph.types.interrupt` instead.",
LangGraphDeprecatedSinceV10,
stacklevel=2,
)
if id is None:
super().__init__([Interrupt(value=value)])
else:
super().__init__([Interrupt(value=value, id=id)])
class ParentCommand(GraphBubbleUp):
args: tuple[Command]
def __init__(self, command: Command) -> None:
super().__init__(command)
class EmptyInputError(Exception):
"""Raised when graph receives an empty input."""
pass
class TaskNotFound(Exception):
"""Raised when the executor is unable to find a task (for distributed mode)."""
pass

View File

@@ -0,0 +1,575 @@
from __future__ import annotations
import functools
import inspect
import warnings
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import (
Any,
Generic,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
from langgraph.cache.base import BaseCache
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.store.base import BaseStore
from typing_extensions import Unpack
from langgraph._internal import _serde
from langgraph._internal._constants import CACHE_NS_WRITES, PREVIOUS
from langgraph._internal._typing import MISSING, DeprecatedKwargs
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.constants import END, START
from langgraph.pregel import Pregel
from langgraph.pregel._call import (
P,
SyncAsyncFuture,
T,
call,
get_runnable_for_entrypoint,
identifier,
)
from langgraph.pregel._read import PregelNode
from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
from langgraph.types import _DC_KWARGS, CachePolicy, RetryPolicy, StreamMode
from langgraph.typing import ContextT
from langgraph.warnings import LangGraphDeprecatedSinceV05, LangGraphDeprecatedSinceV10
__all__ = ("task", "entrypoint")
class _TaskFunction(Generic[P, T]):
def __init__(
self,
func: Callable[P, Awaitable[T]] | Callable[P, T],
*,
retry_policy: Sequence[RetryPolicy],
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
name: str | None = None,
) -> None:
if name is not None:
if hasattr(func, "__func__"):
# handle class methods
# NOTE: we're modifying the instance method to avoid modifying
# the original class method in case it's shared across multiple tasks
instance_method = functools.partial(func.__func__, func.__self__) # type: ignore [union-attr]
instance_method.__name__ = name # type: ignore [attr-defined]
func = instance_method
else:
# handle regular functions / partials / callable classes, etc.
func.__name__ = name
self.func = func
self.retry_policy = retry_policy
self.cache_policy = cache_policy
functools.update_wrapper(self, func)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> SyncAsyncFuture[T]:
return call(
self.func,
retry_policy=self.retry_policy,
cache_policy=self.cache_policy,
*args,
**kwargs,
)
def clear_cache(self, cache: BaseCache) -> None:
"""Clear the cache for this task."""
if self.cache_policy is not None:
cache.clear(((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),))
async def aclear_cache(self, cache: BaseCache) -> None:
"""Clear the cache for this task."""
if self.cache_policy is not None:
await cache.aclear(
((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),)
)
@overload
def task(
__func_or_none__: None = None,
*,
name: str | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> Callable[
[Callable[P, Awaitable[T]] | Callable[P, T]],
_TaskFunction[P, T],
]: ...
@overload
def task(__func_or_none__: Callable[P, Awaitable[T]]) -> _TaskFunction[P, T]: ...
@overload
def task(__func_or_none__: Callable[P, T]) -> _TaskFunction[P, T]: ...
def task(
__func_or_none__: Callable[P, Awaitable[T]] | Callable[P, T] | None = None,
*,
name: str | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> (
Callable[[Callable[P, Awaitable[T]] | Callable[P, T]], _TaskFunction[P, T]]
| _TaskFunction[P, T]
):
"""Define a LangGraph task using the `task` decorator.
!!! important "Requires python 3.11 or higher for async functions"
The `task` decorator supports both sync and async functions. To use async
functions, ensure that you are using Python 3.11 or higher.
Tasks can only be called from within an [`entrypoint`][langgraph.func.entrypoint] or
from within a `StateGraph`. A task can be called like a regular function with the
following differences:
- When a checkpointer is enabled, the function inputs and outputs must be serializable.
- The decorated function can only be called from within an entrypoint or `StateGraph`.
- Calling the function produces a future. This makes it easy to parallelize tasks.
Args:
name: An optional name for the task. If not provided, the function name will be used.
retry_policy: An optional retry policy (or list of policies) to use for the task in case of a failure.
cache_policy: An optional cache policy to use for the task. This allows caching of the task results.
Returns:
A callable function when used as a decorator.
Example: Sync Task
```python
from langgraph.func import entrypoint, task
@task
def add_one_task(a: int) -> int:
return a + 1
@entrypoint()
def add_one(numbers: list[int]) -> list[int]:
futures = [add_one_task(n) for n in numbers]
results = [f.result() for f in futures]
return results
# Call the entrypoint
add_one.invoke([1, 2, 3]) # Returns [2, 3, 4]
```
Example: Async Task
```python
import asyncio
from langgraph.func import entrypoint, task
@task
async def add_one_task(a: int) -> int:
return a + 1
@entrypoint()
async def add_one(numbers: list[int]) -> list[int]:
futures = [add_one_task(n) for n in numbers]
return asyncio.gather(*futures)
# Call the entrypoint
await add_one.ainvoke([1, 2, 3]) # Returns [2, 3, 4]
```
"""
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
warnings.warn(
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
category=LangGraphDeprecatedSinceV05,
stacklevel=2,
)
if retry_policy is None:
retry_policy = retry # type: ignore[assignment]
retry_policies: Sequence[RetryPolicy] = (
()
if retry_policy is None
else (retry_policy,)
if isinstance(retry_policy, RetryPolicy)
else retry_policy
)
def decorator(
func: Callable[P, Awaitable[T]] | Callable[P, T],
) -> Callable[P, SyncAsyncFuture[T]]:
return _TaskFunction(
func, retry_policy=retry_policies, cache_policy=cache_policy, name=name
)
if __func_or_none__ is not None:
return decorator(__func_or_none__)
return decorator
R = TypeVar("R")
S = TypeVar("S")
# The decorator was wrapped in a class to support the `final` attribute.
# In this form, the `final` attribute should play nicely with IDE autocompletion,
# and type checking tools.
# In addition, we'll be able to surface this information in the API Reference.
class entrypoint(Generic[ContextT]):
"""Define a LangGraph workflow using the `entrypoint` decorator.
### Function signature
The decorated function must accept a **single parameter**, which serves as the input
to the function. This input parameter can be of any type. Use a dictionary
to pass **multiple parameters** to the function.
### Injectable parameters
The decorated function can request access to additional parameters
that will be injected automatically at run time. These parameters include:
| Parameter | Description |
|------------------|------------------------------------------------------------------------------------------------------|
| **`config`** | A configuration object (aka `RunnableConfig`) that holds run-time configuration values. |
| **`previous`** | The previous return value for the given thread (available only when a checkpointer is provided). |
| **`runtime`** | A `Runtime` object that contains information about the current run, including context, store, writer |
The entrypoint decorator can be applied to sync functions or async functions.
### State management
The **`previous`** parameter can be used to access the return value of the previous
invocation of the entrypoint on the same thread id. This value is only available
when a checkpointer is provided.
If you want **`previous`** to be different from the return value, you can use the
`entrypoint.final` object to return a value while saving a different value to the
checkpoint.
Args:
checkpointer: Specify a checkpointer to create a workflow that can persist
its state across runs.
store: A generalized key-value store. Some implementations may support
semantic search capabilities through an optional `index` configuration.
cache: A cache to use for caching the results of the workflow.
context_schema: Specifies the schema for the context object that will be
passed to the workflow.
cache_policy: A cache policy to use for caching the results of the workflow.
retry_policy: A retry policy (or list of policies) to use for the workflow in case of a failure.
!!! warning "`config_schema` Deprecated"
The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0.
Please use `context_schema` instead to specify the schema for run-scoped context.
Example: Using entrypoint and tasks
```python
import time
from langgraph.func import entrypoint, task
from langgraph.types import interrupt, Command
from langgraph.checkpoint.memory import InMemorySaver
@task
def compose_essay(topic: str) -> str:
time.sleep(1.0) # Simulate slow operation
return f"An essay about {topic}"
@entrypoint(checkpointer=InMemorySaver())
def review_workflow(topic: str) -> dict:
\"\"\"Manages the workflow for generating and reviewing an essay.
The workflow includes:
1. Generating an essay about the given topic.
2. Interrupting the workflow for human review of the generated essay.
Upon resuming the workflow, compose_essay task will not be re-executed
as its result is cached by the checkpointer.
Args:
topic: The subject of the essay.
Returns:
dict: A dictionary containing the generated essay and the human review.
\"\"\"
essay_future = compose_essay(topic)
essay = essay_future.result()
human_review = interrupt({
\"question\": \"Please provide a review\",
\"essay\": essay
})
return {
\"essay\": essay,
\"review\": human_review,
}
# Example configuration for the workflow
config = {
\"configurable\": {
\"thread_id\": \"some_thread\"
}
}
# Topic for the essay
topic = \"cats\"
# Stream the workflow to generate the essay and await human review
for result in review_workflow.stream(topic, config):
print(result)
# Example human review provided after the interrupt
human_review = \"This essay is great.\"
# Resume the workflow with the provided human review
for result in review_workflow.stream(Command(resume=human_review), config):
print(result)
```
Example: Accessing the previous return value
When a checkpointer is enabled the function can access the previous return value
of the previous invocation on the same thread id.
```python
from typing import Optional
from langgraph.checkpoint.memory import MemorySaver
from langgraph.func import entrypoint
@entrypoint(checkpointer=InMemorySaver())
def my_workflow(input_data: str, previous: Optional[str] = None) -> str:
return "world"
config = {"configurable": {"thread_id": "some_thread"}}
my_workflow.invoke("hello", config)
```
Example: Using `entrypoint.final` to save a value
The `entrypoint.final` object allows you to return a value while saving
a different value to the checkpoint. This value will be accessible
in the next invocation of the entrypoint via the `previous` parameter, as
long as the same thread id is used.
```python
from typing import Any
from langgraph.checkpoint.memory import MemorySaver
from langgraph.func import entrypoint
@entrypoint(checkpointer=InMemorySaver())
def my_workflow(
number: int,
*,
previous: Any = None,
) -> entrypoint.final[int, int]:
previous = previous or 0
# This will return the previous value to the caller, saving
# 2 * number to the checkpoint, which will be used in the next invocation
# for the `previous` parameter.
return entrypoint.final(value=previous, save=2 * number)
config = {"configurable": {"thread_id": "some_thread"}}
my_workflow.invoke(3, config) # 0 (previous was None)
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
```
"""
def __init__(
self,
checkpointer: BaseCheckpointSaver | None = None,
store: BaseStore | None = None,
cache: BaseCache | None = None,
context_schema: type[ContextT] | None = None,
cache_policy: CachePolicy | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> None:
"""Initialize the entrypoint decorator."""
if (config_schema := kwargs.get("config_schema", MISSING)) is not MISSING:
warnings.warn(
"`config_schema` is deprecated and will be removed. Please use `context_schema` instead.",
category=LangGraphDeprecatedSinceV10,
stacklevel=2,
)
if context_schema is None:
context_schema = cast(type[ContextT], config_schema)
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
warnings.warn(
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
category=LangGraphDeprecatedSinceV05,
stacklevel=2,
)
if retry_policy is None:
retry_policy = cast("RetryPolicy | Sequence[RetryPolicy]", retry)
self.checkpointer = checkpointer
self.store = store
self.cache = cache
self.cache_policy = cache_policy
self.retry_policy = retry_policy
self.context_schema = context_schema
@dataclass(**_DC_KWARGS)
class final(Generic[R, S]):
"""A primitive that can be returned from an entrypoint.
This primitive allows to save a value to the checkpointer distinct from the
return value from the entrypoint.
Example: Decoupling the return value and the save value
```python
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.func import entrypoint
@entrypoint(checkpointer=InMemorySaver())
def my_workflow(
number: int,
*,
previous: Any = None,
) -> entrypoint.final[int, int]:
previous = previous or 0
# This will return the previous value to the caller, saving
# 2 * number to the checkpoint, which will be used in the next invocation
# for the `previous` parameter.
return entrypoint.final(value=previous, save=2 * number)
config = {"configurable": {"thread_id": "1"}}
my_workflow.invoke(3, config) # 0 (previous was None)
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
```
"""
value: R
"""Value to return. A value will always be returned even if it is `None`."""
save: S
"""The value for the state for the next checkpoint.
A value will always be saved even if it is `None`.
"""
def __call__(self, func: Callable[..., Any]) -> Pregel:
"""Convert a function into a Pregel graph.
Args:
func: The function to convert. Support both sync and async functions.
Returns:
A Pregel graph.
"""
# wrap generators in a function that writes to StreamWriter
if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func):
raise NotImplementedError(
"Generators are not supported in the Functional API."
)
bound = get_runnable_for_entrypoint(func)
stream_mode: StreamMode = "updates"
# get input and output types
sig = inspect.signature(func)
first_parameter_name = next(iter(sig.parameters.keys()), None)
if not first_parameter_name:
raise ValueError("Entrypoint function must have at least one parameter")
input_type = (
sig.parameters[first_parameter_name].annotation
if sig.parameters[first_parameter_name].annotation
is not inspect.Signature.empty
else Any
)
def _pluck_return_value(value: Any) -> Any:
"""Extract the return_ value the entrypoint.final object or passthrough."""
return value.value if isinstance(value, entrypoint.final) else value
def _pluck_save_value(value: Any) -> Any:
"""Get save value from the entrypoint.final object or passthrough."""
return value.save if isinstance(value, entrypoint.final) else value
output_type, save_type = Any, Any
if sig.return_annotation is not inspect.Signature.empty:
# User does not parameterize entrypoint.final properly
if (
sig.return_annotation is entrypoint.final
): # Un-parameterized entrypoint.final
output_type = save_type = Any
else:
origin = get_origin(sig.return_annotation)
if origin is entrypoint.final:
type_annotations = get_args(sig.return_annotation)
if len(type_annotations) != 2:
raise TypeError(
"Please an annotation for both the return_ and "
"the save values."
"For example, `-> entrypoint.final[int, str]` would assign a "
"return_ a type of `int` and save the type `str`."
)
output_type, save_type = get_args(sig.return_annotation)
else:
output_type = save_type = sig.return_annotation
graph: Pregel[Any, ContextT, Any, Any] = Pregel(
nodes={
func.__name__: PregelNode(
bound=bound,
triggers=[START],
channels=START,
writers=[
ChannelWrite(
[
ChannelWriteEntry(END, mapper=_pluck_return_value),
ChannelWriteEntry(PREVIOUS, mapper=_pluck_save_value),
]
)
],
)
},
channels={
START: EphemeralValue(input_type),
END: LastValue(output_type, END),
PREVIOUS: LastValue(save_type, PREVIOUS),
},
input_channels=START,
output_channels=END,
stream_channels=END,
stream_mode=stream_mode,
stream_eager=True,
checkpointer=self.checkpointer,
store=self.store,
cache=self.cache,
cache_policy=self.cache_policy,
retry_policy=self.retry_policy or (),
context_schema=self.context_schema,
)
if _serde.STRICT_MSGPACK_ENABLED:
serde_allowlist = _serde.build_serde_allowlist(
schemas=[input_type, output_type, save_type]
+ ([self.context_schema] if self.context_schema is not None else []),
channels=graph.channels,
)
graph._serde_allowlist = serde_allowlist
graph.checkpointer = _serde.apply_checkpointer_allowlist(
graph.checkpointer, serde_allowlist
)
return graph

View File

@@ -0,0 +1,12 @@
from langgraph.constants import END, START
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.graph.state import StateGraph
__all__ = (
"END",
"START",
"StateGraph",
"add_messages",
"MessagesState",
"MessageGraph",
)

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable, Hashable, Sequence
from inspect import (
isfunction,
ismethod,
signature,
)
from itertools import zip_longest
from types import FunctionType
from typing import (
Any,
Literal,
NamedTuple,
cast,
get_args,
get_origin,
get_type_hints,
)
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
)
from langgraph._internal._runnable import (
RunnableCallable,
)
from langgraph.constants import END, START
from langgraph.errors import InvalidUpdateError
from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry
from langgraph.types import Send
_Writer = Callable[
[Sequence[str | Send], bool],
Sequence[ChannelWriteEntry | Send],
]
def _get_branch_path_input_schema(
path: Callable[..., Hashable | Sequence[Hashable]]
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
| Runnable[Any, Hashable | Sequence[Hashable]],
) -> type[Any] | None:
input = None
# detect input schema annotation in the branch callable
try:
callable_: (
Callable[..., Hashable | Sequence[Hashable]]
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
| None
) = None
if isinstance(path, (RunnableCallable, RunnableLambda)):
if isfunction(path.func) or ismethod(path.func):
callable_ = path.func
elif (callable_method := getattr(path.func, "__call__", None)) and ismethod(
callable_method
):
callable_ = callable_method
elif isfunction(path.afunc) or ismethod(path.afunc):
callable_ = path.afunc
elif (
callable_method := getattr(path.afunc, "__call__", None)
) and ismethod(callable_method):
callable_ = callable_method
elif callable(path):
callable_ = path
if callable_ is not None and (hints := get_type_hints(callable_)):
first_parameter_name = next(
iter(signature(cast(FunctionType, callable_)).parameters.keys())
)
if input_hint := hints.get(first_parameter_name):
if isinstance(input_hint, type) and get_type_hints(input_hint):
input = input_hint
except (TypeError, StopIteration):
pass
return input
class BranchSpec(NamedTuple):
path: Runnable[Any, Hashable | list[Hashable]]
ends: dict[Hashable, str] | None
input_schema: type[Any] | None = None
@classmethod
def from_path(
cls,
path: Runnable[Any, Hashable | list[Hashable]],
path_map: dict[Hashable, str] | list[str] | None,
infer_schema: bool = False,
) -> BranchSpec:
# coerce path_map to a dictionary
path_map_: dict[Hashable, str] | None = None
try:
if isinstance(path_map, dict):
path_map_ = path_map.copy()
elif isinstance(path_map, list):
path_map_ = {name: name for name in path_map}
else:
# find func
func: Callable | None = None
if isinstance(path, (RunnableCallable, RunnableLambda)):
func = path.func or path.afunc
if func is not None:
# find callable method
if (cal := getattr(path, "__call__", None)) and ismethod(cal):
func = cal
# get the return type
if rtn_type := get_type_hints(func).get("return"):
if get_origin(rtn_type) is Literal:
path_map_ = {name: name for name in get_args(rtn_type)}
except Exception:
pass
# infer input schema
input_schema = _get_branch_path_input_schema(path) if infer_schema else None
# create branch
return cls(path=path, ends=path_map_, input_schema=input_schema)
def run(
self,
writer: _Writer,
reader: Callable[[RunnableConfig], Any] | None = None,
) -> RunnableCallable:
return ChannelWrite.register_writer(
RunnableCallable(
func=self._route,
afunc=self._aroute,
writer=writer,
reader=reader,
name=None,
trace=False,
),
list(
zip_longest(
writer([e for e in self.ends.values()], True),
[str(la) for la, e in self.ends.items()],
)
)
if self.ends
else None,
)
def _route(
self,
input: Any,
config: RunnableConfig,
*,
reader: Callable[[RunnableConfig], Any] | None,
writer: _Writer,
) -> Runnable:
if reader:
value = reader(config)
# passthrough additional keys from node to branch
# only doable when using dict states
if (
isinstance(value, dict)
and isinstance(input, dict)
and self.input_schema is None
):
value = {**input, **value}
else:
value = input
result = self.path.invoke(value, config)
return self._finish(writer, input, result, config)
async def _aroute(
self,
input: Any,
config: RunnableConfig,
*,
reader: Callable[[RunnableConfig], Any] | None,
writer: _Writer,
) -> Runnable:
if reader:
value = reader(config)
# passthrough additional keys from node to branch
# only doable when using dict states
if (
isinstance(value, dict)
and isinstance(input, dict)
and self.input_schema is None
):
value = {**input, **value}
else:
value = input
result = await self.path.ainvoke(value, config)
return self._finish(writer, input, result, config)
def _finish(
self,
writer: _Writer,
input: Any,
result: Any,
config: RunnableConfig,
) -> Runnable | Any:
if not isinstance(result, (list, tuple)):
result = [result]
if self.ends:
destinations: Sequence[Send | str] = [
r if isinstance(r, Send) else self.ends[r] for r in result
]
else:
destinations = cast(Sequence[Send | str], result)
if any(dest is None or dest == START for dest in destinations):
raise ValueError("Branch did not return a valid destination")
if any(p.node == END for p in destinations if isinstance(p, Send)):
raise InvalidUpdateError("Cannot send a packet to the END node")
entries = writer(destinations, False)
if not entries:
return input
else:
need_passthrough = False
for e in entries:
if isinstance(e, ChannelWriteEntry):
if e.value is PASSTHROUGH:
need_passthrough = True
break
if need_passthrough:
return ChannelWrite(entries)
else:
ChannelWrite.do_write(config, entries)
return input

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Generic, Protocol, TypeAlias
from langchain_core.runnables import Runnable, RunnableConfig
from langgraph.store.base import BaseStore
from langgraph._internal._typing import EMPTY_SEQ
from langgraph.runtime import Runtime
from langgraph.types import CachePolicy, RetryPolicy, StreamWriter
from langgraph.typing import ContextT, NodeInputT, NodeInputT_contra
class _Node(Protocol[NodeInputT_contra]):
def __call__(self, state: NodeInputT_contra) -> Any: ...
class _NodeWithConfig(Protocol[NodeInputT_contra]):
def __call__(self, state: NodeInputT_contra, config: RunnableConfig) -> Any: ...
class _NodeWithWriter(Protocol[NodeInputT_contra]):
def __call__(self, state: NodeInputT_contra, *, writer: StreamWriter) -> Any: ...
class _NodeWithStore(Protocol[NodeInputT_contra]):
def __call__(self, state: NodeInputT_contra, *, store: BaseStore) -> Any: ...
class _NodeWithWriterStore(Protocol[NodeInputT_contra]):
def __call__(
self, state: NodeInputT_contra, *, writer: StreamWriter, store: BaseStore
) -> Any: ...
class _NodeWithConfigWriter(Protocol[NodeInputT_contra]):
def __call__(
self, state: NodeInputT_contra, *, config: RunnableConfig, writer: StreamWriter
) -> Any: ...
class _NodeWithConfigStore(Protocol[NodeInputT_contra]):
def __call__(
self, state: NodeInputT_contra, *, config: RunnableConfig, store: BaseStore
) -> Any: ...
class _NodeWithConfigWriterStore(Protocol[NodeInputT_contra]):
def __call__(
self,
state: NodeInputT_contra,
*,
config: RunnableConfig,
writer: StreamWriter,
store: BaseStore,
) -> Any: ...
class _NodeWithRuntime(Protocol[NodeInputT_contra, ContextT]):
def __call__(
self, state: NodeInputT_contra, *, runtime: Runtime[ContextT]
) -> Any: ...
# TODO: we probably don't want to explicitly support the config / store signatures once
# we move to adding a context arg. Maybe what we do is we add support for kwargs with param spec
# this is purely for typing purposes though, so can easily change in the coming weeks.
StateNode: TypeAlias = (
_Node[NodeInputT]
| _NodeWithConfig[NodeInputT]
| _NodeWithWriter[NodeInputT]
| _NodeWithStore[NodeInputT]
| _NodeWithWriterStore[NodeInputT]
| _NodeWithConfigWriter[NodeInputT]
| _NodeWithConfigStore[NodeInputT]
| _NodeWithConfigWriterStore[NodeInputT]
| _NodeWithRuntime[NodeInputT, ContextT]
| Runnable[NodeInputT, Any]
)
@dataclass(slots=True)
class StateNodeSpec(Generic[NodeInputT, ContextT]):
runnable: StateNode[NodeInputT, ContextT]
metadata: dict[str, Any] | None
input_schema: type[NodeInputT]
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
cache_policy: CachePolicy | None
ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ
defer: bool = False

View File

@@ -0,0 +1,372 @@
from __future__ import annotations
import uuid
import warnings
from collections.abc import Callable, Sequence
from functools import partial
from typing import (
Annotated,
Any,
Literal,
cast,
)
from langchain_core.messages import (
AnyMessage,
BaseMessage,
BaseMessageChunk,
MessageLikeRepresentation,
RemoveMessage,
convert_to_messages,
message_chunk_to_message,
)
from typing_extensions import TypedDict, deprecated
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, NS_SEP
from langgraph.graph.state import StateGraph
from langgraph.warnings import LangGraphDeprecatedSinceV10
__all__ = (
"add_messages",
"MessagesState",
"MessageGraph",
"REMOVE_ALL_MESSAGES",
)
Messages = list[MessageLikeRepresentation] | MessageLikeRepresentation
REMOVE_ALL_MESSAGES = "__remove_all__"
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
def _add_messages(
left: Messages | None = None, right: Messages | None = None, **kwargs: Any
) -> Messages | Callable[[Messages, Messages], Messages]:
if left is not None and right is not None:
return func(left, right, **kwargs)
elif left is not None or right is not None:
msg = (
f"Must specify non-null arguments for both 'left' and 'right'. Only "
f"received: '{'left' if left else 'right'}'."
)
raise ValueError(msg)
else:
return partial(func, **kwargs)
_add_messages.__doc__ = func.__doc__
return cast(Callable[[Messages, Messages], Messages], _add_messages)
@_add_messages_wrapper
def add_messages(
left: Messages,
right: Messages,
*,
format: Literal["langchain-openai"] | None = None,
) -> Messages:
"""Merges two lists of messages, updating existing messages by ID.
By default, this ensures the state is "append-only", unless the
new message has the same ID as an existing message.
Args:
left: The base list of `Messages`.
right: The list of `Messages` (or single `Message`) to merge
into the base list.
format: The format to return messages in. If `None` then `Messages` will be
returned as is. If `langchain-openai` then `Messages` will be returned as
`BaseMessage` objects with their contents formatted to match OpenAI message
format, meaning contents can be string, `'text'` blocks, or `'image_url'` blocks
and tool responses are returned as their own `ToolMessage` objects.
!!! important "Requirement"
Must have `langchain-core>=0.3.11` installed to use this feature.
Returns:
A new list of messages with the messages from `right` merged into `left`.
If a message in `right` has the same ID as a message in `left`, the
message from `right` will replace the message from `left`.
Example: Basic usage
```python
from langchain_core.messages import AIMessage, HumanMessage
msgs1 = [HumanMessage(content="Hello", id="1")]
msgs2 = [AIMessage(content="Hi there!", id="2")]
add_messages(msgs1, msgs2)
# [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
```
Example: Overwrite existing message
```python
msgs1 = [HumanMessage(content="Hello", id="1")]
msgs2 = [HumanMessage(content="Hello again", id="1")]
add_messages(msgs1, msgs2)
# [HumanMessage(content='Hello again', id='1')]
```
Example: Use in a StateGraph
```python
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
class State(TypedDict):
messages: Annotated[list, add_messages]
builder = StateGraph(State)
builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
builder.set_entry_point("chatbot")
builder.set_finish_point("chatbot")
graph = builder.compile()
graph.invoke({})
# {'messages': [AIMessage(content='Hello', id=...)]}
```
Example: Use OpenAI message format
```python
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, add_messages
class State(TypedDict):
messages: Annotated[list, add_messages(format="langchain-openai")]
def chatbot_node(state: State) -> list:
return {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Here's an image:",
"cache_control": {"type": "ephemeral"},
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "1234",
},
},
],
},
]
}
builder = StateGraph(State)
builder.add_node("chatbot", chatbot_node)
builder.set_entry_point("chatbot")
builder.set_finish_point("chatbot")
graph = builder.compile()
graph.invoke({"messages": []})
# {
# 'messages': [
# HumanMessage(
# content=[
# {"type": "text", "text": "Here's an image:"},
# {
# "type": "image_url",
# "image_url": {"url": "data:image/jpeg;base64,1234"},
# },
# ],
# ),
# ]
# }
```
"""
remove_all_idx = None
# coerce to list
if not isinstance(left, list):
left = [left] # type: ignore[assignment]
if not isinstance(right, list):
right = [right] # type: ignore[assignment]
# coerce to message
left = [
message_chunk_to_message(cast(BaseMessageChunk, m))
for m in convert_to_messages(left)
]
right = [
message_chunk_to_message(cast(BaseMessageChunk, m))
for m in convert_to_messages(right)
]
# assign missing ids
for m in left:
if m.id is None:
m.id = str(uuid.uuid4())
for idx, m in enumerate(right):
if m.id is None:
m.id = str(uuid.uuid4())
if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
remove_all_idx = idx
if remove_all_idx is not None:
return right[remove_all_idx + 1 :]
# merge
merged = left.copy()
merged_by_id = {m.id: i for i, m in enumerate(merged)}
ids_to_remove = set()
for m in right:
if (existing_idx := merged_by_id.get(m.id)) is not None:
if isinstance(m, RemoveMessage):
ids_to_remove.add(m.id)
else:
ids_to_remove.discard(m.id)
merged[existing_idx] = m
else:
if isinstance(m, RemoveMessage):
raise ValueError(
f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
)
merged_by_id[m.id] = len(merged)
merged.append(m)
merged = [m for m in merged if m.id not in ids_to_remove]
if format == "langchain-openai":
merged = _format_messages(merged)
elif format:
msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
raise ValueError(msg)
else:
pass
return merged
@deprecated(
"MessageGraph is deprecated in langgraph 1.0.0, to be removed in 2.0.0. Please use StateGraph with a `messages` key instead.",
category=None,
)
class MessageGraph(StateGraph):
"""A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
Each node in a MessageGraph takes a list of messages as input and returns zero or more
messages as output. The `add_messages` function is used to merge the output messages from each node
into the existing list of messages in the graph's state.
Examples:
```pycon
>>> from langgraph.graph.message import MessageGraph
...
>>> builder = MessageGraph()
>>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
>>> builder.set_entry_point("chatbot")
>>> builder.set_finish_point("chatbot")
>>> builder.compile().invoke([("user", "Hi there.")])
[HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
```
```pycon
>>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
>>> from langgraph.graph.message import MessageGraph
...
>>> builder = MessageGraph()
>>> builder.add_node(
... "chatbot",
... lambda state: [
... AIMessage(
... content="Hello!",
... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
... )
... ],
... )
>>> builder.add_node(
... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
... )
>>> builder.set_entry_point("chatbot")
>>> builder.add_edge("chatbot", "search")
>>> builder.set_finish_point("search")
>>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
{'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
```
"""
def __init__(self) -> None:
warnings.warn(
"MessageGraph is deprecated in LangGraph v1.0.0, to be removed in v2.0.0. Please use StateGraph with a `messages` key instead.",
category=LangGraphDeprecatedSinceV10,
stacklevel=2,
)
super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
try:
from langchain_core.messages import convert_to_openai_messages
except ImportError:
msg = (
"Must have langchain-core>=0.3.11 installed to use automatic message "
"formatting (format='langchain-openai'). Please update your langchain-core "
"version or remove the 'format' flag. Returning un-formatted "
"messages."
)
warnings.warn(msg)
return list(messages)
else:
return convert_to_messages(convert_to_openai_messages(messages))
def push_message(
message: MessageLikeRepresentation | BaseMessageChunk,
*,
state_key: str | None = "messages",
) -> AnyMessage:
"""Write a message manually to the `messages` / `messages-tuple` stream mode.
Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
"""
from langchain_core.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
)
from langgraph.config import get_config
from langgraph.pregel._messages import StreamMessagesHandler
config = get_config()
message = next(x for x in convert_to_messages([message]))
if message.id is None:
raise ValueError("Message ID is required")
if isinstance(config["callbacks"], BaseCallbackManager):
manager = config["callbacks"]
handlers = manager.handlers
elif isinstance(config["callbacks"], list) and all(
isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
):
handlers = config["callbacks"]
if stream_handler := next(
(x for x in handlers if isinstance(x, StreamMessagesHandler)), None
):
metadata = config["metadata"]
message_meta = (
tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
metadata,
)
stream_handler._emit(message_meta, message, dedupe=False)
if state_key:
config[CONF][CONFIG_KEY_SEND]([(state_key, message)])
return message

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,227 @@
from __future__ import annotations
from typing import Any, Literal, cast
from uuid import uuid4
from langchain_core.messages import AnyMessage
from typing_extensions import TypedDict
from langgraph.config import get_config, get_stream_writer
from langgraph.constants import CONF
__all__ = (
"UIMessage",
"RemoveUIMessage",
"AnyUIMessage",
"push_ui_message",
"delete_ui_message",
"ui_message_reducer",
)
class UIMessage(TypedDict):
"""A message type for UI updates in LangGraph.
This TypedDict represents a UI message that can be sent to update the UI state.
It contains information about the UI component to render and its properties.
Attributes:
type: Literal type indicating this is a UI message.
id: Unique identifier for the UI message.
name: Name of the UI component to render.
props: Properties to pass to the UI component.
metadata: Additional metadata about the UI message.
"""
type: Literal["ui"]
id: str
name: str
props: dict[str, Any]
metadata: dict[str, Any]
class RemoveUIMessage(TypedDict):
"""A message type for removing UI components in LangGraph.
This TypedDict represents a message that can be sent to remove a UI component
from the current state.
Attributes:
type: Literal type indicating this is a remove-ui message.
id: Unique identifier of the UI message to remove.
"""
type: Literal["remove-ui"]
id: str
AnyUIMessage = UIMessage | RemoveUIMessage
def push_ui_message(
name: str,
props: dict[str, Any],
*,
id: str | None = None,
metadata: dict[str, Any] | None = None,
message: AnyMessage | None = None,
state_key: str | None = "ui",
merge: bool = False,
) -> UIMessage:
"""Push a new UI message to update the UI state.
This function creates and sends a UI message that will be rendered in the UI.
It also updates the graph state with the new UI message.
Args:
name: Name of the UI component to render.
props: Properties to pass to the UI component.
id: Optional unique identifier for the UI message.
If not provided, a random UUID will be generated.
metadata: Optional additional metadata about the UI message.
message: Optional message object to associate with the UI message.
state_key: Key in the graph state where the UI messages are stored.
merge: Whether to merge props with existing UI message (True) or replace
them (False).
Returns:
The created UI message.
Example:
```python
push_ui_message(
name="component-name",
props={"content": "Hello world"},
)
```
"""
from langgraph._internal._constants import CONFIG_KEY_SEND
writer = get_stream_writer()
config = get_config()
message_id = None
if message:
if isinstance(message, dict) and "id" in message:
message_id = message.get("id")
elif hasattr(message, "id"):
message_id = message.id
evt: UIMessage = {
"type": "ui",
"id": id or str(uuid4()),
"name": name,
"props": props,
"metadata": {
"merge": merge,
"run_id": config.get("run_id", None),
"tags": config.get("tags", None),
"name": config.get("run_name", None),
**(metadata or {}),
**({"message_id": message_id} if message_id else {}),
},
}
writer(evt)
if state_key:
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
return evt
def delete_ui_message(id: str, *, state_key: str = "ui") -> RemoveUIMessage:
"""Delete a UI message by ID from the UI state.
This function creates and sends a message to remove a UI component from the current state.
It also updates the graph state to remove the UI message.
Args:
id: Unique identifier of the UI component to remove.
state_key: Key in the graph state where the UI messages are stored. Defaults to "ui".
Returns:
The remove UI message.
Example:
```python
delete_ui_message("message-123")
```
"""
from langgraph._internal._constants import CONFIG_KEY_SEND
writer = get_stream_writer()
config = get_config()
evt: RemoveUIMessage = {"type": "remove-ui", "id": id}
writer(evt)
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
return evt
def ui_message_reducer(
left: list[AnyUIMessage] | AnyUIMessage,
right: list[AnyUIMessage] | AnyUIMessage,
) -> list[AnyUIMessage]:
"""Merge two lists of UI messages, supporting removing UI messages.
This function combines two lists of UI messages, handling both regular UI messages
and `remove-ui` messages. When a `remove-ui` message is encountered, it removes any
UI message with the matching ID from the current state.
Args:
left: First list of UI messages or single UI message.
right: Second list of UI messages or single UI message.
Returns:
Combined list of UI messages with removals applied.
Example:
```python
messages = ui_message_reducer(
[{"type": "ui", "id": "1", "name": "Chat", "props": {}}],
{"type": "remove-ui", "id": "1"},
)
```
"""
if not isinstance(left, list):
left = [left]
if not isinstance(right, list):
right = [right]
# merge messages
merged = left.copy()
merged_by_id = {m.get("id"): i for i, m in enumerate(merged)}
ids_to_remove = set()
for msg in right:
msg_id = msg.get("id")
if (existing_idx := merged_by_id.get(msg_id)) is not None:
if msg.get("type") == "remove-ui":
ids_to_remove.add(msg_id)
else:
ids_to_remove.discard(msg_id)
if cast(UIMessage, msg).get("metadata", {}).get("merge", False):
prev_msg = merged[existing_idx]
msg = msg.copy()
msg["props"] = {**prev_msg["props"], **msg["props"]}
merged[existing_idx] = msg
else:
if msg.get("type") == "remove-ui":
raise ValueError(
f"Attempting to delete an UI message with an ID that doesn't exist ('{msg_id}')"
)
merged_by_id[msg_id] = len(merged)
merged.append(msg)
merged = [m for m in merged if m.get("id") not in ids_to_remove]
return merged

View File

@@ -0,0 +1,3 @@
from langgraph.managed.is_last_step import IsLastStep, RemainingSteps
__all__ = ("IsLastStep", "RemainingSteps")

Some files were not shown because too many files have changed in this diff Show More