initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
"""Internal modules for LangGraph.
This module is not part of the public API, and thus stability is not guaranteed.
"""

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
def _freeze(obj: Any, depth: int = 10) -> Hashable:
if isinstance(obj, Hashable) or depth <= 0:
# already hashable, no need to freeze
return obj
elif isinstance(obj, Mapping):
# sort keys so {"a":1,"b":2} == {"b":2,"a":1}
return tuple(sorted((k, _freeze(v, depth - 1)) for k, v in obj.items()))
elif isinstance(obj, Sequence):
return tuple(_freeze(x, depth - 1) for x in obj)
# numpy / pandas etc. can provide their own .tobytes()
elif hasattr(obj, "tobytes"):
return (
type(obj).__name__,
obj.tobytes(),
obj.shape if hasattr(obj, "shape") else None,
)
return obj # strings, ints, dataclasses with frozen=True, etc.
def default_cache_key(*args: Any, **kwargs: Any) -> str | bytes:
"""Default cache key function that uses the arguments and keyword arguments to generate a hashable key."""
import pickle
# protocol 5 strikes a good balance between speed and size
return pickle.dumps((_freeze(args), _freeze(kwargs)), protocol=5, fix_imports=False)

View File

@@ -0,0 +1,329 @@
from __future__ import annotations
from collections import ChainMap
from collections.abc import Mapping, Sequence
from os import getenv
from typing import Any, cast
from langchain_core.callbacks import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
Callbacks,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
CONFIG_KEYS,
COPIABLE_KEYS,
var_child_runnable_config,
)
from langgraph.checkpoint.base import CheckpointMetadata
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_NS,
NS_END,
NS_SEP,
)
DEFAULT_RECURSION_LIMIT = int(getenv("LANGGRAPH_DEFAULT_RECURSION_LIMIT", "10000"))
def recast_checkpoint_ns(ns: str) -> str:
"""Remove task IDs from checkpoint namespace.
Args:
ns: The checkpoint namespace with task IDs.
Returns:
str: The checkpoint namespace without task IDs.
"""
return NS_SEP.join(
part.split(NS_END)[0] for part in ns.split(NS_SEP) if not part.isdigit()
)
def patch_configurable(
config: RunnableConfig | None, patch: dict[str, Any]
) -> RunnableConfig:
if config is None:
return {CONF: patch}
elif CONF not in config:
return {**config, CONF: patch}
else:
return {**config, CONF: {**config[CONF], **patch}}
def patch_checkpoint_map(
config: RunnableConfig | None, metadata: CheckpointMetadata | None
) -> RunnableConfig:
if config is None:
return config
elif parents := (metadata.get("parents") if metadata else None):
conf = config[CONF]
return patch_configurable(
config,
{
CONFIG_KEY_CHECKPOINT_MAP: {
**parents,
conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID],
},
},
)
else:
return config
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
"""Merge multiple configs into one.
Args:
*configs: The configs to merge.
Returns:
RunnableConfig: The merged config.
"""
base: RunnableConfig = {}
# Even though the keys aren't literals, this is correct
# because both dicts are the same type
for config in configs:
if config is None:
continue
for key, value in config.items():
if not value:
continue
if key == "metadata":
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == "tags":
if base_value := base.get(key):
base[key] = [*base_value, *value] # type: ignore
else:
base[key] = value # type: ignore[literal-required]
elif key == CONF:
if base_value := base.get(key):
base[key] = {**base_value, **value} # type: ignore[dict-item]
else:
base[key] = value
elif key == "callbacks":
base_callbacks = base.get("callbacks")
# callbacks can be either None, list[handler] or manager
# so merging two callbacks values has 6 cases
if isinstance(value, list):
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + value
else:
# base_callbacks is a manager
mngr = base_callbacks.copy()
for callback in value:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
elif isinstance(value, BaseCallbackManager):
# value is a manager
if base_callbacks is None:
base["callbacks"] = value.copy()
elif isinstance(base_callbacks, list):
mngr = value.copy()
for callback in base_callbacks:
mngr.add_handler(callback, inherit=True)
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.merge(value)
else:
raise NotImplementedError
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]
else:
base[key] = config[key] # type: ignore[literal-required]
if CONF not in base:
base[CONF] = {}
return base
def patch_config(
config: RunnableConfig | None,
*,
callbacks: Callbacks = None,
recursion_limit: int | None = None,
max_concurrency: int | None = None,
run_name: str | None = None,
configurable: dict[str, Any] | None = None,
) -> RunnableConfig:
"""Patch a config with new values.
Args:
config: The config to patch.
callbacks: The callbacks to set.
recursion_limit: The recursion limit to set.
max_concurrency: The max number of concurrent steps to run, which also applies to parallelized steps.
run_name: The run name to set.
configurable: The configurable to set.
Returns:
RunnableConfig: The patched config.
"""
config = config.copy() if config is not None else {}
if callbacks is not None:
# If we're replacing callbacks, we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if "run_id" in config:
del config["run_id"]
if recursion_limit is not None:
config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if configurable is not None:
config[CONF] = {**config.get(CONF, {}), **configurable}
return config
def get_callback_manager_for_config(
config: RunnableConfig, tags: Sequence[str] | None = None
) -> CallbackManager:
"""Get a callback manager for a config.
Args:
config: The config.
Returns:
CallbackManager: The callback manager.
"""
from langchain_core.callbacks.manager import CallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, CallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=all_tags,
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
tags: Sequence[str] | None = None,
) -> AsyncCallbackManager:
"""Get an async callback manager for a config.
Args:
config: The config.
Returns:
AsyncCallbackManager: The async callback manager.
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
# merge tags
all_tags = config.get("tags")
if all_tags is not None and tags is not None:
all_tags = [*all_tags, *tags]
elif tags is not None:
all_tags = list(tags)
# use existing callbacks if they exist
if (callbacks := config.get("callbacks")) and isinstance(
callbacks, AsyncCallbackManager
):
if all_tags:
callbacks.add_tags(all_tags)
if metadata := config.get("metadata"):
callbacks.add_metadata(metadata)
return callbacks
else:
# otherwise create a new manager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=all_tags,
inheritable_metadata=config.get("metadata"),
)
def _is_not_empty(value: Any) -> bool:
if isinstance(value, (list, tuple, dict)):
return len(value) > 0
else:
return value is not None
def ensure_config(*configs: RunnableConfig | None) -> RunnableConfig:
"""Return a config with all keys, merging any provided configs.
Args:
*configs: Configs to merge before ensuring defaults.
Returns:
RunnableConfig: The merged and ensured config.
"""
empty = RunnableConfig(
tags=[],
metadata=ChainMap(),
callbacks=None,
recursion_limit=DEFAULT_RECURSION_LIMIT,
configurable={},
)
if var_config := var_child_runnable_config.get():
empty.update(
{
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
for k, v in var_config.items()
if _is_not_empty(v)
},
)
for config in configs:
if config is None:
continue
for k, v in config.items():
if _is_not_empty(v) and k in CONFIG_KEYS:
if k == CONF:
empty[k] = cast(dict, v).copy()
else:
empty[k] = v # type: ignore[literal-required]
for k, v in config.items():
if _is_not_empty(v) and k not in CONFIG_KEYS:
empty[CONF][k] = v
_empty_metadata = empty["metadata"]
for key, value in empty[CONF].items():
if _exclude_as_metadata(key, value, _empty_metadata):
continue
_empty_metadata[key] = value
return empty
_OMIT = ("key", "token", "secret", "password", "auth")
def _exclude_as_metadata(key: str, value: Any, metadata: Mapping[str, Any]) -> bool:
key_lower = key.casefold()
return (
key.startswith("__")
or not isinstance(value, (str, int, float, bool))
or key in metadata
or any(substr in key_lower for substr in _OMIT)
)

View File

@@ -0,0 +1,112 @@
"""Constants used for Pregel operations."""
import sys
from typing import Literal, cast
# --- Reserved write keys ---
INPUT = sys.intern("__input__")
# for values passed as input to the graph
INTERRUPT = sys.intern("__interrupt__")
# for dynamic interrupts raised by nodes
RESUME = sys.intern("__resume__")
# for values passed to resume a node after an interrupt
ERROR = sys.intern("__error__")
# for errors raised by nodes
NO_WRITES = sys.intern("__no_writes__")
# marker to signal node didn't write anything
TASKS = sys.intern("__pregel_tasks")
# for Send objects returned by nodes/edges, corresponds to PUSH below
RETURN = sys.intern("__return__")
# for writes of a task where we simply record the return value
PREVIOUS = sys.intern("__previous__")
# the implicit branch that handles each node's Control values
# --- Reserved cache namespaces ---
CACHE_NS_WRITES = sys.intern("__pregel_ns_writes")
# cache namespace for node writes
# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = sys.intern("__pregel_send")
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = sys.intern("__pregel_read")
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CALL = sys.intern("__pregel_call")
# holds the `call` function that accepts a node/func, args and returns a future
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
# holds a `StreamProtocol` passed from parent graph to child graphs
CONFIG_KEY_CACHE = sys.intern("__pregel_cache")
# holds a `BaseCache` made available to subgraphs
CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming")
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id")
# holds the task ID for the current task
CONFIG_KEY_THREAD_ID = sys.intern("thread_id")
# holds the thread ID for the current invocation
CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map")
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id")
# holds the current checkpoint_id, if any
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
# holds the current checkpoint_ns, "" for root graph
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
# holds a callback to be called when a node is finished
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
# holds a mutable dict for temporary storage scoped to the current task
CONFIG_KEY_RUNNER_SUBMIT = sys.intern("__pregel_runner_submit")
# holds a function that receives tasks from runner, executes them and returns results
CONFIG_KEY_DURABILITY = sys.intern("__pregel_durability")
# holds the durability mode, one of "sync", "async", or "exit"
CONFIG_KEY_RUNTIME = sys.intern("__pregel_runtime")
# holds a `Runtime` instance with context, store, stream writer, etc.
CONFIG_KEY_RESUME_MAP = sys.intern("__pregel_resume_map")
# holds a mapping of task ns -> resume value for resuming tasks
# --- Other constants ---
PUSH = sys.intern("__pregel_push")
# denotes push-style tasks, ie. those created by Send objects
PULL = sys.intern("__pregel_pull")
# denotes pull-style tasks, ie. those triggered by edges
NS_SEP = sys.intern("|")
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = sys.intern(":")
# for checkpoint_ns, for each level, separates the namespace from the task_id
CONF = cast(Literal["configurable"], sys.intern("configurable"))
# key for the configurable dict in RunnableConfig
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
# the task_id to use for writes that are not associated with a task
OVERWRITE = sys.intern("__overwrite__")
# dict key for the overwrite value, used as `{'__overwrite__': value}`
# redefined to avoid circular import with langgraph.constants
_TAG_HIDDEN = sys.intern("langsmith:hidden")
RESERVED = {
_TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
RESUME,
ERROR,
NO_WRITES,
# reserved config.configurable keys
CONFIG_KEY_SEND,
CONFIG_KEY_READ,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_STREAM,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_RESUME_MAP,
# other constants
PUSH,
PULL,
NS_SEP,
NS_END,
CONF,
}

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
import dataclasses
import types
import weakref
from collections.abc import Generator, Sequence
from typing import Annotated, Any, Optional, Union, get_origin, get_type_hints
from pydantic import BaseModel
from typing_extensions import NotRequired, ReadOnly, Required
from langgraph._internal._typing import MISSING
def _is_optional_type(type_: Any) -> bool:
"""Check if a type is Optional."""
# Handle new union syntax (PEP 604): str | None
if isinstance(type_, types.UnionType):
return any(
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
)
if hasattr(type_, "__origin__") and hasattr(type_, "__args__"):
origin = get_origin(type_)
if origin is Optional:
return True
if origin is Union:
return any(
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
)
if origin is Annotated:
return _is_optional_type(type_.__args__[0])
return origin is None
if hasattr(type_, "__bound__") and type_.__bound__ is not None:
return _is_optional_type(type_.__bound__)
return type_ is None
def _is_required_type(type_: Any) -> bool | None:
"""Check if an annotation is marked as Required/NotRequired.
Returns:
- True if required
- False if not required
- None if not annotated with either
"""
origin = get_origin(type_)
if origin is Required:
return True
if origin is NotRequired:
return False
if origin is Annotated or getattr(origin, "__args__", None):
# See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated
return _is_required_type(type_.__args__[0])
return None
def _is_readonly_type(type_: Any) -> bool:
"""Check if an annotation is marked as ReadOnly.
Returns:
- True if is read only
- False if not read only
"""
# See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier
origin = get_origin(type_)
if origin is Annotated:
return _is_readonly_type(type_.__args__[0])
if origin is ReadOnly:
return True
return False
_DEFAULT_KEYS: frozenset[str] = frozenset()
def get_field_default(name: str, type_: Any, schema: type[Any]) -> Any:
"""Determine the default value for a field in a state schema.
This is based on:
If TypedDict:
- Required/NotRequired
- total=False -> everything optional
- Type annotation (Optional/Union[None])
"""
optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS)
irq = _is_required_type(type_)
if name in optional_keys:
# Either total=False or explicit NotRequired.
# No type annotation trumps this.
if irq:
# Unless it's earlier versions of python & explicit Required
return ...
return None
if irq is not None:
if irq:
# Handle Required[<type>]
# (we already handled NotRequired and total=False)
return ...
# Handle NotRequired[<type>] for earlier versions of python
return None
if dataclasses.is_dataclass(schema):
field_info = next(
(f for f in dataclasses.fields(schema) if f.name == name), None
)
if field_info:
if (
field_info.default is not dataclasses.MISSING
and field_info.default is not ...
):
return field_info.default
elif field_info.default_factory is not dataclasses.MISSING:
return field_info.default_factory()
# Note, we ignore ReadOnly attributes,
# as they don't make much sense. (we don't care if you mutate the state in your node)
# and mutating state in your node has no effect on our graph state.
# Base case is the annotation
if _is_optional_type(type_):
return None
return ...
def get_enhanced_type_hints(
type: type[Any],
) -> Generator[tuple[str, Any, Any, str | None], None, None]:
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
for name, typ in get_type_hints(type).items():
default = None
description = None
# Pydantic models
try:
if hasattr(type, "model_fields") and name in type.model_fields:
field = type.model_fields[name]
if hasattr(field, "description") and field.description is not None:
description = field.description
if hasattr(field, "default") and field.default is not None:
default = field.default
if (
hasattr(default, "__class__")
and getattr(default.__class__, "__name__", "")
== "PydanticUndefinedType"
):
default = None
except (AttributeError, KeyError, TypeError):
pass
# TypedDict, dataclass
try:
if hasattr(type, "__dict__"):
type_dict = getattr(type, "__dict__")
if name in type_dict:
default = type_dict[name]
except (AttributeError, KeyError, TypeError):
pass
yield name, typ, default, description
def get_update_as_tuples(input: Any, keys: Sequence[str]) -> list[tuple[str, Any]]:
"""Get Pydantic state update as a list of (key, value) tuples."""
if isinstance(input, BaseModel):
keep = input.model_fields_set
defaults = {k: v.default for k, v in type(input).model_fields.items()}
else:
keep = None
defaults = {}
# NOTE: This behavior for Pydantic is somewhat inelegant,
# but we keep around for backwards compatibility
# if input is a Pydantic model, only update values
# that are different from the default values or in the keep set
return [
(k, value)
for k in keys
if (value := getattr(input, k, MISSING)) is not MISSING
and (
value is not None
or defaults.get(k, MISSING) is not None
or (keep is not None and k in keep)
)
]
ANNOTATED_KEYS_CACHE: weakref.WeakKeyDictionary[type[Any], tuple[str, ...]] = (
weakref.WeakKeyDictionary()
)
def get_cached_annotated_keys(obj: type[Any]) -> tuple[str, ...]:
"""Return cached annotated keys for a Python class."""
if obj in ANNOTATED_KEYS_CACHE:
return ANNOTATED_KEYS_CACHE[obj]
if isinstance(obj, type):
keys: list[str] = []
for base in reversed(obj.__mro__):
ann = base.__dict__.get("__annotations__")
# In Python 3.14+, Pydantic models use descriptors for __annotations__
# so we need to fall back to getattr if __dict__.get returns None
if ann is None:
ann = getattr(base, "__annotations__", None)
if ann is None or isinstance(ann, types.GetSetDescriptorType):
continue
keys.extend(ann.keys())
return ANNOTATED_KEYS_CACHE.setdefault(obj, tuple(keys))
else:
raise TypeError(f"Expected a type, got {type(obj)}. ")

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import contextvars
import inspect
import sys
import types
from collections.abc import Awaitable, Coroutine, Generator
from typing import TypeVar, cast
T = TypeVar("T")
AnyFuture = asyncio.Future | concurrent.futures.Future
CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)
EAGER_NOT_SUPPORTED = sys.version_info < (3, 12)
def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
# Tries to call Future.get_loop() if it's available.
# Otherwise fallbacks to using the old '_loop' property.
try:
get_loop = fut.get_loop
except AttributeError:
pass
else:
return get_loop()
return fut._loop
def _convert_future_exc(exc: BaseException) -> BaseException:
exc_class = type(exc)
if exc_class is concurrent.futures.CancelledError:
return asyncio.CancelledError(*exc.args)
elif exc_class is concurrent.futures.TimeoutError:
return asyncio.TimeoutError(*exc.args)
elif exc_class is concurrent.futures.InvalidStateError:
return asyncio.InvalidStateError(*exc.args)
else:
return exc
def _set_concurrent_future_state(
concurrent: concurrent.futures.Future,
source: AnyFuture,
) -> None:
"""Copy state from a future to a concurrent.futures.Future."""
assert source.done()
if source.cancelled():
concurrent.cancel()
if not concurrent.set_running_or_notify_cancel():
return
exception = source.exception()
if exception is not None:
concurrent.set_exception(_convert_future_exc(exception))
else:
result = source.result()
concurrent.set_result(result)
def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
"""Internal helper to copy state from another Future.
The other Future may be a concurrent.futures.Future.
"""
if dest.done():
return
assert source.done()
if dest.cancelled():
return
if source.cancelled():
dest.cancel()
else:
exception = source.exception()
if exception is not None:
dest.set_exception(_convert_future_exc(exception))
else:
result = source.result()
dest.set_result(result)
def _chain_future(source: AnyFuture, destination: AnyFuture) -> None:
"""Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination.
If destination is cancelled, source gets cancelled too.
Compatible with both asyncio.Future and concurrent.futures.Future.
"""
if not asyncio.isfuture(source) and not isinstance(
source, concurrent.futures.Future
):
raise TypeError("A future is required for source argument")
if not asyncio.isfuture(destination) and not isinstance(
destination, concurrent.futures.Future
):
raise TypeError("A future is required for destination argument")
source_loop = _get_loop(source) if asyncio.isfuture(source) else None
dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None
def _set_state(future: AnyFuture, other: AnyFuture) -> None:
if asyncio.isfuture(future):
_copy_future_state(other, future)
else:
_set_concurrent_future_state(future, other)
def _call_check_cancel(destination: AnyFuture) -> None:
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)
def _call_set_state(source: AnyFuture) -> None:
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
return
if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source)
else:
if dest_loop.is_closed():
return
dest_loop.call_soon_threadsafe(_set_state, destination, source)
destination.add_done_callback(_call_check_cancel)
source.add_done_callback(_call_set_state)
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
# adapted from asyncio.run_coroutine_threadsafe
try:
_chain_future(source, destination)
return destination
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
if isinstance(destination, concurrent.futures.Future):
if destination.set_running_or_notify_cancel():
destination.set_exception(exc)
else:
destination.set_exception(exc)
raise
def _ensure_future(
coro_or_future: Coroutine[None, None, T] | Awaitable[T],
*,
loop: asyncio.AbstractEventLoop,
name: str | None = None,
context: contextvars.Context | None = None,
lazy: bool = True,
) -> asyncio.Task[T]:
called_wrap_awaitable = False
if not asyncio.iscoroutine(coro_or_future):
if inspect.isawaitable(coro_or_future):
coro_or_future = cast(
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
)
called_wrap_awaitable = True
else:
raise TypeError(
"An asyncio.Future, a coroutine or an awaitable is required."
f" Got {type(coro_or_future).__name__} instead."
)
try:
if CONTEXT_NOT_SUPPORTED:
return loop.create_task(coro_or_future, name=name)
elif EAGER_NOT_SUPPORTED or lazy:
return loop.create_task(coro_or_future, name=name, context=context)
else:
return asyncio.eager_task_factory(
loop, coro_or_future, name=name, context=context
)
except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise
@types.coroutine
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
"""Helper for asyncio.ensure_future().
Wraps awaitable (an object with __await__) into a coroutine
that will later be wrapped in a Task by ensure_future().
"""
return (yield from awaitable.__await__())
def run_coroutine_threadsafe(
coro: Coroutine[None, None, T],
loop: asyncio.AbstractEventLoop,
*,
lazy: bool,
name: str | None = None,
context: contextvars.Context | None = None,
) -> asyncio.Future[T]:
"""Submit a coroutine object to a given event loop.
Return an asyncio.Future to access the result.
"""
if asyncio._get_running_loop() is loop:
return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy)
else:
future: asyncio.Future[T] = asyncio.Future(loop=loop)
def callback() -> None:
try:
chain_future(
_ensure_future(coro, loop=loop, name=name, context=context),
future,
)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
future.set_exception(exc)
raise
loop.call_soon_threadsafe(callback, context=context)
return future

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import sys
import typing
import warnings
from contextlib import nullcontext
from dataclasses import is_dataclass
from functools import lru_cache
from typing import (
Any,
cast,
overload,
)
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
)
from pydantic import (
create_model as _create_model_base,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from typing_extensions import TypedDict
@overload
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: ...
@overload
def get_fields(model: BaseModel) -> dict[str, FieldInfo]: ...
def get_fields(
model: type[BaseModel] | BaseModel,
) -> dict[str, FieldInfo]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)
_SchemaConfig = ConfigDict(
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
)
NO_DEFAULT = object()
def _create_root_model(
name: str,
type_: Any,
module_name: str | None = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
"""Create a base class."""
def schema(
cls: type[BaseModel],
by_alias: bool = True, # noqa: FBT001,FBT002
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_
def model_json_schema(
cls: type[BaseModel],
by_alias: bool = True, # noqa: FBT001,FBT002
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_
base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": module_name or "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
with warnings.catch_warnings():
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast("type[BaseModel]", custom_root_type)
@lru_cache(maxsize=256)
def _create_root_model_cached(
model_name: str,
type_: Any,
*,
module_name: str | None = None,
default_: object = NO_DEFAULT,
) -> type[BaseModel]:
return _create_root_model(
model_name, type_, default_=default_, module_name=module_name
)
@lru_cache(maxsize=256)
def _create_model_cached(
model_name: str,
/,
**field_definitions: Any,
) -> type[BaseModel]:
return _create_model_base(
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
# Reserved names should capture all the `public` names / methods that are
# used by BaseModel internally. This will keep the reserved names up-to-date.
# For reference, the reserved names are:
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
# "model_computed_fields", "model_config", "model_construct", "model_copy",
# "model_dump", "model_dump_json", "model_extra", "model_fields",
# "model_fields_set", "model_json_schema", "model_parametrized_name",
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
# "model_validate_strings"
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
"""This remaps fields to avoid colliding with internal pydantic fields."""
remapped = {}
for key, value in field_definitions.items():
if key.startswith("_") or key in _RESERVED_NAMES:
# Let's add a prefix to avoid colliding with internal pydantic fields
if isinstance(value, FieldInfo):
msg = (
f"Remapping for fields starting with '_' or fields with a name "
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
f" the field is a pydantic Field instance. Got {key}."
)
raise NotImplementedError(msg)
type_, default_ = value
remapped[f"private_{key}"] = (
type_,
Field(
default=default_,
alias=key,
serialization_alias=key,
title=key.lstrip("_").replace("_", " ").title(),
),
)
else:
remapped[key] = value
return remapped
def create_model(
model_name: str,
*,
field_definitions: dict[str, Any] | None = None,
root: Any | None = None,
) -> type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Attention:
Please do not use outside of langchain packages. This API
is subject to change at any time.
Args:
model_name: The name of the model.
module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
field_definitions: The field definitions for the model.
root: Type for a root model (RootModel)
Returns:
Type[BaseModel]: The created model.
"""
field_definitions = field_definitions or {}
if root:
if field_definitions:
msg = (
"When specifying __root__ no other "
f"fields should be provided. Got {field_definitions}"
)
raise NotImplementedError(msg)
if isinstance(root, tuple):
kwargs = {"type_": root[0], "default_": root[1]}
else:
kwargs = {"type_": root}
try:
named_root_model = _create_root_model_cached(model_name, **kwargs)
except TypeError:
# something in the arguments into _create_root_model_cached is not hashable
named_root_model = _create_root_model(
model_name,
**kwargs,
)
return named_root_model
# No root, just field definitions
names = set(field_definitions.keys())
capture_warnings = False
for name in names:
# Also if any non-reserved name is used (e.g., model_id or model_name)
if name.startswith("model"):
capture_warnings = True
with warnings.catch_warnings() if capture_warnings else nullcontext():
if capture_warnings:
warnings.filterwarnings(action="ignore")
try:
return _create_model_cached(model_name, **field_definitions)
except TypeError:
# something in field definitions is not hashable
return _create_model_base(
model_name,
__config__=_SchemaConfig,
**_remap_field_definitions(field_definitions),
)
def is_supported_by_pydantic(type_: Any) -> bool:
"""Check if a given "complex" type is supported by pydantic.
This will return False for primitive types like int, str, etc.
The check is meant for container types like dataclasses, TypedDicts, etc.
"""
if is_dataclass(type_):
return True
if isinstance(type_, type) and issubclass(type_, BaseModel):
return True
if hasattr(type_, "__orig_bases__"):
for base in type_.__orig_bases__:
if base is TypedDict:
return True
elif base is typing.TypedDict: # noqa: TID251
# ignoring TID251 since it's OK to use typing.TypedDict in this case.
# Pydantic supports typing.TypedDict from Python 3.12
# For older versions, only typing_extensions.TypedDict is supported.
if sys.version_info >= (3, 12):
return True
return False

