initial commit
This commit is contained in:
628
venv/Lib/site-packages/langgraph/checkpoint/base/__init__.py
Normal file
628
venv/Lib/site-packages/langgraph/checkpoint/base/__init__.py
Normal file
@@ -0,0 +1,628 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, Sequence
|
||||
from typing import ( # noqa: UP035
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langgraph.checkpoint.base.id import uuid6
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
|
||||
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
from langgraph.checkpoint.serde.types import (
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
RESUME,
|
||||
SCHEDULED,
|
||||
ChannelProtocol,
|
||||
)
|
||||
|
||||
V = TypeVar("V", int, float, str)
|
||||
PendingWrite = tuple[str, str, Any]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Marked as total=False to allow for future expansion.
|
||||
class CheckpointMetadata(TypedDict, total=False):
|
||||
"""Metadata associated with a checkpoint."""
|
||||
|
||||
source: Literal["input", "loop", "update", "fork"]
|
||||
"""The source of the checkpoint.
|
||||
|
||||
- `"input"`: The checkpoint was created from an input to invoke/stream/batch.
|
||||
- `"loop"`: The checkpoint was created from inside the pregel loop.
|
||||
- `"update"`: The checkpoint was created from a manual state update.
|
||||
- `"fork"`: The checkpoint was created as a copy of another checkpoint.
|
||||
"""
|
||||
step: int
|
||||
"""The step number of the checkpoint.
|
||||
|
||||
`-1` for the first `"input"` checkpoint.
|
||||
`0` for the first `"loop"` checkpoint.
|
||||
`...` for the `nth` checkpoint afterwards.
|
||||
"""
|
||||
parents: dict[str, str]
|
||||
"""The IDs of the parent checkpoints.
|
||||
|
||||
Mapping from checkpoint namespace to checkpoint ID.
|
||||
"""
|
||||
run_id: str
|
||||
"""The ID of the run that created this checkpoint."""
|
||||
|
||||
|
||||
ChannelVersions = dict[str, str | int | float]
|
||||
|
||||
|
||||
class Checkpoint(TypedDict):
|
||||
"""State snapshot at a given point in time."""
|
||||
|
||||
v: int
|
||||
"""The version of the checkpoint format. Currently `1`."""
|
||||
id: str
|
||||
"""The ID of the checkpoint.
|
||||
|
||||
This is both unique and monotonically increasing, so can be used for sorting
|
||||
checkpoints from first to last."""
|
||||
ts: str
|
||||
"""The timestamp of the checkpoint in ISO 8601 format."""
|
||||
channel_values: dict[str, Any]
|
||||
"""The values of the channels at the time of the checkpoint.
|
||||
|
||||
Mapping from channel name to deserialized channel snapshot value.
|
||||
"""
|
||||
channel_versions: ChannelVersions
|
||||
"""The versions of the channels at the time of the checkpoint.
|
||||
|
||||
The keys are channel names and the values are monotonically increasing
|
||||
version strings for each channel.
|
||||
"""
|
||||
versions_seen: dict[str, ChannelVersions]
|
||||
"""Map from node ID to map from channel name to version seen.
|
||||
|
||||
This keeps track of the versions of the channels that each node has seen.
|
||||
Used to determine which nodes to execute next.
|
||||
"""
|
||||
updated_channels: list[str] | None
|
||||
"""The channels that were updated in this checkpoint.
|
||||
"""
|
||||
|
||||
|
||||
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
|
||||
return Checkpoint(
|
||||
v=checkpoint["v"],
|
||||
ts=checkpoint["ts"],
|
||||
id=checkpoint["id"],
|
||||
channel_values=checkpoint["channel_values"].copy(),
|
||||
channel_versions=checkpoint["channel_versions"].copy(),
|
||||
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
|
||||
pending_sends=checkpoint.get("pending_sends", []).copy(),
|
||||
updated_channels=checkpoint.get("updated_channels", None),
|
||||
)
|
||||
|
||||
|
||||
class CheckpointTuple(NamedTuple):
|
||||
"""A tuple containing a checkpoint and its associated data."""
|
||||
|
||||
config: RunnableConfig
|
||||
checkpoint: Checkpoint
|
||||
metadata: CheckpointMetadata
|
||||
parent_config: RunnableConfig | None = None
|
||||
pending_writes: list[PendingWrite] | None = None
|
||||
|
||||
|
||||
class BaseCheckpointSaver(Generic[V]):
|
||||
"""Base class for creating a graph checkpointer.
|
||||
|
||||
Checkpointers allow LangGraph agents to persist their state
|
||||
within and across multiple interactions.
|
||||
|
||||
When a checkpointer is configured, you should pass a `thread_id` in the config when
|
||||
invoking the graph:
|
||||
|
||||
```python
|
||||
config = {"configurable": {"thread_id": "my-thread"}}
|
||||
graph.invoke(inputs, config)
|
||||
```
|
||||
|
||||
The `thread_id` is the primary key used to store and retrieve checkpoints. Without
|
||||
it, the checkpointer cannot save state, resume from interrupts, or enable
|
||||
time-travel debugging.
|
||||
|
||||
How you choose ``thread_id`` depends on your use case:
|
||||
|
||||
- **Single-shot workflows**: Use a unique ID (e.g., uuid4) for each run when
|
||||
executions are independent.
|
||||
- **Conversational memory**: Reuse the same `thread_id` across invocations
|
||||
to accumulate state (e.g., chat history) within a conversation.
|
||||
|
||||
Attributes:
|
||||
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
|
||||
|
||||
Note:
|
||||
When creating a custom checkpoint saver, consider implementing async
|
||||
versions to avoid blocking the main thread.
|
||||
"""
|
||||
|
||||
serde: SerializerProtocol = JsonPlusSerializer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
) -> None:
|
||||
self.serde = maybe_add_typed_methods(serde or self.serde)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list:
|
||||
"""Define the configuration options for the checkpoint saver.
|
||||
|
||||
Returns:
|
||||
list: List of configuration field specs.
|
||||
"""
|
||||
return []
|
||||
|
||||
def get(self, config: RunnableConfig) -> Checkpoint | None:
|
||||
"""Fetch a checkpoint using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint, or `None` if not found.
|
||||
"""
|
||||
if value := self.get_tuple(config):
|
||||
return value.checkpoint
|
||||
|
||||
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Fetch a checkpoint tuple using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint tuple, or `None` if not found.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[CheckpointTuple]:
|
||||
"""List checkpoints that match the given criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Returns:
|
||||
Iterator of matching checkpoint tuples.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Store a checkpoint with its configuration and metadata.
|
||||
|
||||
Args:
|
||||
config: Configuration for the checkpoint.
|
||||
checkpoint: The checkpoint to store.
|
||||
metadata: Additional metadata for the checkpoint.
|
||||
new_versions: New channel versions as of this write.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: Updated configuration after storing the checkpoint.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Store intermediate writes linked to a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Configuration of the related checkpoint.
|
||||
writes: List of writes to store.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with a specific thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_for_runs(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with the given run IDs.
|
||||
|
||||
Args:
|
||||
run_ids: The run IDs whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def copy_thread(
|
||||
self,
|
||||
source_thread_id: str,
|
||||
target_thread_id: str,
|
||||
) -> None:
|
||||
"""Copy all checkpoints and writes from one thread to another.
|
||||
|
||||
Args:
|
||||
source_thread_id: The thread ID to copy from.
|
||||
target_thread_id: The thread ID to copy to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def prune(
|
||||
self,
|
||||
thread_ids: Sequence[str],
|
||||
*,
|
||||
strategy: str = "keep_latest",
|
||||
) -> None:
|
||||
"""Prune checkpoints for the given threads.
|
||||
|
||||
Args:
|
||||
thread_ids: The thread IDs to prune.
|
||||
strategy: The pruning strategy. `"keep_latest"` retains only the most
|
||||
recent checkpoint per namespace. `"delete"` removes all checkpoints.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
|
||||
"""Asynchronously fetch a checkpoint using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint, or `None` if not found.
|
||||
"""
|
||||
if value := await self.aget_tuple(config):
|
||||
return value.checkpoint
|
||||
|
||||
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Asynchronously fetch a checkpoint tuple using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint tuple, or `None` if not found.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> AsyncIterator[CheckpointTuple]:
|
||||
"""Asynchronously list checkpoints that match the given criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria for metadata.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Returns:
|
||||
Async iterator of matching checkpoint tuples.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
yield
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously store a checkpoint with its configuration and metadata.
|
||||
|
||||
Args:
|
||||
config: Configuration for the checkpoint.
|
||||
checkpoint: The checkpoint to store.
|
||||
metadata: Additional metadata for the checkpoint.
|
||||
new_versions: New channel versions as of this write.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: Updated configuration after storing the checkpoint.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Asynchronously store intermediate writes linked to a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Configuration of the related checkpoint.
|
||||
writes: List of writes to store.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def adelete_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with a specific thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def adelete_for_runs(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> None:
|
||||
"""Asynchronously delete all checkpoints and writes for the given run IDs.
|
||||
|
||||
Args:
|
||||
run_ids: The run IDs whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def acopy_thread(
|
||||
self,
|
||||
source_thread_id: str,
|
||||
target_thread_id: str,
|
||||
) -> None:
|
||||
"""Asynchronously copy all checkpoints and writes from one thread to another.
|
||||
|
||||
Args:
|
||||
source_thread_id: The thread ID to copy from.
|
||||
target_thread_id: The thread ID to copy to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aprune(
|
||||
self,
|
||||
thread_ids: Sequence[str],
|
||||
*,
|
||||
strategy: str = "keep_latest",
|
||||
) -> None:
|
||||
"""Asynchronously prune checkpoints for the given threads.
|
||||
|
||||
Args:
|
||||
thread_ids: The thread IDs to prune.
|
||||
strategy: The pruning strategy. `"keep_latest"` retains only the most
|
||||
recent checkpoint per namespace. `"delete"` removes all checkpoints.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_version(self, current: V | None, channel: None) -> V:
|
||||
"""Generate the next version ID for a channel.
|
||||
|
||||
Default is to use integer versions, incrementing by `1`.
|
||||
|
||||
If you override, you can use `str`/`int`/`float` versions, as long as they are monotonically increasing.
|
||||
|
||||
Args:
|
||||
current: The current version identifier (`int`, `float`, or `str`).
|
||||
channel: Deprecated argument, kept for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
V: The next version identifier, which must be increasing.
|
||||
"""
|
||||
if isinstance(current, str):
|
||||
raise NotImplementedError
|
||||
elif current is None:
|
||||
return 1
|
||||
else:
|
||||
return current + 1
|
||||
|
||||
def with_allowlist(
|
||||
self, extra_allowlist: Collection[tuple[str, ...]]
|
||||
) -> BaseCheckpointSaver[V]:
|
||||
"""Return a shallow clone with a derived msgpack allowlist."""
|
||||
serde = _with_msgpack_allowlist(self.serde, extra_allowlist)
|
||||
if serde is self.serde:
|
||||
return self
|
||||
clone = copy.copy(self)
|
||||
clone.serde = maybe_add_typed_methods(serde)
|
||||
return clone
|
||||
|
||||
|
||||
def _with_msgpack_allowlist(
|
||||
serde: SerializerProtocol, extra_allowlist: Collection[tuple[str, ...]]
|
||||
) -> SerializerProtocol:
|
||||
if isinstance(serde, JsonPlusSerializer):
|
||||
return serde.with_msgpack_allowlist(extra_allowlist)
|
||||
if isinstance(serde, EncryptedSerializer):
|
||||
inner = serde.serde
|
||||
if isinstance(inner, JsonPlusSerializer):
|
||||
updated_inner = inner.with_msgpack_allowlist(extra_allowlist)
|
||||
if updated_inner is inner:
|
||||
return serde
|
||||
return EncryptedSerializer(serde.cipher, updated_inner)
|
||||
logger.warning(
|
||||
"Serializer %s does not support msgpack allowlist. "
|
||||
"Strict msgpack deserialization will not be enforced.",
|
||||
type(serde).__name__,
|
||||
)
|
||||
return serde
|
||||
|
||||
|
||||
class EmptyChannelError(Exception):
|
||||
"""Raised when attempting to get the value of a channel that hasn't been updated
|
||||
for the first time yet."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoint_id(config: RunnableConfig) -> str | None:
|
||||
"""Get checkpoint ID."""
|
||||
return config["configurable"].get("checkpoint_id")
|
||||
|
||||
|
||||
def get_checkpoint_metadata(
|
||||
config: RunnableConfig, metadata: CheckpointMetadata
|
||||
) -> CheckpointMetadata:
|
||||
"""Get checkpoint metadata in a backwards-compatible manner."""
|
||||
metadata = {
|
||||
k: v.replace("\u0000", "") if isinstance(v, str) else v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
for obj in (config.get("metadata"), config.get("configurable")):
|
||||
if not obj:
|
||||
continue
|
||||
for key, v in obj.items():
|
||||
if key in metadata or key in EXCLUDED_METADATA_KEYS or key.startswith("__"):
|
||||
continue
|
||||
elif isinstance(v, str):
|
||||
metadata[key] = v.replace("\u0000", "")
|
||||
elif isinstance(v, (int, bool, float)):
|
||||
metadata[key] = v
|
||||
return metadata
|
||||
|
||||
|
||||
def get_serializable_checkpoint_metadata(
|
||||
config: RunnableConfig, metadata: CheckpointMetadata
|
||||
) -> CheckpointMetadata:
|
||||
"""Get checkpoint metadata in a backwards-compatible manner."""
|
||||
checkpoint_metadata = get_checkpoint_metadata(config, metadata)
|
||||
if "writes" in checkpoint_metadata:
|
||||
checkpoint_metadata.pop("writes")
|
||||
return checkpoint_metadata
|
||||
|
||||
|
||||
"""
|
||||
Mapping from error type to error index.
|
||||
Regular writes just map to their index in the list of writes being saved.
|
||||
Special writes (e.g. errors) map to negative indices, to avoid those writes from
|
||||
conflicting with regular writes.
|
||||
Each Checkpointer implementation should use this mapping in put_writes.
|
||||
"""
|
||||
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
|
||||
|
||||
EXCLUDED_METADATA_KEYS = {
|
||||
"thread_id",
|
||||
"checkpoint_id",
|
||||
"checkpoint_ns",
|
||||
"checkpoint_map",
|
||||
"langgraph_step",
|
||||
"langgraph_node",
|
||||
"langgraph_triggers",
|
||||
"langgraph_path",
|
||||
"langgraph_checkpoint_ns",
|
||||
}
|
||||
|
||||
# --- below are deprecated utilities used by past versions of LangGraph ---
|
||||
|
||||
LATEST_VERSION = 2
|
||||
|
||||
|
||||
def empty_checkpoint() -> Checkpoint:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
id=str(uuid6(clock_seq=-2)),
|
||||
ts=datetime.now(timezone.utc).isoformat(),
|
||||
channel_values={},
|
||||
channel_versions={},
|
||||
versions_seen={},
|
||||
pending_sends=[],
|
||||
updated_channels=None,
|
||||
)
|
||||
|
||||
|
||||
def create_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
channels: Mapping[str, ChannelProtocol] | None,
|
||||
step: int,
|
||||
*,
|
||||
id: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""Create a checkpoint for the given channels."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
if channels is None:
|
||||
values = checkpoint["channel_values"]
|
||||
else:
|
||||
values = {}
|
||||
for k, v in channels.items():
|
||||
if k not in checkpoint["channel_versions"]:
|
||||
continue
|
||||
try:
|
||||
values[k] = v.checkpoint()
|
||||
except EmptyChannelError:
|
||||
pass
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
ts=ts,
|
||||
id=id or str(uuid6(clock_seq=step)),
|
||||
channel_values=values,
|
||||
channel_versions=checkpoint["channel_versions"],
|
||||
versions_seen=checkpoint["versions_seen"],
|
||||
pending_sends=checkpoint.get("pending_sends", []),
|
||||
updated_channels=None,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
109
venv/Lib/site-packages/langgraph/checkpoint/base/id.py
Normal file
109
venv/Lib/site-packages/langgraph/checkpoint/base/id.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Adapted from
|
||||
https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95
|
||||
Bundled in to avoid install issues with uuid6 package
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
|
||||
_last_v6_timestamp = None
|
||||
|
||||
|
||||
class UUID(uuid.UUID):
|
||||
r"""UUID draft version objects"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hex: str | None = None,
|
||||
bytes: bytes | None = None,
|
||||
bytes_le: bytes | None = None,
|
||||
fields: tuple[int, int, int, int, int, int] | None = None,
|
||||
int: int | None = None,
|
||||
version: int | None = None,
|
||||
*,
|
||||
is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown,
|
||||
) -> None:
|
||||
r"""Create a UUID."""
|
||||
|
||||
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
|
||||
return super().__init__(
|
||||
hex=hex,
|
||||
bytes=bytes,
|
||||
bytes_le=bytes_le,
|
||||
fields=fields,
|
||||
int=int,
|
||||
version=version,
|
||||
is_safe=is_safe,
|
||||
)
|
||||
if not 0 <= int < 1 << 128:
|
||||
raise ValueError("int is out of range (need a 128-bit value)")
|
||||
if version is not None:
|
||||
if not 6 <= version <= 8:
|
||||
raise ValueError("illegal version number")
|
||||
# Set the variant to RFC 4122.
|
||||
int &= ~(0xC000 << 48)
|
||||
int |= 0x8000 << 48
|
||||
# Set the version number.
|
||||
int &= ~(0xF000 << 64)
|
||||
int |= version << 76
|
||||
super().__init__(int=int, is_safe=is_safe)
|
||||
|
||||
@property
|
||||
def subsec(self) -> int:
|
||||
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
|
||||
|
||||
@property
|
||||
def time(self) -> int:
|
||||
if self.version == 6:
|
||||
return (
|
||||
(self.time_low << 28)
|
||||
| (self.time_mid << 12)
|
||||
| (self.time_hi_version & 0x0FFF)
|
||||
)
|
||||
if self.version == 7:
|
||||
return self.int >> 80
|
||||
if self.version == 8:
|
||||
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
|
||||
return super().time
|
||||
|
||||
|
||||
def _subsec_decode(value: int) -> int:
|
||||
return -(-value * 10**6 // 2**20)
|
||||
|
||||
|
||||
def uuid6(node: int | None = None, clock_seq: int | None = None) -> UUID:
|
||||
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
|
||||
improved DB locality. It is expected that UUIDv6 will primarily be
|
||||
used in contexts where there are existing v1 UUIDs. Systems that do
|
||||
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
|
||||
|
||||
If 'node' is not given, a random 48-bit number is chosen.
|
||||
|
||||
If 'clock_seq' is given, it is used as the sequence number;
|
||||
otherwise a random 14-bit sequence number is chosen."""
|
||||
|
||||
global _last_v6_timestamp
|
||||
|
||||
nanoseconds = time.time_ns()
|
||||
# 0x01b21dd213814000 is the number of 100-ns intervals between the
|
||||
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
|
||||
timestamp = nanoseconds // 100 + 0x01B21DD213814000
|
||||
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
|
||||
timestamp = _last_v6_timestamp + 1
|
||||
_last_v6_timestamp = timestamp
|
||||
if clock_seq is None:
|
||||
clock_seq = random.getrandbits(14) # instead of stable storage
|
||||
if node is None:
|
||||
node = random.getrandbits(48)
|
||||
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
|
||||
time_low_and_version = timestamp & 0x0FFF
|
||||
uuid_int = time_high_and_time_mid << 80
|
||||
uuid_int |= time_low_and_version << 64
|
||||
uuid_int |= (clock_seq & 0x3FFF) << 48
|
||||
uuid_int |= node & 0xFFFFFFFFFFFF
|
||||
return UUID(int=uuid_int, version=6)
|
||||
603
venv/Lib/site-packages/langgraph/checkpoint/memory/__init__.py
Normal file
603
venv/Lib/site-packages/langgraph/checkpoint/memory/__init__.py
Normal file
@@ -0,0 +1,603 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langgraph.checkpoint.base import (
|
||||
WRITES_IDX_MAP,
|
||||
BaseCheckpointSaver,
|
||||
ChannelVersions,
|
||||
Checkpoint,
|
||||
CheckpointMetadata,
|
||||
CheckpointTuple,
|
||||
SerializerProtocol,
|
||||
get_checkpoint_id,
|
||||
get_checkpoint_metadata,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemorySaver(
|
||||
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
|
||||
):
|
||||
"""An in-memory checkpoint saver.
|
||||
|
||||
This checkpoint saver stores checkpoints in memory using a `defaultdict`.
|
||||
|
||||
Note:
|
||||
Only use `InMemorySaver` for debugging or testing purposes.
|
||||
For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`.
|
||||
|
||||
If you are using LangSmith Deployment, no checkpointer needs to be specified. The correct managed checkpointer will be used automatically.
|
||||
|
||||
Args:
|
||||
serde: The serializer to use for serializing and deserializing checkpoints.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
builder = StateGraph(int)
|
||||
builder.add_node("add_one", lambda x: x + 1)
|
||||
builder.set_entry_point("add_one")
|
||||
builder.set_finish_point("add_one")
|
||||
|
||||
memory = InMemorySaver()
|
||||
graph = builder.compile(checkpointer=memory)
|
||||
coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
|
||||
asyncio.run(coro) # Output: 2
|
||||
```
|
||||
"""
|
||||
|
||||
# thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping
|
||||
storage: defaultdict[
|
||||
str,
|
||||
dict[str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], str | None]]],
|
||||
]
|
||||
# (thread ID, checkpoint NS, checkpoint ID) -> (task ID, write idx)
|
||||
writes: defaultdict[
|
||||
tuple[str, str, str],
|
||||
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
|
||||
]
|
||||
blobs: dict[
|
||||
tuple[
|
||||
str, str, str, str | int | float
|
||||
], # thread id, checkpoint ns, channel, version
|
||||
tuple[str, bytes],
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
factory: type[defaultdict] = defaultdict,
|
||||
) -> None:
|
||||
super().__init__(serde=serde)
|
||||
self.storage = factory(lambda: defaultdict(dict))
|
||||
self.writes = factory(dict)
|
||||
self.blobs = factory()
|
||||
self.stack = ExitStack()
|
||||
if factory is not defaultdict:
|
||||
self.stack.enter_context(self.storage) # type: ignore[arg-type]
|
||||
self.stack.enter_context(self.writes) # type: ignore[arg-type]
|
||||
self.stack.enter_context(self.blobs) # type: ignore[arg-type]
|
||||
|
||||
def __enter__(self) -> InMemorySaver:
|
||||
self.stack.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self.stack.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
async def __aenter__(self) -> InMemorySaver:
|
||||
self.stack.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
__exc_type: type[BaseException] | None,
|
||||
__exc_value: BaseException | None,
|
||||
__traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self.stack.__exit__(__exc_type, __exc_value, __traceback)
|
||||
|
||||
def _load_blobs(
|
||||
self, thread_id: str, checkpoint_ns: str, versions: ChannelVersions
|
||||
) -> dict[str, Any]:
|
||||
channel_values: dict[str, Any] = {}
|
||||
for k, v in versions.items():
|
||||
kk = (thread_id, checkpoint_ns, k, v)
|
||||
if kk in self.blobs:
|
||||
vv = self.blobs[kk]
|
||||
if vv[0] != "empty":
|
||||
channel_values[k] = self.serde.loads_typed(vv)
|
||||
return channel_values
|
||||
|
||||
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Get a checkpoint tuple from the in-memory storage.
|
||||
|
||||
This method retrieves a checkpoint tuple from the in-memory storage based on the
|
||||
provided config. If the config contains a `checkpoint_id` key, the checkpoint with
|
||||
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
|
||||
for the given thread ID is retrieved.
|
||||
|
||||
Args:
|
||||
config: The config to use for retrieving the checkpoint.
|
||||
|
||||
Returns:
|
||||
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
|
||||
"""
|
||||
thread_id: str = config["configurable"]["thread_id"]
|
||||
checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "")
|
||||
if checkpoint_id := get_checkpoint_id(config):
|
||||
if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
|
||||
checkpoint, metadata, parent_checkpoint_id = saved
|
||||
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
|
||||
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
|
||||
return CheckpointTuple(
|
||||
config=config,
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
|
||||
),
|
||||
},
|
||||
metadata=self.serde.loads_typed(metadata),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
if checkpoints := self.storage[thread_id][checkpoint_ns]:
|
||||
checkpoint_id = max(checkpoints.keys())
|
||||
checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id]
|
||||
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
|
||||
checkpoint_ = self.serde.loads_typed(checkpoint)
|
||||
return CheckpointTuple(
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint_id,
|
||||
}
|
||||
},
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
|
||||
),
|
||||
},
|
||||
metadata=self.serde.loads_typed(metadata),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[CheckpointTuple]:
|
||||
"""List checkpoints from the in-memory storage.
|
||||
|
||||
This method retrieves a list of checkpoint tuples from the in-memory storage based
|
||||
on the provided criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria for metadata.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Yields:
|
||||
An iterator of matching checkpoint tuples.
|
||||
"""
|
||||
thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
|
||||
config_checkpoint_ns = (
|
||||
config["configurable"].get("checkpoint_ns") if config else None
|
||||
)
|
||||
config_checkpoint_id = get_checkpoint_id(config) if config else None
|
||||
for thread_id in thread_ids:
|
||||
for checkpoint_ns in self.storage[thread_id].keys():
|
||||
if (
|
||||
config_checkpoint_ns is not None
|
||||
and checkpoint_ns != config_checkpoint_ns
|
||||
):
|
||||
continue
|
||||
|
||||
for checkpoint_id, (
|
||||
checkpoint,
|
||||
metadata_b,
|
||||
parent_checkpoint_id,
|
||||
) in sorted(
|
||||
self.storage[thread_id][checkpoint_ns].items(),
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
):
|
||||
# filter by checkpoint ID from config
|
||||
if config_checkpoint_id and checkpoint_id != config_checkpoint_id:
|
||||
continue
|
||||
|
||||
# filter by checkpoint ID from `before` config
|
||||
if (
|
||||
before
|
||||
and (before_checkpoint_id := get_checkpoint_id(before))
|
||||
and checkpoint_id >= before_checkpoint_id
|
||||
):
|
||||
continue
|
||||
|
||||
# filter by metadata
|
||||
metadata = self.serde.loads_typed(metadata_b)
|
||||
if filter and not all(
|
||||
query_value == metadata.get(query_key)
|
||||
for query_key, query_value in filter.items()
|
||||
):
|
||||
continue
|
||||
|
||||
# limit search results
|
||||
if limit is not None and limit <= 0:
|
||||
break
|
||||
elif limit is not None:
|
||||
limit -= 1
|
||||
|
||||
writes = self.writes[
|
||||
(thread_id, checkpoint_ns, checkpoint_id)
|
||||
].values()
|
||||
|
||||
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
|
||||
|
||||
yield CheckpointTuple(
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint_id,
|
||||
}
|
||||
},
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id,
|
||||
checkpoint_ns,
|
||||
checkpoint_["channel_versions"],
|
||||
),
|
||||
},
|
||||
metadata=metadata,
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Save a checkpoint to the in-memory storage.
|
||||
|
||||
This method saves a checkpoint to the in-memory storage. The checkpoint is associated
|
||||
with the provided config.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the checkpoint.
|
||||
checkpoint: The checkpoint to save.
|
||||
metadata: Additional metadata to save with the checkpoint.
|
||||
new_versions: New versions as of this write
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
|
||||
"""
|
||||
c = checkpoint.copy()
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
checkpoint_ns = config["configurable"]["checkpoint_ns"]
|
||||
values: dict[str, Any] = c.pop("channel_values") # type: ignore[misc]
|
||||
for k, v in new_versions.items():
|
||||
self.blobs[(thread_id, checkpoint_ns, k, v)] = (
|
||||
self.serde.dumps_typed(values[k]) if k in values else ("empty", b"")
|
||||
)
|
||||
self.storage[thread_id][checkpoint_ns].update(
|
||||
{
|
||||
checkpoint["id"]: (
|
||||
self.serde.dumps_typed(c),
|
||||
self.serde.dumps_typed(get_checkpoint_metadata(config, metadata)),
|
||||
config["configurable"].get("checkpoint_id"), # parent
|
||||
)
|
||||
}
|
||||
)
|
||||
return {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint["id"],
|
||||
}
|
||||
}
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Save a list of writes to the in-memory storage.
|
||||
|
||||
This method saves a list of writes to the in-memory storage. The writes are associated
|
||||
with the provided config.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the writes.
|
||||
writes: The writes to save.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved writes' timestamp.
|
||||
"""
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
||||
checkpoint_id = config["configurable"]["checkpoint_id"]
|
||||
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
|
||||
outer_writes_ = self.writes.get(outer_key)
|
||||
for idx, (c, v) in enumerate(writes):
|
||||
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
|
||||
if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_:
|
||||
continue
|
||||
|
||||
self.writes[outer_key][inner_key] = (
|
||||
task_id,
|
||||
c,
|
||||
self.serde.dumps_typed(v),
|
||||
task_path,
|
||||
)
|
||||
|
||||
def delete_thread(self, thread_id: str) -> None:
|
||||
"""Delete all checkpoints and writes associated with a thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if thread_id in self.storage:
|
||||
del self.storage[thread_id]
|
||||
for k in list(self.writes.keys()):
|
||||
if k[0] == thread_id:
|
||||
del self.writes[k]
|
||||
for k in list(self.blobs.keys()):
|
||||
if k[0] == thread_id:
|
||||
del self.blobs[k]
|
||||
|
||||
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Asynchronous version of `get_tuple`.
|
||||
|
||||
This method is an asynchronous wrapper around `get_tuple` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to use for retrieving the checkpoint.
|
||||
|
||||
Returns:
|
||||
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
|
||||
"""
|
||||
return self.get_tuple(config)
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> AsyncIterator[CheckpointTuple]:
|
||||
"""Asynchronous version of `list`.
|
||||
|
||||
This method is an asynchronous wrapper around `list` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to use for listing the checkpoints.
|
||||
|
||||
Yields:
|
||||
An asynchronous iterator of checkpoint tuples.
|
||||
"""
|
||||
for item in self.list(config, filter=filter, before=before, limit=limit):
|
||||
yield item
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronous version of `put`.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the checkpoint.
|
||||
checkpoint: The checkpoint to save.
|
||||
metadata: Additional metadata to save with the checkpoint.
|
||||
new_versions: New versions as of this write
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
|
||||
"""
|
||||
return self.put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Asynchronous version of `put_writes`.
|
||||
|
||||
This method is an asynchronous wrapper around `put_writes` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the writes.
|
||||
writes: The writes to save, each as a (channel, value) pair.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return self.put_writes(config, writes, task_id, task_path)
|
||||
|
||||
async def adelete_thread(self, thread_id: str) -> None:
|
||||
"""Delete all checkpoints and writes associated with a thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return self.delete_thread(thread_id)
|
||||
|
||||
def get_next_version(self, current: str | None, channel: None) -> str:
|
||||
if current is None:
|
||||
current_v = 0
|
||||
elif isinstance(current, int):
|
||||
current_v = current
|
||||
else:
|
||||
current_v = int(current.split(".")[0])
|
||||
next_v = current_v + 1
|
||||
next_h = random.random()
|
||||
return f"{next_v:032}.{next_h:016}"
|
||||
|
||||
|
||||
MemorySaver = InMemorySaver # Kept for backwards compatibility
|
||||
|
||||
|
||||
class PersistentDict(defaultdict):
|
||||
"""Persistent dictionary with an API compatible with shelve and anydbm.
|
||||
|
||||
The dict is kept in memory, so the dictionary operations run as fast as
|
||||
a regular dictionary.
|
||||
|
||||
Write to disk is delayed until close or sync (similar to gdbm's fast mode).
|
||||
|
||||
Input file format is automatically discovered.
|
||||
Output file format is selectable between pickle, json, and csv.
|
||||
All three serialization formats are backed by fast C implementations.
|
||||
|
||||
Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
|
||||
self.flag = "c" # r=readonly, c=create, or n=new
|
||||
self.mode = None # None or an octal triple like 0644
|
||||
self.format = "pickle" # 'csv', 'json', or 'pickle'
|
||||
self.filename = filename
|
||||
super().__init__(*args, **kwds)
|
||||
|
||||
def sync(self) -> None:
|
||||
"Write dict to disk"
|
||||
if self.flag == "r":
|
||||
return
|
||||
tempname = self.filename + ".tmp"
|
||||
fileobj = open(tempname, "wb" if self.format == "pickle" else "w")
|
||||
try:
|
||||
self.dump(fileobj)
|
||||
except Exception:
|
||||
os.remove(tempname)
|
||||
raise
|
||||
finally:
|
||||
fileobj.close()
|
||||
shutil.move(tempname, self.filename) # atomic commit
|
||||
if self.mode is not None:
|
||||
os.chmod(self.filename, self.mode)
|
||||
|
||||
def close(self) -> None:
|
||||
self.sync()
|
||||
self.clear()
|
||||
|
||||
def __enter__(self) -> PersistentDict:
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def dump(self, fileobj: Any) -> None:
|
||||
if self.format == "pickle":
|
||||
pickle.dump(dict(self), fileobj, 2)
|
||||
else:
|
||||
raise NotImplementedError("Unknown format: " + repr(self.format))
|
||||
|
||||
def load(self) -> None:
|
||||
# try formats from most restrictive to least restrictive
|
||||
if self.flag == "n":
|
||||
return
|
||||
with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj:
|
||||
for loader in (pickle.load,):
|
||||
fileobj.seek(0)
|
||||
try:
|
||||
return self.update(loader(fileobj))
|
||||
except EOFError:
|
||||
return
|
||||
except Exception:
|
||||
logger.error(f"Failed to load file: {fileobj.name}")
|
||||
raise
|
||||
raise ValueError("File not in a supported format")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
|
||||
STRICT_MSGPACK_ENABLED = os.getenv("LANGGRAPH_STRICT_MSGPACK", "false").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
|
||||
_SENTINEL = cast(None, object())
|
||||
|
||||
SAFE_MSGPACK_TYPES: frozenset[tuple[str, ...]] = frozenset(
|
||||
{
|
||||
# datetime types
|
||||
("datetime", "datetime"),
|
||||
("datetime", "date"),
|
||||
("datetime", "time"),
|
||||
("datetime", "timedelta"),
|
||||
("datetime", "timezone"),
|
||||
# uuid
|
||||
("uuid", "UUID"),
|
||||
# numeric
|
||||
("decimal", "Decimal"),
|
||||
# collections
|
||||
("builtins", "set"),
|
||||
("builtins", "frozenset"),
|
||||
("collections", "deque"),
|
||||
# ip addresses
|
||||
("ipaddress", "IPv4Address"),
|
||||
("ipaddress", "IPv4Interface"),
|
||||
("ipaddress", "IPv4Network"),
|
||||
("ipaddress", "IPv6Address"),
|
||||
("ipaddress", "IPv6Interface"),
|
||||
("ipaddress", "IPv6Network"),
|
||||
# pathlib
|
||||
("pathlib", "Path"),
|
||||
("pathlib", "PosixPath"),
|
||||
("pathlib", "WindowsPath"),
|
||||
# pathlib in Python 3.13+
|
||||
("pathlib._local", "Path"),
|
||||
("pathlib._local", "PosixPath"),
|
||||
("pathlib._local", "WindowsPath"),
|
||||
# zoneinfo
|
||||
("zoneinfo", "ZoneInfo"),
|
||||
# regex
|
||||
("re", "compile"),
|
||||
# langchain-core messages (safe container types used by graph state)
|
||||
("langchain_core.messages.base", "BaseMessage"),
|
||||
("langchain_core.messages.base", "BaseMessageChunk"),
|
||||
("langchain_core.messages.human", "HumanMessage"),
|
||||
("langchain_core.messages.human", "HumanMessageChunk"),
|
||||
("langchain_core.messages.ai", "AIMessage"),
|
||||
("langchain_core.messages.ai", "AIMessageChunk"),
|
||||
("langchain_core.messages.system", "SystemMessage"),
|
||||
("langchain_core.messages.system", "SystemMessageChunk"),
|
||||
("langchain_core.messages.chat", "ChatMessage"),
|
||||
("langchain_core.messages.chat", "ChatMessageChunk"),
|
||||
("langchain_core.messages.tool", "ToolMessage"),
|
||||
("langchain_core.messages.tool", "ToolMessageChunk"),
|
||||
("langchain_core.messages.function", "FunctionMessage"),
|
||||
("langchain_core.messages.function", "FunctionMessageChunk"),
|
||||
("langchain_core.messages.modifier", "RemoveMessage"),
|
||||
# langchain-core document model
|
||||
("langchain_core.documents.base", "Document"),
|
||||
# langgraph
|
||||
("langgraph.types", "Send"),
|
||||
("langgraph.types", "Interrupt"),
|
||||
("langgraph.types", "Command"),
|
||||
("langgraph.types", "StateSnapshot"),
|
||||
("langgraph.types", "PregelTask"),
|
||||
("langgraph.types", "Overwrite"),
|
||||
("langgraph.store.base", "Item"),
|
||||
("langgraph.store.base", "GetOp"),
|
||||
}
|
||||
)
|
||||
|
||||
# Allowed (module, name, method) triples for EXT_METHOD_SINGLE_ARG.
|
||||
# Only these specific method invocations are permitted during deserialization.
|
||||
# This is separate from SAFE_MSGPACK_TYPES which only governs construction.
|
||||
SAFE_MSGPACK_METHODS: frozenset[tuple[str, str, str]] = frozenset(
|
||||
{
|
||||
("datetime", "datetime", "fromisoformat"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
AllowedMsgpackModules = Iterable[tuple[str, ...] | type]
|
||||
64
venv/Lib/site-packages/langgraph/checkpoint/serde/base.py
Normal file
64
venv/Lib/site-packages/langgraph/checkpoint/serde/base.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
class UntypedSerializerProtocol(Protocol):
|
||||
"""Protocol for serialization and deserialization of objects."""
|
||||
|
||||
def dumps(self, obj: Any) -> bytes: ...
|
||||
|
||||
def loads(self, data: bytes) -> Any: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SerializerProtocol(Protocol):
|
||||
"""Protocol for serialization and deserialization of objects.
|
||||
|
||||
- `dumps_typed`: Serialize an object to a tuple `(type, bytes)`.
|
||||
- `loads_typed`: Deserialize an object from a tuple `(type, bytes)`.
|
||||
|
||||
Valid implementations include the `pickle`, `json` and `orjson` modules.
|
||||
"""
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...
|
||||
|
||||
|
||||
class SerializerCompat(SerializerProtocol):
|
||||
def __init__(self, serde: UntypedSerializerProtocol) -> None:
|
||||
self.serde = serde
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
return type(obj).__name__, self.serde.dumps(obj)
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
return self.serde.loads(data[1])
|
||||
|
||||
|
||||
def maybe_add_typed_methods(
|
||||
serde: SerializerProtocol | UntypedSerializerProtocol,
|
||||
) -> SerializerProtocol:
|
||||
"""Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility."""
|
||||
|
||||
if not isinstance(serde, SerializerProtocol):
|
||||
return SerializerCompat(serde)
|
||||
|
||||
return serde
|
||||
|
||||
|
||||
class CipherProtocol(Protocol):
|
||||
"""Protocol for encryption and decryption of data.
|
||||
|
||||
- `encrypt`: Encrypt plaintext.
|
||||
- `decrypt`: Decrypt ciphertext.
|
||||
"""
|
||||
|
||||
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
|
||||
"""Encrypt plaintext. Returns a tuple `(cipher name, ciphertext)`."""
|
||||
...
|
||||
|
||||
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
|
||||
"""Decrypt ciphertext. Returns the plaintext."""
|
||||
...
|
||||
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
|
||||
class EncryptedSerializer(SerializerProtocol):
|
||||
"""Serializer that encrypts and decrypts data using an encryption protocol."""
|
||||
|
||||
def __init__(
|
||||
self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
|
||||
) -> None:
|
||||
self.cipher = cipher
|
||||
self.serde = serde
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
"""Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
|
||||
# serialize data
|
||||
typ, data = self.serde.dumps_typed(obj)
|
||||
# encrypt data
|
||||
ciphername, ciphertext = self.cipher.encrypt(data)
|
||||
# add cipher name to type
|
||||
return f"{typ}+{ciphername}", ciphertext
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
enc_cipher, ciphertext = data
|
||||
# unencrypted data
|
||||
if "+" not in enc_cipher:
|
||||
return self.serde.loads_typed(data)
|
||||
# extract cipher name
|
||||
typ, ciphername = enc_cipher.split("+", 1)
|
||||
# decrypt data
|
||||
decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
|
||||
# deserialize data
|
||||
return self.serde.loads_typed((typ, decrypted_data))
|
||||
|
||||
@classmethod
|
||||
def from_pycryptodome_aes(
|
||||
cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
|
||||
) -> "EncryptedSerializer":
|
||||
"""Create an `EncryptedSerializer` using AES encryption."""
|
||||
try:
|
||||
from Crypto.Cipher import AES
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
|
||||
) from None
|
||||
|
||||
# check if AES key is provided
|
||||
if "key" in kwargs:
|
||||
key: bytes = kwargs.pop("key")
|
||||
else:
|
||||
key_str = os.getenv("LANGGRAPH_AES_KEY")
|
||||
if key_str is None:
|
||||
raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
|
||||
key = key_str.encode()
|
||||
if len(key) not in (16, 24, 32):
|
||||
raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
|
||||
|
||||
# set default mode to EAX if not provided
|
||||
if kwargs.get("mode") is None:
|
||||
kwargs["mode"] = AES.MODE_EAX
|
||||
|
||||
class PycryptodomeAesCipher(CipherProtocol):
|
||||
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
|
||||
cipher = AES.new(key, **kwargs)
|
||||
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
|
||||
return "aes", cipher.nonce + tag + ciphertext
|
||||
|
||||
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
|
||||
assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
|
||||
nonce = ciphertext[:16]
|
||||
tag = ciphertext[16:32]
|
||||
actual_ciphertext = ciphertext[32:]
|
||||
|
||||
cipher = AES.new(key, **kwargs, nonce=nonce)
|
||||
return cipher.decrypt_and_verify(actual_ciphertext, tag)
|
||||
|
||||
return cls(PycryptodomeAesCipher(), serde)
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from threading import Lock
|
||||
from typing import TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SerdeEvent(TypedDict):
|
||||
kind: str
|
||||
module: str
|
||||
name: str
|
||||
method: NotRequired[str]
|
||||
|
||||
|
||||
SerdeEventListener = Callable[[SerdeEvent], None]
|
||||
|
||||
_listeners: list[SerdeEventListener] = []
|
||||
_listeners_lock = Lock()
|
||||
|
||||
|
||||
def register_serde_event_listener(listener: SerdeEventListener) -> Callable[[], None]:
|
||||
"""Register a listener for serde allowlist events."""
|
||||
with _listeners_lock:
|
||||
_listeners.append(listener)
|
||||
|
||||
def unregister() -> None:
|
||||
with _listeners_lock:
|
||||
try:
|
||||
_listeners.remove(listener)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return unregister
|
||||
|
||||
|
||||
def emit_serde_event(event: SerdeEvent) -> None:
|
||||
"""Emit a serde event to all listeners.
|
||||
|
||||
Listener failures are isolated and logged.
|
||||
"""
|
||||
with _listeners_lock:
|
||||
listeners = tuple(_listeners)
|
||||
for listener in listeners:
|
||||
try:
|
||||
listener(event)
|
||||
except Exception:
|
||||
logger.warning("Serde listener failed", exc_info=True)
|
||||
827
venv/Lib/site-packages/langgraph/checkpoint/serde/jsonplus.py
Normal file
827
venv/Lib/site-packages/langgraph/checkpoint/serde/jsonplus.py
Normal file
@@ -0,0 +1,827 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import decimal
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from enum import Enum
|
||||
from inspect import isclass
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Interface,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import ormsgpack
|
||||
from langchain_core.load.load import Reviver
|
||||
|
||||
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
from langgraph.checkpoint.serde.event_hooks import emit_serde_event
|
||||
from langgraph.checkpoint.serde.types import SendProtocol
|
||||
from langgraph.store.base import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.checkpoint.serde._msgpack import (
|
||||
AllowedMsgpackModules,
|
||||
)
|
||||
from langgraph.checkpoint.serde.types import SendProtocol
|
||||
|
||||
LC_REVIVER = Reviver()
|
||||
EMPTY_BYTES = b""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonPlusSerializer(SerializerProtocol):
|
||||
"""Serializer that uses ormsgpack, with optional fallbacks.
|
||||
|
||||
!!! warning
|
||||
|
||||
Security note: This serializer is intended for use within the `BaseCheckpointSaver`
|
||||
class and called within the Pregel loop. It should not be used on untrusted
|
||||
python objects. If an attacker can write directly to your checkpoint database,
|
||||
they may be able to trigger code execution when data is deserialized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pickle_fallback: bool = False,
|
||||
allowed_json_modules: Iterable[tuple[str, ...]] | Literal[True] | None = None,
|
||||
allowed_msgpack_modules: (
|
||||
AllowedMsgpackModules | Literal[True] | None
|
||||
) = _lg_msgpack._SENTINEL,
|
||||
__unpack_ext_hook__: Callable[[int, bytes], Any] | None = None,
|
||||
) -> None:
|
||||
if allowed_msgpack_modules is _lg_msgpack._SENTINEL:
|
||||
if _lg_msgpack.STRICT_MSGPACK_ENABLED:
|
||||
allowed_msgpack_modules = None
|
||||
else:
|
||||
allowed_msgpack_modules = True
|
||||
self.pickle_fallback = pickle_fallback
|
||||
self._allowed_json_modules: set[tuple[str, ...]] | Literal[True] | None = (
|
||||
_normalize_allowlist(allowed_json_modules)
|
||||
)
|
||||
self._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
|
||||
|
||||
self._custom_unpack_ext_hook = __unpack_ext_hook__ is not None
|
||||
self._unpack_ext_hook = (
|
||||
__unpack_ext_hook__
|
||||
if __unpack_ext_hook__ is not None
|
||||
else _create_msgpack_ext_hook(self._allowed_msgpack_modules)
|
||||
)
|
||||
|
||||
def with_msgpack_allowlist(
|
||||
self, extra_allowlist: Iterable[tuple[str, ...] | type]
|
||||
) -> JsonPlusSerializer:
|
||||
"""Return a new serializer with a merged msgpack allowlist."""
|
||||
base_allowlist = self._allowed_msgpack_modules
|
||||
if base_allowlist is True or base_allowlist is False:
|
||||
return self
|
||||
elif base_allowlist:
|
||||
base_allowlist = set(base_allowlist)
|
||||
else:
|
||||
base_allowlist = set()
|
||||
extra = _normalize_module_keys(tuple(extra_allowlist))
|
||||
merged = base_allowlist | extra
|
||||
if merged == base_allowlist:
|
||||
return self
|
||||
allowed_msgpack_modules: AllowedMsgpackModules | Literal[True] | None
|
||||
if merged:
|
||||
allowed_msgpack_modules = tuple(merged)
|
||||
elif isinstance(self._allowed_msgpack_modules, set):
|
||||
allowed_msgpack_modules = tuple(self._allowed_msgpack_modules)
|
||||
else:
|
||||
allowed_msgpack_modules = self._allowed_msgpack_modules
|
||||
|
||||
clone = copy.copy(self)
|
||||
clone._allowed_json_modules = _normalize_allowlist(self._allowed_json_modules)
|
||||
clone._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
|
||||
if not clone._custom_unpack_ext_hook:
|
||||
clone._unpack_ext_hook = _create_msgpack_ext_hook(
|
||||
clone._allowed_msgpack_modules
|
||||
)
|
||||
return clone
|
||||
|
||||
def _encode_constructor_args(
|
||||
self,
|
||||
constructor: Callable | type[Any],
|
||||
*,
|
||||
method: None | str | Sequence[None | str] = None,
|
||||
args: Sequence[Any] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
out = {
|
||||
"lc": 2,
|
||||
"type": "constructor",
|
||||
"id": (*constructor.__module__.split("."), constructor.__name__),
|
||||
}
|
||||
if method is not None:
|
||||
out["method"] = method
|
||||
if args is not None:
|
||||
out["args"] = args
|
||||
if kwargs is not None:
|
||||
out["kwargs"] = kwargs
|
||||
return out
|
||||
|
||||
def _reviver(self, value: dict[str, Any]) -> Any:
|
||||
if self._allowed_json_modules and (
|
||||
value.get("lc", None) == 2
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
try:
|
||||
return self._revive_lc2(value)
|
||||
except InvalidModuleError as e:
|
||||
logger.warning(
|
||||
"Object %s is not in the deserialization allowlist.\n%s",
|
||||
value["id"],
|
||||
e.message,
|
||||
)
|
||||
|
||||
return LC_REVIVER(value)
|
||||
|
||||
def _revive_lc2(self, value: dict[str, Any]) -> Any:
|
||||
self._check_allowed_json_modules(value)
|
||||
|
||||
[*module, name] = value["id"]
|
||||
try:
|
||||
mod = importlib.import_module(".".join(module))
|
||||
cls = getattr(mod, name)
|
||||
method = value.get("method")
|
||||
if isinstance(method, str):
|
||||
methods = [getattr(cls, method)]
|
||||
elif isinstance(method, list):
|
||||
methods = [cls if m is None else getattr(cls, m) for m in method]
|
||||
else:
|
||||
methods = [cls]
|
||||
args = value.get("args")
|
||||
kwargs = value.get("kwargs")
|
||||
for method in methods:
|
||||
try:
|
||||
if isclass(method) and issubclass(method, BaseException):
|
||||
return None
|
||||
if args and kwargs:
|
||||
return method(*args, **kwargs)
|
||||
elif args:
|
||||
return method(*args)
|
||||
elif kwargs:
|
||||
return method(**kwargs)
|
||||
else:
|
||||
return method()
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _check_allowed_json_modules(self, value: dict[str, Any]) -> None:
|
||||
needed = tuple(value["id"])
|
||||
method = value.get("method")
|
||||
if isinstance(method, list):
|
||||
method_display = ",".join(m or "<init>" for m in method)
|
||||
elif isinstance(method, str):
|
||||
method_display = method
|
||||
else:
|
||||
method_display = "<init>"
|
||||
|
||||
dotted = ".".join(needed)
|
||||
if not self._allowed_json_modules:
|
||||
raise InvalidModuleError(
|
||||
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
|
||||
"No allowed_json_modules configured.\n\n"
|
||||
"Unblock with ONE of:\n"
|
||||
f" • JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
|
||||
" • (DANGEROUS) JsonPlusSerializer(allowed_json_modules=True)\n\n"
|
||||
"Note: Prefix allowlists are intentionally unsupported; prefer exact symbols "
|
||||
"or plain-JSON representations revived without import-time side effects."
|
||||
)
|
||||
|
||||
if self._allowed_json_modules is True:
|
||||
return
|
||||
if needed in self._allowed_json_modules:
|
||||
return
|
||||
|
||||
raise InvalidModuleError(
|
||||
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
|
||||
"Symbol is not in the deserialization allowlist.\n\n"
|
||||
"Add exactly this symbol to unblock:\n"
|
||||
f" JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
|
||||
"Or, as a last resort (DANGEROUS):\n"
|
||||
" JsonPlusSerializer(allowed_json_modules=True)"
|
||||
)
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
if obj is None:
|
||||
return "null", EMPTY_BYTES
|
||||
elif isinstance(obj, bytes):
|
||||
return "bytes", obj
|
||||
elif isinstance(obj, bytearray):
|
||||
return "bytearray", obj
|
||||
else:
|
||||
try:
|
||||
return "msgpack", _msgpack_enc(obj)
|
||||
except ormsgpack.MsgpackEncodeError as exc:
|
||||
if self.pickle_fallback:
|
||||
return "pickle", pickle.dumps(obj)
|
||||
raise exc
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
type_, data_ = data
|
||||
if type_ == "null":
|
||||
return None
|
||||
elif type_ == "bytes":
|
||||
return data_
|
||||
elif type_ == "bytearray":
|
||||
return bytearray(data_)
|
||||
elif type_ == "json":
|
||||
return json.loads(data_, object_hook=self._reviver)
|
||||
elif type_ == "msgpack":
|
||||
return ormsgpack.unpackb(
|
||||
data_, ext_hook=self._unpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
elif self.pickle_fallback and type_ == "pickle":
|
||||
return pickle.loads(data_)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown serialization type: {type_}")
|
||||
|
||||
|
||||
# --- msgpack ---
|
||||
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG = 0
|
||||
EXT_CONSTRUCTOR_POS_ARGS = 1
|
||||
EXT_CONSTRUCTOR_KW_ARGS = 2
|
||||
EXT_METHOD_SINGLE_ARG = 3
|
||||
EXT_PYDANTIC_V1 = 4
|
||||
EXT_PYDANTIC_V2 = 5
|
||||
EXT_NUMPY_ARRAY = 6
|
||||
|
||||
|
||||
def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
|
||||
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
|
||||
return ormsgpack.Ext(
|
||||
EXT_PYDANTIC_V2,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.model_dump(),
|
||||
"model_validate_json",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.get_secret_value(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
|
||||
return ormsgpack.Ext(
|
||||
EXT_PYDANTIC_V1,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.dict(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj._asdict(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, pathlib.Path):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, re.Pattern):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
("re", "compile", (obj.pattern, obj.flags)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, UUID):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, decimal.Decimal):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (set, frozenset, deque)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, datetime):
|
||||
return ormsgpack.Ext(
|
||||
EXT_METHOD_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.isoformat(),
|
||||
"fromisoformat",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, timedelta):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
(obj.days, obj.seconds, obj.microseconds),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, date):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
(obj.year, obj.month, obj.day),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, time):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{
|
||||
"hour": obj.hour,
|
||||
"minute": obj.minute,
|
||||
"second": obj.second,
|
||||
"microsecond": obj.microsecond,
|
||||
"tzinfo": obj.tzinfo,
|
||||
"fold": obj.fold,
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, timezone):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.__getinitargs__(), # type: ignore[attr-defined]
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, ZoneInfo):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, Enum):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, SendProtocol):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
|
||||
),
|
||||
)
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{
|
||||
field.name: getattr(obj, field.name)
|
||||
for field in dataclasses.fields(obj)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, Item):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{k: getattr(obj, k) for k in obj.__slots__},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif (np_mod := sys.modules.get("numpy")) is not None and isinstance(
|
||||
obj, np_mod.ndarray
|
||||
):
|
||||
order = "F" if obj.flags.f_contiguous and not obj.flags.c_contiguous else "C"
|
||||
if obj.flags.c_contiguous:
|
||||
mv = memoryview(obj)
|
||||
try:
|
||||
meta = (obj.dtype.str, obj.shape, order, mv)
|
||||
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
|
||||
finally:
|
||||
mv.release()
|
||||
else:
|
||||
buf = obj.tobytes(order="A")
|
||||
meta = (obj.dtype.str, obj.shape, order, buf)
|
||||
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
|
||||
|
||||
elif isinstance(obj, BaseException):
|
||||
return repr(obj)
|
||||
else:
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
|
||||
|
||||
|
||||
def _create_msgpack_ext_hook(
|
||||
allowed_modules: set[tuple[str, ...]] | Literal[True] | None,
|
||||
) -> Callable[[int, bytes], Any]:
|
||||
"""Create msgpack ext hook with allowlist.
|
||||
|
||||
Args:
|
||||
allowed_modules: Set of (module, name) tuples that are allowed to be
|
||||
deserialized, or True to allow all with warnings for unregistered types, or None to only allow safe types.
|
||||
|
||||
Returns:
|
||||
An ext_hook function for use with ormsgpack.unpackb.
|
||||
"""
|
||||
|
||||
def _check_allowed(module: str, name: str) -> bool:
|
||||
"""Check if type is allowed. Returns True if allowed, False if blocked."""
|
||||
key = (module, name)
|
||||
|
||||
if key in _lg_msgpack.SAFE_MSGPACK_TYPES:
|
||||
return True
|
||||
|
||||
if allowed_modules is True:
|
||||
# default is to warn but allow unregistered types
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_unregistered_allowed",
|
||||
"module": module,
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Deserializing unregistered type %s.%s from checkpoint. "
|
||||
"This will be blocked in a future version. "
|
||||
"Add to allowed_msgpack_modules to silence: [(%r, %r)]",
|
||||
module,
|
||||
name,
|
||||
module,
|
||||
name,
|
||||
)
|
||||
return True
|
||||
if allowed_modules is not None:
|
||||
if key in allowed_modules:
|
||||
return True
|
||||
# strict mode blocks unregistered types
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_blocked",
|
||||
"module": module,
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Blocked deserialization of %s.%s - not in allowed_msgpack_modules. "
|
||||
"Add to allowed_msgpack_modules to allow: [(%r, %r)]",
|
||||
module,
|
||||
name,
|
||||
module,
|
||||
name,
|
||||
)
|
||||
return False
|
||||
|
||||
def _check_allowed_method(module: str, name: str, method: str) -> bool:
|
||||
"""Check if a method invocation is allowed."""
|
||||
key = (module, name, method)
|
||||
if key in _lg_msgpack.SAFE_MSGPACK_METHODS:
|
||||
return True
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_method_blocked",
|
||||
"module": module,
|
||||
"name": name,
|
||||
"method": method,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Blocked deserialization of method call %s.%s.%s - "
|
||||
"not in allowed methods set.",
|
||||
module,
|
||||
name,
|
||||
method,
|
||||
)
|
||||
return False
|
||||
|
||||
def ext_hook(code: int, data: bytes) -> Any:
|
||||
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
# We default to returning the raw data. If the user
|
||||
# is using this in the context of a pydantic state, etc., then
|
||||
# it would be validated upon construction.
|
||||
return tup[2]
|
||||
# module, name, arg
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_CONSTRUCTOR_POS_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, args
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_CONSTRUCTOR_KW_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_METHOD_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed_method(tup[0], tup[1], tup[3]):
|
||||
return tup[2]
|
||||
# module, name, arg, method
|
||||
return getattr(
|
||||
getattr(importlib.import_module(tup[0]), tup[1]), tup[3]
|
||||
)(tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_PYDANTIC_V1:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs
|
||||
cls = getattr(importlib.import_module(tup[0]), tup[1])
|
||||
try:
|
||||
return cls(**tup[2])
|
||||
except Exception:
|
||||
return cls.construct(**tup[2])
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
try:
|
||||
return tup[2]
|
||||
except NameError:
|
||||
return None
|
||||
elif code == EXT_PYDANTIC_V2:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs, method
|
||||
cls = getattr(importlib.import_module(tup[0]), tup[1])
|
||||
try:
|
||||
return cls(**tup[2])
|
||||
except Exception:
|
||||
return cls.model_construct(**tup[2])
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
try:
|
||||
return tup[2]
|
||||
except NameError:
|
||||
return None
|
||||
elif code == EXT_NUMPY_ARRAY:
|
||||
try:
|
||||
import numpy as _np
|
||||
|
||||
dtype_str, shape, order, buf = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
|
||||
return arr.reshape(shape, order=order)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
return ext_hook
|
||||
|
||||
|
||||
# Aliasing in case anyone imported it directly
|
||||
_msgpack_ext_hook = _create_msgpack_ext_hook(allowed_modules=None)
|
||||
|
||||
|
||||
def _msgpack_ext_hook_to_json(code: int, data: bytes) -> Any:
|
||||
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
if tup[0] == "uuid" and tup[1] == "UUID":
|
||||
hex_ = tup[2]
|
||||
return (
|
||||
f"{hex_[:8]}-{hex_[8:12]}-{hex_[12:16]}-{hex_[16:20]}-{hex_[20:]}"
|
||||
)
|
||||
# module, name, arg
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_CONSTRUCTOR_POS_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
if tup[0] == "langgraph.types" and tup[1] == "Send":
|
||||
from langgraph.types import Send # type: ignore
|
||||
|
||||
return Send(*tup[2])
|
||||
# module, name, args
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_CONSTRUCTOR_KW_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, args
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_METHOD_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, arg, method
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_PYDANTIC_V1:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, kwargs
|
||||
return tup[2]
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
return
|
||||
elif code == EXT_PYDANTIC_V2:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, kwargs, method
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_NUMPY_ARRAY:
|
||||
try:
|
||||
import numpy as _np
|
||||
|
||||
dtype_str, shape, order, buf = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
|
||||
return arr.reshape(shape, order=order).tolist()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
class InvalidModuleError(Exception):
|
||||
"""Exception raised when a module is not in the allowlist."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
_option = (
|
||||
ormsgpack.OPT_NON_STR_KEYS
|
||||
| ormsgpack.OPT_PASSTHROUGH_DATACLASS
|
||||
| ormsgpack.OPT_PASSTHROUGH_DATETIME
|
||||
| ormsgpack.OPT_PASSTHROUGH_ENUM
|
||||
| ormsgpack.OPT_PASSTHROUGH_UUID
|
||||
| ormsgpack.OPT_REPLACE_SURROGATES
|
||||
)
|
||||
|
||||
|
||||
def _msgpack_enc(data: Any) -> bytes:
|
||||
return ormsgpack.packb(data, default=_msgpack_default, option=_option)
|
||||
|
||||
|
||||
def _normalize_allowlist(
|
||||
allowlist: AllowedMsgpackModules | Literal[True] | None,
|
||||
) -> set[tuple[str, ...]] | Literal[True] | None:
|
||||
if allowlist is True:
|
||||
return allowlist
|
||||
elif allowlist:
|
||||
return _normalize_module_keys(allowlist)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_module_keys(
|
||||
modules: AllowedMsgpackModules,
|
||||
) -> set[tuple[str, ...]]:
|
||||
normalized: set[tuple[str, ...]] = set()
|
||||
for module in modules:
|
||||
if isclass(module):
|
||||
normalized.add((module.__module__, module.__name__))
|
||||
else:
|
||||
normalized.add(cast(tuple[str, ...], module))
|
||||
return normalized
|
||||
51
venv/Lib/site-packages/langgraph/checkpoint/serde/types.py
Normal file
51
venv/Lib/site-packages/langgraph/checkpoint/serde/types.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
ERROR = "__error__"
|
||||
SCHEDULED = "__scheduled__"
|
||||
INTERRUPT = "__interrupt__"
|
||||
RESUME = "__resume__"
|
||||
TASKS = "__pregel_tasks"
|
||||
|
||||
Value = TypeVar("Value", covariant=True)
|
||||
Update = TypeVar("Update", contravariant=True)
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class ChannelProtocol(Protocol[Value, Update, C]):
|
||||
# Mirrors langgraph.channels.base.BaseChannel
|
||||
@property
|
||||
def ValueType(self) -> Any: ...
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> Any: ...
|
||||
|
||||
def checkpoint(self) -> C | None: ...
|
||||
|
||||
def from_checkpoint(self, checkpoint: C | None) -> Self: ...
|
||||
|
||||
def update(self, values: Sequence[Update]) -> bool: ...
|
||||
|
||||
def get(self) -> Value: ...
|
||||
|
||||
def consume(self) -> bool: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SendProtocol(Protocol):
|
||||
# Mirrors langgraph.constants.Send
|
||||
node: str
|
||||
arg: Any
|
||||
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
def __eq__(self, value: object) -> bool: ...
|
||||
Reference in New Issue
Block a user