initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
import os
from collections.abc import Iterable
from typing import cast
STRICT_MSGPACK_ENABLED = os.getenv("LANGGRAPH_STRICT_MSGPACK", "false").lower() in (
"1",
"true",
"yes",
)
_SENTINEL = cast(None, object())
SAFE_MSGPACK_TYPES: frozenset[tuple[str, ...]] = frozenset(
{
# datetime types
("datetime", "datetime"),
("datetime", "date"),
("datetime", "time"),
("datetime", "timedelta"),
("datetime", "timezone"),
# uuid
("uuid", "UUID"),
# numeric
("decimal", "Decimal"),
# collections
("builtins", "set"),
("builtins", "frozenset"),
("collections", "deque"),
# ip addresses
("ipaddress", "IPv4Address"),
("ipaddress", "IPv4Interface"),
("ipaddress", "IPv4Network"),
("ipaddress", "IPv6Address"),
("ipaddress", "IPv6Interface"),
("ipaddress", "IPv6Network"),
# pathlib
("pathlib", "Path"),
("pathlib", "PosixPath"),
("pathlib", "WindowsPath"),
# pathlib in Python 3.13+
("pathlib._local", "Path"),
("pathlib._local", "PosixPath"),
("pathlib._local", "WindowsPath"),
# zoneinfo
("zoneinfo", "ZoneInfo"),
# regex
("re", "compile"),
# langchain-core messages (safe container types used by graph state)
("langchain_core.messages.base", "BaseMessage"),
("langchain_core.messages.base", "BaseMessageChunk"),
("langchain_core.messages.human", "HumanMessage"),
("langchain_core.messages.human", "HumanMessageChunk"),
("langchain_core.messages.ai", "AIMessage"),
("langchain_core.messages.ai", "AIMessageChunk"),
("langchain_core.messages.system", "SystemMessage"),
("langchain_core.messages.system", "SystemMessageChunk"),
("langchain_core.messages.chat", "ChatMessage"),
("langchain_core.messages.chat", "ChatMessageChunk"),
("langchain_core.messages.tool", "ToolMessage"),
("langchain_core.messages.tool", "ToolMessageChunk"),
("langchain_core.messages.function", "FunctionMessage"),
("langchain_core.messages.function", "FunctionMessageChunk"),
("langchain_core.messages.modifier", "RemoveMessage"),
# langchain-core document model
("langchain_core.documents.base", "Document"),
# langgraph
("langgraph.types", "Send"),
("langgraph.types", "Interrupt"),
("langgraph.types", "Command"),
("langgraph.types", "StateSnapshot"),
("langgraph.types", "PregelTask"),
("langgraph.types", "Overwrite"),
("langgraph.store.base", "Item"),
("langgraph.store.base", "GetOp"),
}
)
# Allowed (module, name, method) triples for EXT_METHOD_SINGLE_ARG.
# Only these specific method invocations are permitted during deserialization.
# This is separate from SAFE_MSGPACK_TYPES which only governs construction.
SAFE_MSGPACK_METHODS: frozenset[tuple[str, str, str]] = frozenset(
{
("datetime", "datetime", "fromisoformat"),
}
)
AllowedMsgpackModules = Iterable[tuple[str, ...] | type]

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
class UntypedSerializerProtocol(Protocol):
"""Protocol for serialization and deserialization of objects."""
def dumps(self, obj: Any) -> bytes: ...
def loads(self, data: bytes) -> Any: ...
@runtime_checkable
class SerializerProtocol(Protocol):
"""Protocol for serialization and deserialization of objects.
- `dumps_typed`: Serialize an object to a tuple `(type, bytes)`.
- `loads_typed`: Deserialize an object from a tuple `(type, bytes)`.
Valid implementations include the `pickle`, `json` and `orjson` modules.
"""
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...
class SerializerCompat(SerializerProtocol):
def __init__(self, serde: UntypedSerializerProtocol) -> None:
self.serde = serde
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return type(obj).__name__, self.serde.dumps(obj)
def loads_typed(self, data: tuple[str, bytes]) -> Any:
return self.serde.loads(data[1])
def maybe_add_typed_methods(
serde: SerializerProtocol | UntypedSerializerProtocol,
) -> SerializerProtocol:
"""Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility."""
if not isinstance(serde, SerializerProtocol):
return SerializerCompat(serde)
return serde
class CipherProtocol(Protocol):
"""Protocol for encryption and decryption of data.
- `encrypt`: Encrypt plaintext.
- `decrypt`: Decrypt ciphertext.
"""
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
"""Encrypt plaintext. Returns a tuple `(cipher name, ciphertext)`."""
...
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
"""Decrypt ciphertext. Returns the plaintext."""
...