View File

@@ -0,0 +1,124 @@
# type: ignore
from __future__ import annotations
import asyncio
import queue
import threading
import types
from collections import deque
from time import monotonic
class AsyncQueue(asyncio.Queue):
"""Async unbounded FIFO queue with a wait() method.
Subclassed from asyncio.Queue, adding a wait() method."""
async def wait(self) -> None:
"""If queue is empty, wait until an item is available.
Copied from Queue.get(), removing the call to .get_nowait(),
ie. this doesn't consume the item, just waits for it.
"""
while self.empty():
getter = self._get_loop().create_future()
self._getters.append(getter)
try:
await getter
except BaseException:
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
class Semaphore(threading.Semaphore):
"""Semaphore subclass with a wait() method."""
def wait(self, blocking: bool = True, timeout: float | None = None):
"""Block until the semaphore can be acquired, but don't acquire it."""
if not blocking and timeout is not None:
raise ValueError("can't specify timeout for non-blocking acquire")
rc = False
endtime = None
with self._cond:
while self._value == 0:
if not blocking:
break
if timeout is not None:
if endtime is None:
endtime = monotonic() + timeout
else:
timeout = endtime - monotonic()
if timeout <= 0:
break
self._cond.wait(timeout)
else:
rc = True
return rc
class SyncQueue:
"""Unbounded FIFO queue with a wait() method.
Adapted from pure Python implementation of queue.SimpleQueue.
"""
def __init__(self):
self._queue = deque()
self._count = Semaphore(0)
def put(self, item, block=True, timeout=None):
"""Put the item on the queue.
The optional 'block' and 'timeout' arguments are ignored, as this method
never blocks. They are provided for compatibility with the Queue class.
"""
self._queue.append(item)
self._count.release()
def get(self, block=False, timeout=None):
"""Remove and return an item from the queue.
If optional args 'block' is true and 'timeout' is None (the default),
block if necessary until an item is available. If 'timeout' is
a non-negative number, it blocks at most 'timeout' seconds and raises
the Empty exception if no item was available within that time.
Otherwise ('block' is false), return an item if one is immediately
available, else raise the Empty exception ('timeout' is ignored
in that case).
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
if not self._count.acquire(block, timeout):
raise queue.Empty
try:
return self._queue.popleft()
except IndexError:
raise queue.Empty
def wait(self, block=True, timeout=None):
"""If queue is empty, wait until an item maybe is available,
but don't consume it.
"""
if timeout is not None and timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
self._count.wait(block, timeout)
def empty(self):
"""Return True if the queue is empty, False otherwise (not reliable!)."""
return len(self._queue) == 0
def qsize(self):
"""Return the approximate size of the queue (not reliable!)."""
return len(self._queue)
__class_getitem__ = classmethod(types.GenericAlias)

View File

@@ -0,0 +1,29 @@
def default_retry_on(exc: Exception) -> bool:
import httpx
import requests
if isinstance(exc, ConnectionError):
return True
if isinstance(exc, httpx.HTTPStatusError):
return 500 <= exc.response.status_code < 600
if isinstance(exc, requests.HTTPError):
return 500 <= exc.response.status_code < 600 if exc.response else True
if isinstance(
exc,
(
ValueError,
TypeError,
ArithmeticError,
ImportError,
LookupError,
NameError,
SyntaxError,
RuntimeError,
ReferenceError,
StopIteration,
StopAsyncIteration,
OSError,
),
):
return False
return True

View File

@@ -0,0 +1,914 @@
from __future__ import annotations
import asyncio
import enum
import inspect
import sys
import warnings
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterator,
Sequence,
)
from contextlib import AsyncExitStack, contextmanager
from contextvars import Context, Token, copy_context
from functools import partial, wraps
from typing import (
Any,
Optional,
Protocol,
TypeGuard,
cast,
)
from langchain_core.runnables.base import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableParallel,
RunnableSequence,
)
from langchain_core.runnables.base import (
RunnableLike as LCRunnableLike,
)
from langchain_core.runnables.config import (
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.langchain import LangChainTracer
from langgraph.store.base import BaseStore
from langgraph._internal._config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_RUNTIME,
)
from langgraph._internal._typing import MISSING
from langgraph.types import StreamWriter
try:
from langchain_core.tracers._streaming import _StreamingCallbackHandler
except ImportError:
_StreamingCallbackHandler = None # type: ignore
def _set_config_context(
config: RunnableConfig, run: Any = None
) -> Token[RunnableConfig | None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
config_token = var_child_runnable_config.set(config)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run})
return config_token
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
"""Set the child Runnable config + tracing context.
Args:
token: The config token to reset.
"""
var_child_runnable_config.reset(token)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context(
{
"parent": None,
"project_name": None,
"tags": None,
"metadata": None,
"enabled": None,
"client": None,
}
)
@contextmanager
def set_config_context(
config: RunnableConfig, run: Any = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
ctx = copy_context()
config_token = ctx.run(_set_config_context, config, run)
try:
yield ctx
finally:
ctx.run(_unset_config_context, config_token, run)
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
# Special type to denote any type is accepted
ANY_TYPE = object()
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
# A named argument may appear multiple times if it appears with distinct types.
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
(
"config",
(
RunnableConfig,
"RunnableConfig",
Optional[RunnableConfig], # noqa: UP045
"Optional[RunnableConfig]",
inspect.Parameter.empty,
),
# for now, use config directly, eventually, will pop off of Runtime
"N/A",
inspect.Parameter.empty,
),
(
"writer",
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
"stream_writer",
lambda _: None,
),
(
"store",
(
BaseStore,
"BaseStore",
inspect.Parameter.empty,
),
"store",
inspect.Parameter.empty,
),
(
"store",
(
Optional[BaseStore], # noqa: UP045
"Optional[BaseStore]",
),
"store",
None,
),
(
"previous",
(ANY_TYPE,),
"previous",
inspect.Parameter.empty,
),
(
"runtime",
(ANY_TYPE,),
# we never hit this block, we just inject runtime directly
"N/A",
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Used to configure keyword arguments that can be injected at runtime
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
For a keyword to be injected from the config object, the function signature
must contain a kwarg with the same name and a matching type annotation.
Each tuple contains:
- the name of the kwarg in the function signature
- the type annotation(s) for the kwarg
- the `Runtime` attribute for fetching the value (N/A if not applicable)
This is fully internal and should be further refactored to use `get_type_hints`
to resolve forward references and optional types formatted like BaseStore | None.
"""
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
class _RunnableWithWriter(Protocol[Input, Output]):
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
class _RunnableWithStore(Protocol[Input, Output]):
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
class _RunnableWithWriterStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, writer: StreamWriter, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriter(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
) -> Output: ...
class _RunnableWithConfigStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
def __call__(
self,
state: Input,
*,
config: RunnableConfig,
writer: StreamWriter,
store: BaseStore,
) -> Output: ...
RunnableLike = (
LCRunnableLike
| _RunnableWithWriter[Input, Output]
| _RunnableWithStore[Input, Output]
| _RunnableWithWriterStore[Input, Output]
| _RunnableWithConfigWriter[Input, Output]
| _RunnableWithConfigStore[Input, Output]
| _RunnableWithConfigWriterStore[Input, Output]
)
class RunnableCallable(Runnable):
"""A much simpler version of RunnableLambda that requires sync and async functions."""
def __init__(
self,
func: Callable[..., Any | Runnable] | None,
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
*,
name: str | None = None,
tags: Sequence[str] | None = None,
trace: bool = True,
recurse: bool = True,
explode_args: bool = False,
**kwargs: Any,
) -> None:
self.name = name
if self.name is None:
if func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.func = func
self.afunc = afunc
self.tags = tags
self.kwargs = kwargs
self.trace = trace
self.recurse = recurse
self.explode_args = explode_args
# check signature
if func is None and afunc is None:
raise ValueError("At least one of func or afunc must be provided.")
self.func_accepts: dict[str, tuple[str, Any]] = {}
params = inspect.signature(cast(Callable, func or afunc)).parameters
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
p = params.get(kw)
if p is None or p.kind not in VALID_KINDS:
# If parameter is not found or is not a valid kind, skip
continue
if typ != (ANY_TYPE,) and p.annotation not in typ:
# A specific type is required, but the function annotation does
# not match the expected type.
# If this is a config parameter with incorrect typing, emit a warning
# because we used to support any type but are moving towards more correct typing
if kw == "config" and p.annotation != inspect.Parameter.empty:
warnings.warn(
f"The 'config' parameter should be typed as 'RunnableConfig' or "
f"'RunnableConfig | None', not '{p.annotation}'. ",
UserWarning,
stacklevel=4,
)
continue
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
self.func_accepts[kw] = (runtime_key, default)
def __repr__(self) -> str:
repr_args = {
k: v
for k, v in self.__dict__.items()
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
}
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
def invoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if self.func is None:
raise TypeError(
f'No synchronous function provided to "{self.name}".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg is already set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_callback_manager_for_config(config, self.tags)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
# get the run
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(child_config, run) as context:
ret = context.run(self.func, *args, **kwargs)
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(ret)
else:
ret = self.func(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return ret.invoke(input, config)
return ret
async def ainvoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if not self.afunc:
return self.invoke(input, config)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg has already been set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.name,
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
if ASYNCIO_ACCEPTS_CONTEXT:
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
with set_config_context(child_config, run) as context:
ret = await asyncio.create_task(coro, context=context)
else:
ret = await coro
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(ret)
else:
ret = await self.afunc(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return await ret.ainvoke(input, config)
return ret
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
inspect.iscoroutinefunction(func)
or hasattr(func, "__call__")
and inspect.iscoroutinefunction(func.__call__)
)
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)
def coerce_to_runnable(
thing: RunnableLike, *, name: str | None, trace: bool
) -> Runnable:
"""Coerce a runnable-like object into a Runnable.
Args:
thing: A runnable-like object.
Returns:
A Runnable.
"""
if isinstance(thing, Runnable):
return thing
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableLambda(thing, name=name)
elif callable(thing):
if is_async_callable(thing):
return RunnableCallable(None, thing, name=name, trace=trace)
else:
return RunnableCallable(
thing,
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
name=name,
trace=trace,
)
elif isinstance(thing, dict):
return RunnableParallel(thing)
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
class RunnableSeq(Runnable):
"""Sequence of `Runnable`, where the output of each is the input of the next.
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
LangGraph.
"""
def __init__(
self,
*steps: RunnableLike,
name: str | None = None,
trace_inputs: Callable[[Any], Any] | None = None,
) -> None:
"""Create a new RunnableSeq.
Args:
steps: The steps to include in the sequence.
name: The name of the `Runnable`.
Raises:
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: list[Runnable] = []
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
elif isinstance(step, RunnableSeq):
steps_flat.extend(step.steps)
else:
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
if len(steps_flat) < 2:
raise ValueError(
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
)
self.steps = steps_flat
self.name = name
self.trace_inputs = trace_inputs
def __or__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSeq(
*self.steps,
other.first,
*other.middle,
other.last,
name=self.name or other.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*self.steps,
*other.steps,
name=self.name or other.name,
)
else:
return RunnableSeq(
*self.steps,
coerce_to_runnable(other, name=None, trace=True),
name=self.name,
)
def __ror__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSequence(
other.first,
*other.middle,
other.last,
*self.steps,
name=other.name or self.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*other.steps,
*self.steps,
name=other.name or self.name,
)
else:
return RunnableSequence(
coerce_to_runnable(other, name=None, trace=True),
*self.steps,
name=self.name,
)
def invoke(
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks and context
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(input)
return input
async def ainvoke(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = await asyncio.create_task(
step.ainvoke(input, config, **kwargs), context=context
)
else:
input = await step.ainvoke(input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(input)
return input
def stream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
for idx, step in enumerate(self.steps):
if idx == 0:
iterator = step.stream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
iterator = step.transform(iterator, config)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
iterator = h.tap_output_iter(run_manager.run_id, iterator)
# consume into final output
output = context.run(_consume_iter, iterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
async def astream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{idx + 1}"
),
)
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await asyncio.create_task(
_consume_aiter(aiterator), context=context
)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
else:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await _consume_aiter(aiterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
def _consume_iter(it: Iterator[Any]) -> Any:
"""Consume an iterator."""
output: Any = None
add_supported = False
for chunk in it:
# collect final output
if output is None:
output = chunk
elif add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
"""Consume an async iterator."""
output: Any = None
add_supported = False
async for chunk in it:
# collect final output
if add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output

View File

@@ -0,0 +1,19 @@
import dataclasses
from collections.abc import Callable
from typing import Any
from langgraph.types import _DC_KWARGS
@dataclasses.dataclass(**_DC_KWARGS)
class PregelScratchpad:
step: int
stop: int
# call
call_counter: Callable[[], int]
# interrupt
interrupt_counter: Callable[[], int]
get_null_resume: Callable[[bool], Any]
resume: list[Any]
# subgraph
subgraph_counter: Callable[[], int]

View File

@@ -0,0 +1,253 @@
from __future__ import annotations
import dataclasses
import logging
import sys
import types
from collections import deque
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Union,
get_args,
get_origin,
get_type_hints,
)
from langchain_core import messages as lc_messages
from langgraph.checkpoint.base import BaseCheckpointSaver
from pydantic import BaseModel
from typing_extensions import NotRequired, Required, is_typeddict
try:
from langgraph.checkpoint.serde._msgpack import ( # noqa: F401
STRICT_MSGPACK_ENABLED,
)
except ImportError:
STRICT_MSGPACK_ENABLED = False
_warned_allowlist_unsupported = False
logger = logging.getLogger(__name__)
def _supports_checkpointer_allowlist() -> bool:
return hasattr(BaseCheckpointSaver, "with_allowlist")
_SUPPORTS_ALLOWLIST = _supports_checkpointer_allowlist()
def apply_checkpointer_allowlist(
checkpointer: Any, allowlist: set[tuple[str, ...]] | None
) -> Any:
if not checkpointer or allowlist is None or checkpointer in (True, False):
return checkpointer
if not _SUPPORTS_ALLOWLIST:
global _warned_allowlist_unsupported
if not _warned_allowlist_unsupported:
logger.warning(
"Checkpointer does not support with_allowlist; strict msgpack "
"allowlist will be skipped."
)
_warned_allowlist_unsupported = True
return checkpointer
return checkpointer.with_allowlist(allowlist)
def curated_core_allowlist() -> set[tuple[str, ...]]:
allowlist: set[tuple[str, ...]] = set()
for name in (
"BaseMessage",
"BaseMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"AIMessage",
"AIMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"RemoveMessage",
):
cls = getattr(lc_messages, name, None)
if cls is None:
continue
allowlist.add((cls.__module__, cls.__name__))
return allowlist
def build_serde_allowlist(
*,
schemas: list[type[Any]] | None = None,
channels: dict[str, Any] | None = None,
) -> set[tuple[str, ...]]:
allowlist = curated_core_allowlist()
if schemas:
schemas = [schema for schema in schemas if schema is not None]
return allowlist | collect_allowlist_from_schemas(
schemas=schemas,
channels=channels,
)
def collect_allowlist_from_schemas(
*,
schemas: list[type[Any]] | None = None,
channels: dict[str, Any] | None = None,
) -> set[tuple[str, ...]]:
allowlist: set[tuple[str, ...]] = set()
seen: set[Any] = set()
seen_ids: set[int] = set()
if schemas:
for schema in schemas:
_collect_from_type(schema, allowlist, seen, seen_ids)
if channels:
for channel in channels.values():
value_type = getattr(channel, "ValueType", None)
if value_type is not None:
_collect_from_type(value_type, allowlist, seen, seen_ids)
update_type = getattr(channel, "UpdateType", None)
if update_type is not None:
_collect_from_type(update_type, allowlist, seen, seen_ids)
return allowlist
def _collect_from_type(
typ: Any,
allowlist: set[tuple[str, ...]],
seen: set[Any],
seen_ids: set[int],
) -> None:
if _already_seen(typ, seen, seen_ids):
return
if typ is Any or typ is None:
return
if typ is Literal:
return
if isinstance(typ, types.UnionType):
for arg in typ.__args__:
_collect_from_type(arg, allowlist, seen, seen_ids)
return
origin = get_origin(typ)
if origin is Union:
for arg in get_args(typ):
_collect_from_type(arg, allowlist, seen, seen_ids)
return
if origin is Annotated or origin in (Required, NotRequired):
args = get_args(typ)
if args:
_collect_from_type(args[0], allowlist, seen, seen_ids)
return
if origin is Literal:
return
if origin in (list, set, tuple, dict, deque, frozenset):
for arg in get_args(typ):
_collect_from_type(arg, allowlist, seen, seen_ids)
return
if hasattr(typ, "__supertype__"):
_collect_from_type(typ.__supertype__, allowlist, seen, seen_ids)
return
if is_typeddict(typ):
for field_type in _safe_get_type_hints(typ).values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
return
if _is_pydantic_model(typ):
allowlist.add((typ.__module__, typ.__name__))
field_types = _safe_get_type_hints(typ)
if field_types:
for field_type in field_types.values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
else:
for field_type in _pydantic_field_types(typ):
_collect_from_type(field_type, allowlist, seen, seen_ids)
return
if dataclasses.is_dataclass(typ):
if typ_name := getattr(typ, "__name__", None):
allowlist.add((typ.__module__, typ_name))
field_types = _safe_get_type_hints(typ)
if field_types:
for field_type in field_types.values():
_collect_from_type(field_type, allowlist, seen, seen_ids)
else:
for field in dataclasses.fields(typ):
_collect_from_type(field.type, allowlist, seen, seen_ids)
return
if isinstance(typ, type) and issubclass(typ, Enum):
allowlist.add((typ.__module__, typ.__name__))
return
def _already_seen(typ: Any, seen: set[Any], seen_ids: set[int]) -> bool:
try:
if typ in seen:
return True
seen.add(typ)
return False
except TypeError:
typ_id = id(typ)
if typ_id in seen_ids:
return True
seen_ids.add(typ_id)
return False
def _safe_get_type_hints(typ: Any) -> dict[str, Any]:
try:
module = sys.modules.get(getattr(typ, "__module__", ""))
globalns = module.__dict__ if module else None
localns = dict(vars(typ)) if hasattr(typ, "__dict__") else None
return get_type_hints(
typ, globalns=globalns, localns=localns, include_extras=True
)
except Exception:
return {}
def _is_pydantic_model(typ: Any) -> bool:
if not isinstance(typ, type):
return False
if issubclass(typ, BaseModel):
return True
try:
from pydantic.v1 import BaseModel as BaseModelV1
except Exception:
return False
return issubclass(typ, BaseModelV1)
def _pydantic_field_types(typ: type[Any]) -> list[Any]:
if hasattr(typ, "model_fields"):
return [
field.annotation
for field in typ.model_fields.values()
if getattr(field, "annotation", None) is not None
]
if hasattr(typ, "__fields__"):
return [
field.outer_type_
for field in typ.__fields__.values()
if getattr(field, "outer_type_", None) is not None
]
return []

View File

@@ -0,0 +1,54 @@
"""Private typing utilities for LangGraph."""
from __future__ import annotations
from dataclasses import Field
from typing import Any, ClassVar, Protocol, TypeAlias
from pydantic import BaseModel
from typing_extensions import TypedDict
class TypedDictLikeV1(Protocol):
"""Protocol to represent types that behave like TypedDicts
Version 1: using `ClassVar` for keys."""
__required_keys__: ClassVar[frozenset[str]]
__optional_keys__: ClassVar[frozenset[str]]
class TypedDictLikeV2(Protocol):
"""Protocol to represent types that behave like TypedDicts
Version 2: not using `ClassVar` for keys."""
__required_keys__: frozenset[str]
__optional_keys__: frozenset[str]
class DataclassLike(Protocol):
"""Protocol to represent types that behave like dataclasses.
Inspired by the private _DataclassT from dataclasses that uses a similar protocol as a bound."""
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
StateLike: TypeAlias = TypedDictLikeV1 | TypedDictLikeV2 | DataclassLike | BaseModel
"""Type alias for state-like types.
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in type checking.
"""
MISSING = object()
"""Unset sentinel value."""
class DeprecatedKwargs(TypedDict):
"""TypedDict to use for extra keyword arguments, enabling type checking warnings for deprecated arguments."""
EMPTY_SEQ: tuple[str, ...] = tuple()
"""An empty sequence of strings."""