initial commit
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4
venv/Lib/site-packages/langgraph/_internal/__init__.py
Normal file
4
venv/Lib/site-packages/langgraph/_internal/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Internal modules for LangGraph.
|
||||
|
||||
This module is not part of the public API, and thus stability is not guaranteed.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
31
venv/Lib/site-packages/langgraph/_internal/_cache.py
Normal file
31
venv/Lib/site-packages/langgraph/_internal/_cache.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _freeze(obj: Any, depth: int = 10) -> Hashable:
|
||||
if isinstance(obj, Hashable) or depth <= 0:
|
||||
# already hashable, no need to freeze
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
# sort keys so {"a":1,"b":2} == {"b":2,"a":1}
|
||||
return tuple(sorted((k, _freeze(v, depth - 1)) for k, v in obj.items()))
|
||||
elif isinstance(obj, Sequence):
|
||||
return tuple(_freeze(x, depth - 1) for x in obj)
|
||||
# numpy / pandas etc. can provide their own .tobytes()
|
||||
elif hasattr(obj, "tobytes"):
|
||||
return (
|
||||
type(obj).__name__,
|
||||
obj.tobytes(),
|
||||
obj.shape if hasattr(obj, "shape") else None,
|
||||
)
|
||||
return obj # strings, ints, dataclasses with frozen=True, etc.
|
||||
|
||||
|
||||
def default_cache_key(*args: Any, **kwargs: Any) -> str | bytes:
|
||||
"""Default cache key function that uses the arguments and keyword arguments to generate a hashable key."""
|
||||
import pickle
|
||||
|
||||
# protocol 5 strikes a good balance between speed and size
|
||||
return pickle.dumps((_freeze(args), _freeze(kwargs)), protocol=5, fix_imports=False)
|
||||
329
venv/Lib/site-packages/langgraph/_internal/_config.py
Normal file
329
venv/Lib/site-packages/langgraph/_internal/_config.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import ChainMap
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import getenv
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import (
|
||||
CONFIG_KEYS,
|
||||
COPIABLE_KEYS,
|
||||
var_child_runnable_config,
|
||||
)
|
||||
from langgraph.checkpoint.base import CheckpointMetadata
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINT_ID,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
NS_END,
|
||||
NS_SEP,
|
||||
)
|
||||
|
||||
DEFAULT_RECURSION_LIMIT = int(getenv("LANGGRAPH_DEFAULT_RECURSION_LIMIT", "10000"))
|
||||
|
||||
|
||||
def recast_checkpoint_ns(ns: str) -> str:
|
||||
"""Remove task IDs from checkpoint namespace.
|
||||
|
||||
Args:
|
||||
ns: The checkpoint namespace with task IDs.
|
||||
|
||||
Returns:
|
||||
str: The checkpoint namespace without task IDs.
|
||||
"""
|
||||
return NS_SEP.join(
|
||||
part.split(NS_END)[0] for part in ns.split(NS_SEP) if not part.isdigit()
|
||||
)
|
||||
|
||||
|
||||
def patch_configurable(
|
||||
config: RunnableConfig | None, patch: dict[str, Any]
|
||||
) -> RunnableConfig:
|
||||
if config is None:
|
||||
return {CONF: patch}
|
||||
elif CONF not in config:
|
||||
return {**config, CONF: patch}
|
||||
else:
|
||||
return {**config, CONF: {**config[CONF], **patch}}
|
||||
|
||||
|
||||
def patch_checkpoint_map(
|
||||
config: RunnableConfig | None, metadata: CheckpointMetadata | None
|
||||
) -> RunnableConfig:
|
||||
if config is None:
|
||||
return config
|
||||
elif parents := (metadata.get("parents") if metadata else None):
|
||||
conf = config[CONF]
|
||||
return patch_configurable(
|
||||
config,
|
||||
{
|
||||
CONFIG_KEY_CHECKPOINT_MAP: {
|
||||
**parents,
|
||||
conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID],
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
"""Merge multiple configs into one.
|
||||
|
||||
Args:
|
||||
*configs: The configs to merge.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged config.
|
||||
"""
|
||||
base: RunnableConfig = {}
|
||||
# Even though the keys aren't literals, this is correct
|
||||
# because both dicts are the same type
|
||||
for config in configs:
|
||||
if config is None:
|
||||
continue
|
||||
for key, value in config.items():
|
||||
if not value:
|
||||
continue
|
||||
if key == "metadata":
|
||||
if base_value := base.get(key):
|
||||
base[key] = {**base_value, **value} # type: ignore
|
||||
else:
|
||||
base[key] = value # type: ignore[literal-required]
|
||||
elif key == "tags":
|
||||
if base_value := base.get(key):
|
||||
base[key] = [*base_value, *value] # type: ignore
|
||||
else:
|
||||
base[key] = value # type: ignore[literal-required]
|
||||
elif key == CONF:
|
||||
if base_value := base.get(key):
|
||||
base[key] = {**base_value, **value} # type: ignore[dict-item]
|
||||
else:
|
||||
base[key] = value
|
||||
elif key == "callbacks":
|
||||
base_callbacks = base.get("callbacks")
|
||||
# callbacks can be either None, list[handler] or manager
|
||||
# so merging two callbacks values has 6 cases
|
||||
if isinstance(value, list):
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = value.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
base["callbacks"] = base_callbacks + value
|
||||
else:
|
||||
# base_callbacks is a manager
|
||||
mngr = base_callbacks.copy()
|
||||
for callback in value:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
elif isinstance(value, BaseCallbackManager):
|
||||
# value is a manager
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = value.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
mngr = value.copy()
|
||||
for callback in base_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
base["callbacks"] = base_callbacks.merge(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif key == "recursion_limit":
|
||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||
base["recursion_limit"] = config["recursion_limit"]
|
||||
else:
|
||||
base[key] = config[key] # type: ignore[literal-required]
|
||||
if CONF not in base:
|
||||
base[CONF] = {}
|
||||
return base
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
recursion_limit: int | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
run_name: str | None = None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> RunnableConfig:
|
||||
"""Patch a config with new values.
|
||||
|
||||
Args:
|
||||
config: The config to patch.
|
||||
callbacks: The callbacks to set.
|
||||
recursion_limit: The recursion limit to set.
|
||||
max_concurrency: The max number of concurrent steps to run, which also applies to parallelized steps.
|
||||
run_name: The run name to set.
|
||||
configurable: The configurable to set.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The patched config.
|
||||
"""
|
||||
config = config.copy() if config is not None else {}
|
||||
if callbacks is not None:
|
||||
# If we're replacing callbacks, we need to unset run_name
|
||||
# As that should apply only to the same run as the original callbacks
|
||||
config["callbacks"] = callbacks
|
||||
if "run_name" in config:
|
||||
del config["run_name"]
|
||||
if "run_id" in config:
|
||||
del config["run_id"]
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
if max_concurrency is not None:
|
||||
config["max_concurrency"] = max_concurrency
|
||||
if run_name is not None:
|
||||
config["run_name"] = run_name
|
||||
if configurable is not None:
|
||||
config[CONF] = {**config.get(CONF, {}), **configurable}
|
||||
return config
|
||||
|
||||
|
||||
def get_callback_manager_for_config(
|
||||
config: RunnableConfig, tags: Sequence[str] | None = None
|
||||
) -> CallbackManager:
|
||||
"""Get a callback manager for a config.
|
||||
|
||||
Args:
|
||||
config: The config.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
# merge tags
|
||||
all_tags = config.get("tags")
|
||||
if all_tags is not None and tags is not None:
|
||||
all_tags = [*all_tags, *tags]
|
||||
elif tags is not None:
|
||||
all_tags = list(tags)
|
||||
# use existing callbacks if they exist
|
||||
if (callbacks := config.get("callbacks")) and isinstance(
|
||||
callbacks, CallbackManager
|
||||
):
|
||||
if all_tags:
|
||||
callbacks.add_tags(all_tags)
|
||||
if metadata := config.get("metadata"):
|
||||
callbacks.add_metadata(metadata)
|
||||
return callbacks
|
||||
else:
|
||||
# otherwise create a new manager
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=all_tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def get_async_callback_manager_for_config(
|
||||
config: RunnableConfig,
|
||||
tags: Sequence[str] | None = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Get an async callback manager for a config.
|
||||
|
||||
Args:
|
||||
config: The config.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# merge tags
|
||||
all_tags = config.get("tags")
|
||||
if all_tags is not None and tags is not None:
|
||||
all_tags = [*all_tags, *tags]
|
||||
elif tags is not None:
|
||||
all_tags = list(tags)
|
||||
# use existing callbacks if they exist
|
||||
if (callbacks := config.get("callbacks")) and isinstance(
|
||||
callbacks, AsyncCallbackManager
|
||||
):
|
||||
if all_tags:
|
||||
callbacks.add_tags(all_tags)
|
||||
if metadata := config.get("metadata"):
|
||||
callbacks.add_metadata(metadata)
|
||||
return callbacks
|
||||
else:
|
||||
# otherwise create a new manager
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=all_tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def _is_not_empty(value: Any) -> bool:
|
||||
if isinstance(value, (list, tuple, dict)):
|
||||
return len(value) > 0
|
||||
else:
|
||||
return value is not None
|
||||
|
||||
|
||||
def ensure_config(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
"""Return a config with all keys, merging any provided configs.
|
||||
|
||||
Args:
|
||||
*configs: Configs to merge before ensuring defaults.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged and ensured config.
|
||||
"""
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata=ChainMap(),
|
||||
callbacks=None,
|
||||
recursion_limit=DEFAULT_RECURSION_LIMIT,
|
||||
configurable={},
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
{
|
||||
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||
for k, v in var_config.items()
|
||||
if _is_not_empty(v)
|
||||
},
|
||||
)
|
||||
for config in configs:
|
||||
if config is None:
|
||||
continue
|
||||
for k, v in config.items():
|
||||
if _is_not_empty(v) and k in CONFIG_KEYS:
|
||||
if k == CONF:
|
||||
empty[k] = cast(dict, v).copy()
|
||||
else:
|
||||
empty[k] = v # type: ignore[literal-required]
|
||||
for k, v in config.items():
|
||||
if _is_not_empty(v) and k not in CONFIG_KEYS:
|
||||
empty[CONF][k] = v
|
||||
_empty_metadata = empty["metadata"]
|
||||
for key, value in empty[CONF].items():
|
||||
if _exclude_as_metadata(key, value, _empty_metadata):
|
||||
continue
|
||||
_empty_metadata[key] = value
|
||||
return empty
|
||||
|
||||
|
||||
_OMIT = ("key", "token", "secret", "password", "auth")
|
||||
|
||||
|
||||
def _exclude_as_metadata(key: str, value: Any, metadata: Mapping[str, Any]) -> bool:
|
||||
key_lower = key.casefold()
|
||||
return (
|
||||
key.startswith("__")
|
||||
or not isinstance(value, (str, int, float, bool))
|
||||
or key in metadata
|
||||
or any(substr in key_lower for substr in _OMIT)
|
||||
)
|
||||
112
venv/Lib/site-packages/langgraph/_internal/_constants.py
Normal file
112
venv/Lib/site-packages/langgraph/_internal/_constants.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Constants used for Pregel operations."""
|
||||
|
||||
import sys
|
||||
from typing import Literal, cast
|
||||
|
||||
# --- Reserved write keys ---
|
||||
INPUT = sys.intern("__input__")
|
||||
# for values passed as input to the graph
|
||||
INTERRUPT = sys.intern("__interrupt__")
|
||||
# for dynamic interrupts raised by nodes
|
||||
RESUME = sys.intern("__resume__")
|
||||
# for values passed to resume a node after an interrupt
|
||||
ERROR = sys.intern("__error__")
|
||||
# for errors raised by nodes
|
||||
NO_WRITES = sys.intern("__no_writes__")
|
||||
# marker to signal node didn't write anything
|
||||
TASKS = sys.intern("__pregel_tasks")
|
||||
# for Send objects returned by nodes/edges, corresponds to PUSH below
|
||||
RETURN = sys.intern("__return__")
|
||||
# for writes of a task where we simply record the return value
|
||||
PREVIOUS = sys.intern("__previous__")
|
||||
# the implicit branch that handles each node's Control values
|
||||
|
||||
|
||||
# --- Reserved cache namespaces ---
|
||||
CACHE_NS_WRITES = sys.intern("__pregel_ns_writes")
|
||||
# cache namespace for node writes
|
||||
|
||||
# --- Reserved config.configurable keys ---
|
||||
CONFIG_KEY_SEND = sys.intern("__pregel_send")
|
||||
# holds the `write` function that accepts writes to state/edges/reserved keys
|
||||
CONFIG_KEY_READ = sys.intern("__pregel_read")
|
||||
# holds the `read` function that returns a copy of the current state
|
||||
CONFIG_KEY_CALL = sys.intern("__pregel_call")
|
||||
# holds the `call` function that accepts a node/func, args and returns a future
|
||||
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
|
||||
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
|
||||
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
|
||||
# holds a `StreamProtocol` passed from parent graph to child graphs
|
||||
CONFIG_KEY_CACHE = sys.intern("__pregel_cache")
|
||||
# holds a `BaseCache` made available to subgraphs
|
||||
CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming")
|
||||
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
|
||||
CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id")
|
||||
# holds the task ID for the current task
|
||||
CONFIG_KEY_THREAD_ID = sys.intern("thread_id")
|
||||
# holds the thread ID for the current invocation
|
||||
CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map")
|
||||
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
|
||||
CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id")
|
||||
# holds the current checkpoint_id, if any
|
||||
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
|
||||
# holds the current checkpoint_ns, "" for root graph
|
||||
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
|
||||
# holds a callback to be called when a node is finished
|
||||
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
|
||||
# holds a mutable dict for temporary storage scoped to the current task
|
||||
CONFIG_KEY_RUNNER_SUBMIT = sys.intern("__pregel_runner_submit")
|
||||
# holds a function that receives tasks from runner, executes them and returns results
|
||||
CONFIG_KEY_DURABILITY = sys.intern("__pregel_durability")
|
||||
# holds the durability mode, one of "sync", "async", or "exit"
|
||||
CONFIG_KEY_RUNTIME = sys.intern("__pregel_runtime")
|
||||
# holds a `Runtime` instance with context, store, stream writer, etc.
|
||||
CONFIG_KEY_RESUME_MAP = sys.intern("__pregel_resume_map")
|
||||
# holds a mapping of task ns -> resume value for resuming tasks
|
||||
|
||||
# --- Other constants ---
|
||||
PUSH = sys.intern("__pregel_push")
|
||||
# denotes push-style tasks, ie. those created by Send objects
|
||||
PULL = sys.intern("__pregel_pull")
|
||||
# denotes pull-style tasks, ie. those triggered by edges
|
||||
NS_SEP = sys.intern("|")
|
||||
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
|
||||
NS_END = sys.intern(":")
|
||||
# for checkpoint_ns, for each level, separates the namespace from the task_id
|
||||
CONF = cast(Literal["configurable"], sys.intern("configurable"))
|
||||
# key for the configurable dict in RunnableConfig
|
||||
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
|
||||
# the task_id to use for writes that are not associated with a task
|
||||
OVERWRITE = sys.intern("__overwrite__")
|
||||
# dict key for the overwrite value, used as `{'__overwrite__': value}`
|
||||
|
||||
# redefined to avoid circular import with langgraph.constants
|
||||
_TAG_HIDDEN = sys.intern("langsmith:hidden")
|
||||
|
||||
RESERVED = {
|
||||
_TAG_HIDDEN,
|
||||
# reserved write keys
|
||||
INPUT,
|
||||
INTERRUPT,
|
||||
RESUME,
|
||||
ERROR,
|
||||
NO_WRITES,
|
||||
# reserved config.configurable keys
|
||||
CONFIG_KEY_SEND,
|
||||
CONFIG_KEY_READ,
|
||||
CONFIG_KEY_CHECKPOINTER,
|
||||
CONFIG_KEY_STREAM,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_RESUMING,
|
||||
CONFIG_KEY_TASK_ID,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_CHECKPOINT_ID,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
CONFIG_KEY_RESUME_MAP,
|
||||
# other constants
|
||||
PUSH,
|
||||
PULL,
|
||||
NS_SEP,
|
||||
NS_END,
|
||||
CONF,
|
||||
}
|
||||
213
venv/Lib/site-packages/langgraph/_internal/_fields.py
Normal file
213
venv/Lib/site-packages/langgraph/_internal/_fields.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import types
|
||||
import weakref
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Annotated, Any, Optional, Union, get_origin, get_type_hints
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, ReadOnly, Required
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
|
||||
|
||||
def _is_optional_type(type_: Any) -> bool:
|
||||
"""Check if a type is Optional."""
|
||||
|
||||
# Handle new union syntax (PEP 604): str | None
|
||||
if isinstance(type_, types.UnionType):
|
||||
return any(
|
||||
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
|
||||
)
|
||||
|
||||
if hasattr(type_, "__origin__") and hasattr(type_, "__args__"):
|
||||
origin = get_origin(type_)
|
||||
if origin is Optional:
|
||||
return True
|
||||
if origin is Union:
|
||||
return any(
|
||||
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
|
||||
)
|
||||
if origin is Annotated:
|
||||
return _is_optional_type(type_.__args__[0])
|
||||
return origin is None
|
||||
if hasattr(type_, "__bound__") and type_.__bound__ is not None:
|
||||
return _is_optional_type(type_.__bound__)
|
||||
return type_ is None
|
||||
|
||||
|
||||
def _is_required_type(type_: Any) -> bool | None:
|
||||
"""Check if an annotation is marked as Required/NotRequired.
|
||||
|
||||
Returns:
|
||||
- True if required
|
||||
- False if not required
|
||||
- None if not annotated with either
|
||||
"""
|
||||
origin = get_origin(type_)
|
||||
if origin is Required:
|
||||
return True
|
||||
if origin is NotRequired:
|
||||
return False
|
||||
if origin is Annotated or getattr(origin, "__args__", None):
|
||||
# See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated
|
||||
return _is_required_type(type_.__args__[0])
|
||||
return None
|
||||
|
||||
|
||||
def _is_readonly_type(type_: Any) -> bool:
|
||||
"""Check if an annotation is marked as ReadOnly.
|
||||
|
||||
Returns:
|
||||
- True if is read only
|
||||
- False if not read only
|
||||
"""
|
||||
|
||||
# See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier
|
||||
origin = get_origin(type_)
|
||||
if origin is Annotated:
|
||||
return _is_readonly_type(type_.__args__[0])
|
||||
if origin is ReadOnly:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_DEFAULT_KEYS: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
def get_field_default(name: str, type_: Any, schema: type[Any]) -> Any:
|
||||
"""Determine the default value for a field in a state schema.
|
||||
|
||||
This is based on:
|
||||
If TypedDict:
|
||||
- Required/NotRequired
|
||||
- total=False -> everything optional
|
||||
- Type annotation (Optional/Union[None])
|
||||
"""
|
||||
optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS)
|
||||
irq = _is_required_type(type_)
|
||||
if name in optional_keys:
|
||||
# Either total=False or explicit NotRequired.
|
||||
# No type annotation trumps this.
|
||||
if irq:
|
||||
# Unless it's earlier versions of python & explicit Required
|
||||
return ...
|
||||
return None
|
||||
if irq is not None:
|
||||
if irq:
|
||||
# Handle Required[<type>]
|
||||
# (we already handled NotRequired and total=False)
|
||||
return ...
|
||||
# Handle NotRequired[<type>] for earlier versions of python
|
||||
return None
|
||||
if dataclasses.is_dataclass(schema):
|
||||
field_info = next(
|
||||
(f for f in dataclasses.fields(schema) if f.name == name), None
|
||||
)
|
||||
if field_info:
|
||||
if (
|
||||
field_info.default is not dataclasses.MISSING
|
||||
and field_info.default is not ...
|
||||
):
|
||||
return field_info.default
|
||||
elif field_info.default_factory is not dataclasses.MISSING:
|
||||
return field_info.default_factory()
|
||||
# Note, we ignore ReadOnly attributes,
|
||||
# as they don't make much sense. (we don't care if you mutate the state in your node)
|
||||
# and mutating state in your node has no effect on our graph state.
|
||||
# Base case is the annotation
|
||||
if _is_optional_type(type_):
|
||||
return None
|
||||
return ...
|
||||
|
||||
|
||||
def get_enhanced_type_hints(
|
||||
type: type[Any],
|
||||
) -> Generator[tuple[str, Any, Any, str | None], None, None]:
|
||||
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
|
||||
for name, typ in get_type_hints(type).items():
|
||||
default = None
|
||||
description = None
|
||||
|
||||
# Pydantic models
|
||||
try:
|
||||
if hasattr(type, "model_fields") and name in type.model_fields:
|
||||
field = type.model_fields[name]
|
||||
|
||||
if hasattr(field, "description") and field.description is not None:
|
||||
description = field.description
|
||||
|
||||
if hasattr(field, "default") and field.default is not None:
|
||||
default = field.default
|
||||
if (
|
||||
hasattr(default, "__class__")
|
||||
and getattr(default.__class__, "__name__", "")
|
||||
== "PydanticUndefinedType"
|
||||
):
|
||||
default = None
|
||||
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
# TypedDict, dataclass
|
||||
try:
|
||||
if hasattr(type, "__dict__"):
|
||||
type_dict = getattr(type, "__dict__")
|
||||
|
||||
if name in type_dict:
|
||||
default = type_dict[name]
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
yield name, typ, default, description
|
||||
|
||||
|
||||
def get_update_as_tuples(input: Any, keys: Sequence[str]) -> list[tuple[str, Any]]:
|
||||
"""Get Pydantic state update as a list of (key, value) tuples."""
|
||||
if isinstance(input, BaseModel):
|
||||
keep = input.model_fields_set
|
||||
defaults = {k: v.default for k, v in type(input).model_fields.items()}
|
||||
else:
|
||||
keep = None
|
||||
defaults = {}
|
||||
|
||||
# NOTE: This behavior for Pydantic is somewhat inelegant,
|
||||
# but we keep around for backwards compatibility
|
||||
# if input is a Pydantic model, only update values
|
||||
# that are different from the default values or in the keep set
|
||||
return [
|
||||
(k, value)
|
||||
for k in keys
|
||||
if (value := getattr(input, k, MISSING)) is not MISSING
|
||||
and (
|
||||
value is not None
|
||||
or defaults.get(k, MISSING) is not None
|
||||
or (keep is not None and k in keep)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANNOTATED_KEYS_CACHE: weakref.WeakKeyDictionary[type[Any], tuple[str, ...]] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def get_cached_annotated_keys(obj: type[Any]) -> tuple[str, ...]:
|
||||
"""Return cached annotated keys for a Python class."""
|
||||
if obj in ANNOTATED_KEYS_CACHE:
|
||||
return ANNOTATED_KEYS_CACHE[obj]
|
||||
if isinstance(obj, type):
|
||||
keys: list[str] = []
|
||||
for base in reversed(obj.__mro__):
|
||||
ann = base.__dict__.get("__annotations__")
|
||||
# In Python 3.14+, Pydantic models use descriptors for __annotations__
|
||||
# so we need to fall back to getattr if __dict__.get returns None
|
||||
if ann is None:
|
||||
ann = getattr(base, "__annotations__", None)
|
||||
if ann is None or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
keys.extend(ann.keys())
|
||||
return ANNOTATED_KEYS_CACHE.setdefault(obj, tuple(keys))
|
||||
else:
|
||||
raise TypeError(f"Expected a type, got {type(obj)}. ")
|
||||
220
venv/Lib/site-packages/langgraph/_internal/_future.py
Normal file
220
venv/Lib/site-packages/langgraph/_internal/_future.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Awaitable, Coroutine, Generator
|
||||
from typing import TypeVar, cast
|
||||
|
||||
T = TypeVar("T")
|
||||
AnyFuture = asyncio.Future | concurrent.futures.Future
|
||||
|
||||
CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)
|
||||
EAGER_NOT_SUPPORTED = sys.version_info < (3, 12)
|
||||
|
||||
|
||||
def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
|
||||
# Tries to call Future.get_loop() if it's available.
|
||||
# Otherwise fallbacks to using the old '_loop' property.
|
||||
try:
|
||||
get_loop = fut.get_loop
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
return get_loop()
|
||||
return fut._loop
|
||||
|
||||
|
||||
def _convert_future_exc(exc: BaseException) -> BaseException:
|
||||
exc_class = type(exc)
|
||||
if exc_class is concurrent.futures.CancelledError:
|
||||
return asyncio.CancelledError(*exc.args)
|
||||
elif exc_class is concurrent.futures.TimeoutError:
|
||||
return asyncio.TimeoutError(*exc.args)
|
||||
elif exc_class is concurrent.futures.InvalidStateError:
|
||||
return asyncio.InvalidStateError(*exc.args)
|
||||
else:
|
||||
return exc
|
||||
|
||||
|
||||
def _set_concurrent_future_state(
|
||||
concurrent: concurrent.futures.Future,
|
||||
source: AnyFuture,
|
||||
) -> None:
|
||||
"""Copy state from a future to a concurrent.futures.Future."""
|
||||
assert source.done()
|
||||
if source.cancelled():
|
||||
concurrent.cancel()
|
||||
if not concurrent.set_running_or_notify_cancel():
|
||||
return
|
||||
exception = source.exception()
|
||||
if exception is not None:
|
||||
concurrent.set_exception(_convert_future_exc(exception))
|
||||
else:
|
||||
result = source.result()
|
||||
concurrent.set_result(result)
|
||||
|
||||
|
||||
def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
|
||||
"""Internal helper to copy state from another Future.
|
||||
|
||||
The other Future may be a concurrent.futures.Future.
|
||||
"""
|
||||
if dest.done():
|
||||
return
|
||||
assert source.done()
|
||||
if dest.cancelled():
|
||||
return
|
||||
if source.cancelled():
|
||||
dest.cancel()
|
||||
else:
|
||||
exception = source.exception()
|
||||
if exception is not None:
|
||||
dest.set_exception(_convert_future_exc(exception))
|
||||
else:
|
||||
result = source.result()
|
||||
dest.set_result(result)
|
||||
|
||||
|
||||
def _chain_future(source: AnyFuture, destination: AnyFuture) -> None:
|
||||
"""Chain two futures so that when one completes, so does the other.
|
||||
|
||||
The result (or exception) of source will be copied to destination.
|
||||
If destination is cancelled, source gets cancelled too.
|
||||
Compatible with both asyncio.Future and concurrent.futures.Future.
|
||||
"""
|
||||
if not asyncio.isfuture(source) and not isinstance(
|
||||
source, concurrent.futures.Future
|
||||
):
|
||||
raise TypeError("A future is required for source argument")
|
||||
if not asyncio.isfuture(destination) and not isinstance(
|
||||
destination, concurrent.futures.Future
|
||||
):
|
||||
raise TypeError("A future is required for destination argument")
|
||||
source_loop = _get_loop(source) if asyncio.isfuture(source) else None
|
||||
dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None
|
||||
|
||||
def _set_state(future: AnyFuture, other: AnyFuture) -> None:
|
||||
if asyncio.isfuture(future):
|
||||
_copy_future_state(other, future)
|
||||
else:
|
||||
_set_concurrent_future_state(future, other)
|
||||
|
||||
def _call_check_cancel(destination: AnyFuture) -> None:
|
||||
if destination.cancelled():
|
||||
if source_loop is None or source_loop is dest_loop:
|
||||
source.cancel()
|
||||
else:
|
||||
source_loop.call_soon_threadsafe(source.cancel)
|
||||
|
||||
def _call_set_state(source: AnyFuture) -> None:
|
||||
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
|
||||
return
|
||||
if dest_loop is None or dest_loop is source_loop:
|
||||
_set_state(destination, source)
|
||||
else:
|
||||
if dest_loop.is_closed():
|
||||
return
|
||||
dest_loop.call_soon_threadsafe(_set_state, destination, source)
|
||||
|
||||
destination.add_done_callback(_call_check_cancel)
|
||||
source.add_done_callback(_call_set_state)
|
||||
|
||||
|
||||
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
|
||||
# adapted from asyncio.run_coroutine_threadsafe
|
||||
try:
|
||||
_chain_future(source, destination)
|
||||
return destination
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if isinstance(destination, concurrent.futures.Future):
|
||||
if destination.set_running_or_notify_cancel():
|
||||
destination.set_exception(exc)
|
||||
else:
|
||||
destination.set_exception(exc)
|
||||
raise
|
||||
|
||||
|
||||
def _ensure_future(
|
||||
coro_or_future: Coroutine[None, None, T] | Awaitable[T],
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
name: str | None = None,
|
||||
context: contextvars.Context | None = None,
|
||||
lazy: bool = True,
|
||||
) -> asyncio.Task[T]:
|
||||
called_wrap_awaitable = False
|
||||
if not asyncio.iscoroutine(coro_or_future):
|
||||
if inspect.isawaitable(coro_or_future):
|
||||
coro_or_future = cast(
|
||||
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
|
||||
)
|
||||
called_wrap_awaitable = True
|
||||
else:
|
||||
raise TypeError(
|
||||
"An asyncio.Future, a coroutine or an awaitable is required."
|
||||
f" Got {type(coro_or_future).__name__} instead."
|
||||
)
|
||||
|
||||
try:
|
||||
if CONTEXT_NOT_SUPPORTED:
|
||||
return loop.create_task(coro_or_future, name=name)
|
||||
elif EAGER_NOT_SUPPORTED or lazy:
|
||||
return loop.create_task(coro_or_future, name=name, context=context)
|
||||
else:
|
||||
return asyncio.eager_task_factory(
|
||||
loop, coro_or_future, name=name, context=context
|
||||
)
|
||||
except RuntimeError:
|
||||
if not called_wrap_awaitable:
|
||||
coro_or_future.close()
|
||||
raise
|
||||
|
||||
|
||||
@types.coroutine
|
||||
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
|
||||
"""Helper for asyncio.ensure_future().
|
||||
|
||||
Wraps awaitable (an object with __await__) into a coroutine
|
||||
that will later be wrapped in a Task by ensure_future().
|
||||
"""
|
||||
return (yield from awaitable.__await__())
|
||||
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
coro: Coroutine[None, None, T],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
*,
|
||||
lazy: bool,
|
||||
name: str | None = None,
|
||||
context: contextvars.Context | None = None,
|
||||
) -> asyncio.Future[T]:
|
||||
"""Submit a coroutine object to a given event loop.
|
||||
|
||||
Return an asyncio.Future to access the result.
|
||||
"""
|
||||
|
||||
if asyncio._get_running_loop() is loop:
|
||||
return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy)
|
||||
else:
|
||||
future: asyncio.Future[T] = asyncio.Future(loop=loop)
|
||||
|
||||
def callback() -> None:
|
||||
try:
|
||||
chain_future(
|
||||
_ensure_future(coro, loop=loop, name=name, context=context),
|
||||
future,
|
||||
)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except BaseException as exc:
|
||||
future.set_exception(exc)
|
||||
raise
|
||||
|
||||
loop.call_soon_threadsafe(callback, context=context)
|
||||
return future
|
||||
275
venv/Lib/site-packages/langgraph/_internal/_pydantic.py
Normal file
275
venv/Lib/site-packages/langgraph/_internal/_pydantic.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import is_dataclass
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
)
|
||||
from pydantic import (
|
||||
create_model as _create_model_base,
|
||||
)
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModel) -> dict[str, FieldInfo]: ...
|
||||
|
||||
|
||||
def get_fields(
|
||||
model: type[BaseModel] | BaseModel,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
_SchemaConfig = ConfigDict(
|
||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||
)
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def _create_root_model(
|
||||
name: str,
|
||||
type_: Any,
|
||||
module_name: str | None = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
ref_template=ref_template,
|
||||
schema_generator=schema_generator,
|
||||
mode=mode,
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
base_class_attributes = {
|
||||
"__annotations__": {"root": type_},
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": module_name or "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
with warnings.catch_warnings():
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast("type[BaseModel]", custom_root_type)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_root_model_cached(
|
||||
model_name: str,
|
||||
type_: Any,
|
||||
*,
|
||||
module_name: str | None = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> type[BaseModel]:
|
||||
return _create_root_model(
|
||||
model_name, type_, default_=default_, module_name=module_name
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_model_cached(
|
||||
model_name: str,
|
||||
/,
|
||||
**field_definitions: Any,
|
||||
) -> type[BaseModel]:
|
||||
return _create_model_base(
|
||||
model_name,
|
||||
__config__=_SchemaConfig,
|
||||
**_remap_field_definitions(field_definitions),
|
||||
)
|
||||
|
||||
|
||||
# Reserved names should capture all the `public` names / methods that are
|
||||
# used by BaseModel internally. This will keep the reserved names up-to-date.
|
||||
# For reference, the reserved names are:
|
||||
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
|
||||
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
|
||||
# "model_computed_fields", "model_config", "model_construct", "model_copy",
|
||||
# "model_dump", "model_dump_json", "model_extra", "model_fields",
|
||||
# "model_fields_set", "model_json_schema", "model_parametrized_name",
|
||||
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
|
||||
# "model_validate_strings"
|
||||
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
||||
|
||||
|
||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||
|
||||
remapped = {}
|
||||
for key, value in field_definitions.items():
|
||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||
# Let's add a prefix to avoid colliding with internal pydantic fields
|
||||
if isinstance(value, FieldInfo):
|
||||
msg = (
|
||||
f"Remapping for fields starting with '_' or fields with a name "
|
||||
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
||||
f" the field is a pydantic Field instance. Got {key}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
type_, default_ = value
|
||||
remapped[f"private_{key}"] = (
|
||||
type_,
|
||||
Field(
|
||||
default=default_,
|
||||
alias=key,
|
||||
serialization_alias=key,
|
||||
title=key.lstrip("_").replace("_", " ").title(),
|
||||
),
|
||||
)
|
||||
else:
|
||||
remapped[key] = value
|
||||
return remapped
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
*,
|
||||
field_definitions: dict[str, Any] | None = None,
|
||||
root: Any | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
|
||||
Attention:
|
||||
Please do not use outside of langchain packages. This API
|
||||
is subject to change at any time.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model.
|
||||
module_name: The name of the module where the model is defined.
|
||||
This is used by Pydantic to resolve any forward references.
|
||||
field_definitions: The field definitions for the model.
|
||||
root: Type for a root model (RootModel)
|
||||
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
"""
|
||||
field_definitions = field_definitions or {}
|
||||
|
||||
if root:
|
||||
if field_definitions:
|
||||
msg = (
|
||||
"When specifying __root__ no other "
|
||||
f"fields should be provided. Got {field_definitions}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if isinstance(root, tuple):
|
||||
kwargs = {"type_": root[0], "default_": root[1]}
|
||||
else:
|
||||
kwargs = {"type_": root}
|
||||
|
||||
try:
|
||||
named_root_model = _create_root_model_cached(model_name, **kwargs)
|
||||
except TypeError:
|
||||
# something in the arguments into _create_root_model_cached is not hashable
|
||||
named_root_model = _create_root_model(
|
||||
model_name,
|
||||
**kwargs,
|
||||
)
|
||||
return named_root_model
|
||||
|
||||
# No root, just field definitions
|
||||
names = set(field_definitions.keys())
|
||||
|
||||
capture_warnings = False
|
||||
|
||||
for name in names:
|
||||
# Also if any non-reserved name is used (e.g., model_id or model_name)
|
||||
if name.startswith("model"):
|
||||
capture_warnings = True
|
||||
|
||||
with warnings.catch_warnings() if capture_warnings else nullcontext():
|
||||
if capture_warnings:
|
||||
warnings.filterwarnings(action="ignore")
|
||||
try:
|
||||
return _create_model_cached(model_name, **field_definitions)
|
||||
except TypeError:
|
||||
# something in field definitions is not hashable
|
||||
return _create_model_base(
|
||||
model_name,
|
||||
__config__=_SchemaConfig,
|
||||
**_remap_field_definitions(field_definitions),
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_pydantic(type_: Any) -> bool:
|
||||
"""Check if a given "complex" type is supported by pydantic.
|
||||
|
||||
This will return False for primitive types like int, str, etc.
|
||||
|
||||
The check is meant for container types like dataclasses, TypedDicts, etc.
|
||||
"""
|
||||
if is_dataclass(type_):
|
||||
return True
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, BaseModel):
|
||||
return True
|
||||
|
||||
if hasattr(type_, "__orig_bases__"):
|
||||
for base in type_.__orig_bases__:
|
||||
if base is TypedDict:
|
||||
return True
|
||||
elif base is typing.TypedDict: # noqa: TID251
|
||||
# ignoring TID251 since it's OK to use typing.TypedDict in this case.
|
||||
# Pydantic supports typing.TypedDict from Python 3.12
|
||||
# For older versions, only typing_extensions.TypedDict is supported.
|
||||
if sys.version_info >= (3, 12):
|
||||
return True
|
||||
return False
|
||||
124
venv/Lib/site-packages/langgraph/_internal/_queue.py
Normal file
124
venv/Lib/site-packages/langgraph/_internal/_queue.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections import deque
|
||||
from time import monotonic
|
||||
|
||||
|
||||
class AsyncQueue(asyncio.Queue):
|
||||
"""Async unbounded FIFO queue with a wait() method.
|
||||
|
||||
Subclassed from asyncio.Queue, adding a wait() method."""
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""If queue is empty, wait until an item is available.
|
||||
|
||||
Copied from Queue.get(), removing the call to .get_nowait(),
|
||||
ie. this doesn't consume the item, just waits for it.
|
||||
"""
|
||||
while self.empty():
|
||||
getter = self._get_loop().create_future()
|
||||
self._getters.append(getter)
|
||||
try:
|
||||
await getter
|
||||
except BaseException:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._getters)
|
||||
raise
|
||||
|
||||
|
||||
class Semaphore(threading.Semaphore):
|
||||
"""Semaphore subclass with a wait() method."""
|
||||
|
||||
def wait(self, blocking: bool = True, timeout: float | None = None):
|
||||
"""Block until the semaphore can be acquired, but don't acquire it."""
|
||||
if not blocking and timeout is not None:
|
||||
raise ValueError("can't specify timeout for non-blocking acquire")
|
||||
rc = False
|
||||
endtime = None
|
||||
with self._cond:
|
||||
while self._value == 0:
|
||||
if not blocking:
|
||||
break
|
||||
if timeout is not None:
|
||||
if endtime is None:
|
||||
endtime = monotonic() + timeout
|
||||
else:
|
||||
timeout = endtime - monotonic()
|
||||
if timeout <= 0:
|
||||
break
|
||||
self._cond.wait(timeout)
|
||||
else:
|
||||
rc = True
|
||||
return rc
|
||||
|
||||
|
||||
class SyncQueue:
|
||||
"""Unbounded FIFO queue with a wait() method.
|
||||
Adapted from pure Python implementation of queue.SimpleQueue.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue = deque()
|
||||
self._count = Semaphore(0)
|
||||
|
||||
def put(self, item, block=True, timeout=None):
|
||||
"""Put the item on the queue.
|
||||
|
||||
The optional 'block' and 'timeout' arguments are ignored, as this method
|
||||
never blocks. They are provided for compatibility with the Queue class.
|
||||
"""
|
||||
self._queue.append(item)
|
||||
self._count.release()
|
||||
|
||||
def get(self, block=False, timeout=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If optional args 'block' is true and 'timeout' is None (the default),
|
||||
block if necessary until an item is available. If 'timeout' is
|
||||
a non-negative number, it blocks at most 'timeout' seconds and raises
|
||||
the Empty exception if no item was available within that time.
|
||||
Otherwise ('block' is false), return an item if one is immediately
|
||||
available, else raise the Empty exception ('timeout' is ignored
|
||||
in that case).
|
||||
"""
|
||||
if timeout is not None and timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
if not self._count.acquire(block, timeout):
|
||||
raise queue.Empty
|
||||
try:
|
||||
return self._queue.popleft()
|
||||
except IndexError:
|
||||
raise queue.Empty
|
||||
|
||||
def wait(self, block=True, timeout=None):
|
||||
"""If queue is empty, wait until an item maybe is available,
|
||||
but don't consume it.
|
||||
"""
|
||||
if timeout is not None and timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
self._count.wait(block, timeout)
|
||||
|
||||
def empty(self):
|
||||
"""Return True if the queue is empty, False otherwise (not reliable!)."""
|
||||
return len(self._queue) == 0
|
||||
|
||||
def qsize(self):
|
||||
"""Return the approximate size of the queue (not reliable!)."""
|
||||
return len(self._queue)
|
||||
|
||||
__class_getitem__ = classmethod(types.GenericAlias)
|
||||
29
venv/Lib/site-packages/langgraph/_internal/_retry.py
Normal file
29
venv/Lib/site-packages/langgraph/_internal/_retry.py
Normal file
@@ -0,0 +1,29 @@
|
||||
def default_retry_on(exc: Exception) -> bool:
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
if isinstance(exc, ConnectionError):
|
||||
return True
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return 500 <= exc.response.status_code < 600
|
||||
if isinstance(exc, requests.HTTPError):
|
||||
return 500 <= exc.response.status_code < 600 if exc.response else True
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
ValueError,
|
||||
TypeError,
|
||||
ArithmeticError,
|
||||
ImportError,
|
||||
LookupError,
|
||||
NameError,
|
||||
SyntaxError,
|
||||
RuntimeError,
|
||||
ReferenceError,
|
||||
StopIteration,
|
||||
StopAsyncIteration,
|
||||
OSError,
|
||||
),
|
||||
):
|
||||
return False
|
||||
return True
|
||||
914
venv/Lib/site-packages/langgraph/_internal/_runnable.py
Normal file
914
venv/Lib/site-packages/langgraph/_internal/_runnable.py
Normal file
@@ -0,0 +1,914 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from contextvars import Context, Token, copy_context
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeGuard,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnableSequence,
|
||||
)
|
||||
from langchain_core.runnables.base import (
|
||||
RunnableLike as LCRunnableLike,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
run_in_executor,
|
||||
var_child_runnable_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from langgraph._internal._config import (
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_RUNTIME,
|
||||
)
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
try:
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
except ImportError:
|
||||
_StreamingCallbackHandler = None # type: ignore
|
||||
|
||||
|
||||
def _set_config_context(
|
||||
config: RunnableConfig, run: Any = None
|
||||
) -> Token[RunnableConfig | None]:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
config: The config to set.
|
||||
"""
|
||||
config_token = var_child_runnable_config.set(config)
|
||||
if run is not None:
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
|
||||
_set_tracing_context({"parent": run})
|
||||
return config_token
|
||||
|
||||
|
||||
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
token: The config token to reset.
|
||||
"""
|
||||
var_child_runnable_config.reset(token)
|
||||
if run is not None:
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
|
||||
_set_tracing_context(
|
||||
{
|
||||
"parent": None,
|
||||
"project_name": None,
|
||||
"tags": None,
|
||||
"metadata": None,
|
||||
"enabled": None,
|
||||
"client": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_config_context(
|
||||
config: RunnableConfig, run: Any = None
|
||||
) -> Generator[Context, None, None]:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
config: The config to set.
|
||||
"""
|
||||
ctx = copy_context()
|
||||
config_token = ctx.run(_set_config_context, config, run)
|
||||
try:
|
||||
yield ctx
|
||||
finally:
|
||||
ctx.run(_unset_config_context, config_token, run)
|
||||
|
||||
|
||||
# Before Python 3.11 native StrEnum is not available
|
||||
class StrEnum(str, enum.Enum):
|
||||
"""A string enum."""
|
||||
|
||||
|
||||
# Special type to denote any type is accepted
|
||||
ANY_TYPE = object()
|
||||
|
||||
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
|
||||
|
||||
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
|
||||
# A named argument may appear multiple times if it appears with distinct types.
|
||||
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
|
||||
(
|
||||
"config",
|
||||
(
|
||||
RunnableConfig,
|
||||
"RunnableConfig",
|
||||
Optional[RunnableConfig], # noqa: UP045
|
||||
"Optional[RunnableConfig]",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
# for now, use config directly, eventually, will pop off of Runtime
|
||||
"N/A",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"writer",
|
||||
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
|
||||
"stream_writer",
|
||||
lambda _: None,
|
||||
),
|
||||
(
|
||||
"store",
|
||||
(
|
||||
BaseStore,
|
||||
"BaseStore",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
"store",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"store",
|
||||
(
|
||||
Optional[BaseStore], # noqa: UP045
|
||||
"Optional[BaseStore]",
|
||||
),
|
||||
"store",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"previous",
|
||||
(ANY_TYPE,),
|
||||
"previous",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"runtime",
|
||||
(ANY_TYPE,),
|
||||
# we never hit this block, we just inject runtime directly
|
||||
"N/A",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
)
|
||||
"""List of kwargs that can be passed to functions, and their corresponding
|
||||
config keys, default values and type annotations.
|
||||
|
||||
Used to configure keyword arguments that can be injected at runtime
|
||||
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
|
||||
|
||||
For a keyword to be injected from the config object, the function signature
|
||||
must contain a kwarg with the same name and a matching type annotation.
|
||||
|
||||
Each tuple contains:
|
||||
- the name of the kwarg in the function signature
|
||||
- the type annotation(s) for the kwarg
|
||||
- the `Runtime` attribute for fetching the value (N/A if not applicable)
|
||||
|
||||
This is fully internal and should be further refactored to use `get_type_hints`
|
||||
to resolve forward references and optional types formatted like BaseStore | None.
|
||||
"""
|
||||
|
||||
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
|
||||
|
||||
|
||||
class _RunnableWithWriter(Protocol[Input, Output]):
|
||||
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithStore(Protocol[Input, Output]):
|
||||
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithWriterStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, writer: StreamWriter, store: BaseStore
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigWriter(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, config: RunnableConfig, store: BaseStore
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self,
|
||||
state: Input,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter,
|
||||
store: BaseStore,
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
RunnableLike = (
|
||||
LCRunnableLike
|
||||
| _RunnableWithWriter[Input, Output]
|
||||
| _RunnableWithStore[Input, Output]
|
||||
| _RunnableWithWriterStore[Input, Output]
|
||||
| _RunnableWithConfigWriter[Input, Output]
|
||||
| _RunnableWithConfigStore[Input, Output]
|
||||
| _RunnableWithConfigWriterStore[Input, Output]
|
||||
)
|
||||
|
||||
|
||||
class RunnableCallable(Runnable):
|
||||
"""A much simpler version of RunnableLambda that requires sync and async functions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any | Runnable] | None,
|
||||
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
trace: bool = True,
|
||||
recurse: bool = True,
|
||||
explode_args: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.name = name
|
||||
if self.name is None:
|
||||
if func:
|
||||
try:
|
||||
if func.__name__ != "<lambda>":
|
||||
self.name = func.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
elif afunc:
|
||||
try:
|
||||
self.name = afunc.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
self.func = func
|
||||
self.afunc = afunc
|
||||
self.tags = tags
|
||||
self.kwargs = kwargs
|
||||
self.trace = trace
|
||||
self.recurse = recurse
|
||||
self.explode_args = explode_args
|
||||
# check signature
|
||||
if func is None and afunc is None:
|
||||
raise ValueError("At least one of func or afunc must be provided.")
|
||||
|
||||
self.func_accepts: dict[str, tuple[str, Any]] = {}
|
||||
params = inspect.signature(cast(Callable, func or afunc)).parameters
|
||||
|
||||
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
|
||||
p = params.get(kw)
|
||||
|
||||
if p is None or p.kind not in VALID_KINDS:
|
||||
# If parameter is not found or is not a valid kind, skip
|
||||
continue
|
||||
|
||||
if typ != (ANY_TYPE,) and p.annotation not in typ:
|
||||
# A specific type is required, but the function annotation does
|
||||
# not match the expected type.
|
||||
|
||||
# If this is a config parameter with incorrect typing, emit a warning
|
||||
# because we used to support any type but are moving towards more correct typing
|
||||
if kw == "config" and p.annotation != inspect.Parameter.empty:
|
||||
warnings.warn(
|
||||
f"The 'config' parameter should be typed as 'RunnableConfig' or "
|
||||
f"'RunnableConfig | None', not '{p.annotation}'. ",
|
||||
UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
|
||||
self.func_accepts[kw] = (runtime_key, default)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_args = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
|
||||
}
|
||||
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
|
||||
|
||||
def invoke(
|
||||
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if self.func is None:
|
||||
raise TypeError(
|
||||
f'No synchronous function provided to "{self.name}".'
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
if self.explode_args:
|
||||
args, _kwargs = input
|
||||
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||||
else:
|
||||
args = (input,)
|
||||
kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||||
|
||||
for kw, (runtime_key, default) in self.func_accepts.items():
|
||||
# If the kwarg is already set, use the set value
|
||||
if kw in kwargs:
|
||||
continue
|
||||
|
||||
kw_value: Any = MISSING
|
||||
if kw == "config":
|
||||
kw_value = config
|
||||
elif runtime:
|
||||
if kw == "runtime":
|
||||
kw_value = runtime
|
||||
else:
|
||||
try:
|
||||
kw_value = getattr(runtime, runtime_key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if kw_value is MISSING:
|
||||
if default is inspect.Parameter.empty:
|
||||
raise ValueError(
|
||||
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||||
)
|
||||
kw_value = default
|
||||
kwargs[kw] = kw_value
|
||||
|
||||
if self.trace:
|
||||
callback_manager = get_callback_manager_for_config(config, self.tags)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
# get the run
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(child_config, run) as context:
|
||||
ret = context.run(self.func, *args, **kwargs)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(ret)
|
||||
else:
|
||||
ret = self.func(*args, **kwargs)
|
||||
if self.recurse and isinstance(ret, Runnable):
|
||||
return ret.invoke(input, config)
|
||||
return ret
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if not self.afunc:
|
||||
return self.invoke(input, config)
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
if self.explode_args:
|
||||
args, _kwargs = input
|
||||
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||||
else:
|
||||
args = (input,)
|
||||
kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||||
|
||||
for kw, (runtime_key, default) in self.func_accepts.items():
|
||||
# If the kwarg has already been set, use the set value
|
||||
if kw in kwargs:
|
||||
continue
|
||||
|
||||
kw_value: Any = MISSING
|
||||
if kw == "config":
|
||||
kw_value = config
|
||||
elif runtime:
|
||||
if kw == "runtime":
|
||||
kw_value = runtime
|
||||
else:
|
||||
try:
|
||||
kw_value = getattr(runtime, runtime_key)
|
||||
except AttributeError:
|
||||
pass
|
||||
if kw_value is MISSING:
|
||||
if default is inspect.Parameter.empty:
|
||||
raise ValueError(
|
||||
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||||
)
|
||||
kw_value = default
|
||||
kwargs[kw] = kw_value
|
||||
|
||||
if self.trace:
|
||||
callback_manager = get_async_callback_manager_for_config(config, self.tags)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.name,
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
with set_config_context(child_config, run) as context:
|
||||
ret = await asyncio.create_task(coro, context=context)
|
||||
else:
|
||||
ret = await coro
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(ret)
|
||||
else:
|
||||
ret = await self.afunc(*args, **kwargs)
|
||||
if self.recurse and isinstance(ret, Runnable):
|
||||
return await ret.ainvoke(input, config)
|
||||
return ret
|
||||
|
||||
|
||||
def is_async_callable(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||
"""Check if a function is async."""
|
||||
return (
|
||||
inspect.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and inspect.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
def is_async_generator(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||
"""Check if a function is an async generator."""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
def coerce_to_runnable(
|
||||
thing: RunnableLike, *, name: str | None, trace: bool
|
||||
) -> Runnable:
|
||||
"""Coerce a runnable-like object into a Runnable.
|
||||
|
||||
Args:
|
||||
thing: A runnable-like object.
|
||||
|
||||
Returns:
|
||||
A Runnable.
|
||||
"""
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||||
return RunnableLambda(thing, name=name)
|
||||
elif callable(thing):
|
||||
if is_async_callable(thing):
|
||||
return RunnableCallable(None, thing, name=name, trace=trace)
|
||||
else:
|
||||
return RunnableCallable(
|
||||
thing,
|
||||
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
|
||||
name=name,
|
||||
trace=trace,
|
||||
)
|
||||
elif isinstance(thing, dict):
|
||||
return RunnableParallel(thing)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected a Runnable, callable or dict."
|
||||
f"Instead got an unsupported type: {type(thing)}"
|
||||
)
|
||||
|
||||
|
||||
class RunnableSeq(Runnable):
|
||||
"""Sequence of `Runnable`, where the output of each is the input of the next.
|
||||
|
||||
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
|
||||
LangGraph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*steps: RunnableLike,
|
||||
name: str | None = None,
|
||||
trace_inputs: Callable[[Any], Any] | None = None,
|
||||
) -> None:
|
||||
"""Create a new RunnableSeq.
|
||||
|
||||
Args:
|
||||
steps: The steps to include in the sequence.
|
||||
name: The name of the `Runnable`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the sequence has less than 2 steps.
|
||||
"""
|
||||
steps_flat: list[Runnable] = []
|
||||
for step in steps:
|
||||
if isinstance(step, RunnableSequence):
|
||||
steps_flat.extend(step.steps)
|
||||
elif isinstance(step, RunnableSeq):
|
||||
steps_flat.extend(step.steps)
|
||||
else:
|
||||
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
|
||||
if len(steps_flat) < 2:
|
||||
raise ValueError(
|
||||
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
|
||||
)
|
||||
self.steps = steps_flat
|
||||
self.name = name
|
||||
self.trace_inputs = trace_inputs
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Any,
|
||||
) -> Runnable:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
name=self.name or other.name,
|
||||
)
|
||||
elif isinstance(other, RunnableSeq):
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
*other.steps,
|
||||
name=self.name or other.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
coerce_to_runnable(other, name=None, trace=True),
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Any,
|
||||
) -> Runnable:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
*self.steps,
|
||||
name=other.name or self.name,
|
||||
)
|
||||
elif isinstance(other, RunnableSeq):
|
||||
return RunnableSeq(
|
||||
*other.steps,
|
||||
*self.steps,
|
||||
name=other.name or self.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
coerce_to_runnable(other, name=None, trace=True),
|
||||
*self.steps,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks and context
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||
)
|
||||
# 1st step is the actual node,
|
||||
# others are writers which don't need to be run in context
|
||||
if i == 0:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(config, run) as context:
|
||||
input = context.run(step.invoke, input, config, **kwargs)
|
||||
else:
|
||||
input = step.invoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(input)
|
||||
return input
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||
)
|
||||
# 1st step is the actual node,
|
||||
# others are writers which don't need to be run in context
|
||||
if i == 0:
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(config, run) as context:
|
||||
input = await asyncio.create_task(
|
||||
step.ainvoke(input, config, **kwargs), context=context
|
||||
)
|
||||
else:
|
||||
input = await step.ainvoke(input, config, **kwargs)
|
||||
else:
|
||||
input = await step.ainvoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(input)
|
||||
return input
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# create first step config
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||||
)
|
||||
# run all in context
|
||||
with set_config_context(config, run) as context:
|
||||
try:
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
for idx, step in enumerate(self.steps):
|
||||
if idx == 0:
|
||||
iterator = step.stream(input, config, **kwargs)
|
||||
else:
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||||
)
|
||||
iterator = step.transform(iterator, config)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
iterator = h.tap_output_iter(run_manager.run_id, iterator)
|
||||
# consume into final output
|
||||
output = context.run(_consume_iter, iterator)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# create first step config
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||||
)
|
||||
# run all in context
|
||||
with set_config_context(config, run) as context:
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
for idx, step in enumerate(self.steps):
|
||||
if idx == 0:
|
||||
aiterator = step.astream(input, config, **kwargs)
|
||||
else:
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{idx + 1}"
|
||||
),
|
||||
)
|
||||
aiterator = step.atransform(aiterator, config)
|
||||
if hasattr(aiterator, "aclose"):
|
||||
stack.push_async_callback(aiterator.aclose)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
aiterator = h.tap_output_aiter(
|
||||
run_manager.run_id, aiterator
|
||||
)
|
||||
# consume into final output
|
||||
output = await asyncio.create_task(
|
||||
_consume_aiter(aiterator), context=context
|
||||
)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
else:
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
for idx, step in enumerate(self.steps):
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||||
)
|
||||
if idx == 0:
|
||||
aiterator = step.astream(input, config, **kwargs)
|
||||
else:
|
||||
aiterator = step.atransform(aiterator, config)
|
||||
if hasattr(aiterator, "aclose"):
|
||||
stack.push_async_callback(aiterator.aclose)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
aiterator = h.tap_output_aiter(
|
||||
run_manager.run_id, aiterator
|
||||
)
|
||||
# consume into final output
|
||||
output = await _consume_aiter(aiterator)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
|
||||
|
||||
def _consume_iter(it: Iterator[Any]) -> Any:
|
||||
"""Consume an iterator."""
|
||||
output: Any = None
|
||||
add_supported = False
|
||||
for chunk in it:
|
||||
# collect final output
|
||||
if output is None:
|
||||
output = chunk
|
||||
elif add_supported:
|
||||
try:
|
||||
output = output + chunk
|
||||
except TypeError:
|
||||
output = chunk
|
||||
add_supported = False
|
||||
else:
|
||||
output = chunk
|
||||
return output
|
||||
|
||||
|
||||
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
|
||||
"""Consume an async iterator."""
|
||||
output: Any = None
|
||||
add_supported = False
|
||||
async for chunk in it:
|
||||
# collect final output
|
||||
if add_supported:
|
||||
try:
|
||||
output = output + chunk
|
||||
except TypeError:
|
||||
output = chunk
|
||||
add_supported = False
|
||||
else:
|
||||
output = chunk
|
||||
return output
|
||||
19
venv/Lib/site-packages/langgraph/_internal/_scratchpad.py
Normal file
19
venv/Lib/site-packages/langgraph/_internal/_scratchpad.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import _DC_KWARGS
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_DC_KWARGS)
|
||||
class PregelScratchpad:
|
||||
step: int
|
||||
stop: int
|
||||
# call
|
||||
call_counter: Callable[[], int]
|
||||
# interrupt
|
||||
interrupt_counter: Callable[[], int]
|
||||
get_null_resume: Callable[[bool], Any]
|
||||
resume: list[Any]
|
||||
# subgraph
|
||||
subgraph_counter: Callable[[], int]
|
||||
253
venv/Lib/site-packages/langgraph/_internal/_serde.py
Normal file
253
venv/Lib/site-packages/langgraph/_internal/_serde.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
import types
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from langchain_core import messages as lc_messages
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, Required, is_typeddict
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.serde._msgpack import ( # noqa: F401
|
||||
STRICT_MSGPACK_ENABLED,
|
||||
)
|
||||
except ImportError:
|
||||
STRICT_MSGPACK_ENABLED = False
|
||||
|
||||
_warned_allowlist_unsupported = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _supports_checkpointer_allowlist() -> bool:
|
||||
return hasattr(BaseCheckpointSaver, "with_allowlist")
|
||||
|
||||
|
||||
_SUPPORTS_ALLOWLIST = _supports_checkpointer_allowlist()
|
||||
|
||||
|
||||
def apply_checkpointer_allowlist(
|
||||
checkpointer: Any, allowlist: set[tuple[str, ...]] | None
|
||||
) -> Any:
|
||||
if not checkpointer or allowlist is None or checkpointer in (True, False):
|
||||
return checkpointer
|
||||
if not _SUPPORTS_ALLOWLIST:
|
||||
global _warned_allowlist_unsupported
|
||||
if not _warned_allowlist_unsupported:
|
||||
logger.warning(
|
||||
"Checkpointer does not support with_allowlist; strict msgpack "
|
||||
"allowlist will be skipped."
|
||||
)
|
||||
_warned_allowlist_unsupported = True
|
||||
return checkpointer
|
||||
return checkpointer.with_allowlist(allowlist)
|
||||
|
||||
|
||||
def curated_core_allowlist() -> set[tuple[str, ...]]:
|
||||
allowlist: set[tuple[str, ...]] = set()
|
||||
for name in (
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"RemoveMessage",
|
||||
):
|
||||
cls = getattr(lc_messages, name, None)
|
||||
if cls is None:
|
||||
continue
|
||||
allowlist.add((cls.__module__, cls.__name__))
|
||||
|
||||
return allowlist
|
||||
|
||||
|
||||
def build_serde_allowlist(
|
||||
*,
|
||||
schemas: list[type[Any]] | None = None,
|
||||
channels: dict[str, Any] | None = None,
|
||||
) -> set[tuple[str, ...]]:
|
||||
allowlist = curated_core_allowlist()
|
||||
if schemas:
|
||||
schemas = [schema for schema in schemas if schema is not None]
|
||||
return allowlist | collect_allowlist_from_schemas(
|
||||
schemas=schemas,
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
|
||||
def collect_allowlist_from_schemas(
|
||||
*,
|
||||
schemas: list[type[Any]] | None = None,
|
||||
channels: dict[str, Any] | None = None,
|
||||
) -> set[tuple[str, ...]]:
|
||||
allowlist: set[tuple[str, ...]] = set()
|
||||
seen: set[Any] = set()
|
||||
seen_ids: set[int] = set()
|
||||
|
||||
if schemas:
|
||||
for schema in schemas:
|
||||
_collect_from_type(schema, allowlist, seen, seen_ids)
|
||||
|
||||
if channels:
|
||||
for channel in channels.values():
|
||||
value_type = getattr(channel, "ValueType", None)
|
||||
if value_type is not None:
|
||||
_collect_from_type(value_type, allowlist, seen, seen_ids)
|
||||
update_type = getattr(channel, "UpdateType", None)
|
||||
if update_type is not None:
|
||||
_collect_from_type(update_type, allowlist, seen, seen_ids)
|
||||
|
||||
return allowlist
|
||||
|
||||
|
||||
def _collect_from_type(
|
||||
typ: Any,
|
||||
allowlist: set[tuple[str, ...]],
|
||||
seen: set[Any],
|
||||
seen_ids: set[int],
|
||||
) -> None:
|
||||
if _already_seen(typ, seen, seen_ids):
|
||||
return
|
||||
|
||||
if typ is Any or typ is None:
|
||||
return
|
||||
|
||||
if typ is Literal:
|
||||
return
|
||||
|
||||
if isinstance(typ, types.UnionType):
|
||||
for arg in typ.__args__:
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
origin = get_origin(typ)
|
||||
if origin is Union:
|
||||
for arg in get_args(typ):
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
if origin is Annotated or origin in (Required, NotRequired):
|
||||
args = get_args(typ)
|
||||
if args:
|
||||
_collect_from_type(args[0], allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if origin is Literal:
|
||||
return
|
||||
|
||||
if origin in (list, set, tuple, dict, deque, frozenset):
|
||||
for arg in get_args(typ):
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if hasattr(typ, "__supertype__"):
|
||||
_collect_from_type(typ.__supertype__, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if is_typeddict(typ):
|
||||
for field_type in _safe_get_type_hints(typ).values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if _is_pydantic_model(typ):
|
||||
allowlist.add((typ.__module__, typ.__name__))
|
||||
field_types = _safe_get_type_hints(typ)
|
||||
if field_types:
|
||||
for field_type in field_types.values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
else:
|
||||
for field_type in _pydantic_field_types(typ):
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if dataclasses.is_dataclass(typ):
|
||||
if typ_name := getattr(typ, "__name__", None):
|
||||
allowlist.add((typ.__module__, typ_name))
|
||||
field_types = _safe_get_type_hints(typ)
|
||||
if field_types:
|
||||
for field_type in field_types.values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
else:
|
||||
for field in dataclasses.fields(typ):
|
||||
_collect_from_type(field.type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if isinstance(typ, type) and issubclass(typ, Enum):
|
||||
allowlist.add((typ.__module__, typ.__name__))
|
||||
return
|
||||
|
||||
|
||||
def _already_seen(typ: Any, seen: set[Any], seen_ids: set[int]) -> bool:
|
||||
try:
|
||||
if typ in seen:
|
||||
return True
|
||||
seen.add(typ)
|
||||
return False
|
||||
except TypeError:
|
||||
typ_id = id(typ)
|
||||
if typ_id in seen_ids:
|
||||
return True
|
||||
seen_ids.add(typ_id)
|
||||
return False
|
||||
|
||||
|
||||
def _safe_get_type_hints(typ: Any) -> dict[str, Any]:
|
||||
try:
|
||||
module = sys.modules.get(getattr(typ, "__module__", ""))
|
||||
globalns = module.__dict__ if module else None
|
||||
localns = dict(vars(typ)) if hasattr(typ, "__dict__") else None
|
||||
return get_type_hints(
|
||||
typ, globalns=globalns, localns=localns, include_extras=True
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _is_pydantic_model(typ: Any) -> bool:
|
||||
if not isinstance(typ, type):
|
||||
return False
|
||||
if issubclass(typ, BaseModel):
|
||||
return True
|
||||
try:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
except Exception:
|
||||
return False
|
||||
return issubclass(typ, BaseModelV1)
|
||||
|
||||
|
||||
def _pydantic_field_types(typ: type[Any]) -> list[Any]:
|
||||
if hasattr(typ, "model_fields"):
|
||||
return [
|
||||
field.annotation
|
||||
for field in typ.model_fields.values()
|
||||
if getattr(field, "annotation", None) is not None
|
||||
]
|
||||
if hasattr(typ, "__fields__"):
|
||||
return [
|
||||
field.outer_type_
|
||||
for field in typ.__fields__.values()
|
||||
if getattr(field, "outer_type_", None) is not None
|
||||
]
|
||||
return []
|
||||
54
venv/Lib/site-packages/langgraph/_internal/_typing.py
Normal file
54
venv/Lib/site-packages/langgraph/_internal/_typing.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Private typing utilities for LangGraph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import Field
|
||||
from typing import Any, ClassVar, Protocol, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class TypedDictLikeV1(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts
|
||||
|
||||
Version 1: using `ClassVar` for keys."""
|
||||
|
||||
__required_keys__: ClassVar[frozenset[str]]
|
||||
__optional_keys__: ClassVar[frozenset[str]]
|
||||
|
||||
|
||||
class TypedDictLikeV2(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts
|
||||
|
||||
Version 2: not using `ClassVar` for keys."""
|
||||
|
||||
__required_keys__: frozenset[str]
|
||||
__optional_keys__: frozenset[str]
|
||||
|
||||
|
||||
class DataclassLike(Protocol):
|
||||
"""Protocol to represent types that behave like dataclasses.
|
||||
|
||||
Inspired by the private _DataclassT from dataclasses that uses a similar protocol as a bound."""
|
||||
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
|
||||
StateLike: TypeAlias = TypedDictLikeV1 | TypedDictLikeV2 | DataclassLike | BaseModel
|
||||
"""Type alias for state-like types.
|
||||
|
||||
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
|
||||
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in type checking.
|
||||
"""
|
||||
|
||||
MISSING = object()
|
||||
"""Unset sentinel value."""
|
||||
|
||||
|
||||
class DeprecatedKwargs(TypedDict):
|
||||
"""TypedDict to use for extra keyword arguments, enabling type checking warnings for deprecated arguments."""
|
||||
|
||||
|
||||
EMPTY_SEQ: tuple[str, ...] = tuple()
|
||||
"""An empty sequence of strings."""
|
||||
48
venv/Lib/site-packages/langgraph/cache/base/__init__.py
vendored
Normal file
48
venv/Lib/site-packages/langgraph/cache/base/__init__.py
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
ValueT = TypeVar("ValueT")
|
||||
Namespace = tuple[str, ...]
|
||||
FullKey = tuple[Namespace, str]
|
||||
|
||||
|
||||
class BaseCache(ABC, Generic[ValueT]):
|
||||
"""Base class for a cache."""
|
||||
|
||||
serde: SerializerProtocol = JsonPlusSerializer(pickle_fallback=False)
|
||||
|
||||
def __init__(self, *, serde: SerializerProtocol | None = None) -> None:
|
||||
"""Initialize the cache with a serializer."""
|
||||
self.serde = serde or self.serde
|
||||
|
||||
@abstractmethod
|
||||
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Get the cached values for the given keys."""
|
||||
|
||||
@abstractmethod
|
||||
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Asynchronously get the cached values for the given keys."""
|
||||
|
||||
@abstractmethod
|
||||
def set(self, pairs: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Set the cached values for the given keys and TTLs."""
|
||||
|
||||
@abstractmethod
|
||||
async def aset(self, pairs: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Asynchronously set the cached values for the given keys and TTLs."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
|
||||
@abstractmethod
|
||||
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Asynchronously delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
BIN
venv/Lib/site-packages/langgraph/cache/base/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
venv/Lib/site-packages/langgraph/cache/base/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
0
venv/Lib/site-packages/langgraph/cache/base/py.typed
vendored
Normal file
0
venv/Lib/site-packages/langgraph/cache/base/py.typed
vendored
Normal file
73
venv/Lib/site-packages/langgraph/cache/memory/__init__.py
vendored
Normal file
73
venv/Lib/site-packages/langgraph/cache/memory/__init__.py
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langgraph.cache.base import BaseCache, FullKey, Namespace, ValueT
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache[ValueT]):
|
||||
def __init__(self, *, serde: SerializerProtocol | None = None):
|
||||
super().__init__(serde=serde)
|
||||
self._cache: dict[Namespace, dict[str, tuple[str, bytes, float | None]]] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Get the cached values for the given keys."""
|
||||
with self._lock:
|
||||
if not keys:
|
||||
return {}
|
||||
now = datetime.datetime.now(datetime.timezone.utc).timestamp()
|
||||
values: dict[FullKey, ValueT] = {}
|
||||
for ns_tuple, key in keys:
|
||||
ns = Namespace(ns_tuple)
|
||||
if ns in self._cache and key in self._cache[ns]:
|
||||
enc, val, expiry = self._cache[ns][key]
|
||||
if expiry is None or now < expiry:
|
||||
values[(ns, key)] = self.serde.loads_typed((enc, val))
|
||||
else:
|
||||
del self._cache[ns][key]
|
||||
return values
|
||||
|
||||
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Asynchronously get the cached values for the given keys."""
|
||||
return self.get(keys)
|
||||
|
||||
def set(self, keys: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Set the cached values for the given keys."""
|
||||
with self._lock:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
for (ns, key), (value, ttl) in keys.items():
|
||||
if ttl is not None:
|
||||
delta = datetime.timedelta(seconds=ttl)
|
||||
expiry: float | None = (now + delta).timestamp()
|
||||
else:
|
||||
expiry = None
|
||||
if ns not in self._cache:
|
||||
self._cache[ns] = {}
|
||||
self._cache[ns][key] = (
|
||||
*self.serde.dumps_typed(value),
|
||||
expiry,
|
||||
)
|
||||
|
||||
async def aset(self, keys: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Asynchronously set the cached values for the given keys."""
|
||||
self.set(keys)
|
||||
|
||||
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
with self._lock:
|
||||
if namespaces is None:
|
||||
self._cache.clear()
|
||||
else:
|
||||
for ns in namespaces:
|
||||
if ns in self._cache:
|
||||
del self._cache[ns]
|
||||
|
||||
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Asynchronously delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
self.clear(namespaces)
|
||||
BIN
venv/Lib/site-packages/langgraph/cache/memory/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
venv/Lib/site-packages/langgraph/cache/memory/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
144
venv/Lib/site-packages/langgraph/cache/redis/__init__.py
vendored
Normal file
144
venv/Lib/site-packages/langgraph/cache/redis/__init__.py
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langgraph.cache.base import BaseCache, FullKey, Namespace, ValueT
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
|
||||
|
||||
class RedisCache(BaseCache[ValueT]):
|
||||
"""Redis-based cache implementation with TTL support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: Any,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
prefix: str = "langgraph:cache:",
|
||||
) -> None:
|
||||
"""Initialize the cache with a Redis client.
|
||||
|
||||
Args:
|
||||
redis: Redis client instance (sync or async)
|
||||
serde: Serializer to use for values
|
||||
prefix: Key prefix for all cached values
|
||||
"""
|
||||
super().__init__(serde=serde)
|
||||
self.redis = redis
|
||||
self.prefix = prefix
|
||||
|
||||
def _make_key(self, ns: Namespace, key: str) -> str:
|
||||
"""Create a Redis key from namespace and key."""
|
||||
ns_str = ":".join(ns) if ns else ""
|
||||
return f"{self.prefix}{ns_str}:{key}" if ns_str else f"{self.prefix}{key}"
|
||||
|
||||
def _parse_key(self, redis_key: str) -> tuple[Namespace, str]:
|
||||
"""Parse a Redis key back to namespace and key."""
|
||||
if not redis_key.startswith(self.prefix):
|
||||
raise ValueError(
|
||||
f"Key {redis_key} does not start with prefix {self.prefix}"
|
||||
)
|
||||
|
||||
remaining = redis_key[len(self.prefix) :]
|
||||
if ":" in remaining:
|
||||
parts = remaining.split(":")
|
||||
key = parts[-1]
|
||||
ns_parts = parts[:-1]
|
||||
return (tuple(ns_parts), key)
|
||||
else:
|
||||
return (tuple(), remaining)
|
||||
|
||||
def get(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Get the cached values for the given keys."""
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
# Build Redis keys
|
||||
redis_keys = [self._make_key(ns, key) for ns, key in keys]
|
||||
|
||||
# Get values from Redis using MGET
|
||||
try:
|
||||
raw_values = self.redis.mget(redis_keys)
|
||||
except Exception:
|
||||
# If Redis is unavailable, return empty dict
|
||||
return {}
|
||||
|
||||
values: dict[FullKey, ValueT] = {}
|
||||
for i, raw_value in enumerate(raw_values):
|
||||
if raw_value is not None:
|
||||
try:
|
||||
# Deserialize the value
|
||||
encoding, data = raw_value.split(b":", 1)
|
||||
values[keys[i]] = self.serde.loads_typed((encoding.decode(), data))
|
||||
except Exception:
|
||||
# Skip corrupted entries
|
||||
continue
|
||||
|
||||
return values
|
||||
|
||||
async def aget(self, keys: Sequence[FullKey]) -> dict[FullKey, ValueT]:
|
||||
"""Asynchronously get the cached values for the given keys."""
|
||||
return self.get(keys)
|
||||
|
||||
def set(self, mapping: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Set the cached values for the given keys and TTLs."""
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
# Use pipeline for efficient batch operations
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
for (ns, key), (value, ttl) in mapping.items():
|
||||
redis_key = self._make_key(ns, key)
|
||||
encoding, data = self.serde.dumps_typed(value)
|
||||
|
||||
# Store as "encoding:data" format
|
||||
serialized_value = f"{encoding}:".encode() + data
|
||||
|
||||
if ttl is not None:
|
||||
pipe.setex(redis_key, ttl, serialized_value)
|
||||
else:
|
||||
pipe.set(redis_key, serialized_value)
|
||||
|
||||
try:
|
||||
pipe.execute()
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
pass
|
||||
|
||||
async def aset(self, mapping: Mapping[FullKey, tuple[ValueT, int | None]]) -> None:
|
||||
"""Asynchronously set the cached values for the given keys and TTLs."""
|
||||
self.set(mapping)
|
||||
|
||||
def clear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
try:
|
||||
if namespaces is None:
|
||||
# Clear all keys with our prefix
|
||||
pattern = f"{self.prefix}*"
|
||||
keys = self.redis.keys(pattern)
|
||||
if keys:
|
||||
self.redis.delete(*keys)
|
||||
else:
|
||||
# Clear specific namespaces
|
||||
keys_to_delete = []
|
||||
for ns in namespaces:
|
||||
ns_str = ":".join(ns) if ns else ""
|
||||
pattern = (
|
||||
f"{self.prefix}{ns_str}:*" if ns_str else f"{self.prefix}*"
|
||||
)
|
||||
keys = self.redis.keys(pattern)
|
||||
keys_to_delete.extend(keys)
|
||||
|
||||
if keys_to_delete:
|
||||
self.redis.delete(*keys_to_delete)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
pass
|
||||
|
||||
async def aclear(self, namespaces: Sequence[Namespace] | None = None) -> None:
|
||||
"""Asynchronously delete the cached values for the given namespaces.
|
||||
If no namespaces are provided, clear all cached values."""
|
||||
self.clear(namespaces)
|
||||
BIN
venv/Lib/site-packages/langgraph/cache/redis/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
BIN
venv/Lib/site-packages/langgraph/cache/redis/__pycache__/__init__.cpython-311.pyc
vendored
Normal file
Binary file not shown.
27
venv/Lib/site-packages/langgraph/channels/__init__.py
Normal file
27
venv/Lib/site-packages/langgraph/channels/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langgraph.channels.any_value import AnyValue
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.channels.binop import BinaryOperatorAggregate
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.last_value import LastValue, LastValueAfterFinish
|
||||
from langgraph.channels.named_barrier_value import (
|
||||
NamedBarrierValue,
|
||||
NamedBarrierValueAfterFinish,
|
||||
)
|
||||
from langgraph.channels.topic import Topic
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
|
||||
__all__ = (
|
||||
# base
|
||||
"BaseChannel",
|
||||
# value types
|
||||
"AnyValue",
|
||||
"LastValue",
|
||||
"LastValueAfterFinish",
|
||||
"UntrackedValue",
|
||||
"EphemeralValue",
|
||||
"BinaryOperatorAggregate",
|
||||
"NamedBarrierValue",
|
||||
"NamedBarrierValueAfterFinish",
|
||||
# topics
|
||||
"Topic",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
72
venv/Lib/site-packages/langgraph/channels/any_value.py
Normal file
72
venv/Lib/site-packages/langgraph/channels/any_value.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
__all__ = ("AnyValue",)
|
||||
|
||||
|
||||
class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, assumes that if multiple values are
|
||||
received, they are all equal."""
|
||||
|
||||
__slots__ = ("typ", "value")
|
||||
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, AnyValue)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
if self.value is MISSING:
|
||||
return False
|
||||
else:
|
||||
self.value = MISSING
|
||||
return True
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
121
venv/Lib/site-packages/langgraph/channels/base.py
Normal file
121
venv/Lib/site-packages/langgraph/channels/base.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
Value = TypeVar("Value")
|
||||
Update = TypeVar("Update")
|
||||
Checkpoint = TypeVar("Checkpoint")
|
||||
|
||||
__all__ = ("BaseChannel",)
|
||||
|
||||
|
||||
class BaseChannel(Generic[Value, Update, Checkpoint], ABC):
|
||||
"""Base class for all channels."""
|
||||
|
||||
__slots__ = ("key", "typ")
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
self.typ = typ
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ValueType(self) -> Any:
|
||||
"""The type of the value stored in the channel."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def UpdateType(self) -> Any:
|
||||
"""The type of the update received by the channel."""
|
||||
|
||||
# serialize/deserialize methods
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel.
|
||||
|
||||
By default, delegates to `checkpoint()` and `from_checkpoint()`.
|
||||
|
||||
Subclasses can override this method with a more efficient implementation.
|
||||
"""
|
||||
return self.from_checkpoint(self.checkpoint())
|
||||
|
||||
def checkpoint(self) -> Checkpoint | Any:
|
||||
"""Return a serializable representation of the channel's current state.
|
||||
|
||||
Raises `EmptyChannelError` if the channel is empty (never updated yet),
|
||||
or doesn't support checkpoints.
|
||||
"""
|
||||
try:
|
||||
return self.get()
|
||||
except EmptyChannelError:
|
||||
return MISSING
|
||||
|
||||
@abstractmethod
|
||||
def from_checkpoint(self, checkpoint: Checkpoint | Any) -> Self:
|
||||
"""Return a new identical channel, optionally initialized from a checkpoint.
|
||||
|
||||
If the checkpoint contains complex data structures, they should be copied.
|
||||
"""
|
||||
|
||||
# read methods
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> Value:
|
||||
"""Return the current value of the channel.
|
||||
|
||||
Raises `EmptyChannelError` if the channel is empty (never updated yet)."""
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Return `True` if the channel is available (not empty), `False` otherwise.
|
||||
|
||||
Subclasses should override this method to provide a more efficient
|
||||
implementation than calling `get()` and catching `EmptyChannelError`.
|
||||
"""
|
||||
try:
|
||||
self.get()
|
||||
return True
|
||||
except EmptyChannelError:
|
||||
return False
|
||||
|
||||
# write methods
|
||||
|
||||
@abstractmethod
|
||||
def update(self, values: Sequence[Update]) -> bool:
|
||||
"""Update the channel's value with the given sequence of updates.
|
||||
The order of the updates in the sequence is arbitrary.
|
||||
This method is called by Pregel for all channels at the end of each step.
|
||||
|
||||
If there are no updates, it is called with an empty sequence.
|
||||
|
||||
Raises `InvalidUpdateError` if the sequence of updates is invalid.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise."""
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Notify the channel that a subscribed task ran.
|
||||
|
||||
By default, no-op.
|
||||
|
||||
A channel can use this method to modify its state, preventing the value from being consumed again.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
"""Notify the channel that the Pregel run is finishing.
|
||||
|
||||
By default, no-op.
|
||||
|
||||
A channel can use this method to modify its state, preventing finish.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise.
|
||||
"""
|
||||
return False
|
||||
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import NotRequired, Required, Self
|
||||
|
||||
from langgraph._internal._constants import OVERWRITE
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import (
|
||||
EmptyChannelError,
|
||||
ErrorCode,
|
||||
InvalidUpdateError,
|
||||
create_error_message,
|
||||
)
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
__all__ = ("BinaryOperatorAggregate",)
|
||||
|
||||
|
||||
# Adapted from typing_extensions
|
||||
def _strip_extras(t): # type: ignore[no-untyped-def]
|
||||
"""Strips Annotated, Required and NotRequired from a given type."""
|
||||
if hasattr(t, "__origin__"):
|
||||
return _strip_extras(t.__origin__)
|
||||
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
|
||||
return _strip_extras(t.__args__[0])
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def _get_overwrite(value: Any) -> tuple[bool, Any]:
|
||||
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
|
||||
if isinstance(value, Overwrite):
|
||||
return True, value.value
|
||||
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
|
||||
return True, value[OVERWRITE]
|
||||
return False, None
|
||||
|
||||
|
||||
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the result of applying a binary operator to the current value and each new value.
|
||||
|
||||
```python
|
||||
import operator
|
||||
|
||||
total = Channels.BinaryOperatorAggregate(int, operator.add)
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("value", "operator")
|
||||
|
||||
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
|
||||
super().__init__(typ)
|
||||
self.operator = operator
|
||||
# special forms from typing or collections.abc are not instantiable
|
||||
# so we need to replace them with their concrete counterparts
|
||||
typ = _strip_extras(typ)
|
||||
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
|
||||
typ = list
|
||||
if typ in (collections.abc.Set, collections.abc.MutableSet):
|
||||
typ = set
|
||||
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
|
||||
typ = dict
|
||||
try:
|
||||
self.value = typ()
|
||||
except Exception:
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, BinaryOperatorAggregate) and (
|
||||
value.operator is self.operator
|
||||
if value.operator.__name__ != "<lambda>"
|
||||
and self.operator.__name__ != "<lambda>"
|
||||
else True
|
||||
)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if not values:
|
||||
return False
|
||||
if self.value is MISSING:
|
||||
self.value = values[0]
|
||||
values = values[1:]
|
||||
seen_overwrite: bool = False
|
||||
for value in values:
|
||||
is_overwrite, overwrite_value = _get_overwrite(value)
|
||||
if is_overwrite:
|
||||
if seen_overwrite:
|
||||
msg = create_error_message(
|
||||
message="Can receive only one Overwrite value per super-step.",
|
||||
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||||
)
|
||||
raise InvalidUpdateError(msg)
|
||||
self.value = overwrite_value
|
||||
seen_overwrite = True
|
||||
continue
|
||||
if not seen_overwrite:
|
||||
self.value = self.operator(self.value, value)
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
79
venv/Lib/site-packages/langgraph/channels/ephemeral_value.py
Normal file
79
venv/Lib/site-packages/langgraph/channels/ephemeral_value.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("EphemeralValue",)
|
||||
|
||||
|
||||
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the value received in the step immediately preceding, clears after."""
|
||||
|
||||
__slots__ = ("value", "guard")
|
||||
|
||||
value: Value | Any
|
||||
guard: bool
|
||||
|
||||
def __init__(self, typ: Any, guard: bool = True) -> None:
|
||||
super().__init__(typ)
|
||||
self.guard = guard
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, EphemeralValue) and value.guard == self.guard
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
if self.value is not MISSING:
|
||||
self.value = MISSING
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if len(values) != 1 and self.guard:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
|
||||
)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
151
venv/Lib/site-packages/langgraph/channels/last_value.py
Normal file
151
venv/Lib/site-packages/langgraph/channels/last_value.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import (
|
||||
EmptyChannelError,
|
||||
ErrorCode,
|
||||
InvalidUpdateError,
|
||||
create_error_message,
|
||||
)
|
||||
|
||||
__all__ = ("LastValue", "LastValueAfterFinish")
|
||||
|
||||
|
||||
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, can receive at most one value per step."""
|
||||
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, LastValue)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
if len(values) != 1:
|
||||
msg = create_error_message(
|
||||
message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.",
|
||||
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||||
)
|
||||
raise InvalidUpdateError(msg)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
|
||||
|
||||
class LastValueAfterFinish(
|
||||
Generic[Value], BaseChannel[Value, Value, tuple[Value, bool]]
|
||||
):
|
||||
"""Stores the last value received, but only made available after finish().
|
||||
Once made available, clears the value."""
|
||||
|
||||
__slots__ = ("value", "finished")
|
||||
|
||||
value: Value | Any
|
||||
finished: bool
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
self.finished = False
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, LastValueAfterFinish)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def checkpoint(self) -> tuple[Value | Any, bool] | Any:
|
||||
if self.value is MISSING:
|
||||
return MISSING
|
||||
return (self.value, self.finished)
|
||||
|
||||
def from_checkpoint(self, checkpoint: tuple[Value | Any, bool] | Any) -> Self:
|
||||
empty = self.__class__(self.typ)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value, empty.finished = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value | Any]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
|
||||
self.finished = False
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.finished:
|
||||
self.finished = False
|
||||
self.value = MISSING
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
if not self.finished and self.value is not MISSING:
|
||||
self.finished = True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING or not self.finished:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING and self.finished
|
||||
167
venv/Lib/site-packages/langgraph/channels/named_barrier_value.py
Normal file
167
venv/Lib/site-packages/langgraph/channels/named_barrier_value.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("NamedBarrierValue", "NamedBarrierValueAfterFinish")
|
||||
|
||||
|
||||
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
|
||||
"""A channel that waits until all named values are received before making the value available."""
|
||||
|
||||
__slots__ = ("names", "seen")
|
||||
|
||||
names: set[Value]
|
||||
seen: set[Value]
|
||||
|
||||
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
||||
super().__init__(typ)
|
||||
self.names = names
|
||||
self.seen: set[str] = set()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, NamedBarrierValue) and value.names == self.names
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
empty.seen = self.seen.copy()
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> set[Value]:
|
||||
return self.seen
|
||||
|
||||
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.seen = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
updated = False
|
||||
for value in values:
|
||||
if value in self.names:
|
||||
if value not in self.seen:
|
||||
self.seen.add(value)
|
||||
updated = True
|
||||
else:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': Value {value} not in {self.names}"
|
||||
)
|
||||
return updated
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.seen != self.names:
|
||||
raise EmptyChannelError()
|
||||
return None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.seen == self.names
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.seen == self.names:
|
||||
self.seen = set()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class NamedBarrierValueAfterFinish(
|
||||
Generic[Value], BaseChannel[Value, Value, set[Value]]
|
||||
):
|
||||
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
|
||||
|
||||
__slots__ = ("names", "seen", "finished")
|
||||
|
||||
names: set[Value]
|
||||
seen: set[Value]
|
||||
|
||||
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
||||
super().__init__(typ)
|
||||
self.names = names
|
||||
self.seen: set[str] = set()
|
||||
self.finished = False
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return (
|
||||
isinstance(value, NamedBarrierValueAfterFinish)
|
||||
and value.names == self.names
|
||||
)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
empty.seen = self.seen.copy()
|
||||
empty.finished = self.finished
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> tuple[set[Value], bool]:
|
||||
return (self.seen, self.finished)
|
||||
|
||||
def from_checkpoint(self, checkpoint: tuple[set[Value], bool]) -> Self:
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.seen, empty.finished = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
updated = False
|
||||
for value in values:
|
||||
if value in self.names:
|
||||
if value not in self.seen:
|
||||
self.seen.add(value)
|
||||
updated = True
|
||||
else:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': Value {value} not in {self.names}"
|
||||
)
|
||||
return updated
|
||||
|
||||
def get(self) -> Value:
|
||||
if not self.finished or self.seen != self.names:
|
||||
raise EmptyChannelError()
|
||||
return None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.finished and self.seen == self.names
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.finished and self.seen == self.names:
|
||||
self.finished = False
|
||||
self.seen = set()
|
||||
return True
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
if not self.finished and self.seen == self.names:
|
||||
self.finished = True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
94
venv/Lib/site-packages/langgraph/channels/topic.py
Normal file
94
venv/Lib/site-packages/langgraph/channels/topic.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
__all__ = ("Topic",)
|
||||
|
||||
|
||||
def _flatten(values: Sequence[Value | list[Value]]) -> Iterator[Value]:
|
||||
for value in values:
|
||||
if isinstance(value, list):
|
||||
yield from value
|
||||
else:
|
||||
yield value
|
||||
|
||||
|
||||
class Topic(
|
||||
Generic[Value],
|
||||
BaseChannel[Sequence[Value], Value | list[Value], list[Value]],
|
||||
):
|
||||
"""A configurable PubSub Topic.
|
||||
|
||||
Args:
|
||||
typ: The type of the value stored in the channel.
|
||||
accumulate: Whether to accumulate values across steps. If `False`, the channel will be emptied after each step.
|
||||
"""
|
||||
|
||||
__slots__ = ("values", "accumulate")
|
||||
|
||||
def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
|
||||
super().__init__(typ)
|
||||
# attrs
|
||||
self.accumulate = accumulate
|
||||
# state
|
||||
self.values = list[Value]()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, Topic) and value.accumulate == self.accumulate
|
||||
|
||||
@property
|
||||
def ValueType(self) -> Any:
|
||||
"""The type of the value stored in the channel."""
|
||||
return Sequence[self.typ] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> Any:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ | list[self.typ] # type: ignore[name-defined]
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.accumulate)
|
||||
empty.key = self.key
|
||||
empty.values = self.values.copy()
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> list[Value]:
|
||||
return self.values
|
||||
|
||||
def from_checkpoint(self, checkpoint: list[Value]) -> Self:
|
||||
empty = self.__class__(self.typ, self.accumulate)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
if isinstance(checkpoint, tuple):
|
||||
# backwards compatibility
|
||||
empty.values = checkpoint[1]
|
||||
else:
|
||||
empty.values = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value | list[Value]]) -> bool:
|
||||
updated = False
|
||||
if not self.accumulate:
|
||||
updated = bool(self.values)
|
||||
self.values = list[Value]()
|
||||
if flat_values := tuple(_flatten(values)):
|
||||
updated = True
|
||||
self.values.extend(flat_values)
|
||||
return updated
|
||||
|
||||
def get(self) -> Sequence[Value]:
|
||||
if self.values:
|
||||
return list(self.values)
|
||||
else:
|
||||
raise EmptyChannelError
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return bool(self.values)
|
||||
73
venv/Lib/site-packages/langgraph/channels/untracked_value.py
Normal file
73
venv/Lib/site-packages/langgraph/channels/untracked_value.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("UntrackedValue",)
|
||||
|
||||
|
||||
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, never checkpointed."""
|
||||
|
||||
__slots__ = ("value", "guard")
|
||||
|
||||
guard: bool
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: type[Value], guard: bool = True) -> None:
|
||||
super().__init__(typ)
|
||||
self.guard = guard
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, UntrackedValue) and value.guard == self.guard
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> Value | Any:
|
||||
return MISSING
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
if len(values) != 1 and self.guard:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
|
||||
)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
628
venv/Lib/site-packages/langgraph/checkpoint/base/__init__.py
Normal file
628
venv/Lib/site-packages/langgraph/checkpoint/base/__init__.py
Normal file
@@ -0,0 +1,628 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Collection, Iterator, Mapping, Sequence
|
||||
from typing import ( # noqa: UP035
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langgraph.checkpoint.base.id import uuid6
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
|
||||
from langgraph.checkpoint.serde.encrypted import EncryptedSerializer
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
from langgraph.checkpoint.serde.types import (
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
RESUME,
|
||||
SCHEDULED,
|
||||
ChannelProtocol,
|
||||
)
|
||||
|
||||
V = TypeVar("V", int, float, str)
|
||||
PendingWrite = tuple[str, str, Any]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Marked as total=False to allow for future expansion.
|
||||
class CheckpointMetadata(TypedDict, total=False):
|
||||
"""Metadata associated with a checkpoint."""
|
||||
|
||||
source: Literal["input", "loop", "update", "fork"]
|
||||
"""The source of the checkpoint.
|
||||
|
||||
- `"input"`: The checkpoint was created from an input to invoke/stream/batch.
|
||||
- `"loop"`: The checkpoint was created from inside the pregel loop.
|
||||
- `"update"`: The checkpoint was created from a manual state update.
|
||||
- `"fork"`: The checkpoint was created as a copy of another checkpoint.
|
||||
"""
|
||||
step: int
|
||||
"""The step number of the checkpoint.
|
||||
|
||||
`-1` for the first `"input"` checkpoint.
|
||||
`0` for the first `"loop"` checkpoint.
|
||||
`...` for the `nth` checkpoint afterwards.
|
||||
"""
|
||||
parents: dict[str, str]
|
||||
"""The IDs of the parent checkpoints.
|
||||
|
||||
Mapping from checkpoint namespace to checkpoint ID.
|
||||
"""
|
||||
run_id: str
|
||||
"""The ID of the run that created this checkpoint."""
|
||||
|
||||
|
||||
ChannelVersions = dict[str, str | int | float]
|
||||
|
||||
|
||||
class Checkpoint(TypedDict):
|
||||
"""State snapshot at a given point in time."""
|
||||
|
||||
v: int
|
||||
"""The version of the checkpoint format. Currently `1`."""
|
||||
id: str
|
||||
"""The ID of the checkpoint.
|
||||
|
||||
This is both unique and monotonically increasing, so can be used for sorting
|
||||
checkpoints from first to last."""
|
||||
ts: str
|
||||
"""The timestamp of the checkpoint in ISO 8601 format."""
|
||||
channel_values: dict[str, Any]
|
||||
"""The values of the channels at the time of the checkpoint.
|
||||
|
||||
Mapping from channel name to deserialized channel snapshot value.
|
||||
"""
|
||||
channel_versions: ChannelVersions
|
||||
"""The versions of the channels at the time of the checkpoint.
|
||||
|
||||
The keys are channel names and the values are monotonically increasing
|
||||
version strings for each channel.
|
||||
"""
|
||||
versions_seen: dict[str, ChannelVersions]
|
||||
"""Map from node ID to map from channel name to version seen.
|
||||
|
||||
This keeps track of the versions of the channels that each node has seen.
|
||||
Used to determine which nodes to execute next.
|
||||
"""
|
||||
updated_channels: list[str] | None
|
||||
"""The channels that were updated in this checkpoint.
|
||||
"""
|
||||
|
||||
|
||||
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint:
|
||||
return Checkpoint(
|
||||
v=checkpoint["v"],
|
||||
ts=checkpoint["ts"],
|
||||
id=checkpoint["id"],
|
||||
channel_values=checkpoint["channel_values"].copy(),
|
||||
channel_versions=checkpoint["channel_versions"].copy(),
|
||||
versions_seen={k: v.copy() for k, v in checkpoint["versions_seen"].items()},
|
||||
pending_sends=checkpoint.get("pending_sends", []).copy(),
|
||||
updated_channels=checkpoint.get("updated_channels", None),
|
||||
)
|
||||
|
||||
|
||||
class CheckpointTuple(NamedTuple):
|
||||
"""A tuple containing a checkpoint and its associated data."""
|
||||
|
||||
config: RunnableConfig
|
||||
checkpoint: Checkpoint
|
||||
metadata: CheckpointMetadata
|
||||
parent_config: RunnableConfig | None = None
|
||||
pending_writes: list[PendingWrite] | None = None
|
||||
|
||||
|
||||
class BaseCheckpointSaver(Generic[V]):
|
||||
"""Base class for creating a graph checkpointer.
|
||||
|
||||
Checkpointers allow LangGraph agents to persist their state
|
||||
within and across multiple interactions.
|
||||
|
||||
When a checkpointer is configured, you should pass a `thread_id` in the config when
|
||||
invoking the graph:
|
||||
|
||||
```python
|
||||
config = {"configurable": {"thread_id": "my-thread"}}
|
||||
graph.invoke(inputs, config)
|
||||
```
|
||||
|
||||
The `thread_id` is the primary key used to store and retrieve checkpoints. Without
|
||||
it, the checkpointer cannot save state, resume from interrupts, or enable
|
||||
time-travel debugging.
|
||||
|
||||
How you choose ``thread_id`` depends on your use case:
|
||||
|
||||
- **Single-shot workflows**: Use a unique ID (e.g., uuid4) for each run when
|
||||
executions are independent.
|
||||
- **Conversational memory**: Reuse the same `thread_id` across invocations
|
||||
to accumulate state (e.g., chat history) within a conversation.
|
||||
|
||||
Attributes:
|
||||
serde (SerializerProtocol): Serializer for encoding/decoding checkpoints.
|
||||
|
||||
Note:
|
||||
When creating a custom checkpoint saver, consider implementing async
|
||||
versions to avoid blocking the main thread.
|
||||
"""
|
||||
|
||||
serde: SerializerProtocol = JsonPlusSerializer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
) -> None:
|
||||
self.serde = maybe_add_typed_methods(serde or self.serde)
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list:
|
||||
"""Define the configuration options for the checkpoint saver.
|
||||
|
||||
Returns:
|
||||
list: List of configuration field specs.
|
||||
"""
|
||||
return []
|
||||
|
||||
def get(self, config: RunnableConfig) -> Checkpoint | None:
|
||||
"""Fetch a checkpoint using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint, or `None` if not found.
|
||||
"""
|
||||
if value := self.get_tuple(config):
|
||||
return value.checkpoint
|
||||
|
||||
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Fetch a checkpoint tuple using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint tuple, or `None` if not found.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[CheckpointTuple]:
|
||||
"""List checkpoints that match the given criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Returns:
|
||||
Iterator of matching checkpoint tuples.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Store a checkpoint with its configuration and metadata.
|
||||
|
||||
Args:
|
||||
config: Configuration for the checkpoint.
|
||||
checkpoint: The checkpoint to store.
|
||||
metadata: Additional metadata for the checkpoint.
|
||||
new_versions: New channel versions as of this write.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: Updated configuration after storing the checkpoint.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Store intermediate writes linked to a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Configuration of the related checkpoint.
|
||||
writes: List of writes to store.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with a specific thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_for_runs(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with the given run IDs.
|
||||
|
||||
Args:
|
||||
run_ids: The run IDs whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def copy_thread(
|
||||
self,
|
||||
source_thread_id: str,
|
||||
target_thread_id: str,
|
||||
) -> None:
|
||||
"""Copy all checkpoints and writes from one thread to another.
|
||||
|
||||
Args:
|
||||
source_thread_id: The thread ID to copy from.
|
||||
target_thread_id: The thread ID to copy to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def prune(
|
||||
self,
|
||||
thread_ids: Sequence[str],
|
||||
*,
|
||||
strategy: str = "keep_latest",
|
||||
) -> None:
|
||||
"""Prune checkpoints for the given threads.
|
||||
|
||||
Args:
|
||||
thread_ids: The thread IDs to prune.
|
||||
strategy: The pruning strategy. `"keep_latest"` retains only the most
|
||||
recent checkpoint per namespace. `"delete"` removes all checkpoints.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
|
||||
"""Asynchronously fetch a checkpoint using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint, or `None` if not found.
|
||||
"""
|
||||
if value := await self.aget_tuple(config):
|
||||
return value.checkpoint
|
||||
|
||||
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Asynchronously fetch a checkpoint tuple using the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration specifying which checkpoint to retrieve.
|
||||
|
||||
Returns:
|
||||
The requested checkpoint tuple, or `None` if not found.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> AsyncIterator[CheckpointTuple]:
|
||||
"""Asynchronously list checkpoints that match the given criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria for metadata.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Returns:
|
||||
Async iterator of matching checkpoint tuples.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
yield
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronously store a checkpoint with its configuration and metadata.
|
||||
|
||||
Args:
|
||||
config: Configuration for the checkpoint.
|
||||
checkpoint: The checkpoint to store.
|
||||
metadata: Additional metadata for the checkpoint.
|
||||
new_versions: New channel versions as of this write.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: Updated configuration after storing the checkpoint.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Asynchronously store intermediate writes linked to a checkpoint.
|
||||
|
||||
Args:
|
||||
config: Configuration of the related checkpoint.
|
||||
writes: List of writes to store.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Implement this method in your custom checkpoint saver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def adelete_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
"""Delete all checkpoints and writes associated with a specific thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def adelete_for_runs(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> None:
|
||||
"""Asynchronously delete all checkpoints and writes for the given run IDs.
|
||||
|
||||
Args:
|
||||
run_ids: The run IDs whose checkpoints should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def acopy_thread(
|
||||
self,
|
||||
source_thread_id: str,
|
||||
target_thread_id: str,
|
||||
) -> None:
|
||||
"""Asynchronously copy all checkpoints and writes from one thread to another.
|
||||
|
||||
Args:
|
||||
source_thread_id: The thread ID to copy from.
|
||||
target_thread_id: The thread ID to copy to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def aprune(
|
||||
self,
|
||||
thread_ids: Sequence[str],
|
||||
*,
|
||||
strategy: str = "keep_latest",
|
||||
) -> None:
|
||||
"""Asynchronously prune checkpoints for the given threads.
|
||||
|
||||
Args:
|
||||
thread_ids: The thread IDs to prune.
|
||||
strategy: The pruning strategy. `"keep_latest"` retains only the most
|
||||
recent checkpoint per namespace. `"delete"` removes all checkpoints.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_next_version(self, current: V | None, channel: None) -> V:
|
||||
"""Generate the next version ID for a channel.
|
||||
|
||||
Default is to use integer versions, incrementing by `1`.
|
||||
|
||||
If you override, you can use `str`/`int`/`float` versions, as long as they are monotonically increasing.
|
||||
|
||||
Args:
|
||||
current: The current version identifier (`int`, `float`, or `str`).
|
||||
channel: Deprecated argument, kept for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
V: The next version identifier, which must be increasing.
|
||||
"""
|
||||
if isinstance(current, str):
|
||||
raise NotImplementedError
|
||||
elif current is None:
|
||||
return 1
|
||||
else:
|
||||
return current + 1
|
||||
|
||||
def with_allowlist(
|
||||
self, extra_allowlist: Collection[tuple[str, ...]]
|
||||
) -> BaseCheckpointSaver[V]:
|
||||
"""Return a shallow clone with a derived msgpack allowlist."""
|
||||
serde = _with_msgpack_allowlist(self.serde, extra_allowlist)
|
||||
if serde is self.serde:
|
||||
return self
|
||||
clone = copy.copy(self)
|
||||
clone.serde = maybe_add_typed_methods(serde)
|
||||
return clone
|
||||
|
||||
|
||||
def _with_msgpack_allowlist(
|
||||
serde: SerializerProtocol, extra_allowlist: Collection[tuple[str, ...]]
|
||||
) -> SerializerProtocol:
|
||||
if isinstance(serde, JsonPlusSerializer):
|
||||
return serde.with_msgpack_allowlist(extra_allowlist)
|
||||
if isinstance(serde, EncryptedSerializer):
|
||||
inner = serde.serde
|
||||
if isinstance(inner, JsonPlusSerializer):
|
||||
updated_inner = inner.with_msgpack_allowlist(extra_allowlist)
|
||||
if updated_inner is inner:
|
||||
return serde
|
||||
return EncryptedSerializer(serde.cipher, updated_inner)
|
||||
logger.warning(
|
||||
"Serializer %s does not support msgpack allowlist. "
|
||||
"Strict msgpack deserialization will not be enforced.",
|
||||
type(serde).__name__,
|
||||
)
|
||||
return serde
|
||||
|
||||
|
||||
class EmptyChannelError(Exception):
|
||||
"""Raised when attempting to get the value of a channel that hasn't been updated
|
||||
for the first time yet."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoint_id(config: RunnableConfig) -> str | None:
|
||||
"""Get checkpoint ID."""
|
||||
return config["configurable"].get("checkpoint_id")
|
||||
|
||||
|
||||
def get_checkpoint_metadata(
|
||||
config: RunnableConfig, metadata: CheckpointMetadata
|
||||
) -> CheckpointMetadata:
|
||||
"""Get checkpoint metadata in a backwards-compatible manner."""
|
||||
metadata = {
|
||||
k: v.replace("\u0000", "") if isinstance(v, str) else v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
for obj in (config.get("metadata"), config.get("configurable")):
|
||||
if not obj:
|
||||
continue
|
||||
for key, v in obj.items():
|
||||
if key in metadata or key in EXCLUDED_METADATA_KEYS or key.startswith("__"):
|
||||
continue
|
||||
elif isinstance(v, str):
|
||||
metadata[key] = v.replace("\u0000", "")
|
||||
elif isinstance(v, (int, bool, float)):
|
||||
metadata[key] = v
|
||||
return metadata
|
||||
|
||||
|
||||
def get_serializable_checkpoint_metadata(
|
||||
config: RunnableConfig, metadata: CheckpointMetadata
|
||||
) -> CheckpointMetadata:
|
||||
"""Get checkpoint metadata in a backwards-compatible manner."""
|
||||
checkpoint_metadata = get_checkpoint_metadata(config, metadata)
|
||||
if "writes" in checkpoint_metadata:
|
||||
checkpoint_metadata.pop("writes")
|
||||
return checkpoint_metadata
|
||||
|
||||
|
||||
"""
|
||||
Mapping from error type to error index.
|
||||
Regular writes just map to their index in the list of writes being saved.
|
||||
Special writes (e.g. errors) map to negative indices, to avoid those writes from
|
||||
conflicting with regular writes.
|
||||
Each Checkpointer implementation should use this mapping in put_writes.
|
||||
"""
|
||||
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
|
||||
|
||||
EXCLUDED_METADATA_KEYS = {
|
||||
"thread_id",
|
||||
"checkpoint_id",
|
||||
"checkpoint_ns",
|
||||
"checkpoint_map",
|
||||
"langgraph_step",
|
||||
"langgraph_node",
|
||||
"langgraph_triggers",
|
||||
"langgraph_path",
|
||||
"langgraph_checkpoint_ns",
|
||||
}
|
||||
|
||||
# --- below are deprecated utilities used by past versions of LangGraph ---
|
||||
|
||||
LATEST_VERSION = 2
|
||||
|
||||
|
||||
def empty_checkpoint() -> Checkpoint:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
id=str(uuid6(clock_seq=-2)),
|
||||
ts=datetime.now(timezone.utc).isoformat(),
|
||||
channel_values={},
|
||||
channel_versions={},
|
||||
versions_seen={},
|
||||
pending_sends=[],
|
||||
updated_channels=None,
|
||||
)
|
||||
|
||||
|
||||
def create_checkpoint(
|
||||
checkpoint: Checkpoint,
|
||||
channels: Mapping[str, ChannelProtocol] | None,
|
||||
step: int,
|
||||
*,
|
||||
id: str | None = None,
|
||||
) -> Checkpoint:
|
||||
"""Create a checkpoint for the given channels."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
if channels is None:
|
||||
values = checkpoint["channel_values"]
|
||||
else:
|
||||
values = {}
|
||||
for k, v in channels.items():
|
||||
if k not in checkpoint["channel_versions"]:
|
||||
continue
|
||||
try:
|
||||
values[k] = v.checkpoint()
|
||||
except EmptyChannelError:
|
||||
pass
|
||||
return Checkpoint(
|
||||
v=LATEST_VERSION,
|
||||
ts=ts,
|
||||
id=id or str(uuid6(clock_seq=step)),
|
||||
channel_values=values,
|
||||
channel_versions=checkpoint["channel_versions"],
|
||||
versions_seen=checkpoint["versions_seen"],
|
||||
pending_sends=checkpoint.get("pending_sends", []),
|
||||
updated_channels=None,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
109
venv/Lib/site-packages/langgraph/checkpoint/base/id.py
Normal file
109
venv/Lib/site-packages/langgraph/checkpoint/base/id.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Adapted from
|
||||
https://github.com/oittaa/uuid6-python/blob/main/src/uuid6/__init__.py#L95
|
||||
Bundled in to avoid install issues with uuid6 package
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
|
||||
_last_v6_timestamp = None
|
||||
|
||||
|
||||
class UUID(uuid.UUID):
|
||||
r"""UUID draft version objects"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hex: str | None = None,
|
||||
bytes: bytes | None = None,
|
||||
bytes_le: bytes | None = None,
|
||||
fields: tuple[int, int, int, int, int, int] | None = None,
|
||||
int: int | None = None,
|
||||
version: int | None = None,
|
||||
*,
|
||||
is_safe: uuid.SafeUUID = uuid.SafeUUID.unknown,
|
||||
) -> None:
|
||||
r"""Create a UUID."""
|
||||
|
||||
if int is None or [hex, bytes, bytes_le, fields].count(None) != 4:
|
||||
return super().__init__(
|
||||
hex=hex,
|
||||
bytes=bytes,
|
||||
bytes_le=bytes_le,
|
||||
fields=fields,
|
||||
int=int,
|
||||
version=version,
|
||||
is_safe=is_safe,
|
||||
)
|
||||
if not 0 <= int < 1 << 128:
|
||||
raise ValueError("int is out of range (need a 128-bit value)")
|
||||
if version is not None:
|
||||
if not 6 <= version <= 8:
|
||||
raise ValueError("illegal version number")
|
||||
# Set the variant to RFC 4122.
|
||||
int &= ~(0xC000 << 48)
|
||||
int |= 0x8000 << 48
|
||||
# Set the version number.
|
||||
int &= ~(0xF000 << 64)
|
||||
int |= version << 76
|
||||
super().__init__(int=int, is_safe=is_safe)
|
||||
|
||||
@property
|
||||
def subsec(self) -> int:
|
||||
return ((self.int >> 64) & 0x0FFF) << 8 | ((self.int >> 54) & 0xFF)
|
||||
|
||||
@property
|
||||
def time(self) -> int:
|
||||
if self.version == 6:
|
||||
return (
|
||||
(self.time_low << 28)
|
||||
| (self.time_mid << 12)
|
||||
| (self.time_hi_version & 0x0FFF)
|
||||
)
|
||||
if self.version == 7:
|
||||
return self.int >> 80
|
||||
if self.version == 8:
|
||||
return (self.int >> 80) * 10**6 + _subsec_decode(self.subsec)
|
||||
return super().time
|
||||
|
||||
|
||||
def _subsec_decode(value: int) -> int:
|
||||
return -(-value * 10**6 // 2**20)
|
||||
|
||||
|
||||
def uuid6(node: int | None = None, clock_seq: int | None = None) -> UUID:
|
||||
r"""UUID version 6 is a field-compatible version of UUIDv1, reordered for
|
||||
improved DB locality. It is expected that UUIDv6 will primarily be
|
||||
used in contexts where there are existing v1 UUIDs. Systems that do
|
||||
not involve legacy UUIDv1 SHOULD consider using UUIDv7 instead.
|
||||
|
||||
If 'node' is not given, a random 48-bit number is chosen.
|
||||
|
||||
If 'clock_seq' is given, it is used as the sequence number;
|
||||
otherwise a random 14-bit sequence number is chosen."""
|
||||
|
||||
global _last_v6_timestamp
|
||||
|
||||
nanoseconds = time.time_ns()
|
||||
# 0x01b21dd213814000 is the number of 100-ns intervals between the
|
||||
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
|
||||
timestamp = nanoseconds // 100 + 0x01B21DD213814000
|
||||
if _last_v6_timestamp is not None and timestamp <= _last_v6_timestamp:
|
||||
timestamp = _last_v6_timestamp + 1
|
||||
_last_v6_timestamp = timestamp
|
||||
if clock_seq is None:
|
||||
clock_seq = random.getrandbits(14) # instead of stable storage
|
||||
if node is None:
|
||||
node = random.getrandbits(48)
|
||||
time_high_and_time_mid = (timestamp >> 12) & 0xFFFFFFFFFFFF
|
||||
time_low_and_version = timestamp & 0x0FFF
|
||||
uuid_int = time_high_and_time_mid << 80
|
||||
uuid_int |= time_low_and_version << 64
|
||||
uuid_int |= (clock_seq & 0x3FFF) << 48
|
||||
uuid_int |= node & 0xFFFFFFFFFFFF
|
||||
return UUID(int=uuid_int, version=6)
|
||||
603
venv/Lib/site-packages/langgraph/checkpoint/memory/__init__.py
Normal file
603
venv/Lib/site-packages/langgraph/checkpoint/memory/__init__.py
Normal file
@@ -0,0 +1,603 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from langgraph.checkpoint.base import (
|
||||
WRITES_IDX_MAP,
|
||||
BaseCheckpointSaver,
|
||||
ChannelVersions,
|
||||
Checkpoint,
|
||||
CheckpointMetadata,
|
||||
CheckpointTuple,
|
||||
SerializerProtocol,
|
||||
get_checkpoint_id,
|
||||
get_checkpoint_metadata,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemorySaver(
|
||||
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
|
||||
):
|
||||
"""An in-memory checkpoint saver.
|
||||
|
||||
This checkpoint saver stores checkpoints in memory using a `defaultdict`.
|
||||
|
||||
Note:
|
||||
Only use `InMemorySaver` for debugging or testing purposes.
|
||||
For production use cases we recommend installing [langgraph-checkpoint-postgres](https://pypi.org/project/langgraph-checkpoint-postgres/) and using `PostgresSaver` / `AsyncPostgresSaver`.
|
||||
|
||||
If you are using LangSmith Deployment, no checkpointer needs to be specified. The correct managed checkpointer will be used automatically.
|
||||
|
||||
Args:
|
||||
serde: The serializer to use for serializing and deserializing checkpoints.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
builder = StateGraph(int)
|
||||
builder.add_node("add_one", lambda x: x + 1)
|
||||
builder.set_entry_point("add_one")
|
||||
builder.set_finish_point("add_one")
|
||||
|
||||
memory = InMemorySaver()
|
||||
graph = builder.compile(checkpointer=memory)
|
||||
coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}})
|
||||
asyncio.run(coro) # Output: 2
|
||||
```
|
||||
"""
|
||||
|
||||
# thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping
|
||||
storage: defaultdict[
|
||||
str,
|
||||
dict[str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], str | None]]],
|
||||
]
|
||||
# (thread ID, checkpoint NS, checkpoint ID) -> (task ID, write idx)
|
||||
writes: defaultdict[
|
||||
tuple[str, str, str],
|
||||
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
|
||||
]
|
||||
blobs: dict[
|
||||
tuple[
|
||||
str, str, str, str | int | float
|
||||
], # thread id, checkpoint ns, channel, version
|
||||
tuple[str, bytes],
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serde: SerializerProtocol | None = None,
|
||||
factory: type[defaultdict] = defaultdict,
|
||||
) -> None:
|
||||
super().__init__(serde=serde)
|
||||
self.storage = factory(lambda: defaultdict(dict))
|
||||
self.writes = factory(dict)
|
||||
self.blobs = factory()
|
||||
self.stack = ExitStack()
|
||||
if factory is not defaultdict:
|
||||
self.stack.enter_context(self.storage) # type: ignore[arg-type]
|
||||
self.stack.enter_context(self.writes) # type: ignore[arg-type]
|
||||
self.stack.enter_context(self.blobs) # type: ignore[arg-type]
|
||||
|
||||
def __enter__(self) -> InMemorySaver:
|
||||
self.stack.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self.stack.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
async def __aenter__(self) -> InMemorySaver:
|
||||
self.stack.__enter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
__exc_type: type[BaseException] | None,
|
||||
__exc_value: BaseException | None,
|
||||
__traceback: TracebackType | None,
|
||||
) -> bool | None:
|
||||
return self.stack.__exit__(__exc_type, __exc_value, __traceback)
|
||||
|
||||
def _load_blobs(
|
||||
self, thread_id: str, checkpoint_ns: str, versions: ChannelVersions
|
||||
) -> dict[str, Any]:
|
||||
channel_values: dict[str, Any] = {}
|
||||
for k, v in versions.items():
|
||||
kk = (thread_id, checkpoint_ns, k, v)
|
||||
if kk in self.blobs:
|
||||
vv = self.blobs[kk]
|
||||
if vv[0] != "empty":
|
||||
channel_values[k] = self.serde.loads_typed(vv)
|
||||
return channel_values
|
||||
|
||||
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Get a checkpoint tuple from the in-memory storage.
|
||||
|
||||
This method retrieves a checkpoint tuple from the in-memory storage based on the
|
||||
provided config. If the config contains a `checkpoint_id` key, the checkpoint with
|
||||
the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint
|
||||
for the given thread ID is retrieved.
|
||||
|
||||
Args:
|
||||
config: The config to use for retrieving the checkpoint.
|
||||
|
||||
Returns:
|
||||
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
|
||||
"""
|
||||
thread_id: str = config["configurable"]["thread_id"]
|
||||
checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "")
|
||||
if checkpoint_id := get_checkpoint_id(config):
|
||||
if saved := self.storage[thread_id][checkpoint_ns].get(checkpoint_id):
|
||||
checkpoint, metadata, parent_checkpoint_id = saved
|
||||
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
|
||||
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
|
||||
return CheckpointTuple(
|
||||
config=config,
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
|
||||
),
|
||||
},
|
||||
metadata=self.serde.loads_typed(metadata),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
if checkpoints := self.storage[thread_id][checkpoint_ns]:
|
||||
checkpoint_id = max(checkpoints.keys())
|
||||
checkpoint, metadata, parent_checkpoint_id = checkpoints[checkpoint_id]
|
||||
writes = self.writes[(thread_id, checkpoint_ns, checkpoint_id)].values()
|
||||
checkpoint_ = self.serde.loads_typed(checkpoint)
|
||||
return CheckpointTuple(
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint_id,
|
||||
}
|
||||
},
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id, checkpoint_ns, checkpoint_["channel_versions"]
|
||||
),
|
||||
},
|
||||
metadata=self.serde.loads_typed(metadata),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[CheckpointTuple]:
|
||||
"""List checkpoints from the in-memory storage.
|
||||
|
||||
This method retrieves a list of checkpoint tuples from the in-memory storage based
|
||||
on the provided criteria.
|
||||
|
||||
Args:
|
||||
config: Base configuration for filtering checkpoints.
|
||||
filter: Additional filtering criteria for metadata.
|
||||
before: List checkpoints created before this configuration.
|
||||
limit: Maximum number of checkpoints to return.
|
||||
|
||||
Yields:
|
||||
An iterator of matching checkpoint tuples.
|
||||
"""
|
||||
thread_ids = (config["configurable"]["thread_id"],) if config else self.storage
|
||||
config_checkpoint_ns = (
|
||||
config["configurable"].get("checkpoint_ns") if config else None
|
||||
)
|
||||
config_checkpoint_id = get_checkpoint_id(config) if config else None
|
||||
for thread_id in thread_ids:
|
||||
for checkpoint_ns in self.storage[thread_id].keys():
|
||||
if (
|
||||
config_checkpoint_ns is not None
|
||||
and checkpoint_ns != config_checkpoint_ns
|
||||
):
|
||||
continue
|
||||
|
||||
for checkpoint_id, (
|
||||
checkpoint,
|
||||
metadata_b,
|
||||
parent_checkpoint_id,
|
||||
) in sorted(
|
||||
self.storage[thread_id][checkpoint_ns].items(),
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
):
|
||||
# filter by checkpoint ID from config
|
||||
if config_checkpoint_id and checkpoint_id != config_checkpoint_id:
|
||||
continue
|
||||
|
||||
# filter by checkpoint ID from `before` config
|
||||
if (
|
||||
before
|
||||
and (before_checkpoint_id := get_checkpoint_id(before))
|
||||
and checkpoint_id >= before_checkpoint_id
|
||||
):
|
||||
continue
|
||||
|
||||
# filter by metadata
|
||||
metadata = self.serde.loads_typed(metadata_b)
|
||||
if filter and not all(
|
||||
query_value == metadata.get(query_key)
|
||||
for query_key, query_value in filter.items()
|
||||
):
|
||||
continue
|
||||
|
||||
# limit search results
|
||||
if limit is not None and limit <= 0:
|
||||
break
|
||||
elif limit is not None:
|
||||
limit -= 1
|
||||
|
||||
writes = self.writes[
|
||||
(thread_id, checkpoint_ns, checkpoint_id)
|
||||
].values()
|
||||
|
||||
checkpoint_: Checkpoint = self.serde.loads_typed(checkpoint)
|
||||
|
||||
yield CheckpointTuple(
|
||||
config={
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint_id,
|
||||
}
|
||||
},
|
||||
checkpoint={
|
||||
**checkpoint_,
|
||||
"channel_values": self._load_blobs(
|
||||
thread_id,
|
||||
checkpoint_ns,
|
||||
checkpoint_["channel_versions"],
|
||||
),
|
||||
},
|
||||
metadata=metadata,
|
||||
parent_config=(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": parent_checkpoint_id,
|
||||
}
|
||||
}
|
||||
if parent_checkpoint_id
|
||||
else None
|
||||
),
|
||||
pending_writes=[
|
||||
(id, c, self.serde.loads_typed(v)) for id, c, v, _ in writes
|
||||
],
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Save a checkpoint to the in-memory storage.
|
||||
|
||||
This method saves a checkpoint to the in-memory storage. The checkpoint is associated
|
||||
with the provided config.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the checkpoint.
|
||||
checkpoint: The checkpoint to save.
|
||||
metadata: Additional metadata to save with the checkpoint.
|
||||
new_versions: New versions as of this write
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
|
||||
"""
|
||||
c = checkpoint.copy()
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
checkpoint_ns = config["configurable"]["checkpoint_ns"]
|
||||
values: dict[str, Any] = c.pop("channel_values") # type: ignore[misc]
|
||||
for k, v in new_versions.items():
|
||||
self.blobs[(thread_id, checkpoint_ns, k, v)] = (
|
||||
self.serde.dumps_typed(values[k]) if k in values else ("empty", b"")
|
||||
)
|
||||
self.storage[thread_id][checkpoint_ns].update(
|
||||
{
|
||||
checkpoint["id"]: (
|
||||
self.serde.dumps_typed(c),
|
||||
self.serde.dumps_typed(get_checkpoint_metadata(config, metadata)),
|
||||
config["configurable"].get("checkpoint_id"), # parent
|
||||
)
|
||||
}
|
||||
)
|
||||
return {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": checkpoint_ns,
|
||||
"checkpoint_id": checkpoint["id"],
|
||||
}
|
||||
}
|
||||
|
||||
def put_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Save a list of writes to the in-memory storage.
|
||||
|
||||
This method saves a list of writes to the in-memory storage. The writes are associated
|
||||
with the provided config.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the writes.
|
||||
writes: The writes to save.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved writes' timestamp.
|
||||
"""
|
||||
thread_id = config["configurable"]["thread_id"]
|
||||
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
||||
checkpoint_id = config["configurable"]["checkpoint_id"]
|
||||
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
|
||||
outer_writes_ = self.writes.get(outer_key)
|
||||
for idx, (c, v) in enumerate(writes):
|
||||
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
|
||||
if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_:
|
||||
continue
|
||||
|
||||
self.writes[outer_key][inner_key] = (
|
||||
task_id,
|
||||
c,
|
||||
self.serde.dumps_typed(v),
|
||||
task_path,
|
||||
)
|
||||
|
||||
def delete_thread(self, thread_id: str) -> None:
|
||||
"""Delete all checkpoints and writes associated with a thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if thread_id in self.storage:
|
||||
del self.storage[thread_id]
|
||||
for k in list(self.writes.keys()):
|
||||
if k[0] == thread_id:
|
||||
del self.writes[k]
|
||||
for k in list(self.blobs.keys()):
|
||||
if k[0] == thread_id:
|
||||
del self.blobs[k]
|
||||
|
||||
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
||||
"""Asynchronous version of `get_tuple`.
|
||||
|
||||
This method is an asynchronous wrapper around `get_tuple` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to use for retrieving the checkpoint.
|
||||
|
||||
Returns:
|
||||
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
|
||||
"""
|
||||
return self.get_tuple(config)
|
||||
|
||||
async def alist(
|
||||
self,
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
filter: dict[str, Any] | None = None,
|
||||
before: RunnableConfig | None = None,
|
||||
limit: int | None = None,
|
||||
) -> AsyncIterator[CheckpointTuple]:
|
||||
"""Asynchronous version of `list`.
|
||||
|
||||
This method is an asynchronous wrapper around `list` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to use for listing the checkpoints.
|
||||
|
||||
Yields:
|
||||
An asynchronous iterator of checkpoint tuples.
|
||||
"""
|
||||
for item in self.list(config, filter=filter, before=before, limit=limit):
|
||||
yield item
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
checkpoint: Checkpoint,
|
||||
metadata: CheckpointMetadata,
|
||||
new_versions: ChannelVersions,
|
||||
) -> RunnableConfig:
|
||||
"""Asynchronous version of `put`.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the checkpoint.
|
||||
checkpoint: The checkpoint to save.
|
||||
metadata: Additional metadata to save with the checkpoint.
|
||||
new_versions: New versions as of this write
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
|
||||
"""
|
||||
return self.put(config, checkpoint, metadata, new_versions)
|
||||
|
||||
async def aput_writes(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
writes: Sequence[tuple[str, Any]],
|
||||
task_id: str,
|
||||
task_path: str = "",
|
||||
) -> None:
|
||||
"""Asynchronous version of `put_writes`.
|
||||
|
||||
This method is an asynchronous wrapper around `put_writes` that runs the synchronous
|
||||
method in a separate thread using asyncio.
|
||||
|
||||
Args:
|
||||
config: The config to associate with the writes.
|
||||
writes: The writes to save, each as a (channel, value) pair.
|
||||
task_id: Identifier for the task creating the writes.
|
||||
task_path: Path of the task creating the writes.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return self.put_writes(config, writes, task_id, task_path)
|
||||
|
||||
async def adelete_thread(self, thread_id: str) -> None:
|
||||
"""Delete all checkpoints and writes associated with a thread ID.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return self.delete_thread(thread_id)
|
||||
|
||||
def get_next_version(self, current: str | None, channel: None) -> str:
|
||||
if current is None:
|
||||
current_v = 0
|
||||
elif isinstance(current, int):
|
||||
current_v = current
|
||||
else:
|
||||
current_v = int(current.split(".")[0])
|
||||
next_v = current_v + 1
|
||||
next_h = random.random()
|
||||
return f"{next_v:032}.{next_h:016}"
|
||||
|
||||
|
||||
MemorySaver = InMemorySaver # Kept for backwards compatibility
|
||||
|
||||
|
||||
class PersistentDict(defaultdict):
|
||||
"""Persistent dictionary with an API compatible with shelve and anydbm.
|
||||
|
||||
The dict is kept in memory, so the dictionary operations run as fast as
|
||||
a regular dictionary.
|
||||
|
||||
Write to disk is delayed until close or sync (similar to gdbm's fast mode).
|
||||
|
||||
Input file format is automatically discovered.
|
||||
Output file format is selectable between pickle, json, and csv.
|
||||
All three serialization formats are backed by fast C implementations.
|
||||
|
||||
Adapted from https://code.activestate.com/recipes/576642-persistent-dict-with-multiple-standard-file-format/
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
|
||||
self.flag = "c" # r=readonly, c=create, or n=new
|
||||
self.mode = None # None or an octal triple like 0644
|
||||
self.format = "pickle" # 'csv', 'json', or 'pickle'
|
||||
self.filename = filename
|
||||
super().__init__(*args, **kwds)
|
||||
|
||||
def sync(self) -> None:
|
||||
"Write dict to disk"
|
||||
if self.flag == "r":
|
||||
return
|
||||
tempname = self.filename + ".tmp"
|
||||
fileobj = open(tempname, "wb" if self.format == "pickle" else "w")
|
||||
try:
|
||||
self.dump(fileobj)
|
||||
except Exception:
|
||||
os.remove(tempname)
|
||||
raise
|
||||
finally:
|
||||
fileobj.close()
|
||||
shutil.move(tempname, self.filename) # atomic commit
|
||||
if self.mode is not None:
|
||||
os.chmod(self.filename, self.mode)
|
||||
|
||||
def close(self) -> None:
|
||||
self.sync()
|
||||
self.clear()
|
||||
|
||||
def __enter__(self) -> PersistentDict:
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def dump(self, fileobj: Any) -> None:
|
||||
if self.format == "pickle":
|
||||
pickle.dump(dict(self), fileobj, 2)
|
||||
else:
|
||||
raise NotImplementedError("Unknown format: " + repr(self.format))
|
||||
|
||||
def load(self) -> None:
|
||||
# try formats from most restrictive to least restrictive
|
||||
if self.flag == "n":
|
||||
return
|
||||
with open(self.filename, "rb" if self.format == "pickle" else "r") as fileobj:
|
||||
for loader in (pickle.load,):
|
||||
fileobj.seek(0)
|
||||
try:
|
||||
return self.update(loader(fileobj))
|
||||
except EOFError:
|
||||
return
|
||||
except Exception:
|
||||
logger.error(f"Failed to load file: {fileobj.name}")
|
||||
raise
|
||||
raise ValueError("File not in a supported format")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
|
||||
STRICT_MSGPACK_ENABLED = os.getenv("LANGGRAPH_STRICT_MSGPACK", "false").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
|
||||
_SENTINEL = cast(None, object())
|
||||
|
||||
SAFE_MSGPACK_TYPES: frozenset[tuple[str, ...]] = frozenset(
|
||||
{
|
||||
# datetime types
|
||||
("datetime", "datetime"),
|
||||
("datetime", "date"),
|
||||
("datetime", "time"),
|
||||
("datetime", "timedelta"),
|
||||
("datetime", "timezone"),
|
||||
# uuid
|
||||
("uuid", "UUID"),
|
||||
# numeric
|
||||
("decimal", "Decimal"),
|
||||
# collections
|
||||
("builtins", "set"),
|
||||
("builtins", "frozenset"),
|
||||
("collections", "deque"),
|
||||
# ip addresses
|
||||
("ipaddress", "IPv4Address"),
|
||||
("ipaddress", "IPv4Interface"),
|
||||
("ipaddress", "IPv4Network"),
|
||||
("ipaddress", "IPv6Address"),
|
||||
("ipaddress", "IPv6Interface"),
|
||||
("ipaddress", "IPv6Network"),
|
||||
# pathlib
|
||||
("pathlib", "Path"),
|
||||
("pathlib", "PosixPath"),
|
||||
("pathlib", "WindowsPath"),
|
||||
# pathlib in Python 3.13+
|
||||
("pathlib._local", "Path"),
|
||||
("pathlib._local", "PosixPath"),
|
||||
("pathlib._local", "WindowsPath"),
|
||||
# zoneinfo
|
||||
("zoneinfo", "ZoneInfo"),
|
||||
# regex
|
||||
("re", "compile"),
|
||||
# langchain-core messages (safe container types used by graph state)
|
||||
("langchain_core.messages.base", "BaseMessage"),
|
||||
("langchain_core.messages.base", "BaseMessageChunk"),
|
||||
("langchain_core.messages.human", "HumanMessage"),
|
||||
("langchain_core.messages.human", "HumanMessageChunk"),
|
||||
("langchain_core.messages.ai", "AIMessage"),
|
||||
("langchain_core.messages.ai", "AIMessageChunk"),
|
||||
("langchain_core.messages.system", "SystemMessage"),
|
||||
("langchain_core.messages.system", "SystemMessageChunk"),
|
||||
("langchain_core.messages.chat", "ChatMessage"),
|
||||
("langchain_core.messages.chat", "ChatMessageChunk"),
|
||||
("langchain_core.messages.tool", "ToolMessage"),
|
||||
("langchain_core.messages.tool", "ToolMessageChunk"),
|
||||
("langchain_core.messages.function", "FunctionMessage"),
|
||||
("langchain_core.messages.function", "FunctionMessageChunk"),
|
||||
("langchain_core.messages.modifier", "RemoveMessage"),
|
||||
# langchain-core document model
|
||||
("langchain_core.documents.base", "Document"),
|
||||
# langgraph
|
||||
("langgraph.types", "Send"),
|
||||
("langgraph.types", "Interrupt"),
|
||||
("langgraph.types", "Command"),
|
||||
("langgraph.types", "StateSnapshot"),
|
||||
("langgraph.types", "PregelTask"),
|
||||
("langgraph.types", "Overwrite"),
|
||||
("langgraph.store.base", "Item"),
|
||||
("langgraph.store.base", "GetOp"),
|
||||
}
|
||||
)
|
||||
|
||||
# Allowed (module, name, method) triples for EXT_METHOD_SINGLE_ARG.
|
||||
# Only these specific method invocations are permitted during deserialization.
|
||||
# This is separate from SAFE_MSGPACK_TYPES which only governs construction.
|
||||
SAFE_MSGPACK_METHODS: frozenset[tuple[str, str, str]] = frozenset(
|
||||
{
|
||||
("datetime", "datetime", "fromisoformat"),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
AllowedMsgpackModules = Iterable[tuple[str, ...] | type]
|
||||
64
venv/Lib/site-packages/langgraph/checkpoint/serde/base.py
Normal file
64
venv/Lib/site-packages/langgraph/checkpoint/serde/base.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
class UntypedSerializerProtocol(Protocol):
|
||||
"""Protocol for serialization and deserialization of objects."""
|
||||
|
||||
def dumps(self, obj: Any) -> bytes: ...
|
||||
|
||||
def loads(self, data: bytes) -> Any: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SerializerProtocol(Protocol):
|
||||
"""Protocol for serialization and deserialization of objects.
|
||||
|
||||
- `dumps_typed`: Serialize an object to a tuple `(type, bytes)`.
|
||||
- `loads_typed`: Deserialize an object from a tuple `(type, bytes)`.
|
||||
|
||||
Valid implementations include the `pickle`, `json` and `orjson` modules.
|
||||
"""
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...
|
||||
|
||||
|
||||
class SerializerCompat(SerializerProtocol):
|
||||
def __init__(self, serde: UntypedSerializerProtocol) -> None:
|
||||
self.serde = serde
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
return type(obj).__name__, self.serde.dumps(obj)
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
return self.serde.loads(data[1])
|
||||
|
||||
|
||||
def maybe_add_typed_methods(
|
||||
serde: SerializerProtocol | UntypedSerializerProtocol,
|
||||
) -> SerializerProtocol:
|
||||
"""Wrap serde old serde implementations in a class with loads_typed and dumps_typed for backwards compatibility."""
|
||||
|
||||
if not isinstance(serde, SerializerProtocol):
|
||||
return SerializerCompat(serde)
|
||||
|
||||
return serde
|
||||
|
||||
|
||||
class CipherProtocol(Protocol):
|
||||
"""Protocol for encryption and decryption of data.
|
||||
|
||||
- `encrypt`: Encrypt plaintext.
|
||||
- `decrypt`: Decrypt ciphertext.
|
||||
"""
|
||||
|
||||
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
|
||||
"""Encrypt plaintext. Returns a tuple `(cipher name, ciphertext)`."""
|
||||
...
|
||||
|
||||
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
|
||||
"""Decrypt ciphertext. Returns the plaintext."""
|
||||
...
|
||||
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.serde.base import CipherProtocol, SerializerProtocol
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
|
||||
class EncryptedSerializer(SerializerProtocol):
|
||||
"""Serializer that encrypts and decrypts data using an encryption protocol."""
|
||||
|
||||
def __init__(
|
||||
self, cipher: CipherProtocol, serde: SerializerProtocol = JsonPlusSerializer()
|
||||
) -> None:
|
||||
self.cipher = cipher
|
||||
self.serde = serde
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
"""Serialize an object to a tuple `(type, bytes)` and encrypt the bytes."""
|
||||
# serialize data
|
||||
typ, data = self.serde.dumps_typed(obj)
|
||||
# encrypt data
|
||||
ciphername, ciphertext = self.cipher.encrypt(data)
|
||||
# add cipher name to type
|
||||
return f"{typ}+{ciphername}", ciphertext
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
enc_cipher, ciphertext = data
|
||||
# unencrypted data
|
||||
if "+" not in enc_cipher:
|
||||
return self.serde.loads_typed(data)
|
||||
# extract cipher name
|
||||
typ, ciphername = enc_cipher.split("+", 1)
|
||||
# decrypt data
|
||||
decrypted_data = self.cipher.decrypt(ciphername, ciphertext)
|
||||
# deserialize data
|
||||
return self.serde.loads_typed((typ, decrypted_data))
|
||||
|
||||
@classmethod
|
||||
def from_pycryptodome_aes(
|
||||
cls, serde: SerializerProtocol = JsonPlusSerializer(), **kwargs: Any
|
||||
) -> "EncryptedSerializer":
|
||||
"""Create an `EncryptedSerializer` using AES encryption."""
|
||||
try:
|
||||
from Crypto.Cipher import AES
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pycryptodome is not installed. Please install it with `pip install pycryptodome`."
|
||||
) from None
|
||||
|
||||
# check if AES key is provided
|
||||
if "key" in kwargs:
|
||||
key: bytes = kwargs.pop("key")
|
||||
else:
|
||||
key_str = os.getenv("LANGGRAPH_AES_KEY")
|
||||
if key_str is None:
|
||||
raise ValueError("LANGGRAPH_AES_KEY environment variable is not set.")
|
||||
key = key_str.encode()
|
||||
if len(key) not in (16, 24, 32):
|
||||
raise ValueError("LANGGRAPH_AES_KEY must be 16, 24, or 32 bytes long.")
|
||||
|
||||
# set default mode to EAX if not provided
|
||||
if kwargs.get("mode") is None:
|
||||
kwargs["mode"] = AES.MODE_EAX
|
||||
|
||||
class PycryptodomeAesCipher(CipherProtocol):
|
||||
def encrypt(self, plaintext: bytes) -> tuple[str, bytes]:
|
||||
cipher = AES.new(key, **kwargs)
|
||||
ciphertext, tag = cipher.encrypt_and_digest(plaintext)
|
||||
return "aes", cipher.nonce + tag + ciphertext
|
||||
|
||||
def decrypt(self, ciphername: str, ciphertext: bytes) -> bytes:
|
||||
assert ciphername == "aes", f"Unsupported cipher: {ciphername}"
|
||||
nonce = ciphertext[:16]
|
||||
tag = ciphertext[16:32]
|
||||
actual_ciphertext = ciphertext[32:]
|
||||
|
||||
cipher = AES.new(key, **kwargs, nonce=nonce)
|
||||
return cipher.decrypt_and_verify(actual_ciphertext, tag)
|
||||
|
||||
return cls(PycryptodomeAesCipher(), serde)
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from threading import Lock
|
||||
from typing import TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SerdeEvent(TypedDict):
|
||||
kind: str
|
||||
module: str
|
||||
name: str
|
||||
method: NotRequired[str]
|
||||
|
||||
|
||||
SerdeEventListener = Callable[[SerdeEvent], None]
|
||||
|
||||
_listeners: list[SerdeEventListener] = []
|
||||
_listeners_lock = Lock()
|
||||
|
||||
|
||||
def register_serde_event_listener(listener: SerdeEventListener) -> Callable[[], None]:
|
||||
"""Register a listener for serde allowlist events."""
|
||||
with _listeners_lock:
|
||||
_listeners.append(listener)
|
||||
|
||||
def unregister() -> None:
|
||||
with _listeners_lock:
|
||||
try:
|
||||
_listeners.remove(listener)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return unregister
|
||||
|
||||
|
||||
def emit_serde_event(event: SerdeEvent) -> None:
|
||||
"""Emit a serde event to all listeners.
|
||||
|
||||
Listener failures are isolated and logged.
|
||||
"""
|
||||
with _listeners_lock:
|
||||
listeners = tuple(_listeners)
|
||||
for listener in listeners:
|
||||
try:
|
||||
listener(event)
|
||||
except Exception:
|
||||
logger.warning("Serde listener failed", exc_info=True)
|
||||
827
venv/Lib/site-packages/langgraph/checkpoint/serde/jsonplus.py
Normal file
827
venv/Lib/site-packages/langgraph/checkpoint/serde/jsonplus.py
Normal file
@@ -0,0 +1,827 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import decimal
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from enum import Enum
|
||||
from inspect import isclass
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Interface,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import ormsgpack
|
||||
from langchain_core.load.load import Reviver
|
||||
|
||||
from langgraph.checkpoint.serde import _msgpack as _lg_msgpack
|
||||
from langgraph.checkpoint.serde.base import SerializerProtocol
|
||||
from langgraph.checkpoint.serde.event_hooks import emit_serde_event
|
||||
from langgraph.checkpoint.serde.types import SendProtocol
|
||||
from langgraph.store.base import Item
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.checkpoint.serde._msgpack import (
|
||||
AllowedMsgpackModules,
|
||||
)
|
||||
from langgraph.checkpoint.serde.types import SendProtocol
|
||||
|
||||
LC_REVIVER = Reviver()
|
||||
EMPTY_BYTES = b""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonPlusSerializer(SerializerProtocol):
|
||||
"""Serializer that uses ormsgpack, with optional fallbacks.
|
||||
|
||||
!!! warning
|
||||
|
||||
Security note: This serializer is intended for use within the `BaseCheckpointSaver`
|
||||
class and called within the Pregel loop. It should not be used on untrusted
|
||||
python objects. If an attacker can write directly to your checkpoint database,
|
||||
they may be able to trigger code execution when data is deserialized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pickle_fallback: bool = False,
|
||||
allowed_json_modules: Iterable[tuple[str, ...]] | Literal[True] | None = None,
|
||||
allowed_msgpack_modules: (
|
||||
AllowedMsgpackModules | Literal[True] | None
|
||||
) = _lg_msgpack._SENTINEL,
|
||||
__unpack_ext_hook__: Callable[[int, bytes], Any] | None = None,
|
||||
) -> None:
|
||||
if allowed_msgpack_modules is _lg_msgpack._SENTINEL:
|
||||
if _lg_msgpack.STRICT_MSGPACK_ENABLED:
|
||||
allowed_msgpack_modules = None
|
||||
else:
|
||||
allowed_msgpack_modules = True
|
||||
self.pickle_fallback = pickle_fallback
|
||||
self._allowed_json_modules: set[tuple[str, ...]] | Literal[True] | None = (
|
||||
_normalize_allowlist(allowed_json_modules)
|
||||
)
|
||||
self._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
|
||||
|
||||
self._custom_unpack_ext_hook = __unpack_ext_hook__ is not None
|
||||
self._unpack_ext_hook = (
|
||||
__unpack_ext_hook__
|
||||
if __unpack_ext_hook__ is not None
|
||||
else _create_msgpack_ext_hook(self._allowed_msgpack_modules)
|
||||
)
|
||||
|
||||
def with_msgpack_allowlist(
|
||||
self, extra_allowlist: Iterable[tuple[str, ...] | type]
|
||||
) -> JsonPlusSerializer:
|
||||
"""Return a new serializer with a merged msgpack allowlist."""
|
||||
base_allowlist = self._allowed_msgpack_modules
|
||||
if base_allowlist is True or base_allowlist is False:
|
||||
return self
|
||||
elif base_allowlist:
|
||||
base_allowlist = set(base_allowlist)
|
||||
else:
|
||||
base_allowlist = set()
|
||||
extra = _normalize_module_keys(tuple(extra_allowlist))
|
||||
merged = base_allowlist | extra
|
||||
if merged == base_allowlist:
|
||||
return self
|
||||
allowed_msgpack_modules: AllowedMsgpackModules | Literal[True] | None
|
||||
if merged:
|
||||
allowed_msgpack_modules = tuple(merged)
|
||||
elif isinstance(self._allowed_msgpack_modules, set):
|
||||
allowed_msgpack_modules = tuple(self._allowed_msgpack_modules)
|
||||
else:
|
||||
allowed_msgpack_modules = self._allowed_msgpack_modules
|
||||
|
||||
clone = copy.copy(self)
|
||||
clone._allowed_json_modules = _normalize_allowlist(self._allowed_json_modules)
|
||||
clone._allowed_msgpack_modules = _normalize_allowlist(allowed_msgpack_modules)
|
||||
if not clone._custom_unpack_ext_hook:
|
||||
clone._unpack_ext_hook = _create_msgpack_ext_hook(
|
||||
clone._allowed_msgpack_modules
|
||||
)
|
||||
return clone
|
||||
|
||||
def _encode_constructor_args(
|
||||
self,
|
||||
constructor: Callable | type[Any],
|
||||
*,
|
||||
method: None | str | Sequence[None | str] = None,
|
||||
args: Sequence[Any] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
out = {
|
||||
"lc": 2,
|
||||
"type": "constructor",
|
||||
"id": (*constructor.__module__.split("."), constructor.__name__),
|
||||
}
|
||||
if method is not None:
|
||||
out["method"] = method
|
||||
if args is not None:
|
||||
out["args"] = args
|
||||
if kwargs is not None:
|
||||
out["kwargs"] = kwargs
|
||||
return out
|
||||
|
||||
def _reviver(self, value: dict[str, Any]) -> Any:
|
||||
if self._allowed_json_modules and (
|
||||
value.get("lc", None) == 2
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
try:
|
||||
return self._revive_lc2(value)
|
||||
except InvalidModuleError as e:
|
||||
logger.warning(
|
||||
"Object %s is not in the deserialization allowlist.\n%s",
|
||||
value["id"],
|
||||
e.message,
|
||||
)
|
||||
|
||||
return LC_REVIVER(value)
|
||||
|
||||
def _revive_lc2(self, value: dict[str, Any]) -> Any:
|
||||
self._check_allowed_json_modules(value)
|
||||
|
||||
[*module, name] = value["id"]
|
||||
try:
|
||||
mod = importlib.import_module(".".join(module))
|
||||
cls = getattr(mod, name)
|
||||
method = value.get("method")
|
||||
if isinstance(method, str):
|
||||
methods = [getattr(cls, method)]
|
||||
elif isinstance(method, list):
|
||||
methods = [cls if m is None else getattr(cls, m) for m in method]
|
||||
else:
|
||||
methods = [cls]
|
||||
args = value.get("args")
|
||||
kwargs = value.get("kwargs")
|
||||
for method in methods:
|
||||
try:
|
||||
if isclass(method) and issubclass(method, BaseException):
|
||||
return None
|
||||
if args and kwargs:
|
||||
return method(*args, **kwargs)
|
||||
elif args:
|
||||
return method(*args)
|
||||
elif kwargs:
|
||||
return method(**kwargs)
|
||||
else:
|
||||
return method()
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _check_allowed_json_modules(self, value: dict[str, Any]) -> None:
|
||||
needed = tuple(value["id"])
|
||||
method = value.get("method")
|
||||
if isinstance(method, list):
|
||||
method_display = ",".join(m or "<init>" for m in method)
|
||||
elif isinstance(method, str):
|
||||
method_display = method
|
||||
else:
|
||||
method_display = "<init>"
|
||||
|
||||
dotted = ".".join(needed)
|
||||
if not self._allowed_json_modules:
|
||||
raise InvalidModuleError(
|
||||
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
|
||||
"No allowed_json_modules configured.\n\n"
|
||||
"Unblock with ONE of:\n"
|
||||
f" • JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
|
||||
" • (DANGEROUS) JsonPlusSerializer(allowed_json_modules=True)\n\n"
|
||||
"Note: Prefix allowlists are intentionally unsupported; prefer exact symbols "
|
||||
"or plain-JSON representations revived without import-time side effects."
|
||||
)
|
||||
|
||||
if self._allowed_json_modules is True:
|
||||
return
|
||||
if needed in self._allowed_json_modules:
|
||||
return
|
||||
|
||||
raise InvalidModuleError(
|
||||
f"Refused to deserialize JSON constructor: {dotted} (method: {method_display}). "
|
||||
"Symbol is not in the deserialization allowlist.\n\n"
|
||||
"Add exactly this symbol to unblock:\n"
|
||||
f" JsonPlusSerializer(allowed_json_modules=[{needed!r}, ...])\n"
|
||||
"Or, as a last resort (DANGEROUS):\n"
|
||||
" JsonPlusSerializer(allowed_json_modules=True)"
|
||||
)
|
||||
|
||||
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
||||
if obj is None:
|
||||
return "null", EMPTY_BYTES
|
||||
elif isinstance(obj, bytes):
|
||||
return "bytes", obj
|
||||
elif isinstance(obj, bytearray):
|
||||
return "bytearray", obj
|
||||
else:
|
||||
try:
|
||||
return "msgpack", _msgpack_enc(obj)
|
||||
except ormsgpack.MsgpackEncodeError as exc:
|
||||
if self.pickle_fallback:
|
||||
return "pickle", pickle.dumps(obj)
|
||||
raise exc
|
||||
|
||||
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
||||
type_, data_ = data
|
||||
if type_ == "null":
|
||||
return None
|
||||
elif type_ == "bytes":
|
||||
return data_
|
||||
elif type_ == "bytearray":
|
||||
return bytearray(data_)
|
||||
elif type_ == "json":
|
||||
return json.loads(data_, object_hook=self._reviver)
|
||||
elif type_ == "msgpack":
|
||||
return ormsgpack.unpackb(
|
||||
data_, ext_hook=self._unpack_ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
elif self.pickle_fallback and type_ == "pickle":
|
||||
return pickle.loads(data_)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown serialization type: {type_}")
|
||||
|
||||
|
||||
# --- msgpack ---
|
||||
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG = 0
|
||||
EXT_CONSTRUCTOR_POS_ARGS = 1
|
||||
EXT_CONSTRUCTOR_KW_ARGS = 2
|
||||
EXT_METHOD_SINGLE_ARG = 3
|
||||
EXT_PYDANTIC_V1 = 4
|
||||
EXT_PYDANTIC_V2 = 5
|
||||
EXT_NUMPY_ARRAY = 6
|
||||
|
||||
|
||||
def _msgpack_default(obj: Any) -> str | ormsgpack.Ext:
|
||||
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
|
||||
return ormsgpack.Ext(
|
||||
EXT_PYDANTIC_V2,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.model_dump(),
|
||||
"model_validate_json",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "get_secret_value") and callable(obj.get_secret_value):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.get_secret_value(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "dict") and callable(obj.dict): # pydantic v1
|
||||
return ormsgpack.Ext(
|
||||
EXT_PYDANTIC_V1,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.dict(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif hasattr(obj, "_asdict") and callable(obj._asdict): # namedtuple
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj._asdict(),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, pathlib.Path):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.parts),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, re.Pattern):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
("re", "compile", (obj.pattern, obj.flags)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, UUID):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.hex),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, decimal.Decimal):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (set, frozenset, deque)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, tuple(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (IPv4Address, IPv4Interface, IPv4Network)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, (IPv6Address, IPv6Interface, IPv6Network)):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, str(obj)),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, datetime):
|
||||
return ormsgpack.Ext(
|
||||
EXT_METHOD_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.isoformat(),
|
||||
"fromisoformat",
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, timedelta):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
(obj.days, obj.seconds, obj.microseconds),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, date):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
(obj.year, obj.month, obj.day),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, time):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{
|
||||
"hour": obj.hour,
|
||||
"minute": obj.minute,
|
||||
"second": obj.second,
|
||||
"microsecond": obj.microsecond,
|
||||
"tzinfo": obj.tzinfo,
|
||||
"fold": obj.fold,
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, timezone):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
obj.__getinitargs__(), # type: ignore[attr-defined]
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, ZoneInfo):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.key),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, Enum):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_SINGLE_ARG,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, obj.value),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, SendProtocol):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_POS_ARGS,
|
||||
_msgpack_enc(
|
||||
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
|
||||
),
|
||||
)
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{
|
||||
field.name: getattr(obj, field.name)
|
||||
for field in dataclasses.fields(obj)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(obj, Item):
|
||||
return ormsgpack.Ext(
|
||||
EXT_CONSTRUCTOR_KW_ARGS,
|
||||
_msgpack_enc(
|
||||
(
|
||||
obj.__class__.__module__,
|
||||
obj.__class__.__name__,
|
||||
{k: getattr(obj, k) for k in obj.__slots__},
|
||||
),
|
||||
),
|
||||
)
|
||||
elif (np_mod := sys.modules.get("numpy")) is not None and isinstance(
|
||||
obj, np_mod.ndarray
|
||||
):
|
||||
order = "F" if obj.flags.f_contiguous and not obj.flags.c_contiguous else "C"
|
||||
if obj.flags.c_contiguous:
|
||||
mv = memoryview(obj)
|
||||
try:
|
||||
meta = (obj.dtype.str, obj.shape, order, mv)
|
||||
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
|
||||
finally:
|
||||
mv.release()
|
||||
else:
|
||||
buf = obj.tobytes(order="A")
|
||||
meta = (obj.dtype.str, obj.shape, order, buf)
|
||||
return ormsgpack.Ext(EXT_NUMPY_ARRAY, _msgpack_enc(meta))
|
||||
|
||||
elif isinstance(obj, BaseException):
|
||||
return repr(obj)
|
||||
else:
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")
|
||||
|
||||
|
||||
def _create_msgpack_ext_hook(
|
||||
allowed_modules: set[tuple[str, ...]] | Literal[True] | None,
|
||||
) -> Callable[[int, bytes], Any]:
|
||||
"""Create msgpack ext hook with allowlist.
|
||||
|
||||
Args:
|
||||
allowed_modules: Set of (module, name) tuples that are allowed to be
|
||||
deserialized, or True to allow all with warnings for unregistered types, or None to only allow safe types.
|
||||
|
||||
Returns:
|
||||
An ext_hook function for use with ormsgpack.unpackb.
|
||||
"""
|
||||
|
||||
def _check_allowed(module: str, name: str) -> bool:
|
||||
"""Check if type is allowed. Returns True if allowed, False if blocked."""
|
||||
key = (module, name)
|
||||
|
||||
if key in _lg_msgpack.SAFE_MSGPACK_TYPES:
|
||||
return True
|
||||
|
||||
if allowed_modules is True:
|
||||
# default is to warn but allow unregistered types
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_unregistered_allowed",
|
||||
"module": module,
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Deserializing unregistered type %s.%s from checkpoint. "
|
||||
"This will be blocked in a future version. "
|
||||
"Add to allowed_msgpack_modules to silence: [(%r, %r)]",
|
||||
module,
|
||||
name,
|
||||
module,
|
||||
name,
|
||||
)
|
||||
return True
|
||||
if allowed_modules is not None:
|
||||
if key in allowed_modules:
|
||||
return True
|
||||
# strict mode blocks unregistered types
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_blocked",
|
||||
"module": module,
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Blocked deserialization of %s.%s - not in allowed_msgpack_modules. "
|
||||
"Add to allowed_msgpack_modules to allow: [(%r, %r)]",
|
||||
module,
|
||||
name,
|
||||
module,
|
||||
name,
|
||||
)
|
||||
return False
|
||||
|
||||
def _check_allowed_method(module: str, name: str, method: str) -> bool:
|
||||
"""Check if a method invocation is allowed."""
|
||||
key = (module, name, method)
|
||||
if key in _lg_msgpack.SAFE_MSGPACK_METHODS:
|
||||
return True
|
||||
emit_serde_event(
|
||||
{
|
||||
"kind": "msgpack_method_blocked",
|
||||
"module": module,
|
||||
"name": name,
|
||||
"method": method,
|
||||
}
|
||||
)
|
||||
logger.warning(
|
||||
"Blocked deserialization of method call %s.%s.%s - "
|
||||
"not in allowed methods set.",
|
||||
module,
|
||||
name,
|
||||
method,
|
||||
)
|
||||
return False
|
||||
|
||||
def ext_hook(code: int, data: bytes) -> Any:
|
||||
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
# We default to returning the raw data. If the user
|
||||
# is using this in the context of a pydantic state, etc., then
|
||||
# it would be validated upon construction.
|
||||
return tup[2]
|
||||
# module, name, arg
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_CONSTRUCTOR_POS_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, args
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(*tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_CONSTRUCTOR_KW_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs
|
||||
return getattr(importlib.import_module(tup[0]), tup[1])(**tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_METHOD_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed_method(tup[0], tup[1], tup[3]):
|
||||
return tup[2]
|
||||
# module, name, arg, method
|
||||
return getattr(
|
||||
getattr(importlib.import_module(tup[0]), tup[1]), tup[3]
|
||||
)(tup[2])
|
||||
except Exception:
|
||||
return None
|
||||
elif code == EXT_PYDANTIC_V1:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs
|
||||
cls = getattr(importlib.import_module(tup[0]), tup[1])
|
||||
try:
|
||||
return cls(**tup[2])
|
||||
except Exception:
|
||||
return cls.construct(**tup[2])
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
try:
|
||||
return tup[2]
|
||||
except NameError:
|
||||
return None
|
||||
elif code == EXT_PYDANTIC_V2:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
if not _check_allowed(tup[0], tup[1]):
|
||||
return tup[2]
|
||||
# module, name, kwargs, method
|
||||
cls = getattr(importlib.import_module(tup[0]), tup[1])
|
||||
try:
|
||||
return cls(**tup[2])
|
||||
except Exception:
|
||||
return cls.model_construct(**tup[2])
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
try:
|
||||
return tup[2]
|
||||
except NameError:
|
||||
return None
|
||||
elif code == EXT_NUMPY_ARRAY:
|
||||
try:
|
||||
import numpy as _np
|
||||
|
||||
dtype_str, shape, order, buf = ormsgpack.unpackb(
|
||||
data, ext_hook=ext_hook, option=ormsgpack.OPT_NON_STR_KEYS
|
||||
)
|
||||
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
|
||||
return arr.reshape(shape, order=order)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
return ext_hook
|
||||
|
||||
|
||||
# Aliasing in case anyone imported it directly
|
||||
_msgpack_ext_hook = _create_msgpack_ext_hook(allowed_modules=None)
|
||||
|
||||
|
||||
def _msgpack_ext_hook_to_json(code: int, data: bytes) -> Any:
|
||||
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
if tup[0] == "uuid" and tup[1] == "UUID":
|
||||
hex_ = tup[2]
|
||||
return (
|
||||
f"{hex_[:8]}-{hex_[8:12]}-{hex_[12:16]}-{hex_[16:20]}-{hex_[20:]}"
|
||||
)
|
||||
# module, name, arg
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_CONSTRUCTOR_POS_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
if tup[0] == "langgraph.types" and tup[1] == "Send":
|
||||
from langgraph.types import Send # type: ignore
|
||||
|
||||
return Send(*tup[2])
|
||||
# module, name, args
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_CONSTRUCTOR_KW_ARGS:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, args
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_METHOD_SINGLE_ARG:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, arg, method
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_PYDANTIC_V1:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, kwargs
|
||||
return tup[2]
|
||||
except Exception:
|
||||
# for pydantic objects we can't find/reconstruct
|
||||
# let's return the kwargs dict instead
|
||||
return
|
||||
elif code == EXT_PYDANTIC_V2:
|
||||
try:
|
||||
tup = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
# module, name, kwargs, method
|
||||
return tup[2]
|
||||
except Exception:
|
||||
return
|
||||
elif code == EXT_NUMPY_ARRAY:
|
||||
try:
|
||||
import numpy as _np
|
||||
|
||||
dtype_str, shape, order, buf = ormsgpack.unpackb(
|
||||
data,
|
||||
ext_hook=_msgpack_ext_hook_to_json,
|
||||
option=ormsgpack.OPT_NON_STR_KEYS,
|
||||
)
|
||||
arr = _np.frombuffer(buf, dtype=_np.dtype(dtype_str))
|
||||
return arr.reshape(shape, order=order).tolist()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
class InvalidModuleError(Exception):
|
||||
"""Exception raised when a module is not in the allowlist."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
_option = (
|
||||
ormsgpack.OPT_NON_STR_KEYS
|
||||
| ormsgpack.OPT_PASSTHROUGH_DATACLASS
|
||||
| ormsgpack.OPT_PASSTHROUGH_DATETIME
|
||||
| ormsgpack.OPT_PASSTHROUGH_ENUM
|
||||
| ormsgpack.OPT_PASSTHROUGH_UUID
|
||||
| ormsgpack.OPT_REPLACE_SURROGATES
|
||||
)
|
||||
|
||||
|
||||
def _msgpack_enc(data: Any) -> bytes:
|
||||
return ormsgpack.packb(data, default=_msgpack_default, option=_option)
|
||||
|
||||
|
||||
def _normalize_allowlist(
|
||||
allowlist: AllowedMsgpackModules | Literal[True] | None,
|
||||
) -> set[tuple[str, ...]] | Literal[True] | None:
|
||||
if allowlist is True:
|
||||
return allowlist
|
||||
elif allowlist:
|
||||
return _normalize_module_keys(allowlist)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_module_keys(
|
||||
modules: AllowedMsgpackModules,
|
||||
) -> set[tuple[str, ...]]:
|
||||
normalized: set[tuple[str, ...]] = set()
|
||||
for module in modules:
|
||||
if isclass(module):
|
||||
normalized.add((module.__module__, module.__name__))
|
||||
else:
|
||||
normalized.add(cast(tuple[str, ...], module))
|
||||
return normalized
|
||||
51
venv/Lib/site-packages/langgraph/checkpoint/serde/types.py
Normal file
51
venv/Lib/site-packages/langgraph/checkpoint/serde/types.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
ERROR = "__error__"
|
||||
SCHEDULED = "__scheduled__"
|
||||
INTERRUPT = "__interrupt__"
|
||||
RESUME = "__resume__"
|
||||
TASKS = "__pregel_tasks"
|
||||
|
||||
Value = TypeVar("Value", covariant=True)
|
||||
Update = TypeVar("Update", contravariant=True)
|
||||
C = TypeVar("C")
|
||||
|
||||
|
||||
class ChannelProtocol(Protocol[Value, Update, C]):
|
||||
# Mirrors langgraph.channels.base.BaseChannel
|
||||
@property
|
||||
def ValueType(self) -> Any: ...
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> Any: ...
|
||||
|
||||
def checkpoint(self) -> C | None: ...
|
||||
|
||||
def from_checkpoint(self, checkpoint: C | None) -> Self: ...
|
||||
|
||||
def update(self, values: Sequence[Update]) -> bool: ...
|
||||
|
||||
def get(self) -> Value: ...
|
||||
|
||||
def consume(self) -> bool: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SendProtocol(Protocol):
|
||||
# Mirrors langgraph.constants.Send
|
||||
node: str
|
||||
arg: Any
|
||||
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
def __eq__(self, value: object) -> bool: ...
|
||||
196
venv/Lib/site-packages/langgraph/config.py
Normal file
196
venv/Lib/site-packages/langgraph/config.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import var_child_runnable_config
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_RUNTIME
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
|
||||
def _no_op_stream_writer(c: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def get_config() -> RunnableConfig:
|
||||
if sys.version_info < (3, 11):
|
||||
try:
|
||||
if asyncio.current_task():
|
||||
raise RuntimeError(
|
||||
"Python 3.11 or later required to use this in an async context"
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
if var_config := var_child_runnable_config.get():
|
||||
return var_config
|
||||
else:
|
||||
raise RuntimeError("Called get_config outside of a runnable context")
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
"""Access LangGraph store from inside a graph node or entrypoint task at runtime.
|
||||
|
||||
Can be called from inside any [`StateGraph`][langgraph.graph.StateGraph] node or
|
||||
functional API [`task`][langgraph.func.task], as long as the `StateGraph` or the [`entrypoint`][langgraph.func.entrypoint]
|
||||
was initialized with a store, e.g.:
|
||||
|
||||
```python
|
||||
# with StateGraph
|
||||
graph = (
|
||||
StateGraph(...)
|
||||
...
|
||||
.compile(store=store)
|
||||
)
|
||||
|
||||
# or with entrypoint
|
||||
@entrypoint(store=store)
|
||||
def workflow(inputs):
|
||||
...
|
||||
```
|
||||
|
||||
!!! warning "Async with Python < 3.11"
|
||||
|
||||
If you are using Python < 3.11 and are running LangGraph asynchronously,
|
||||
`get_store()` won't work since it uses [`contextvar`](https://docs.python.org/3/library/contextvars.html) propagation (only available in [Python >= 3.11](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task)).
|
||||
|
||||
|
||||
Example: Using with `StateGraph`
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.config import get_store
|
||||
|
||||
store = InMemoryStore()
|
||||
store.put(("values",), "foo", {"bar": 2})
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
foo: int
|
||||
|
||||
|
||||
def my_node(state: State):
|
||||
my_store = get_store()
|
||||
stored_value = my_store.get(("values",), "foo").value["bar"]
|
||||
return {"foo": stored_value + 1}
|
||||
|
||||
|
||||
graph = (
|
||||
StateGraph(State)
|
||||
.add_node(my_node)
|
||||
.add_edge(START, "my_node")
|
||||
.compile(store=store)
|
||||
)
|
||||
|
||||
graph.invoke({"foo": 1})
|
||||
```
|
||||
|
||||
```pycon
|
||||
{"foo": 3}
|
||||
```
|
||||
|
||||
Example: Using with functional API
|
||||
```python
|
||||
from langgraph.func import entrypoint, task
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.config import get_store
|
||||
|
||||
store = InMemoryStore()
|
||||
store.put(("values",), "foo", {"bar": 2})
|
||||
|
||||
|
||||
@task
|
||||
def my_task(value: int):
|
||||
my_store = get_store()
|
||||
stored_value = my_store.get(("values",), "foo").value["bar"]
|
||||
return stored_value + 1
|
||||
|
||||
|
||||
@entrypoint(store=store)
|
||||
def workflow(value: int):
|
||||
return my_task(value).result()
|
||||
|
||||
|
||||
workflow.invoke(1)
|
||||
```
|
||||
|
||||
```pycon
|
||||
3
|
||||
```
|
||||
"""
|
||||
return get_config()[CONF][CONFIG_KEY_RUNTIME].store
|
||||
|
||||
|
||||
def get_stream_writer() -> StreamWriter:
|
||||
"""Access LangGraph [`StreamWriter`][langgraph.types.StreamWriter] from inside a graph node or entrypoint task at runtime.
|
||||
|
||||
Can be called from inside any [`StateGraph`][langgraph.graph.StateGraph] node or
|
||||
functional API [`task`][langgraph.func.task].
|
||||
|
||||
!!! warning "Async with Python < 3.11"
|
||||
|
||||
If you are using Python < 3.11 and are running LangGraph asynchronously,
|
||||
`get_stream_writer()` won't work since it uses [`contextvar`](https://docs.python.org/3/library/contextvars.html) propagation (only available in [Python >= 3.11](https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task)).
|
||||
|
||||
Example: Using with `StateGraph`
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
foo: int
|
||||
|
||||
|
||||
def my_node(state: State):
|
||||
my_stream_writer = get_stream_writer()
|
||||
my_stream_writer({"custom_data": "Hello!"})
|
||||
return {"foo": state["foo"] + 1}
|
||||
|
||||
|
||||
graph = (
|
||||
StateGraph(State)
|
||||
.add_node(my_node)
|
||||
.add_edge(START, "my_node")
|
||||
.compile(store=store)
|
||||
)
|
||||
|
||||
for chunk in graph.stream({"foo": 1}, stream_mode="custom"):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
```pycon
|
||||
{"custom_data": "Hello!"}
|
||||
```
|
||||
|
||||
Example: Using with functional API
|
||||
```python
|
||||
from langgraph.func import entrypoint, task
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
|
||||
@task
|
||||
def my_task(value: int):
|
||||
my_stream_writer = get_stream_writer()
|
||||
my_stream_writer({"custom_data": "Hello!"})
|
||||
return value + 1
|
||||
|
||||
|
||||
@entrypoint(store=store)
|
||||
def workflow(value: int):
|
||||
return my_task(value).result()
|
||||
|
||||
|
||||
for chunk in workflow.stream(1, stream_mode="custom"):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
```pycon
|
||||
{"custom_data": "Hello!"}
|
||||
```
|
||||
"""
|
||||
runtime = get_config()[CONF][CONFIG_KEY_RUNTIME]
|
||||
return runtime.stream_writer
|
||||
64
venv/Lib/site-packages/langgraph/constants.py
Normal file
64
venv/Lib/site-packages/langgraph/constants.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINTER,
|
||||
TASKS,
|
||||
)
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV10
|
||||
|
||||
__all__ = (
|
||||
"TAG_NOSTREAM",
|
||||
"TAG_HIDDEN",
|
||||
"START",
|
||||
"END",
|
||||
# retained for backwards compatibility (mostly langgraph-api), should be removed in v2 (or earlier)
|
||||
"CONF",
|
||||
"TASKS",
|
||||
"CONFIG_KEY_CHECKPOINTER",
|
||||
)
|
||||
|
||||
# --- Public constants ---
|
||||
TAG_NOSTREAM = sys.intern("nostream")
|
||||
"""Tag to disable streaming for a chat model."""
|
||||
TAG_HIDDEN = sys.intern("langsmith:hidden")
|
||||
"""Tag to hide a node/edge from certain tracing/streaming environments."""
|
||||
END = sys.intern("__end__")
|
||||
"""The last (maybe virtual) node in graph-style Pregel."""
|
||||
START = sys.intern("__start__")
|
||||
"""The first (maybe virtual) node in graph-style Pregel."""
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in ["Send", "Interrupt"]:
|
||||
warn(
|
||||
f"Importing {name} from langgraph.constants is deprecated. "
|
||||
f"Please use 'from langgraph.types import {name}' instead.",
|
||||
LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
module = import_module("langgraph.types")
|
||||
return getattr(module, name)
|
||||
|
||||
try:
|
||||
from importlib import import_module
|
||||
|
||||
private_constants = import_module("langgraph._internal._constants")
|
||||
attr = getattr(private_constants, name)
|
||||
warn(
|
||||
f"Importing {name} from langgraph.constants is deprecated. "
|
||||
f"This constant is now private and should not be used directly. "
|
||||
"Please let the LangGraph team know if you need this value.",
|
||||
LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
return attr
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
raise AttributeError(f"module has no attribute '{name}'")
|
||||
127
venv/Lib/site-packages/langgraph/errors.py
Normal file
127
venv/Lib/site-packages/langgraph/errors.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
# EmptyChannelError is re-exported from langgraph.channels.base
|
||||
from langgraph.checkpoint.base import EmptyChannelError # noqa: F401
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from langgraph.types import Command, Interrupt
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV10
|
||||
|
||||
__all__ = (
|
||||
"EmptyChannelError",
|
||||
"ErrorCode",
|
||||
"GraphRecursionError",
|
||||
"InvalidUpdateError",
|
||||
"GraphBubbleUp",
|
||||
"GraphInterrupt",
|
||||
"NodeInterrupt",
|
||||
"ParentCommand",
|
||||
"EmptyInputError",
|
||||
"TaskNotFound",
|
||||
)
|
||||
|
||||
|
||||
class ErrorCode(Enum):
|
||||
GRAPH_RECURSION_LIMIT = "GRAPH_RECURSION_LIMIT"
|
||||
INVALID_CONCURRENT_GRAPH_UPDATE = "INVALID_CONCURRENT_GRAPH_UPDATE"
|
||||
INVALID_GRAPH_NODE_RETURN_VALUE = "INVALID_GRAPH_NODE_RETURN_VALUE"
|
||||
MULTIPLE_SUBGRAPHS = "MULTIPLE_SUBGRAPHS"
|
||||
INVALID_CHAT_HISTORY = "INVALID_CHAT_HISTORY"
|
||||
|
||||
|
||||
def create_error_message(*, message: str, error_code: ErrorCode) -> str:
|
||||
return (
|
||||
f"{message}\n"
|
||||
"For troubleshooting, visit: https://docs.langchain.com/oss/python/langgraph/"
|
||||
f"errors/{error_code.value}"
|
||||
)
|
||||
|
||||
|
||||
class GraphRecursionError(RecursionError):
|
||||
"""Raised when the graph has exhausted the maximum number of steps.
|
||||
|
||||
This prevents infinite loops. To increase the maximum number of steps,
|
||||
run your graph with a config specifying a higher `recursion_limit`.
|
||||
|
||||
Troubleshooting guides:
|
||||
|
||||
- [`GRAPH_RECURSION_LIMIT`](https://docs.langchain.com/oss/python/langgraph/GRAPH_RECURSION_LIMIT)
|
||||
|
||||
Examples:
|
||||
|
||||
graph = builder.compile()
|
||||
graph.invoke(
|
||||
{"messages": [("user", "Hello, world!")]},
|
||||
# The config is the second positional argument
|
||||
{"recursion_limit": 1000},
|
||||
)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidUpdateError(Exception):
|
||||
"""Raised when attempting to update a channel with an invalid set of updates.
|
||||
|
||||
Troubleshooting guides:
|
||||
|
||||
- [`INVALID_CONCURRENT_GRAPH_UPDATE`](https://docs.langchain.com/oss/python/langgraph/INVALID_CONCURRENT_GRAPH_UPDATE)
|
||||
- [`INVALID_GRAPH_NODE_RETURN_VALUE`](https://docs.langchain.com/oss/python/langgraph/INVALID_GRAPH_NODE_RETURN_VALUE)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GraphBubbleUp(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class GraphInterrupt(GraphBubbleUp):
|
||||
"""Raised when a subgraph is interrupted, suppressed by the root graph.
|
||||
Never raised directly, or surfaced to the user."""
|
||||
|
||||
def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None:
|
||||
super().__init__(interrupts)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"NodeInterrupt is deprecated. Please use [`interrupt`][langgraph.types.interrupt] instead.",
|
||||
category=None,
|
||||
)
|
||||
class NodeInterrupt(GraphInterrupt):
|
||||
"""Raised by a node to interrupt execution."""
|
||||
|
||||
def __init__(self, value: Any, id: str | None = None) -> None:
|
||||
warn(
|
||||
"NodeInterrupt is deprecated. Please use `langgraph.types.interrupt` instead.",
|
||||
LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
if id is None:
|
||||
super().__init__([Interrupt(value=value)])
|
||||
else:
|
||||
super().__init__([Interrupt(value=value, id=id)])
|
||||
|
||||
|
||||
class ParentCommand(GraphBubbleUp):
|
||||
args: tuple[Command]
|
||||
|
||||
def __init__(self, command: Command) -> None:
|
||||
super().__init__(command)
|
||||
|
||||
|
||||
class EmptyInputError(Exception):
|
||||
"""Raised when graph receives an empty input."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskNotFound(Exception):
|
||||
"""Raised when the executor is unable to find a task (for distributed mode)."""
|
||||
|
||||
pass
|
||||
575
venv/Lib/site-packages/langgraph/func/__init__.py
Normal file
575
venv/Lib/site-packages/langgraph/func/__init__.py
Normal file
@@ -0,0 +1,575 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langgraph.cache.base import BaseCache
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.store.base import BaseStore
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from langgraph._internal import _serde
|
||||
from langgraph._internal._constants import CACHE_NS_WRITES, PREVIOUS
|
||||
from langgraph._internal._typing import MISSING, DeprecatedKwargs
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.last_value import LastValue
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.pregel import Pregel
|
||||
from langgraph.pregel._call import (
|
||||
P,
|
||||
SyncAsyncFuture,
|
||||
T,
|
||||
call,
|
||||
get_runnable_for_entrypoint,
|
||||
identifier,
|
||||
)
|
||||
from langgraph.pregel._read import PregelNode
|
||||
from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
|
||||
from langgraph.types import _DC_KWARGS, CachePolicy, RetryPolicy, StreamMode
|
||||
from langgraph.typing import ContextT
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV05, LangGraphDeprecatedSinceV10
|
||||
|
||||
__all__ = ("task", "entrypoint")
|
||||
|
||||
|
||||
class _TaskFunction(Generic[P, T]):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[P, Awaitable[T]] | Callable[P, T],
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy],
|
||||
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
if name is not None:
|
||||
if hasattr(func, "__func__"):
|
||||
# handle class methods
|
||||
# NOTE: we're modifying the instance method to avoid modifying
|
||||
# the original class method in case it's shared across multiple tasks
|
||||
instance_method = functools.partial(func.__func__, func.__self__) # type: ignore [union-attr]
|
||||
instance_method.__name__ = name # type: ignore [attr-defined]
|
||||
func = instance_method
|
||||
else:
|
||||
# handle regular functions / partials / callable classes, etc.
|
||||
func.__name__ = name
|
||||
self.func = func
|
||||
self.retry_policy = retry_policy
|
||||
self.cache_policy = cache_policy
|
||||
functools.update_wrapper(self, func)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> SyncAsyncFuture[T]:
|
||||
return call(
|
||||
self.func,
|
||||
retry_policy=self.retry_policy,
|
||||
cache_policy=self.cache_policy,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def clear_cache(self, cache: BaseCache) -> None:
|
||||
"""Clear the cache for this task."""
|
||||
if self.cache_policy is not None:
|
||||
cache.clear(((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),))
|
||||
|
||||
async def aclear_cache(self, cache: BaseCache) -> None:
|
||||
"""Clear the cache for this task."""
|
||||
if self.cache_policy is not None:
|
||||
await cache.aclear(
|
||||
((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),)
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def task(
|
||||
__func_or_none__: None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||||
**kwargs: Unpack[DeprecatedKwargs],
|
||||
) -> Callable[
|
||||
[Callable[P, Awaitable[T]] | Callable[P, T]],
|
||||
_TaskFunction[P, T],
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def task(__func_or_none__: Callable[P, Awaitable[T]]) -> _TaskFunction[P, T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def task(__func_or_none__: Callable[P, T]) -> _TaskFunction[P, T]: ...
|
||||
|
||||
|
||||
def task(
|
||||
__func_or_none__: Callable[P, Awaitable[T]] | Callable[P, T] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||||
**kwargs: Unpack[DeprecatedKwargs],
|
||||
) -> (
|
||||
Callable[[Callable[P, Awaitable[T]] | Callable[P, T]], _TaskFunction[P, T]]
|
||||
| _TaskFunction[P, T]
|
||||
):
|
||||
"""Define a LangGraph task using the `task` decorator.
|
||||
|
||||
!!! important "Requires python 3.11 or higher for async functions"
|
||||
The `task` decorator supports both sync and async functions. To use async
|
||||
functions, ensure that you are using Python 3.11 or higher.
|
||||
|
||||
Tasks can only be called from within an [`entrypoint`][langgraph.func.entrypoint] or
|
||||
from within a `StateGraph`. A task can be called like a regular function with the
|
||||
following differences:
|
||||
|
||||
- When a checkpointer is enabled, the function inputs and outputs must be serializable.
|
||||
- The decorated function can only be called from within an entrypoint or `StateGraph`.
|
||||
- Calling the function produces a future. This makes it easy to parallelize tasks.
|
||||
|
||||
Args:
|
||||
name: An optional name for the task. If not provided, the function name will be used.
|
||||
retry_policy: An optional retry policy (or list of policies) to use for the task in case of a failure.
|
||||
cache_policy: An optional cache policy to use for the task. This allows caching of the task results.
|
||||
|
||||
Returns:
|
||||
A callable function when used as a decorator.
|
||||
|
||||
Example: Sync Task
|
||||
```python
|
||||
from langgraph.func import entrypoint, task
|
||||
|
||||
|
||||
@task
|
||||
def add_one_task(a: int) -> int:
|
||||
return a + 1
|
||||
|
||||
|
||||
@entrypoint()
|
||||
def add_one(numbers: list[int]) -> list[int]:
|
||||
futures = [add_one_task(n) for n in numbers]
|
||||
results = [f.result() for f in futures]
|
||||
return results
|
||||
|
||||
|
||||
# Call the entrypoint
|
||||
add_one.invoke([1, 2, 3]) # Returns [2, 3, 4]
|
||||
```
|
||||
|
||||
Example: Async Task
|
||||
```python
|
||||
import asyncio
|
||||
from langgraph.func import entrypoint, task
|
||||
|
||||
|
||||
@task
|
||||
async def add_one_task(a: int) -> int:
|
||||
return a + 1
|
||||
|
||||
|
||||
@entrypoint()
|
||||
async def add_one(numbers: list[int]) -> list[int]:
|
||||
futures = [add_one_task(n) for n in numbers]
|
||||
return asyncio.gather(*futures)
|
||||
|
||||
|
||||
# Call the entrypoint
|
||||
await add_one.ainvoke([1, 2, 3]) # Returns [2, 3, 4]
|
||||
```
|
||||
"""
|
||||
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
|
||||
warnings.warn(
|
||||
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
|
||||
category=LangGraphDeprecatedSinceV05,
|
||||
stacklevel=2,
|
||||
)
|
||||
if retry_policy is None:
|
||||
retry_policy = retry # type: ignore[assignment]
|
||||
|
||||
retry_policies: Sequence[RetryPolicy] = (
|
||||
()
|
||||
if retry_policy is None
|
||||
else (retry_policy,)
|
||||
if isinstance(retry_policy, RetryPolicy)
|
||||
else retry_policy
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Awaitable[T]] | Callable[P, T],
|
||||
) -> Callable[P, SyncAsyncFuture[T]]:
|
||||
return _TaskFunction(
|
||||
func, retry_policy=retry_policies, cache_policy=cache_policy, name=name
|
||||
)
|
||||
|
||||
if __func_or_none__ is not None:
|
||||
return decorator(__func_or_none__)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
R = TypeVar("R")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
# The decorator was wrapped in a class to support the `final` attribute.
|
||||
# In this form, the `final` attribute should play nicely with IDE autocompletion,
|
||||
# and type checking tools.
|
||||
# In addition, we'll be able to surface this information in the API Reference.
|
||||
class entrypoint(Generic[ContextT]):
|
||||
"""Define a LangGraph workflow using the `entrypoint` decorator.
|
||||
|
||||
### Function signature
|
||||
|
||||
The decorated function must accept a **single parameter**, which serves as the input
|
||||
to the function. This input parameter can be of any type. Use a dictionary
|
||||
to pass **multiple parameters** to the function.
|
||||
|
||||
### Injectable parameters
|
||||
|
||||
The decorated function can request access to additional parameters
|
||||
that will be injected automatically at run time. These parameters include:
|
||||
|
||||
| Parameter | Description |
|
||||
|------------------|------------------------------------------------------------------------------------------------------|
|
||||
| **`config`** | A configuration object (aka `RunnableConfig`) that holds run-time configuration values. |
|
||||
| **`previous`** | The previous return value for the given thread (available only when a checkpointer is provided). |
|
||||
| **`runtime`** | A `Runtime` object that contains information about the current run, including context, store, writer |
|
||||
|
||||
The entrypoint decorator can be applied to sync functions or async functions.
|
||||
|
||||
### State management
|
||||
|
||||
The **`previous`** parameter can be used to access the return value of the previous
|
||||
invocation of the entrypoint on the same thread id. This value is only available
|
||||
when a checkpointer is provided.
|
||||
|
||||
If you want **`previous`** to be different from the return value, you can use the
|
||||
`entrypoint.final` object to return a value while saving a different value to the
|
||||
checkpoint.
|
||||
|
||||
Args:
|
||||
checkpointer: Specify a checkpointer to create a workflow that can persist
|
||||
its state across runs.
|
||||
store: A generalized key-value store. Some implementations may support
|
||||
semantic search capabilities through an optional `index` configuration.
|
||||
cache: A cache to use for caching the results of the workflow.
|
||||
context_schema: Specifies the schema for the context object that will be
|
||||
passed to the workflow.
|
||||
cache_policy: A cache policy to use for caching the results of the workflow.
|
||||
retry_policy: A retry policy (or list of policies) to use for the workflow in case of a failure.
|
||||
|
||||
!!! warning "`config_schema` Deprecated"
|
||||
The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0.
|
||||
Please use `context_schema` instead to specify the schema for run-scoped context.
|
||||
|
||||
|
||||
Example: Using entrypoint and tasks
|
||||
```python
|
||||
import time
|
||||
|
||||
from langgraph.func import entrypoint, task
|
||||
from langgraph.types import interrupt, Command
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
@task
|
||||
def compose_essay(topic: str) -> str:
|
||||
time.sleep(1.0) # Simulate slow operation
|
||||
return f"An essay about {topic}"
|
||||
|
||||
@entrypoint(checkpointer=InMemorySaver())
|
||||
def review_workflow(topic: str) -> dict:
|
||||
\"\"\"Manages the workflow for generating and reviewing an essay.
|
||||
|
||||
The workflow includes:
|
||||
1. Generating an essay about the given topic.
|
||||
2. Interrupting the workflow for human review of the generated essay.
|
||||
|
||||
Upon resuming the workflow, compose_essay task will not be re-executed
|
||||
as its result is cached by the checkpointer.
|
||||
|
||||
Args:
|
||||
topic: The subject of the essay.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the generated essay and the human review.
|
||||
\"\"\"
|
||||
essay_future = compose_essay(topic)
|
||||
essay = essay_future.result()
|
||||
human_review = interrupt({
|
||||
\"question\": \"Please provide a review\",
|
||||
\"essay\": essay
|
||||
})
|
||||
return {
|
||||
\"essay\": essay,
|
||||
\"review\": human_review,
|
||||
}
|
||||
|
||||
# Example configuration for the workflow
|
||||
config = {
|
||||
\"configurable\": {
|
||||
\"thread_id\": \"some_thread\"
|
||||
}
|
||||
}
|
||||
|
||||
# Topic for the essay
|
||||
topic = \"cats\"
|
||||
|
||||
# Stream the workflow to generate the essay and await human review
|
||||
for result in review_workflow.stream(topic, config):
|
||||
print(result)
|
||||
|
||||
# Example human review provided after the interrupt
|
||||
human_review = \"This essay is great.\"
|
||||
|
||||
# Resume the workflow with the provided human review
|
||||
for result in review_workflow.stream(Command(resume=human_review), config):
|
||||
print(result)
|
||||
```
|
||||
|
||||
Example: Accessing the previous return value
|
||||
When a checkpointer is enabled the function can access the previous return value
|
||||
of the previous invocation on the same thread id.
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from langgraph.func import entrypoint
|
||||
|
||||
|
||||
@entrypoint(checkpointer=InMemorySaver())
|
||||
def my_workflow(input_data: str, previous: Optional[str] = None) -> str:
|
||||
return "world"
|
||||
|
||||
|
||||
config = {"configurable": {"thread_id": "some_thread"}}
|
||||
my_workflow.invoke("hello", config)
|
||||
```
|
||||
|
||||
Example: Using `entrypoint.final` to save a value
|
||||
The `entrypoint.final` object allows you to return a value while saving
|
||||
a different value to the checkpoint. This value will be accessible
|
||||
in the next invocation of the entrypoint via the `previous` parameter, as
|
||||
long as the same thread id is used.
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from langgraph.func import entrypoint
|
||||
|
||||
|
||||
@entrypoint(checkpointer=InMemorySaver())
|
||||
def my_workflow(
|
||||
number: int,
|
||||
*,
|
||||
previous: Any = None,
|
||||
) -> entrypoint.final[int, int]:
|
||||
previous = previous or 0
|
||||
# This will return the previous value to the caller, saving
|
||||
# 2 * number to the checkpoint, which will be used in the next invocation
|
||||
# for the `previous` parameter.
|
||||
return entrypoint.final(value=previous, save=2 * number)
|
||||
|
||||
|
||||
config = {"configurable": {"thread_id": "some_thread"}}
|
||||
|
||||
my_workflow.invoke(3, config) # 0 (previous was None)
|
||||
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpointer: BaseCheckpointSaver | None = None,
|
||||
store: BaseStore | None = None,
|
||||
cache: BaseCache | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
**kwargs: Unpack[DeprecatedKwargs],
|
||||
) -> None:
|
||||
"""Initialize the entrypoint decorator."""
|
||||
if (config_schema := kwargs.get("config_schema", MISSING)) is not MISSING:
|
||||
warnings.warn(
|
||||
"`config_schema` is deprecated and will be removed. Please use `context_schema` instead.",
|
||||
category=LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
if context_schema is None:
|
||||
context_schema = cast(type[ContextT], config_schema)
|
||||
|
||||
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
|
||||
warnings.warn(
|
||||
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
|
||||
category=LangGraphDeprecatedSinceV05,
|
||||
stacklevel=2,
|
||||
)
|
||||
if retry_policy is None:
|
||||
retry_policy = cast("RetryPolicy | Sequence[RetryPolicy]", retry)
|
||||
|
||||
self.checkpointer = checkpointer
|
||||
self.store = store
|
||||
self.cache = cache
|
||||
self.cache_policy = cache_policy
|
||||
self.retry_policy = retry_policy
|
||||
self.context_schema = context_schema
|
||||
|
||||
@dataclass(**_DC_KWARGS)
|
||||
class final(Generic[R, S]):
|
||||
"""A primitive that can be returned from an entrypoint.
|
||||
|
||||
This primitive allows to save a value to the checkpointer distinct from the
|
||||
return value from the entrypoint.
|
||||
|
||||
Example: Decoupling the return value and the save value
|
||||
```python
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.func import entrypoint
|
||||
|
||||
|
||||
@entrypoint(checkpointer=InMemorySaver())
|
||||
def my_workflow(
|
||||
number: int,
|
||||
*,
|
||||
previous: Any = None,
|
||||
) -> entrypoint.final[int, int]:
|
||||
previous = previous or 0
|
||||
# This will return the previous value to the caller, saving
|
||||
# 2 * number to the checkpoint, which will be used in the next invocation
|
||||
# for the `previous` parameter.
|
||||
return entrypoint.final(value=previous, save=2 * number)
|
||||
|
||||
|
||||
config = {"configurable": {"thread_id": "1"}}
|
||||
|
||||
my_workflow.invoke(3, config) # 0 (previous was None)
|
||||
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
|
||||
```
|
||||
"""
|
||||
|
||||
value: R
|
||||
"""Value to return. A value will always be returned even if it is `None`."""
|
||||
save: S
|
||||
"""The value for the state for the next checkpoint.
|
||||
|
||||
A value will always be saved even if it is `None`.
|
||||
"""
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Pregel:
|
||||
"""Convert a function into a Pregel graph.
|
||||
|
||||
Args:
|
||||
func: The function to convert. Support both sync and async functions.
|
||||
|
||||
Returns:
|
||||
A Pregel graph.
|
||||
"""
|
||||
# wrap generators in a function that writes to StreamWriter
|
||||
if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func):
|
||||
raise NotImplementedError(
|
||||
"Generators are not supported in the Functional API."
|
||||
)
|
||||
|
||||
bound = get_runnable_for_entrypoint(func)
|
||||
stream_mode: StreamMode = "updates"
|
||||
|
||||
# get input and output types
|
||||
sig = inspect.signature(func)
|
||||
first_parameter_name = next(iter(sig.parameters.keys()), None)
|
||||
if not first_parameter_name:
|
||||
raise ValueError("Entrypoint function must have at least one parameter")
|
||||
input_type = (
|
||||
sig.parameters[first_parameter_name].annotation
|
||||
if sig.parameters[first_parameter_name].annotation
|
||||
is not inspect.Signature.empty
|
||||
else Any
|
||||
)
|
||||
|
||||
def _pluck_return_value(value: Any) -> Any:
|
||||
"""Extract the return_ value the entrypoint.final object or passthrough."""
|
||||
return value.value if isinstance(value, entrypoint.final) else value
|
||||
|
||||
def _pluck_save_value(value: Any) -> Any:
|
||||
"""Get save value from the entrypoint.final object or passthrough."""
|
||||
return value.save if isinstance(value, entrypoint.final) else value
|
||||
|
||||
output_type, save_type = Any, Any
|
||||
if sig.return_annotation is not inspect.Signature.empty:
|
||||
# User does not parameterize entrypoint.final properly
|
||||
if (
|
||||
sig.return_annotation is entrypoint.final
|
||||
): # Un-parameterized entrypoint.final
|
||||
output_type = save_type = Any
|
||||
else:
|
||||
origin = get_origin(sig.return_annotation)
|
||||
if origin is entrypoint.final:
|
||||
type_annotations = get_args(sig.return_annotation)
|
||||
if len(type_annotations) != 2:
|
||||
raise TypeError(
|
||||
"Please an annotation for both the return_ and "
|
||||
"the save values."
|
||||
"For example, `-> entrypoint.final[int, str]` would assign a "
|
||||
"return_ a type of `int` and save the type `str`."
|
||||
)
|
||||
output_type, save_type = get_args(sig.return_annotation)
|
||||
else:
|
||||
output_type = save_type = sig.return_annotation
|
||||
|
||||
graph: Pregel[Any, ContextT, Any, Any] = Pregel(
|
||||
nodes={
|
||||
func.__name__: PregelNode(
|
||||
bound=bound,
|
||||
triggers=[START],
|
||||
channels=START,
|
||||
writers=[
|
||||
ChannelWrite(
|
||||
[
|
||||
ChannelWriteEntry(END, mapper=_pluck_return_value),
|
||||
ChannelWriteEntry(PREVIOUS, mapper=_pluck_save_value),
|
||||
]
|
||||
)
|
||||
],
|
||||
)
|
||||
},
|
||||
channels={
|
||||
START: EphemeralValue(input_type),
|
||||
END: LastValue(output_type, END),
|
||||
PREVIOUS: LastValue(save_type, PREVIOUS),
|
||||
},
|
||||
input_channels=START,
|
||||
output_channels=END,
|
||||
stream_channels=END,
|
||||
stream_mode=stream_mode,
|
||||
stream_eager=True,
|
||||
checkpointer=self.checkpointer,
|
||||
store=self.store,
|
||||
cache=self.cache,
|
||||
cache_policy=self.cache_policy,
|
||||
retry_policy=self.retry_policy or (),
|
||||
context_schema=self.context_schema,
|
||||
)
|
||||
if _serde.STRICT_MSGPACK_ENABLED:
|
||||
serde_allowlist = _serde.build_serde_allowlist(
|
||||
schemas=[input_type, output_type, save_type]
|
||||
+ ([self.context_schema] if self.context_schema is not None else []),
|
||||
channels=graph.channels,
|
||||
)
|
||||
graph._serde_allowlist = serde_allowlist
|
||||
graph.checkpointer = _serde.apply_checkpointer_allowlist(
|
||||
graph.checkpointer, serde_allowlist
|
||||
)
|
||||
return graph
|
||||
Binary file not shown.
12
venv/Lib/site-packages/langgraph/graph/__init__.py
Normal file
12
venv/Lib/site-packages/langgraph/graph/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
|
||||
from langgraph.graph.state import StateGraph
|
||||
|
||||
__all__ = (
|
||||
"END",
|
||||
"START",
|
||||
"StateGraph",
|
||||
"add_messages",
|
||||
"MessagesState",
|
||||
"MessageGraph",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
225
venv/Lib/site-packages/langgraph/graph/_branch.py
Normal file
225
venv/Lib/site-packages/langgraph/graph/_branch.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Hashable, Sequence
|
||||
from inspect import (
|
||||
isfunction,
|
||||
ismethod,
|
||||
signature,
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
)
|
||||
|
||||
from langgraph._internal._runnable import (
|
||||
RunnableCallable,
|
||||
)
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry
|
||||
from langgraph.types import Send
|
||||
|
||||
_Writer = Callable[
|
||||
[Sequence[str | Send], bool],
|
||||
Sequence[ChannelWriteEntry | Send],
|
||||
]
|
||||
|
||||
|
||||
def _get_branch_path_input_schema(
|
||||
path: Callable[..., Hashable | Sequence[Hashable]]
|
||||
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
|
||||
| Runnable[Any, Hashable | Sequence[Hashable]],
|
||||
) -> type[Any] | None:
|
||||
input = None
|
||||
# detect input schema annotation in the branch callable
|
||||
try:
|
||||
callable_: (
|
||||
Callable[..., Hashable | Sequence[Hashable]]
|
||||
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
|
||||
| None
|
||||
) = None
|
||||
if isinstance(path, (RunnableCallable, RunnableLambda)):
|
||||
if isfunction(path.func) or ismethod(path.func):
|
||||
callable_ = path.func
|
||||
elif (callable_method := getattr(path.func, "__call__", None)) and ismethod(
|
||||
callable_method
|
||||
):
|
||||
callable_ = callable_method
|
||||
elif isfunction(path.afunc) or ismethod(path.afunc):
|
||||
callable_ = path.afunc
|
||||
elif (
|
||||
callable_method := getattr(path.afunc, "__call__", None)
|
||||
) and ismethod(callable_method):
|
||||
callable_ = callable_method
|
||||
elif callable(path):
|
||||
callable_ = path
|
||||
|
||||
if callable_ is not None and (hints := get_type_hints(callable_)):
|
||||
first_parameter_name = next(
|
||||
iter(signature(cast(FunctionType, callable_)).parameters.keys())
|
||||
)
|
||||
if input_hint := hints.get(first_parameter_name):
|
||||
if isinstance(input_hint, type) and get_type_hints(input_hint):
|
||||
input = input_hint
|
||||
except (TypeError, StopIteration):
|
||||
pass
|
||||
|
||||
return input
|
||||
|
||||
|
||||
class BranchSpec(NamedTuple):
|
||||
path: Runnable[Any, Hashable | list[Hashable]]
|
||||
ends: dict[Hashable, str] | None
|
||||
input_schema: type[Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls,
|
||||
path: Runnable[Any, Hashable | list[Hashable]],
|
||||
path_map: dict[Hashable, str] | list[str] | None,
|
||||
infer_schema: bool = False,
|
||||
) -> BranchSpec:
|
||||
# coerce path_map to a dictionary
|
||||
path_map_: dict[Hashable, str] | None = None
|
||||
try:
|
||||
if isinstance(path_map, dict):
|
||||
path_map_ = path_map.copy()
|
||||
elif isinstance(path_map, list):
|
||||
path_map_ = {name: name for name in path_map}
|
||||
else:
|
||||
# find func
|
||||
func: Callable | None = None
|
||||
if isinstance(path, (RunnableCallable, RunnableLambda)):
|
||||
func = path.func or path.afunc
|
||||
if func is not None:
|
||||
# find callable method
|
||||
if (cal := getattr(path, "__call__", None)) and ismethod(cal):
|
||||
func = cal
|
||||
# get the return type
|
||||
if rtn_type := get_type_hints(func).get("return"):
|
||||
if get_origin(rtn_type) is Literal:
|
||||
path_map_ = {name: name for name in get_args(rtn_type)}
|
||||
except Exception:
|
||||
pass
|
||||
# infer input schema
|
||||
input_schema = _get_branch_path_input_schema(path) if infer_schema else None
|
||||
# create branch
|
||||
return cls(path=path, ends=path_map_, input_schema=input_schema)
|
||||
|
||||
def run(
|
||||
self,
|
||||
writer: _Writer,
|
||||
reader: Callable[[RunnableConfig], Any] | None = None,
|
||||
) -> RunnableCallable:
|
||||
return ChannelWrite.register_writer(
|
||||
RunnableCallable(
|
||||
func=self._route,
|
||||
afunc=self._aroute,
|
||||
writer=writer,
|
||||
reader=reader,
|
||||
name=None,
|
||||
trace=False,
|
||||
),
|
||||
list(
|
||||
zip_longest(
|
||||
writer([e for e in self.ends.values()], True),
|
||||
[str(la) for la, e in self.ends.items()],
|
||||
)
|
||||
)
|
||||
if self.ends
|
||||
else None,
|
||||
)
|
||||
|
||||
def _route(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
reader: Callable[[RunnableConfig], Any] | None,
|
||||
writer: _Writer,
|
||||
) -> Runnable:
|
||||
if reader:
|
||||
value = reader(config)
|
||||
# passthrough additional keys from node to branch
|
||||
# only doable when using dict states
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and isinstance(input, dict)
|
||||
and self.input_schema is None
|
||||
):
|
||||
value = {**input, **value}
|
||||
else:
|
||||
value = input
|
||||
result = self.path.invoke(value, config)
|
||||
return self._finish(writer, input, result, config)
|
||||
|
||||
async def _aroute(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
reader: Callable[[RunnableConfig], Any] | None,
|
||||
writer: _Writer,
|
||||
) -> Runnable:
|
||||
if reader:
|
||||
value = reader(config)
|
||||
# passthrough additional keys from node to branch
|
||||
# only doable when using dict states
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and isinstance(input, dict)
|
||||
and self.input_schema is None
|
||||
):
|
||||
value = {**input, **value}
|
||||
else:
|
||||
value = input
|
||||
result = await self.path.ainvoke(value, config)
|
||||
return self._finish(writer, input, result, config)
|
||||
|
||||
def _finish(
|
||||
self,
|
||||
writer: _Writer,
|
||||
input: Any,
|
||||
result: Any,
|
||||
config: RunnableConfig,
|
||||
) -> Runnable | Any:
|
||||
if not isinstance(result, (list, tuple)):
|
||||
result = [result]
|
||||
if self.ends:
|
||||
destinations: Sequence[Send | str] = [
|
||||
r if isinstance(r, Send) else self.ends[r] for r in result
|
||||
]
|
||||
else:
|
||||
destinations = cast(Sequence[Send | str], result)
|
||||
if any(dest is None or dest == START for dest in destinations):
|
||||
raise ValueError("Branch did not return a valid destination")
|
||||
if any(p.node == END for p in destinations if isinstance(p, Send)):
|
||||
raise InvalidUpdateError("Cannot send a packet to the END node")
|
||||
entries = writer(destinations, False)
|
||||
if not entries:
|
||||
return input
|
||||
else:
|
||||
need_passthrough = False
|
||||
for e in entries:
|
||||
if isinstance(e, ChannelWriteEntry):
|
||||
if e.value is PASSTHROUGH:
|
||||
need_passthrough = True
|
||||
break
|
||||
if need_passthrough:
|
||||
return ChannelWrite(entries)
|
||||
else:
|
||||
ChannelWrite.do_write(config, entries)
|
||||
return input
|
||||
92
venv/Lib/site-packages/langgraph/graph/_node.py
Normal file
92
venv/Lib/site-packages/langgraph/graph/_node.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Protocol, TypeAlias
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from langgraph._internal._typing import EMPTY_SEQ
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import CachePolicy, RetryPolicy, StreamWriter
|
||||
from langgraph.typing import ContextT, NodeInputT, NodeInputT_contra
|
||||
|
||||
|
||||
class _Node(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfig(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, config: RunnableConfig) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithWriter(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, *, writer: StreamWriter) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, *, store: BaseStore) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithWriterStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, writer: StreamWriter, store: BaseStore
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigWriter(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, config: RunnableConfig, store: BaseStore
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigWriterStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self,
|
||||
state: NodeInputT_contra,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter,
|
||||
store: BaseStore,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithRuntime(Protocol[NodeInputT_contra, ContextT]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, runtime: Runtime[ContextT]
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
# TODO: we probably don't want to explicitly support the config / store signatures once
|
||||
# we move to adding a context arg. Maybe what we do is we add support for kwargs with param spec
|
||||
# this is purely for typing purposes though, so can easily change in the coming weeks.
|
||||
StateNode: TypeAlias = (
|
||||
_Node[NodeInputT]
|
||||
| _NodeWithConfig[NodeInputT]
|
||||
| _NodeWithWriter[NodeInputT]
|
||||
| _NodeWithStore[NodeInputT]
|
||||
| _NodeWithWriterStore[NodeInputT]
|
||||
| _NodeWithConfigWriter[NodeInputT]
|
||||
| _NodeWithConfigStore[NodeInputT]
|
||||
| _NodeWithConfigWriterStore[NodeInputT]
|
||||
| _NodeWithRuntime[NodeInputT, ContextT]
|
||||
| Runnable[NodeInputT, Any]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StateNodeSpec(Generic[NodeInputT, ContextT]):
|
||||
runnable: StateNode[NodeInputT, ContextT]
|
||||
metadata: dict[str, Any] | None
|
||||
input_schema: type[NodeInputT]
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
|
||||
cache_policy: CachePolicy | None
|
||||
ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ
|
||||
defer: bool = False
|
||||
372
venv/Lib/site-packages/langgraph/graph/message.py
Normal file
372
venv/Lib/site-packages/langgraph/graph/message.py
Normal file
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
convert_to_messages,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from typing_extensions import TypedDict, deprecated
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, NS_SEP
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV10
|
||||
|
||||
__all__ = (
|
||||
"add_messages",
|
||||
"MessagesState",
|
||||
"MessageGraph",
|
||||
"REMOVE_ALL_MESSAGES",
|
||||
)
|
||||
|
||||
Messages = list[MessageLikeRepresentation] | MessageLikeRepresentation
|
||||
|
||||
REMOVE_ALL_MESSAGES = "__remove_all__"
|
||||
|
||||
|
||||
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
|
||||
def _add_messages(
|
||||
left: Messages | None = None, right: Messages | None = None, **kwargs: Any
|
||||
) -> Messages | Callable[[Messages, Messages], Messages]:
|
||||
if left is not None and right is not None:
|
||||
return func(left, right, **kwargs)
|
||||
elif left is not None or right is not None:
|
||||
msg = (
|
||||
f"Must specify non-null arguments for both 'left' and 'right'. Only "
|
||||
f"received: '{'left' if left else 'right'}'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
return partial(func, **kwargs)
|
||||
|
||||
_add_messages.__doc__ = func.__doc__
|
||||
return cast(Callable[[Messages, Messages], Messages], _add_messages)
|
||||
|
||||
|
||||
@_add_messages_wrapper
|
||||
def add_messages(
|
||||
left: Messages,
|
||||
right: Messages,
|
||||
*,
|
||||
format: Literal["langchain-openai"] | None = None,
|
||||
) -> Messages:
|
||||
"""Merges two lists of messages, updating existing messages by ID.
|
||||
|
||||
By default, this ensures the state is "append-only", unless the
|
||||
new message has the same ID as an existing message.
|
||||
|
||||
Args:
|
||||
left: The base list of `Messages`.
|
||||
right: The list of `Messages` (or single `Message`) to merge
|
||||
into the base list.
|
||||
format: The format to return messages in. If `None` then `Messages` will be
|
||||
returned as is. If `langchain-openai` then `Messages` will be returned as
|
||||
`BaseMessage` objects with their contents formatted to match OpenAI message
|
||||
format, meaning contents can be string, `'text'` blocks, or `'image_url'` blocks
|
||||
and tool responses are returned as their own `ToolMessage` objects.
|
||||
|
||||
!!! important "Requirement"
|
||||
|
||||
Must have `langchain-core>=0.3.11` installed to use this feature.
|
||||
|
||||
Returns:
|
||||
A new list of messages with the messages from `right` merged into `left`.
|
||||
If a message in `right` has the same ID as a message in `left`, the
|
||||
message from `right` will replace the message from `left`.
|
||||
|
||||
Example: Basic usage
|
||||
```python
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
msgs1 = [HumanMessage(content="Hello", id="1")]
|
||||
msgs2 = [AIMessage(content="Hi there!", id="2")]
|
||||
add_messages(msgs1, msgs2)
|
||||
# [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
|
||||
```
|
||||
|
||||
Example: Overwrite existing message
|
||||
```python
|
||||
msgs1 = [HumanMessage(content="Hello", id="1")]
|
||||
msgs2 = [HumanMessage(content="Hello again", id="1")]
|
||||
add_messages(msgs1, msgs2)
|
||||
# [HumanMessage(content='Hello again', id='1')]
|
||||
```
|
||||
|
||||
Example: Use in a StateGraph
|
||||
```python
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
|
||||
builder.set_entry_point("chatbot")
|
||||
builder.set_finish_point("chatbot")
|
||||
graph = builder.compile()
|
||||
graph.invoke({})
|
||||
# {'messages': [AIMessage(content='Hello', id=...)]}
|
||||
```
|
||||
|
||||
Example: Use OpenAI message format
|
||||
```python
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages(format="langchain-openai")]
|
||||
|
||||
|
||||
def chatbot_node(state: State) -> list:
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here's an image:",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "1234",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("chatbot", chatbot_node)
|
||||
builder.set_entry_point("chatbot")
|
||||
builder.set_finish_point("chatbot")
|
||||
graph = builder.compile()
|
||||
graph.invoke({"messages": []})
|
||||
# {
|
||||
# 'messages': [
|
||||
# HumanMessage(
|
||||
# content=[
|
||||
# {"type": "text", "text": "Here's an image:"},
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {"url": "data:image/jpeg;base64,1234"},
|
||||
# },
|
||||
# ],
|
||||
# ),
|
||||
# ]
|
||||
# }
|
||||
```
|
||||
|
||||
"""
|
||||
remove_all_idx = None
|
||||
# coerce to list
|
||||
if not isinstance(left, list):
|
||||
left = [left] # type: ignore[assignment]
|
||||
if not isinstance(right, list):
|
||||
right = [right] # type: ignore[assignment]
|
||||
# coerce to message
|
||||
left = [
|
||||
message_chunk_to_message(cast(BaseMessageChunk, m))
|
||||
for m in convert_to_messages(left)
|
||||
]
|
||||
right = [
|
||||
message_chunk_to_message(cast(BaseMessageChunk, m))
|
||||
for m in convert_to_messages(right)
|
||||
]
|
||||
# assign missing ids
|
||||
for m in left:
|
||||
if m.id is None:
|
||||
m.id = str(uuid.uuid4())
|
||||
for idx, m in enumerate(right):
|
||||
if m.id is None:
|
||||
m.id = str(uuid.uuid4())
|
||||
if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
|
||||
remove_all_idx = idx
|
||||
|
||||
if remove_all_idx is not None:
|
||||
return right[remove_all_idx + 1 :]
|
||||
|
||||
# merge
|
||||
merged = left.copy()
|
||||
merged_by_id = {m.id: i for i, m in enumerate(merged)}
|
||||
ids_to_remove = set()
|
||||
for m in right:
|
||||
if (existing_idx := merged_by_id.get(m.id)) is not None:
|
||||
if isinstance(m, RemoveMessage):
|
||||
ids_to_remove.add(m.id)
|
||||
else:
|
||||
ids_to_remove.discard(m.id)
|
||||
merged[existing_idx] = m
|
||||
else:
|
||||
if isinstance(m, RemoveMessage):
|
||||
raise ValueError(
|
||||
f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
|
||||
)
|
||||
|
||||
merged_by_id[m.id] = len(merged)
|
||||
merged.append(m)
|
||||
merged = [m for m in merged if m.id not in ids_to_remove]
|
||||
|
||||
if format == "langchain-openai":
|
||||
merged = _format_messages(merged)
|
||||
elif format:
|
||||
msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
pass
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
@deprecated(
|
||||
"MessageGraph is deprecated in langgraph 1.0.0, to be removed in 2.0.0. Please use StateGraph with a `messages` key instead.",
|
||||
category=None,
|
||||
)
|
||||
class MessageGraph(StateGraph):
|
||||
"""A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
|
||||
|
||||
MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
|
||||
Each node in a MessageGraph takes a list of messages as input and returns zero or more
|
||||
messages as output. The `add_messages` function is used to merge the output messages from each node
|
||||
into the existing list of messages in the graph's state.
|
||||
|
||||
Examples:
|
||||
```pycon
|
||||
>>> from langgraph.graph.message import MessageGraph
|
||||
...
|
||||
>>> builder = MessageGraph()
|
||||
>>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
|
||||
>>> builder.set_entry_point("chatbot")
|
||||
>>> builder.set_finish_point("chatbot")
|
||||
>>> builder.compile().invoke([("user", "Hi there.")])
|
||||
[HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
|
||||
```
|
||||
|
||||
```pycon
|
||||
>>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
>>> from langgraph.graph.message import MessageGraph
|
||||
...
|
||||
>>> builder = MessageGraph()
|
||||
>>> builder.add_node(
|
||||
... "chatbot",
|
||||
... lambda state: [
|
||||
... AIMessage(
|
||||
... content="Hello!",
|
||||
... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
|
||||
... )
|
||||
... ],
|
||||
... )
|
||||
>>> builder.add_node(
|
||||
... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
|
||||
... )
|
||||
>>> builder.set_entry_point("chatbot")
|
||||
>>> builder.add_edge("chatbot", "search")
|
||||
>>> builder.set_finish_point("search")
|
||||
>>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
|
||||
{'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
|
||||
AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
|
||||
ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
warnings.warn(
|
||||
"MessageGraph is deprecated in LangGraph v1.0.0, to be removed in v2.0.0. Please use StateGraph with a `messages` key instead.",
|
||||
category=LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
|
||||
|
||||
def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
try:
|
||||
from langchain_core.messages import convert_to_openai_messages
|
||||
except ImportError:
|
||||
msg = (
|
||||
"Must have langchain-core>=0.3.11 installed to use automatic message "
|
||||
"formatting (format='langchain-openai'). Please update your langchain-core "
|
||||
"version or remove the 'format' flag. Returning un-formatted "
|
||||
"messages."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return list(messages)
|
||||
else:
|
||||
return convert_to_messages(convert_to_openai_messages(messages))
|
||||
|
||||
|
||||
def push_message(
|
||||
message: MessageLikeRepresentation | BaseMessageChunk,
|
||||
*,
|
||||
state_key: str | None = "messages",
|
||||
) -> AnyMessage:
|
||||
"""Write a message manually to the `messages` / `messages-tuple` stream mode.
|
||||
|
||||
Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
|
||||
"""
|
||||
|
||||
from langchain_core.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
)
|
||||
|
||||
from langgraph.config import get_config
|
||||
from langgraph.pregel._messages import StreamMessagesHandler
|
||||
|
||||
config = get_config()
|
||||
message = next(x for x in convert_to_messages([message]))
|
||||
|
||||
if message.id is None:
|
||||
raise ValueError("Message ID is required")
|
||||
|
||||
if isinstance(config["callbacks"], BaseCallbackManager):
|
||||
manager = config["callbacks"]
|
||||
handlers = manager.handlers
|
||||
elif isinstance(config["callbacks"], list) and all(
|
||||
isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
|
||||
):
|
||||
handlers = config["callbacks"]
|
||||
|
||||
if stream_handler := next(
|
||||
(x for x in handlers if isinstance(x, StreamMessagesHandler)), None
|
||||
):
|
||||
metadata = config["metadata"]
|
||||
message_meta = (
|
||||
tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
|
||||
metadata,
|
||||
)
|
||||
stream_handler._emit(message_meta, message, dedupe=False)
|
||||
|
||||
if state_key:
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, message)])
|
||||
|
||||
return message
|
||||
1731
venv/Lib/site-packages/langgraph/graph/state.py
Normal file
1731
venv/Lib/site-packages/langgraph/graph/state.py
Normal file
File diff suppressed because it is too large
Load Diff
227
venv/Lib/site-packages/langgraph/graph/ui.py
Normal file
227
venv/Lib/site-packages/langgraph/graph/ui.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.config import get_config, get_stream_writer
|
||||
from langgraph.constants import CONF
|
||||
|
||||
__all__ = (
|
||||
"UIMessage",
|
||||
"RemoveUIMessage",
|
||||
"AnyUIMessage",
|
||||
"push_ui_message",
|
||||
"delete_ui_message",
|
||||
"ui_message_reducer",
|
||||
)
|
||||
|
||||
|
||||
class UIMessage(TypedDict):
|
||||
"""A message type for UI updates in LangGraph.
|
||||
|
||||
This TypedDict represents a UI message that can be sent to update the UI state.
|
||||
It contains information about the UI component to render and its properties.
|
||||
|
||||
Attributes:
|
||||
type: Literal type indicating this is a UI message.
|
||||
id: Unique identifier for the UI message.
|
||||
name: Name of the UI component to render.
|
||||
props: Properties to pass to the UI component.
|
||||
metadata: Additional metadata about the UI message.
|
||||
"""
|
||||
|
||||
type: Literal["ui"]
|
||||
id: str
|
||||
name: str
|
||||
props: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class RemoveUIMessage(TypedDict):
|
||||
"""A message type for removing UI components in LangGraph.
|
||||
|
||||
This TypedDict represents a message that can be sent to remove a UI component
|
||||
from the current state.
|
||||
|
||||
Attributes:
|
||||
type: Literal type indicating this is a remove-ui message.
|
||||
id: Unique identifier of the UI message to remove.
|
||||
"""
|
||||
|
||||
type: Literal["remove-ui"]
|
||||
id: str
|
||||
|
||||
|
||||
AnyUIMessage = UIMessage | RemoveUIMessage
|
||||
|
||||
|
||||
def push_ui_message(
|
||||
name: str,
|
||||
props: dict[str, Any],
|
||||
*,
|
||||
id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
message: AnyMessage | None = None,
|
||||
state_key: str | None = "ui",
|
||||
merge: bool = False,
|
||||
) -> UIMessage:
|
||||
"""Push a new UI message to update the UI state.
|
||||
|
||||
This function creates and sends a UI message that will be rendered in the UI.
|
||||
It also updates the graph state with the new UI message.
|
||||
|
||||
Args:
|
||||
name: Name of the UI component to render.
|
||||
props: Properties to pass to the UI component.
|
||||
id: Optional unique identifier for the UI message.
|
||||
If not provided, a random UUID will be generated.
|
||||
metadata: Optional additional metadata about the UI message.
|
||||
message: Optional message object to associate with the UI message.
|
||||
state_key: Key in the graph state where the UI messages are stored.
|
||||
merge: Whether to merge props with existing UI message (True) or replace
|
||||
them (False).
|
||||
|
||||
Returns:
|
||||
The created UI message.
|
||||
|
||||
Example:
|
||||
```python
|
||||
push_ui_message(
|
||||
name="component-name",
|
||||
props={"content": "Hello world"},
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
from langgraph._internal._constants import CONFIG_KEY_SEND
|
||||
|
||||
writer = get_stream_writer()
|
||||
config = get_config()
|
||||
|
||||
message_id = None
|
||||
if message:
|
||||
if isinstance(message, dict) and "id" in message:
|
||||
message_id = message.get("id")
|
||||
elif hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
|
||||
evt: UIMessage = {
|
||||
"type": "ui",
|
||||
"id": id or str(uuid4()),
|
||||
"name": name,
|
||||
"props": props,
|
||||
"metadata": {
|
||||
"merge": merge,
|
||||
"run_id": config.get("run_id", None),
|
||||
"tags": config.get("tags", None),
|
||||
"name": config.get("run_name", None),
|
||||
**(metadata or {}),
|
||||
**({"message_id": message_id} if message_id else {}),
|
||||
},
|
||||
}
|
||||
|
||||
writer(evt)
|
||||
if state_key:
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
|
||||
|
||||
return evt
|
||||
|
||||
|
||||
def delete_ui_message(id: str, *, state_key: str = "ui") -> RemoveUIMessage:
|
||||
"""Delete a UI message by ID from the UI state.
|
||||
|
||||
This function creates and sends a message to remove a UI component from the current state.
|
||||
It also updates the graph state to remove the UI message.
|
||||
|
||||
Args:
|
||||
id: Unique identifier of the UI component to remove.
|
||||
state_key: Key in the graph state where the UI messages are stored. Defaults to "ui".
|
||||
|
||||
Returns:
|
||||
The remove UI message.
|
||||
|
||||
Example:
|
||||
```python
|
||||
delete_ui_message("message-123")
|
||||
```
|
||||
|
||||
"""
|
||||
from langgraph._internal._constants import CONFIG_KEY_SEND
|
||||
|
||||
writer = get_stream_writer()
|
||||
config = get_config()
|
||||
|
||||
evt: RemoveUIMessage = {"type": "remove-ui", "id": id}
|
||||
|
||||
writer(evt)
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
|
||||
|
||||
return evt
|
||||
|
||||
|
||||
def ui_message_reducer(
|
||||
left: list[AnyUIMessage] | AnyUIMessage,
|
||||
right: list[AnyUIMessage] | AnyUIMessage,
|
||||
) -> list[AnyUIMessage]:
|
||||
"""Merge two lists of UI messages, supporting removing UI messages.
|
||||
|
||||
This function combines two lists of UI messages, handling both regular UI messages
|
||||
and `remove-ui` messages. When a `remove-ui` message is encountered, it removes any
|
||||
UI message with the matching ID from the current state.
|
||||
|
||||
Args:
|
||||
left: First list of UI messages or single UI message.
|
||||
right: Second list of UI messages or single UI message.
|
||||
|
||||
Returns:
|
||||
Combined list of UI messages with removals applied.
|
||||
|
||||
Example:
|
||||
```python
|
||||
messages = ui_message_reducer(
|
||||
[{"type": "ui", "id": "1", "name": "Chat", "props": {}}],
|
||||
{"type": "remove-ui", "id": "1"},
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
if not isinstance(left, list):
|
||||
left = [left]
|
||||
|
||||
if not isinstance(right, list):
|
||||
right = [right]
|
||||
|
||||
# merge messages
|
||||
merged = left.copy()
|
||||
merged_by_id = {m.get("id"): i for i, m in enumerate(merged)}
|
||||
ids_to_remove = set()
|
||||
|
||||
for msg in right:
|
||||
msg_id = msg.get("id")
|
||||
|
||||
if (existing_idx := merged_by_id.get(msg_id)) is not None:
|
||||
if msg.get("type") == "remove-ui":
|
||||
ids_to_remove.add(msg_id)
|
||||
else:
|
||||
ids_to_remove.discard(msg_id)
|
||||
|
||||
if cast(UIMessage, msg).get("metadata", {}).get("merge", False):
|
||||
prev_msg = merged[existing_idx]
|
||||
msg = msg.copy()
|
||||
msg["props"] = {**prev_msg["props"], **msg["props"]}
|
||||
|
||||
merged[existing_idx] = msg
|
||||
else:
|
||||
if msg.get("type") == "remove-ui":
|
||||
raise ValueError(
|
||||
f"Attempting to delete an UI message with an ID that doesn't exist ('{msg_id}')"
|
||||
)
|
||||
|
||||
merged_by_id[msg_id] = len(merged)
|
||||
merged.append(msg)
|
||||
|
||||
merged = [m for m in merged if m.get("id") not in ids_to_remove]
|
||||
return merged
|
||||
3
venv/Lib/site-packages/langgraph/managed/__init__.py
Normal file
3
venv/Lib/site-packages/langgraph/managed/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from langgraph.managed.is_last_step import IsLastStep, RemainingSteps
|
||||
|
||||
__all__ = ("IsLastStep", "RemainingSteps")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user