View File

@@ -0,0 +1,80 @@
import os
from typing import Any
from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
class EncryptedSerializer(SerializerProtocol):
"""Serializer that encrypts and decrypts data using an encryption protocol."""
def __init__(
self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
) -> None:
self.cipher = cipher
self.serde = serde
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
"""Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
# serialize data
typ, data = self.serde.dumps_typed(obj)
# encrypt data
ciphername, ciphertext = self.cipher.encrypt(data)
# add cipher name to type
return f"{typ}+{ciphername}", ciphertext
def loads_typed(self, data: tuple[str, bytes]) -> Any:
enc_cipher, ciphertext = data
# unencrypted data
if "+" not in enc_cipher:
return self.serde.loads_typed(data)
# extract cipher name
typ, ciphername = enc_cipher.split("+", 1)
# decrypt data
decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
# deserialize data
return self.serde.loads_typed((typ, decrypted_data))
@classmethod
def from_pycryptodome_aes(
cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
) -> "EncryptedSerializer":
"""Create an `EncryptedSerializer` using AES encryption."""
try:
from Crypto.Cipher import AES
except ImportError:
raise ImportError(
"Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
) from None
# check if AES key is provided
if "key" in kwargs:
key: bytes = kwargs.pop("key")
else:
key_str = os.getenv("LANGGRAPH_AES_KEY")
if key_str is None:
raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
key = key_str.encode()
if len(key) not in (16, 24, 32):
raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
# set default mode to EAX if not provided
if kwargs.get("mode") is None:
kwargs["mode"] = AES.MODE_EAX
class PycryptodomeAesCipher(CipherProtocol):
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
cipher = AES.new(key, **kwargs)
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
return "aes", cipher.nonce + tag + ciphertext
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
nonce = ciphertext[:16]
tag = ciphertext[16:32]
actual_ciphertext = ciphertext[32:]
cipher = AES.new(key, **kwargs, nonce=nonce)
return cipher.decrypt_and_verify(actual_ciphertext, tag)
return cls(PycryptodomeAesCipher(), serde)

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import logging
from collections.abc import Callable
from threading import Lock
from typing import TypedDict
from typing_extensions import NotRequired
logger = logging.getLogger(__name__)
class SerdeEvent(TypedDict):
kind: str
module: str
name: str
method: NotRequired[str]
SerdeEventListener = Callable[[SerdeEvent], None]
_listeners: list[SerdeEventListener] = []
_listeners_lock = Lock()
def register_serde_event_listener(listener: SerdeEventListener) -> Callable[[], None]:
"""Register a listener for serde allowlist events."""
with _listeners_lock:
_listeners.append(listener)
def unregister() -> None:
with _listeners_lock:
try:
_listeners.remove(listener)
except ValueError:
pass
return unregister
def emit_serde_event(event: SerdeEvent) -> None:
"""Emit a serde event to all listeners.
Listener failures are isolated and logged.
"""
with _listeners_lock:
listeners = tuple(_listeners)
for listener in listeners:
try:
listener(event)
except Exception:
logger.warning("Serde listener failed", exc_info=True)

View File

@@ -0,0 +1,827 @@
from __future__ import annotations
import copy
import dataclasses
import decimal
import importlib
import json
import logging
import pathlib
import pickle
import re
import sys
from collections import deque
from collections.abc import Callable, Iterable, Sequence
from datetime import date, datetime, time, timedelta, timezone
from enum import Enum
from inspect import isclass
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import UUID
from zoneinfo import ZoneInfo
import ormsgpack
from langchain_core.load.load import Reviver
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.event_hooks import emit_serde_event
from langgraph.checkpoint.serde.types import SendProtocol
from langgraph.store.base import Item
if TYPE_CHECKING:
from langgraph.checkpoint.serde._msgpack import (
AllowedMsgpackModules,
)
from langgraph.checkpoint.serde.types import SendProtocol
LC_REVIVER = Reviver()
EMPTY_BYTES = b""
logger = logging.getLogger(__name__)
class JsonPlusSerializer(SerializerProtocol):
"""Serializer that uses ormsgpack, with optional fallbacks.
!!! warning
Security note: This serializer is intended for use within the `BaseCheckpointSaver`
class and called within the Pregel loop. It should not be used on untrusted
python objects. If an attacker can write directly to your checkpoint database,
they may be able to trigger code execution when data is deserialized.
"""
def __init__(
self,
*,
pickle_fallback: bool = False,
allowed_json_modules: Iterable[tuple[str, ...]] | Literal[True] | None = None,
allowed_msgpack_modules: (
AllowedMsgpackModules | Literal[True] | None
) = _lg_msgpack._SENTINEL,
__unpack_ext_hook__: Callable[[int, bytes], Any] | None = None,
) -> None:
if allowed_msgpack_modules is _lg_msgpack._SENTINEL:
if _lg_msgpack.STRICT_MSGPACK_ENABLED:
allowed_msgpack_modules = None
else:
allowed_msgpack_modules = True
self.pickle_fallback = pickle_fallback
self._allowed_json_modules: set[tuple[str, ...]] | Literal[True] | None = (
_normalize_allowlist(allowed_json_modules)
)
self._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
self._custom_unpack_ext_hook = __unpack_ext_hook__ is not None
self._unpack_ext_hook = (
__unpack_ext_hook__
if __unpack_ext_hook__ is not None
else _create_msgpack_ext_hook(self._allowed_msgpack_modules)
)
def with_msgpack_allowlist(
self, extra_allowlist: Iterable[tuple[str, ...] | type]
) -> JsonPlusSerializer:
"""Return a new serializer with a merged msgpack allowlist."""
base_allowlist = self._allowed_msgpack_modules
if base_allowlist is True or base_allowlist is False:
return self
elif base_allowlist:
base_allowlist = set(base_allowlist)
else:
base_allowlist = set()
extra = _normalize_module_keys(tuple(extra_allowlist))
merged = base_allowlist | extra
if merged == base_allowlist:
return self
allowed_msgpack_modules: AllowedMsgpackModules | Literal[True] | None
if merged:
allowed_msgpack_modules = tuple(merged)
elif isinstance(self._allowed_msgpack_modules, set):
allowed_msgpack_modules = tuple(self._allowed_msgpack_modules)
else:
allowed_msgpack_modules = self._allowed_msgpack_modules
clone = copy.copy(self)
clone._allowed_json_modules = _normalize_allowlist(self._allowed_json_modules)
clone._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
if not clone._custom_unpack_ext_hook:
clone._unpack_ext_hook = _create_msgpack_ext_hook(
clone._allowed_msgpack_modules
)
return clone
def _encode_constructor_args(
self,
constructor: Callable | type[Any],
*,
method: None | str | Sequence[None | str] = None,
args: Sequence[Any] | None = None,
kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
out = {
"lc": 2,
"type": "constructor",
"id": (*constructor.__module__.split("."), constructor.__name__),
}
if method is not None:
out["method"] = method
if args is not None:
out["args"] = args
if kwargs is not None:
out["kwargs"] = kwargs
return out
def _reviver(self, value: dict[str, Any]) -> Any:
if self._allowed_json_modules and (
value.get("lc", None) == 2
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
try:
return self._revive_lc2(value)
except InvalidModuleError as e:
logger.warning(
"Object %s is not in the deserialization allowlist.\n%s",
value["id"],
e.message,
)
return LC_REVIVER(value)
def _revive_lc2(self, value: dict[str, Any]) -> Any:
self._check_allowed_json_modules(value)
[*module, name] = value["id"]
try:
mod = importlib.import_module(".".join(module))
cls = getattr(mod, name)
method = value.get("method")
if isinstance(method, str):
methods = [getattr(cls, method)]
elif isinstance(method, list):
methods = [cls if m is None else getattr(cls, m) for m in method]
else:
methods = [cls]
args = value.get("args")
kwargs = value.get("kwargs")
for method in methods:
try:
if isclass(method) and issubclass(method, BaseException):
return None
if args and kwargs:
return method(*args, **kwargs)
elif args:
return method(*args)
elif kwargs:
return method(**kwargs)
else:
return method()
except Exception:
continue
except Exception:
return None
def _check_allowed_json_modules(self, value: dict[str, Any]) -> None:
needed = tuple(value["id"])
method = value.get("method")
if isinstance(method, list):
method_display = ",".join(m or "<init>" for m in method)
elif isinstance(method, str):
method_display = method
else:
method_display = "<init>"
dotted = ".".join(needed)
if not self._allowed_json_modules:
raise InvalidModuleError(
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
"No allowed_json_modules configured.\n\n"
"Unblock with ONE of:\n"
f" • JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
" • (DANGEROUS) JsonPlusSerializer(allowed_json_modules=True)\n\n"
"Note: Prefix allowlists are intentionally unsupported; prefer exact symbols "
"or plain-JSON representations revived without import-time side effects."
)
if self._allowed_json_modules is True:
return
if needed in self._allowed_json_modules:
return
raise InvalidModuleError(
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
"Symbol is not in the deserialization allowlist.\n\n"
"Add exactly this symbol to unblock:\n"
f" JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
"Or, as a last resort (DANGEROUS):\n"
" JsonPlusSerializer(allowed_json_modules=True)"
)
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
if obj is None:
return "null", EMPTY_BYTES
elif isinstance(obj, bytes):
return "bytes", obj
elif isinstance(obj, bytearray):
return "bytearray", obj
else:
try:
return "msgpack", _msgpack_enc(obj)
except ormsgpack.MsgpackEncodeError as exc:
if self.pickle_fallback:
return "pickle", pickle.dumps(obj)
raise exc
def loads_typed(self, data: tuple[str, bytes]) -> Any:
type_, data_ = data
if type_ == "null":
return None
elif type_ == "bytes":
return data_
elif type_ == "bytearray":
return bytearray(data_)
elif type_ == "json":
return json.loads(data_, object_hook=self._reviver)
elif type_ == "msgpack":
return ormsgpack.unpackb(
data_, ext_hook=self._unpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
elif self.pickle_fallback and type_ == "pickle":
return pickle.loads(data_)
else:
raise NotImplementedError(f"Unknown serialization type: {type_}")
# --- msgpack ---
EXT_CONSTRUCTOR_SINGLE_ARG = 0
EXT_CONSTRUCTOR_POS_ARGS = 1
EXT_CONSTRUCTOR_KW_ARGS = 2
EXT_METHOD_SINGLE_ARG = 3
EXT_PYDANTIC_V1 = 4
EXT_PYDANTIC_V2 = 5
EXT_NUMPY_ARRAY = 6
def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
return ormsgpack.Ext(
EXT_PYDANTIC_V2,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.model_dump(),
"model_validate_json",
),
),
)
elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.get_secret_value(),
),
),
)
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
return ormsgpack.Ext(
EXT_PYDANTIC_V1,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.dict(),
),
),
)
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj._asdict(),
),
),
)
elif isinstance(obj, pathlib.Path):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
),
)
elif isinstance(obj, re.Pattern):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
("re", "compile", (obj.pattern, obj.flags)),
),
)
elif isinstance(obj, UUID):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
),
)
elif isinstance(obj, decimal.Decimal):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (set, frozenset, deque)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
),
)
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
),
)
elif isinstance(obj, datetime):
return ormsgpack.Ext(
EXT_METHOD_SINGLE_ARG,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.isoformat(),
"fromisoformat",
),
),
)
elif isinstance(obj, timedelta):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.days, obj.seconds, obj.microseconds),
),
),
)
elif isinstance(obj, date):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
(obj.year, obj.month, obj.day),
),
),
)
elif isinstance(obj, time):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
"hour": obj.hour,
"minute": obj.minute,
"second": obj.second,
"microsecond": obj.microsecond,
"tzinfo": obj.tzinfo,
"fold": obj.fold,
},
),
),
)
elif isinstance(obj, timezone):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.__getinitargs__(), # type: ignore[attr-defined]
),
),
)
elif isinstance(obj, ZoneInfo):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
),
)
elif isinstance(obj, Enum):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_SINGLE_ARG,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
),
)
elif isinstance(obj, SendProtocol):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_POS_ARGS,
_msgpack_enc(
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{
field.name: getattr(obj, field.name)
for field in dataclasses.fields(obj)
},
),
),
)
elif isinstance(obj, Item):
return ormsgpack.Ext(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{k: getattr(obj, k) for k in obj.__slots__},
),
),
)
elif (np_mod := sys.modules.get("numpy")) is not None and isinstance(
obj, np_mod.ndarray
):
order = "F" if obj.flags.f_contiguous and not obj.flags.c_contiguous else "C"
if obj.flags.c_contiguous:
mv = memoryview(obj)
try:
meta = (obj.dtype.str, obj.shape, order, mv)
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
finally:
mv.release()
else:
buf = obj.tobytes(order="A")
meta = (obj.dtype.str, obj.shape, order, buf)
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
elif isinstance(obj, BaseException):
return repr(obj)
else:
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
def _create_msgpack_ext_hook(
allowed_modules: set[tuple[str, ...]] | Literal[True] | None,
) -> Callable[[int, bytes], Any]:
"""Create msgpack ext hook with allowlist.
Args:
allowed_modules: Set of (module, name) tuples that are allowed to be
deserialized, or True to allow all with warnings for unregistered types, or None to only allow safe types.
Returns:
An ext_hook function for use with ormsgpack.unpackb.
"""
def _check_allowed(module: str, name: str) -> bool:
"""Check if type is allowed. Returns True if allowed, False if blocked."""
key = (module, name)
if key in _lg_msgpack.SAFE_MSGPACK_TYPES:
return True
if allowed_modules is True:
# default is to warn but allow unregistered types
emit_serde_event(
{
"kind": "msgpack_unregistered_allowed",
"module": module,
"name": name,
}
)
logger.warning(
"Deserializing unregistered type %s.%s from checkpoint. "
"This will be blocked in a future version. "
"Add to allowed_msgpack_modules to silence: [(%r, %r)]",
module,
name,
module,
name,
)
return True
if allowed_modules is not None:
if key in allowed_modules:
return True
# strict mode blocks unregistered types
emit_serde_event(
{
"kind": "msgpack_blocked",
"module": module,
"name": name,
}
)
logger.warning(
"Blocked deserialization of %s.%s - not in allowed_msgpack_modules. "
"Add to allowed_msgpack_modules to allow: [(%r, %r)]",
module,
name,
module,
name,
)
return False
def _check_allowed_method(module: str, name: str, method: str) -> bool:
"""Check if a method invocation is allowed."""
key = (module, name, method)
if key in _lg_msgpack.SAFE_MSGPACK_METHODS:
return True
emit_serde_event(
{
"kind": "msgpack_method_blocked",
"module": module,
"name": name,
"method": method,
}
)
logger.warning(
"Blocked deserialization of method call %s.%s.%s - "
"not in allowed methods set.",
module,
name,
method,
)
return False
def ext_hook(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
# We default to returning the raw data. If the user
# is using this in the context of a pydantic state, etc., then
# it would be validated upon construction.
return tup[2]
# module, name, arg
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
except Exception:
return None
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, args
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
except Exception:
return None
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
except Exception:
return None
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed_method(tup[0], tup[1], tup[3]):
return tup[2]
# module, name, arg, method
return getattr(
getattr(importlib.import_module(tup[0]), tup[1]), tup[3]
)(tup[2])
except Exception:
return None
elif code == EXT_PYDANTIC_V1:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.construct(**tup[2])
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
try:
return tup[2]
except NameError:
return None
elif code == EXT_PYDANTIC_V2:
try:
tup = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
if not _check_allowed(tup[0], tup[1]):
return tup[2]
# module, name, kwargs, method
cls = getattr(importlib.import_module(tup[0]), tup[1])
try:
return cls(**tup[2])
except Exception:
return cls.model_construct(**tup[2])
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
try:
return tup[2]
except NameError:
return None
elif code == EXT_NUMPY_ARRAY:
try:
import numpy as _np
dtype_str, shape, order, buf = ormsgpack.unpackb(
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
)
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
return arr.reshape(shape, order=order)
except Exception:
return None
return None
return ext_hook
# Aliasing in case anyone imported it directly
_msgpack_ext_hook = _create_msgpack_ext_hook(allowed_modules=None)
def _msgpack_ext_hook_to_json(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
if tup[0] == "uuid" and tup[1] == "UUID":
hex_ = tup[2]
return (
f"{hex_[:8]}-{hex_[8:12]}-{hex_[12:16]}-{hex_[16:20]}-{hex_[20:]}"
)
# module, name, arg
return tup[2]
except Exception:
return
elif code == EXT_CONSTRUCTOR_POS_ARGS:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
if tup[0] == "langgraph.types" and tup[1] == "Send":
from langgraph.types import Send # type: ignore
return Send(*tup[2])
# module, name, args
return tup[2]
except Exception:
return
elif code == EXT_CONSTRUCTOR_KW_ARGS:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, args
return tup[2]
except Exception:
return
elif code == EXT_METHOD_SINGLE_ARG:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, arg, method
return tup[2]
except Exception:
return
elif code == EXT_PYDANTIC_V1:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, kwargs
return tup[2]
except Exception:
# for pydantic objects we can't find/reconstruct
# let's return the kwargs dict instead
return
elif code == EXT_PYDANTIC_V2:
try:
tup = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
# module, name, kwargs, method
return tup[2]
except Exception:
return
elif code == EXT_NUMPY_ARRAY:
try:
import numpy as _np
dtype_str, shape, order, buf = ormsgpack.unpackb(
data,
ext_hook=_msgpack_ext_hook_to_json,
option=ormsgpack.OPT_NON_STR_KEYS,
)
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
return arr.reshape(shape, order=order).tolist()
except Exception:
return
class InvalidModuleError(Exception):
"""Exception raised when a module is not in the allowlist."""
def __init__(self, message: str):
self.message = message
_option = (
ormsgpack.OPT_NON_STR_KEYS
| ormsgpack.OPT_PASSTHROUGH_DATACLASS
| ormsgpack.OPT_PASSTHROUGH_DATETIME
| ormsgpack.OPT_PASSTHROUGH_ENUM
| ormsgpack.OPT_PASSTHROUGH_UUID
| ormsgpack.OPT_REPLACE_SURROGATES
)
def _msgpack_enc(data: Any) -> bytes:
return ormsgpack.packb(data, default=_msgpack_default, option=_option)
def _normalize_allowlist(
allowlist: AllowedMsgpackModules | Literal[True] | None,
) -> set[tuple[str, ...]] | Literal[True] | None:
if allowlist is True:
return allowlist
elif allowlist:
return _normalize_module_keys(allowlist)
else:
return None
def _normalize_module_keys(
modules: AllowedMsgpackModules,
) -> set[tuple[str, ...]]:
normalized: set[tuple[str, ...]] = set()
for module in modules:
if isclass(module):
normalized.add((module.__module__, module.__name__))
else:
normalized.add(cast(tuple[str, ...], module))
return normalized

View File

@@ -0,0 +1,51 @@
from collections.abc import Sequence
from typing import (
Any,
Protocol,
TypeVar,
runtime_checkable,
)
from typing_extensions import Self
ERROR = "__error__"
SCHEDULED = "__scheduled__"
INTERRUPT = "__interrupt__"
RESUME = "__resume__"
TASKS = "__pregel_tasks"
Value = TypeVar("Value", covariant=True)
Update = TypeVar("Update", contravariant=True)
C = TypeVar("C")
class ChannelProtocol(Protocol[Value, Update, C]):
# Mirrors langgraph.channels.base.BaseChannel
@property
def ValueType(self) -> Any: ...
@property
def UpdateType(self) -> Any: ...
def checkpoint(self) -> C | None: ...
def from_checkpoint(self, checkpoint: C | None) -> Self: ...
def update(self, values: Sequence[Update]) -> bool: ...
def get(self) -> Value: ...
def consume(self) -> bool: ...
@runtime_checkable
class SendProtocol(Protocol):
# Mirrors langgraph.constants.Send
node: str
arg: Any
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def __eq__(self, value: object) -> bool: ...