initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,628 @@
from __future__ import annotations
import copy
import logging
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, Sequence
from typing import ( # noqa: UP035
Any,
Generic,
Literal,
NamedTuple,
TypedDict,
TypeVar,
)
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base.id import uuid6
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import (
ERROR,
INTERRUPT,
RESUME,
SCHEDULED,
ChannelProtocol,
)
V = TypeVar("V", int, float, str)
PendingWrite = tuple[str, str, Any]
logger = logging.getLogger(__name__)
# Marked as total=False to allow for future expansion.
class CheckpointMetadata(TypedDict, total=False):
"""Metadata associated with a checkpoint."""
source: Literal["input", "loop", "update", "fork"]
"""The source of the checkpoint.
- `"input"`: The checkpoint was created from an input to invoke/stream/batch.
- `"loop"`: The checkpoint was created from inside the pregel loop.
- `"update"`: The checkpoint was created from a manual state update.
- `"fork"`: The checkpoint was created as a copy of another checkpoint.
"""
step: int
"""The step number of the checkpoint.
`-1` for the first `"input"` checkpoint.
`0` for the first `"loop"` checkpoint.
`...` for the `nth` checkpoint afterwards.
"""
parents: dict[str, str]
"""The IDs of the parent checkpoints.
Mapping from checkpoint namespace to checkpoint ID.
"""
run_id: str
"""The ID of the run that created this checkpoint."""
ChannelVersions = dict[str, str | int | float]
class Checkpoint(TypedDict):
"""State snapshot at a given point in time."""
v: int
"""The version of the checkpoint format. Currently `1`."""
id: str
"""The ID of the checkpoint.
This is both unique and monotonically increasing, so can be used for sorting
checkpoints from first to last."""
ts: str
"""The timestamp of the checkpoint in ISO 8601 format."""
channel_values: dict[str, Any]
"""The values of the channels at the time of the checkpoint.
Mapping from channel name to deserialized channel snapshot value.
"""
channel_versions: ChannelVersions
"""The versions of the channels at the time of the checkpoint.
The keys are channel names and the values are monotonically increasing
version strings for each channel.
"""
versions_seen: dict[str, ChannelVersions]
"""Map from node ID to map from channel name to version seen.
This keeps track of the versions of the channels that each node has seen.
Used to determine which nodes to execute next.
"""
updated_channels: list[str] | None
"""The channels that were updated in this checkpoint.
"""
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
return Checkpoint(
v=checkpoint["v"],
ts=checkpoint["ts"],
id=checkpoint["id"],
channel_values=checkpoint["channel_values"].copy(),
channel_versions=checkpoint["channel_versions"].copy(),
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
pending_sends=checkpoint.get("pending_sends", []).copy(),
updated_channels=checkpoint.get("updated_channels", None),
)
class CheckpointTuple(NamedTuple):
"""A tuple containing a checkpoint and its associated data."""
config: RunnableConfig
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: RunnableConfig | None = None
pending_writes: list[PendingWrite] | None = None
class BaseCheckpointSaver(Generic[V]):
"""Base class for creating a graph checkpointer.
Checkpointers allow LangGraph agents to persist their state
within and across multiple interactions.
When a checkpointer is configured, you should pass a `thread_id` in the config when
invoking the graph:
```python
config = {"configurable": {"thread_id": "my-thread"}}
graph.invoke(inputs, config)
```
The `thread_id` is the primary key used to store and retrieve checkpoints. Without
it, the checkpointer cannot save state, resume from interrupts, or enable
time-travel debugging.
How you choose ``thread_id`` depends on your use case:
- **Single-shot workflows**: Use a unique ID (e.g., uuid4) for each run when
executions are independent.
- **Conversational memory**: Reuse the same `thread_id` across invocations
to accumulate state (e.g., chat history) within a conversation.
Attributes:
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
Note:
When creating a custom checkpoint saver, consider implementing async
versions to avoid blocking the main thread.
"""
serde: SerializerProtocol = JsonPlusSerializer()
def __init__(
self,
*,
serde: SerializerProtocol | None = None,
) -> None:
self.serde = maybe_add_typed_methods(serde or self.serde)
@property
def config_specs(self) -> list:
"""Define the configuration options for the checkpoint saver.
Returns:
list: List of configuration field specs.
"""
return []
def get(self, config: RunnableConfig) -> Checkpoint | None:
"""Fetch a checkpoint using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint, or `None` if not found.
"""
if value := self.get_tuple(config):
return value.checkpoint
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Fetch a checkpoint tuple using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint tuple, or `None` if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def list(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints that match the given criteria.
Args:
config: Base configuration for filtering checkpoints.
filter: Additional filtering criteria.
before: List checkpoints created before this configuration.
limit: Maximum number of checkpoints to return.
Returns:
Iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Store a checkpoint with its configuration and metadata.
Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
def delete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
raise NotImplementedError
def delete_for_runs(
self,
run_ids: Sequence[str],
) -> None:
"""Delete all checkpoints and writes associated with the given run IDs.
Args:
run_ids: The run IDs whose checkpoints should be deleted.
"""
raise NotImplementedError
def copy_thread(
self,
source_thread_id: str,
target_thread_id: str,
) -> None:
"""Copy all checkpoints and writes from one thread to another.
Args:
source_thread_id: The thread ID to copy from.
target_thread_id: The thread ID to copy to.
"""
raise NotImplementedError
def prune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
) -> None:
"""Prune checkpoints for the given threads.
Args:
thread_ids: The thread IDs to prune.
strategy: The pruning strategy. `"keep_latest"` retains only the most
recent checkpoint per namespace. `"delete"` removes all checkpoints.
"""
raise NotImplementedError
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
"""Asynchronously fetch a checkpoint using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint, or `None` if not found.
"""
if value := await self.aget_tuple(config):
return value.checkpoint
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Asynchronously fetch a checkpoint tuple using the given configuration.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
The requested checkpoint tuple, or `None` if not found.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def alist(
self,
config: RunnableConfig | None,
*,
filter: dict[str, Any] | None = None,
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.
Args:
config: Base configuration for filtering checkpoints.
filter: Additional filtering criteria for metadata.
before: List checkpoints created before this configuration.
limit: Maximum number of checkpoints to return.
Returns:
Async iterator of matching checkpoint tuples.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
yield
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.
Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
Raises:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError
async def adelete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
raise NotImplementedError
async def adelete_for_runs(
self,
run_ids: Sequence[str],
) -> None:
"""Asynchronously delete all checkpoints and writes for the given run IDs.
Args:
run_ids: The run IDs whose checkpoints should be deleted.
"""
raise NotImplementedError
async def acopy_thread(
self,
source_thread_id: str,
target_thread_id: str,
) -> None:
"""Asynchronously copy all checkpoints and writes from one thread to another.
Args:
source_thread_id: The thread ID to copy from.
target_thread_id: The thread ID to copy to.
"""
raise NotImplementedError
async def aprune(
self,
thread_ids: Sequence[str],
*,
strategy: str = "keep_latest",
) -> None:
"""Asynchronously prune checkpoints for the given threads.
Args:
thread_ids: The thread IDs to prune.
strategy: The pruning strategy. `"keep_latest"` retains only the most
recent checkpoint per namespace. `"delete"` removes all checkpoints.
"""
raise NotImplementedError
def get_next_version(self, current: V | None, channel: None) -> V:
"""Generate the next version ID for a channel.
Default is to use integer versions, incrementing by `1`.
If you override, you can use `str`/`int`/`float` versions, as long as they are monotonically increasing.
Args:
current: The current version identifier (`int`, `float`, or `str`).
channel: Deprecated argument, kept for backwards compatibility.
Returns:
V: The next version identifier, which must be increasing.
"""
if isinstance(current, str):
raise NotImplementedError
elif current is None:
return 1
else:
return current + 1
def with_allowlist(
self, extra_allowlist: Collection[tuple[str, ...]]
) -> BaseCheckpointSaver[V]:
"""Return a shallow clone with a derived msgpack allowlist."""
serde = _with_msgpack_allowlist(self.serde, extra_allowlist)
if serde is self.serde:
return self
clone = copy.copy(self)
clone.serde = maybe_add_typed_methods(serde)
return clone
def _with_msgpack_allowlist(
serde: SerializerProtocol, extra_allowlist: Collection[tuple[str, ...]]
) -> SerializerProtocol:
if isinstance(serde, JsonPlusSerializer):
return serde.with_msgpack_allowlist(extra_allowlist)
if isinstance(serde, EncryptedSerializer):
inner = serde.serde
if isinstance(inner, JsonPlusSerializer):
updated_inner = inner.with_msgpack_allowlist(extra_allowlist)
if updated_inner is inner:
return serde
return EncryptedSerializer(serde.cipher, updated_inner)
logger.warning(
"Serializer %s does not support msgpack allowlist. "
"Strict msgpack deserialization will not be enforced.",
type(serde).__name__,
)
return serde
class EmptyChannelError(Exception):
"""Raised when attempting to get the value of a channel that hasn't been updated
for the first time yet."""
pass
def get_checkpoint_id(config: RunnableConfig) -> str | None:
"""Get checkpoint ID."""
return config["configurable"].get("checkpoint_id")
def get_checkpoint_metadata(
config: RunnableConfig, metadata: CheckpointMetadata
) -> CheckpointMetadata:
"""Get checkpoint metadata in a backwards-compatible manner."""
metadata = {
k: v.replace("\u0000", "") if isinstance(v, str) else v
for k, v in metadata.items()
}
for obj in (config.get("metadata"), config.get("configurable")):
if not obj:
continue
for key, v in obj.items():
if key in metadata or key in EXCLUDED_METADATA_KEYS or key.startswith("__"):
continue
elif isinstance(v, str):
metadata[key] = v.replace("\u0000", "")
elif isinstance(v, (int, bool, float)):
metadata[key] = v
return metadata
def get_serializable_checkpoint_metadata(
config: RunnableConfig, metadata: CheckpointMetadata
) -> CheckpointMetadata:
"""Get checkpoint metadata in a backwards-compatible manner."""
checkpoint_metadata = get_checkpoint_metadata(config, metadata)
if "writes" in checkpoint_metadata:
checkpoint_metadata.pop("writes")
return checkpoint_metadata
"""
Mapping from error type to error index.
Regular writes just map to their index in the list of writes being saved.
Special writes (e.g. errors) map to negative indices, to avoid those writes from
conflicting with regular writes.
Each Checkpointer implementation should use this mapping in put_writes.
"""
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
EXCLUDED_METADATA_KEYS = {
"thread_id",
"checkpoint_id",
"checkpoint_ns",
"checkpoint_map",
"langgraph_step",
"langgraph_node",
"langgraph_triggers",
"langgraph_path",
"langgraph_checkpoint_ns",
}
# --- below are deprecated utilities used by past versions of LangGraph ---
LATEST_VERSION = 2
def empty_checkpoint() -> Checkpoint:
from datetime import datetime, timezone
return Checkpoint(
v=LATEST_VERSION,
id=str(uuid6(clock_seq=-2)),
ts=datetime.now(timezone.utc).isoformat(),
channel_values={},
channel_versions={},
versions_seen={},
pending_sends=[],
updated_channels=None,
)
def create_checkpoint(
checkpoint: Checkpoint,
channels: Mapping[str, ChannelProtocol] | None,
step: int,
*,
id: str | None = None,
) -> Checkpoint:
"""Create a checkpoint for the given channels."""
from datetime import datetime, timezone
ts = datetime.now(timezone.utc).isoformat()
if channels is None:
values = checkpoint["channel_values"]
else:
values = {}
for k, v in channels.items():
if k not in checkpoint["channel_versions"]:
continue
try:
values[k] = v.checkpoint()
except EmptyChannelError:
pass
return Checkpoint(
v=LATEST_VERSION,
ts=ts,
id=id or str(uuid6(clock_seq=step)),
channel_values=values,
channel_versions=checkpoint["channel_versions"],
versions_seen=checkpoint["versions_seen"],
pending_sends=checkpoint.get("pending_sends", []),
updated_channels=None,
)

View File

@@ -0,0 +1,109 @@
"""Adapted from
https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95
Bundled in to avoid install issues with uuid6 package
"""
from __future__ import annotations
import random
import time
import uuid
_last_v6_timestamp = None
class UUID(uuid.UUID):
r"""UUID draft version objects"""
__slots__ = ()
def __init__(
self,
hex: str | None = None,
bytes: bytes | None = None,
bytes_le: bytes | None = None,
fields: tuple[int, int, int, int, int, int] | None = None,
int: int | None = None,
version: int | None = None,
*,
is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown,
) -> None:
r"""Create a UUID."""
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
return super().__init__(
hex=hex,
bytes=bytes,
bytes_le=bytes_le,
fields=fields,
int=int,
version=version,
is_safe=is_safe,
)
if not 0 <= int < 1 << 128:
raise ValueError("int is out of range (need a 128-bit value)")
if version is not None:
if not 6 <= version <= 8:
raise ValueError("illegal version number")
# Set the variant to RFC 4122.
int &= ~(0xC000 << 48)
int |= 0x8000 << 48
# Set the version number.
int &= ~(0xF000 << 64)
int |= version << 76
super().__init__(int=int, is_safe=is_safe)
@property
def subsec(self) -> int:
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
@property
def time(self) -> int:
if self.version == 6:
return (
(self.time_low << 28)
| (self.time_mid << 12)
| (self.time_hi_version & 0x0FFF)
)
if self.version == 7:
return self.int >> 80
if self.version == 8:
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
return super().time
def _subsec_decode(value: int) -> int:
return -(-value * 10**6 // 2**20)
def uuid6(node: int | None = None, clock_seq: int | None = None) -> UUID:
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
improved DB locality. It is expected that UUIDv6 will primarily be
used in contexts where there are existing v1 UUIDs. Systems that do
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
If 'node' is not given, a random 48-bit number is chosen.
If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
global _last_v6_timestamp
nanoseconds = time.time_ns()
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = nanoseconds // 100 + 0x01B21DD213814000
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
timestamp = _last_v6_timestamp + 1
_last_v6_timestamp = timestamp
if clock_seq is None:
clock_seq = random.getrandbits(14) # instead of stable storage
if node is None:
node = random.getrandbits(48)
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
time_low_and_version = timestamp & 0x0FFF
uuid_int = time_high_and_time_mid << 80
uuid_int |= time_low_and_version << 64
uuid_int |= (clock_seq & 0x3FFF) << 48
uuid_int |= node & 0xFFFFFFFFFFFF
return UUID(int=uuid_int, version=6)