initial commit
This commit is contained in:
4
venv/Lib/site-packages/langgraph/_internal/__init__.py
Normal file
4
venv/Lib/site-packages/langgraph/_internal/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Internal modules for LangGraph.
|
||||
|
||||
This module is not part of the public API, and thus stability is not guaranteed.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
31
venv/Lib/site-packages/langgraph/_internal/_cache.py
Normal file
31
venv/Lib/site-packages/langgraph/_internal/_cache.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _freeze(obj: Any, depth: int = 10) -> Hashable:
|
||||
if isinstance(obj, Hashable) or depth <= 0:
|
||||
# already hashable, no need to freeze
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
# sort keys so {"a":1,"b":2} == {"b":2,"a":1}
|
||||
return tuple(sorted((k, _freeze(v, depth - 1)) for k, v in obj.items()))
|
||||
elif isinstance(obj, Sequence):
|
||||
return tuple(_freeze(x, depth - 1) for x in obj)
|
||||
# numpy / pandas etc. can provide their own .tobytes()
|
||||
elif hasattr(obj, "tobytes"):
|
||||
return (
|
||||
type(obj).__name__,
|
||||
obj.tobytes(),
|
||||
obj.shape if hasattr(obj, "shape") else None,
|
||||
)
|
||||
return obj # strings, ints, dataclasses with frozen=True, etc.
|
||||
|
||||
|
||||
def default_cache_key(*args: Any, **kwargs: Any) -> str | bytes:
|
||||
"""Default cache key function that uses the arguments and keyword arguments to generate a hashable key."""
|
||||
import pickle
|
||||
|
||||
# protocol 5 strikes a good balance between speed and size
|
||||
return pickle.dumps((_freeze(args), _freeze(kwargs)), protocol=5, fix_imports=False)
|
||||
329
venv/Lib/site-packages/langgraph/_internal/_config.py
Normal file
329
venv/Lib/site-packages/langgraph/_internal/_config.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import ChainMap
|
||||
from collections.abc import Mapping, Sequence
|
||||
from os import getenv
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import (
|
||||
CONFIG_KEYS,
|
||||
COPIABLE_KEYS,
|
||||
var_child_runnable_config,
|
||||
)
|
||||
from langgraph.checkpoint.base import CheckpointMetadata
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CHECKPOINT_ID,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
NS_END,
|
||||
NS_SEP,
|
||||
)
|
||||
|
||||
DEFAULT_RECURSION_LIMIT = int(getenv("LANGGRAPH_DEFAULT_RECURSION_LIMIT", "10000"))
|
||||
|
||||
|
||||
def recast_checkpoint_ns(ns: str) -> str:
|
||||
"""Remove task IDs from checkpoint namespace.
|
||||
|
||||
Args:
|
||||
ns: The checkpoint namespace with task IDs.
|
||||
|
||||
Returns:
|
||||
str: The checkpoint namespace without task IDs.
|
||||
"""
|
||||
return NS_SEP.join(
|
||||
part.split(NS_END)[0] for part in ns.split(NS_SEP) if not part.isdigit()
|
||||
)
|
||||
|
||||
|
||||
def patch_configurable(
|
||||
config: RunnableConfig | None, patch: dict[str, Any]
|
||||
) -> RunnableConfig:
|
||||
if config is None:
|
||||
return {CONF: patch}
|
||||
elif CONF not in config:
|
||||
return {**config, CONF: patch}
|
||||
else:
|
||||
return {**config, CONF: {**config[CONF], **patch}}
|
||||
|
||||
|
||||
def patch_checkpoint_map(
|
||||
config: RunnableConfig | None, metadata: CheckpointMetadata | None
|
||||
) -> RunnableConfig:
|
||||
if config is None:
|
||||
return config
|
||||
elif parents := (metadata.get("parents") if metadata else None):
|
||||
conf = config[CONF]
|
||||
return patch_configurable(
|
||||
config,
|
||||
{
|
||||
CONFIG_KEY_CHECKPOINT_MAP: {
|
||||
**parents,
|
||||
conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID],
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
"""Merge multiple configs into one.
|
||||
|
||||
Args:
|
||||
*configs: The configs to merge.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged config.
|
||||
"""
|
||||
base: RunnableConfig = {}
|
||||
# Even though the keys aren't literals, this is correct
|
||||
# because both dicts are the same type
|
||||
for config in configs:
|
||||
if config is None:
|
||||
continue
|
||||
for key, value in config.items():
|
||||
if not value:
|
||||
continue
|
||||
if key == "metadata":
|
||||
if base_value := base.get(key):
|
||||
base[key] = {**base_value, **value} # type: ignore
|
||||
else:
|
||||
base[key] = value # type: ignore[literal-required]
|
||||
elif key == "tags":
|
||||
if base_value := base.get(key):
|
||||
base[key] = [*base_value, *value] # type: ignore
|
||||
else:
|
||||
base[key] = value # type: ignore[literal-required]
|
||||
elif key == CONF:
|
||||
if base_value := base.get(key):
|
||||
base[key] = {**base_value, **value} # type: ignore[dict-item]
|
||||
else:
|
||||
base[key] = value
|
||||
elif key == "callbacks":
|
||||
base_callbacks = base.get("callbacks")
|
||||
# callbacks can be either None, list[handler] or manager
|
||||
# so merging two callbacks values has 6 cases
|
||||
if isinstance(value, list):
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = value.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
base["callbacks"] = base_callbacks + value
|
||||
else:
|
||||
# base_callbacks is a manager
|
||||
mngr = base_callbacks.copy()
|
||||
for callback in value:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
elif isinstance(value, BaseCallbackManager):
|
||||
# value is a manager
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = value.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
mngr = value.copy()
|
||||
for callback in base_callbacks:
|
||||
mngr.add_handler(callback, inherit=True)
|
||||
base["callbacks"] = mngr
|
||||
else:
|
||||
# base_callbacks is also a manager
|
||||
base["callbacks"] = base_callbacks.merge(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif key == "recursion_limit":
|
||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||
base["recursion_limit"] = config["recursion_limit"]
|
||||
else:
|
||||
base[key] = config[key] # type: ignore[literal-required]
|
||||
if CONF not in base:
|
||||
base[CONF] = {}
|
||||
return base
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: RunnableConfig | None,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
recursion_limit: int | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
run_name: str | None = None,
|
||||
configurable: dict[str, Any] | None = None,
|
||||
) -> RunnableConfig:
|
||||
"""Patch a config with new values.
|
||||
|
||||
Args:
|
||||
config: The config to patch.
|
||||
callbacks: The callbacks to set.
|
||||
recursion_limit: The recursion limit to set.
|
||||
max_concurrency: The max number of concurrent steps to run, which also applies to parallelized steps.
|
||||
run_name: The run name to set.
|
||||
configurable: The configurable to set.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The patched config.
|
||||
"""
|
||||
config = config.copy() if config is not None else {}
|
||||
if callbacks is not None:
|
||||
# If we're replacing callbacks, we need to unset run_name
|
||||
# As that should apply only to the same run as the original callbacks
|
||||
config["callbacks"] = callbacks
|
||||
if "run_name" in config:
|
||||
del config["run_name"]
|
||||
if "run_id" in config:
|
||||
del config["run_id"]
|
||||
if recursion_limit is not None:
|
||||
config["recursion_limit"] = recursion_limit
|
||||
if max_concurrency is not None:
|
||||
config["max_concurrency"] = max_concurrency
|
||||
if run_name is not None:
|
||||
config["run_name"] = run_name
|
||||
if configurable is not None:
|
||||
config[CONF] = {**config.get(CONF, {}), **configurable}
|
||||
return config
|
||||
|
||||
|
||||
def get_callback_manager_for_config(
|
||||
config: RunnableConfig, tags: Sequence[str] | None = None
|
||||
) -> CallbackManager:
|
||||
"""Get a callback manager for a config.
|
||||
|
||||
Args:
|
||||
config: The config.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
# merge tags
|
||||
all_tags = config.get("tags")
|
||||
if all_tags is not None and tags is not None:
|
||||
all_tags = [*all_tags, *tags]
|
||||
elif tags is not None:
|
||||
all_tags = list(tags)
|
||||
# use existing callbacks if they exist
|
||||
if (callbacks := config.get("callbacks")) and isinstance(
|
||||
callbacks, CallbackManager
|
||||
):
|
||||
if all_tags:
|
||||
callbacks.add_tags(all_tags)
|
||||
if metadata := config.get("metadata"):
|
||||
callbacks.add_metadata(metadata)
|
||||
return callbacks
|
||||
else:
|
||||
# otherwise create a new manager
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=all_tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def get_async_callback_manager_for_config(
|
||||
config: RunnableConfig,
|
||||
tags: Sequence[str] | None = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Get an async callback manager for a config.
|
||||
|
||||
Args:
|
||||
config: The config.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager.
|
||||
"""
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
# merge tags
|
||||
all_tags = config.get("tags")
|
||||
if all_tags is not None and tags is not None:
|
||||
all_tags = [*all_tags, *tags]
|
||||
elif tags is not None:
|
||||
all_tags = list(tags)
|
||||
# use existing callbacks if they exist
|
||||
if (callbacks := config.get("callbacks")) and isinstance(
|
||||
callbacks, AsyncCallbackManager
|
||||
):
|
||||
if all_tags:
|
||||
callbacks.add_tags(all_tags)
|
||||
if metadata := config.get("metadata"):
|
||||
callbacks.add_metadata(metadata)
|
||||
return callbacks
|
||||
else:
|
||||
# otherwise create a new manager
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=all_tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def _is_not_empty(value: Any) -> bool:
|
||||
if isinstance(value, (list, tuple, dict)):
|
||||
return len(value) > 0
|
||||
else:
|
||||
return value is not None
|
||||
|
||||
|
||||
def ensure_config(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
"""Return a config with all keys, merging any provided configs.
|
||||
|
||||
Args:
|
||||
*configs: Configs to merge before ensuring defaults.
|
||||
|
||||
Returns:
|
||||
RunnableConfig: The merged and ensured config.
|
||||
"""
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata=ChainMap(),
|
||||
callbacks=None,
|
||||
recursion_limit=DEFAULT_RECURSION_LIMIT,
|
||||
configurable={},
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
{
|
||||
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||
for k, v in var_config.items()
|
||||
if _is_not_empty(v)
|
||||
},
|
||||
)
|
||||
for config in configs:
|
||||
if config is None:
|
||||
continue
|
||||
for k, v in config.items():
|
||||
if _is_not_empty(v) and k in CONFIG_KEYS:
|
||||
if k == CONF:
|
||||
empty[k] = cast(dict, v).copy()
|
||||
else:
|
||||
empty[k] = v # type: ignore[literal-required]
|
||||
for k, v in config.items():
|
||||
if _is_not_empty(v) and k not in CONFIG_KEYS:
|
||||
empty[CONF][k] = v
|
||||
_empty_metadata = empty["metadata"]
|
||||
for key, value in empty[CONF].items():
|
||||
if _exclude_as_metadata(key, value, _empty_metadata):
|
||||
continue
|
||||
_empty_metadata[key] = value
|
||||
return empty
|
||||
|
||||
|
||||
_OMIT = ("key", "token", "secret", "password", "auth")
|
||||
|
||||
|
||||
def _exclude_as_metadata(key: str, value: Any, metadata: Mapping[str, Any]) -> bool:
|
||||
key_lower = key.casefold()
|
||||
return (
|
||||
key.startswith("__")
|
||||
or not isinstance(value, (str, int, float, bool))
|
||||
or key in metadata
|
||||
or any(substr in key_lower for substr in _OMIT)
|
||||
)
|
||||
112
venv/Lib/site-packages/langgraph/_internal/_constants.py
Normal file
112
venv/Lib/site-packages/langgraph/_internal/_constants.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Constants used for Pregel operations."""
|
||||
|
||||
import sys
|
||||
from typing import Literal, cast
|
||||
|
||||
# --- Reserved write keys ---
|
||||
INPUT = sys.intern("__input__")
|
||||
# for values passed as input to the graph
|
||||
INTERRUPT = sys.intern("__interrupt__")
|
||||
# for dynamic interrupts raised by nodes
|
||||
RESUME = sys.intern("__resume__")
|
||||
# for values passed to resume a node after an interrupt
|
||||
ERROR = sys.intern("__error__")
|
||||
# for errors raised by nodes
|
||||
NO_WRITES = sys.intern("__no_writes__")
|
||||
# marker to signal node didn't write anything
|
||||
TASKS = sys.intern("__pregel_tasks")
|
||||
# for Send objects returned by nodes/edges, corresponds to PUSH below
|
||||
RETURN = sys.intern("__return__")
|
||||
# for writes of a task where we simply record the return value
|
||||
PREVIOUS = sys.intern("__previous__")
|
||||
# the implicit branch that handles each node's Control values
|
||||
|
||||
|
||||
# --- Reserved cache namespaces ---
|
||||
CACHE_NS_WRITES = sys.intern("__pregel_ns_writes")
|
||||
# cache namespace for node writes
|
||||
|
||||
# --- Reserved config.configurable keys ---
|
||||
CONFIG_KEY_SEND = sys.intern("__pregel_send")
|
||||
# holds the `write` function that accepts writes to state/edges/reserved keys
|
||||
CONFIG_KEY_READ = sys.intern("__pregel_read")
|
||||
# holds the `read` function that returns a copy of the current state
|
||||
CONFIG_KEY_CALL = sys.intern("__pregel_call")
|
||||
# holds the `call` function that accepts a node/func, args and returns a future
|
||||
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
|
||||
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
|
||||
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
|
||||
# holds a `StreamProtocol` passed from parent graph to child graphs
|
||||
CONFIG_KEY_CACHE = sys.intern("__pregel_cache")
|
||||
# holds a `BaseCache` made available to subgraphs
|
||||
CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming")
|
||||
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
|
||||
CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id")
|
||||
# holds the task ID for the current task
|
||||
CONFIG_KEY_THREAD_ID = sys.intern("thread_id")
|
||||
# holds the thread ID for the current invocation
|
||||
CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map")
|
||||
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
|
||||
CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id")
|
||||
# holds the current checkpoint_id, if any
|
||||
CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns")
|
||||
# holds the current checkpoint_ns, "" for root graph
|
||||
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
|
||||
# holds a callback to be called when a node is finished
|
||||
CONFIG_KEY_SCRATCHPAD = sys.intern("__pregel_scratchpad")
|
||||
# holds a mutable dict for temporary storage scoped to the current task
|
||||
CONFIG_KEY_RUNNER_SUBMIT = sys.intern("__pregel_runner_submit")
|
||||
# holds a function that receives tasks from runner, executes them and returns results
|
||||
CONFIG_KEY_DURABILITY = sys.intern("__pregel_durability")
|
||||
# holds the durability mode, one of "sync", "async", or "exit"
|
||||
CONFIG_KEY_RUNTIME = sys.intern("__pregel_runtime")
|
||||
# holds a `Runtime` instance with context, store, stream writer, etc.
|
||||
CONFIG_KEY_RESUME_MAP = sys.intern("__pregel_resume_map")
|
||||
# holds a mapping of task ns -> resume value for resuming tasks
|
||||
|
||||
# --- Other constants ---
|
||||
PUSH = sys.intern("__pregel_push")
|
||||
# denotes push-style tasks, ie. those created by Send objects
|
||||
PULL = sys.intern("__pregel_pull")
|
||||
# denotes pull-style tasks, ie. those triggered by edges
|
||||
NS_SEP = sys.intern("|")
|
||||
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
|
||||
NS_END = sys.intern(":")
|
||||
# for checkpoint_ns, for each level, separates the namespace from the task_id
|
||||
CONF = cast(Literal["configurable"], sys.intern("configurable"))
|
||||
# key for the configurable dict in RunnableConfig
|
||||
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
|
||||
# the task_id to use for writes that are not associated with a task
|
||||
OVERWRITE = sys.intern("__overwrite__")
|
||||
# dict key for the overwrite value, used as `{'__overwrite__': value}`
|
||||
|
||||
# redefined to avoid circular import with langgraph.constants
|
||||
_TAG_HIDDEN = sys.intern("langsmith:hidden")
|
||||
|
||||
RESERVED = {
|
||||
_TAG_HIDDEN,
|
||||
# reserved write keys
|
||||
INPUT,
|
||||
INTERRUPT,
|
||||
RESUME,
|
||||
ERROR,
|
||||
NO_WRITES,
|
||||
# reserved config.configurable keys
|
||||
CONFIG_KEY_SEND,
|
||||
CONFIG_KEY_READ,
|
||||
CONFIG_KEY_CHECKPOINTER,
|
||||
CONFIG_KEY_STREAM,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_RESUMING,
|
||||
CONFIG_KEY_TASK_ID,
|
||||
CONFIG_KEY_CHECKPOINT_MAP,
|
||||
CONFIG_KEY_CHECKPOINT_ID,
|
||||
CONFIG_KEY_CHECKPOINT_NS,
|
||||
CONFIG_KEY_RESUME_MAP,
|
||||
# other constants
|
||||
PUSH,
|
||||
PULL,
|
||||
NS_SEP,
|
||||
NS_END,
|
||||
CONF,
|
||||
}
|
||||
213
venv/Lib/site-packages/langgraph/_internal/_fields.py
Normal file
213
venv/Lib/site-packages/langgraph/_internal/_fields.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import types
|
||||
import weakref
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Annotated, Any, Optional, Union, get_origin, get_type_hints
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, ReadOnly, Required
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
|
||||
|
||||
def _is_optional_type(type_: Any) -> bool:
|
||||
"""Check if a type is Optional."""
|
||||
|
||||
# Handle new union syntax (PEP 604): str | None
|
||||
if isinstance(type_, types.UnionType):
|
||||
return any(
|
||||
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
|
||||
)
|
||||
|
||||
if hasattr(type_, "__origin__") and hasattr(type_, "__args__"):
|
||||
origin = get_origin(type_)
|
||||
if origin is Optional:
|
||||
return True
|
||||
if origin is Union:
|
||||
return any(
|
||||
arg is type(None) or _is_optional_type(arg) for arg in type_.__args__
|
||||
)
|
||||
if origin is Annotated:
|
||||
return _is_optional_type(type_.__args__[0])
|
||||
return origin is None
|
||||
if hasattr(type_, "__bound__") and type_.__bound__ is not None:
|
||||
return _is_optional_type(type_.__bound__)
|
||||
return type_ is None
|
||||
|
||||
|
||||
def _is_required_type(type_: Any) -> bool | None:
|
||||
"""Check if an annotation is marked as Required/NotRequired.
|
||||
|
||||
Returns:
|
||||
- True if required
|
||||
- False if not required
|
||||
- None if not annotated with either
|
||||
"""
|
||||
origin = get_origin(type_)
|
||||
if origin is Required:
|
||||
return True
|
||||
if origin is NotRequired:
|
||||
return False
|
||||
if origin is Annotated or getattr(origin, "__args__", None):
|
||||
# See https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-annotated
|
||||
return _is_required_type(type_.__args__[0])
|
||||
return None
|
||||
|
||||
|
||||
def _is_readonly_type(type_: Any) -> bool:
|
||||
"""Check if an annotation is marked as ReadOnly.
|
||||
|
||||
Returns:
|
||||
- True if is read only
|
||||
- False if not read only
|
||||
"""
|
||||
|
||||
# See: https://typing.readthedocs.io/en/latest/spec/typeddict.html#typing-readonly-type-qualifier
|
||||
origin = get_origin(type_)
|
||||
if origin is Annotated:
|
||||
return _is_readonly_type(type_.__args__[0])
|
||||
if origin is ReadOnly:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_DEFAULT_KEYS: frozenset[str] = frozenset()
|
||||
|
||||
|
||||
def get_field_default(name: str, type_: Any, schema: type[Any]) -> Any:
|
||||
"""Determine the default value for a field in a state schema.
|
||||
|
||||
This is based on:
|
||||
If TypedDict:
|
||||
- Required/NotRequired
|
||||
- total=False -> everything optional
|
||||
- Type annotation (Optional/Union[None])
|
||||
"""
|
||||
optional_keys = getattr(schema, "__optional_keys__", _DEFAULT_KEYS)
|
||||
irq = _is_required_type(type_)
|
||||
if name in optional_keys:
|
||||
# Either total=False or explicit NotRequired.
|
||||
# No type annotation trumps this.
|
||||
if irq:
|
||||
# Unless it's earlier versions of python & explicit Required
|
||||
return ...
|
||||
return None
|
||||
if irq is not None:
|
||||
if irq:
|
||||
# Handle Required[<type>]
|
||||
# (we already handled NotRequired and total=False)
|
||||
return ...
|
||||
# Handle NotRequired[<type>] for earlier versions of python
|
||||
return None
|
||||
if dataclasses.is_dataclass(schema):
|
||||
field_info = next(
|
||||
(f for f in dataclasses.fields(schema) if f.name == name), None
|
||||
)
|
||||
if field_info:
|
||||
if (
|
||||
field_info.default is not dataclasses.MISSING
|
||||
and field_info.default is not ...
|
||||
):
|
||||
return field_info.default
|
||||
elif field_info.default_factory is not dataclasses.MISSING:
|
||||
return field_info.default_factory()
|
||||
# Note, we ignore ReadOnly attributes,
|
||||
# as they don't make much sense. (we don't care if you mutate the state in your node)
|
||||
# and mutating state in your node has no effect on our graph state.
|
||||
# Base case is the annotation
|
||||
if _is_optional_type(type_):
|
||||
return None
|
||||
return ...
|
||||
|
||||
|
||||
def get_enhanced_type_hints(
|
||||
type: type[Any],
|
||||
) -> Generator[tuple[str, Any, Any, str | None], None, None]:
|
||||
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
|
||||
for name, typ in get_type_hints(type).items():
|
||||
default = None
|
||||
description = None
|
||||
|
||||
# Pydantic models
|
||||
try:
|
||||
if hasattr(type, "model_fields") and name in type.model_fields:
|
||||
field = type.model_fields[name]
|
||||
|
||||
if hasattr(field, "description") and field.description is not None:
|
||||
description = field.description
|
||||
|
||||
if hasattr(field, "default") and field.default is not None:
|
||||
default = field.default
|
||||
if (
|
||||
hasattr(default, "__class__")
|
||||
and getattr(default.__class__, "__name__", "")
|
||||
== "PydanticUndefinedType"
|
||||
):
|
||||
default = None
|
||||
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
# TypedDict, dataclass
|
||||
try:
|
||||
if hasattr(type, "__dict__"):
|
||||
type_dict = getattr(type, "__dict__")
|
||||
|
||||
if name in type_dict:
|
||||
default = type_dict[name]
|
||||
except (AttributeError, KeyError, TypeError):
|
||||
pass
|
||||
|
||||
yield name, typ, default, description
|
||||
|
||||
|
||||
def get_update_as_tuples(input: Any, keys: Sequence[str]) -> list[tuple[str, Any]]:
|
||||
"""Get Pydantic state update as a list of (key, value) tuples."""
|
||||
if isinstance(input, BaseModel):
|
||||
keep = input.model_fields_set
|
||||
defaults = {k: v.default for k, v in type(input).model_fields.items()}
|
||||
else:
|
||||
keep = None
|
||||
defaults = {}
|
||||
|
||||
# NOTE: This behavior for Pydantic is somewhat inelegant,
|
||||
# but we keep around for backwards compatibility
|
||||
# if input is a Pydantic model, only update values
|
||||
# that are different from the default values or in the keep set
|
||||
return [
|
||||
(k, value)
|
||||
for k in keys
|
||||
if (value := getattr(input, k, MISSING)) is not MISSING
|
||||
and (
|
||||
value is not None
|
||||
or defaults.get(k, MISSING) is not None
|
||||
or (keep is not None and k in keep)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANNOTATED_KEYS_CACHE: weakref.WeakKeyDictionary[type[Any], tuple[str, ...]] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def get_cached_annotated_keys(obj: type[Any]) -> tuple[str, ...]:
|
||||
"""Return cached annotated keys for a Python class."""
|
||||
if obj in ANNOTATED_KEYS_CACHE:
|
||||
return ANNOTATED_KEYS_CACHE[obj]
|
||||
if isinstance(obj, type):
|
||||
keys: list[str] = []
|
||||
for base in reversed(obj.__mro__):
|
||||
ann = base.__dict__.get("__annotations__")
|
||||
# In Python 3.14+, Pydantic models use descriptors for __annotations__
|
||||
# so we need to fall back to getattr if __dict__.get returns None
|
||||
if ann is None:
|
||||
ann = getattr(base, "__annotations__", None)
|
||||
if ann is None or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
keys.extend(ann.keys())
|
||||
return ANNOTATED_KEYS_CACHE.setdefault(obj, tuple(keys))
|
||||
else:
|
||||
raise TypeError(f"Expected a type, got {type(obj)}. ")
|
||||
220
venv/Lib/site-packages/langgraph/_internal/_future.py
Normal file
220
venv/Lib/site-packages/langgraph/_internal/_future.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Awaitable, Coroutine, Generator
|
||||
from typing import TypeVar, cast
|
||||
|
||||
T = TypeVar("T")
|
||||
AnyFuture = asyncio.Future | concurrent.futures.Future
|
||||
|
||||
CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)
|
||||
EAGER_NOT_SUPPORTED = sys.version_info < (3, 12)
|
||||
|
||||
|
||||
def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
|
||||
# Tries to call Future.get_loop() if it's available.
|
||||
# Otherwise fallbacks to using the old '_loop' property.
|
||||
try:
|
||||
get_loop = fut.get_loop
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
return get_loop()
|
||||
return fut._loop
|
||||
|
||||
|
||||
def _convert_future_exc(exc: BaseException) -> BaseException:
|
||||
exc_class = type(exc)
|
||||
if exc_class is concurrent.futures.CancelledError:
|
||||
return asyncio.CancelledError(*exc.args)
|
||||
elif exc_class is concurrent.futures.TimeoutError:
|
||||
return asyncio.TimeoutError(*exc.args)
|
||||
elif exc_class is concurrent.futures.InvalidStateError:
|
||||
return asyncio.InvalidStateError(*exc.args)
|
||||
else:
|
||||
return exc
|
||||
|
||||
|
||||
def _set_concurrent_future_state(
|
||||
concurrent: concurrent.futures.Future,
|
||||
source: AnyFuture,
|
||||
) -> None:
|
||||
"""Copy state from a future to a concurrent.futures.Future."""
|
||||
assert source.done()
|
||||
if source.cancelled():
|
||||
concurrent.cancel()
|
||||
if not concurrent.set_running_or_notify_cancel():
|
||||
return
|
||||
exception = source.exception()
|
||||
if exception is not None:
|
||||
concurrent.set_exception(_convert_future_exc(exception))
|
||||
else:
|
||||
result = source.result()
|
||||
concurrent.set_result(result)
|
||||
|
||||
|
||||
def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
|
||||
"""Internal helper to copy state from another Future.
|
||||
|
||||
The other Future may be a concurrent.futures.Future.
|
||||
"""
|
||||
if dest.done():
|
||||
return
|
||||
assert source.done()
|
||||
if dest.cancelled():
|
||||
return
|
||||
if source.cancelled():
|
||||
dest.cancel()
|
||||
else:
|
||||
exception = source.exception()
|
||||
if exception is not None:
|
||||
dest.set_exception(_convert_future_exc(exception))
|
||||
else:
|
||||
result = source.result()
|
||||
dest.set_result(result)
|
||||
|
||||
|
||||
def _chain_future(source: AnyFuture, destination: AnyFuture) -> None:
|
||||
"""Chain two futures so that when one completes, so does the other.
|
||||
|
||||
The result (or exception) of source will be copied to destination.
|
||||
If destination is cancelled, source gets cancelled too.
|
||||
Compatible with both asyncio.Future and concurrent.futures.Future.
|
||||
"""
|
||||
if not asyncio.isfuture(source) and not isinstance(
|
||||
source, concurrent.futures.Future
|
||||
):
|
||||
raise TypeError("A future is required for source argument")
|
||||
if not asyncio.isfuture(destination) and not isinstance(
|
||||
destination, concurrent.futures.Future
|
||||
):
|
||||
raise TypeError("A future is required for destination argument")
|
||||
source_loop = _get_loop(source) if asyncio.isfuture(source) else None
|
||||
dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None
|
||||
|
||||
def _set_state(future: AnyFuture, other: AnyFuture) -> None:
|
||||
if asyncio.isfuture(future):
|
||||
_copy_future_state(other, future)
|
||||
else:
|
||||
_set_concurrent_future_state(future, other)
|
||||
|
||||
def _call_check_cancel(destination: AnyFuture) -> None:
|
||||
if destination.cancelled():
|
||||
if source_loop is None or source_loop is dest_loop:
|
||||
source.cancel()
|
||||
else:
|
||||
source_loop.call_soon_threadsafe(source.cancel)
|
||||
|
||||
def _call_set_state(source: AnyFuture) -> None:
|
||||
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
|
||||
return
|
||||
if dest_loop is None or dest_loop is source_loop:
|
||||
_set_state(destination, source)
|
||||
else:
|
||||
if dest_loop.is_closed():
|
||||
return
|
||||
dest_loop.call_soon_threadsafe(_set_state, destination, source)
|
||||
|
||||
destination.add_done_callback(_call_check_cancel)
|
||||
source.add_done_callback(_call_set_state)
|
||||
|
||||
|
||||
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
|
||||
# adapted from asyncio.run_coroutine_threadsafe
|
||||
try:
|
||||
_chain_future(source, destination)
|
||||
return destination
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if isinstance(destination, concurrent.futures.Future):
|
||||
if destination.set_running_or_notify_cancel():
|
||||
destination.set_exception(exc)
|
||||
else:
|
||||
destination.set_exception(exc)
|
||||
raise
|
||||
|
||||
|
||||
def _ensure_future(
|
||||
coro_or_future: Coroutine[None, None, T] | Awaitable[T],
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
name: str | None = None,
|
||||
context: contextvars.Context | None = None,
|
||||
lazy: bool = True,
|
||||
) -> asyncio.Task[T]:
|
||||
called_wrap_awaitable = False
|
||||
if not asyncio.iscoroutine(coro_or_future):
|
||||
if inspect.isawaitable(coro_or_future):
|
||||
coro_or_future = cast(
|
||||
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
|
||||
)
|
||||
called_wrap_awaitable = True
|
||||
else:
|
||||
raise TypeError(
|
||||
"An asyncio.Future, a coroutine or an awaitable is required."
|
||||
f" Got {type(coro_or_future).__name__} instead."
|
||||
)
|
||||
|
||||
try:
|
||||
if CONTEXT_NOT_SUPPORTED:
|
||||
return loop.create_task(coro_or_future, name=name)
|
||||
elif EAGER_NOT_SUPPORTED or lazy:
|
||||
return loop.create_task(coro_or_future, name=name, context=context)
|
||||
else:
|
||||
return asyncio.eager_task_factory(
|
||||
loop, coro_or_future, name=name, context=context
|
||||
)
|
||||
except RuntimeError:
|
||||
if not called_wrap_awaitable:
|
||||
coro_or_future.close()
|
||||
raise
|
||||
|
||||
|
||||
@types.coroutine
|
||||
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
|
||||
"""Helper for asyncio.ensure_future().
|
||||
|
||||
Wraps awaitable (an object with __await__) into a coroutine
|
||||
that will later be wrapped in a Task by ensure_future().
|
||||
"""
|
||||
return (yield from awaitable.__await__())
|
||||
|
||||
|
||||
def run_coroutine_threadsafe(
|
||||
coro: Coroutine[None, None, T],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
*,
|
||||
lazy: bool,
|
||||
name: str | None = None,
|
||||
context: contextvars.Context | None = None,
|
||||
) -> asyncio.Future[T]:
|
||||
"""Submit a coroutine object to a given event loop.
|
||||
|
||||
Return an asyncio.Future to access the result.
|
||||
"""
|
||||
|
||||
if asyncio._get_running_loop() is loop:
|
||||
return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy)
|
||||
else:
|
||||
future: asyncio.Future[T] = asyncio.Future(loop=loop)
|
||||
|
||||
def callback() -> None:
|
||||
try:
|
||||
chain_future(
|
||||
_ensure_future(coro, loop=loop, name=name, context=context),
|
||||
future,
|
||||
)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
raise
|
||||
except BaseException as exc:
|
||||
future.set_exception(exc)
|
||||
raise
|
||||
|
||||
loop.call_soon_threadsafe(callback, context=context)
|
||||
return future
|
||||
275
venv/Lib/site-packages/langgraph/_internal/_pydantic.py
Normal file
275
venv/Lib/site-packages/langgraph/_internal/_pydantic.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import is_dataclass
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
)
|
||||
from pydantic import (
|
||||
create_model as _create_model_base,
|
||||
)
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaMode,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModel) -> dict[str, FieldInfo]: ...
|
||||
|
||||
|
||||
def get_fields(
|
||||
model: type[BaseModel] | BaseModel,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
_SchemaConfig = ConfigDict(
|
||||
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
||||
)
|
||||
|
||||
NO_DEFAULT = object()
|
||||
|
||||
|
||||
def _create_root_model(
|
||||
name: str,
|
||||
type_: Any,
|
||||
module_name: str | None = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
def schema(
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
) -> dict[str, Any]:
|
||||
# Complains about schema not being defined in superclass
|
||||
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
||||
by_alias=by_alias, ref_template=ref_template
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
def model_json_schema(
|
||||
cls: type[BaseModel],
|
||||
by_alias: bool = True, # noqa: FBT001,FBT002
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = "validation",
|
||||
) -> dict[str, Any]:
|
||||
# Complains about model_json_schema not being defined in superclass
|
||||
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
||||
by_alias=by_alias,
|
||||
ref_template=ref_template,
|
||||
schema_generator=schema_generator,
|
||||
mode=mode,
|
||||
)
|
||||
schema_["title"] = name
|
||||
return schema_
|
||||
|
||||
base_class_attributes = {
|
||||
"__annotations__": {"root": type_},
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": module_name or "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
base_class_attributes["root"] = default_
|
||||
with warnings.catch_warnings():
|
||||
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
||||
return cast("type[BaseModel]", custom_root_type)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_root_model_cached(
|
||||
model_name: str,
|
||||
type_: Any,
|
||||
*,
|
||||
module_name: str | None = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> type[BaseModel]:
|
||||
return _create_root_model(
|
||||
model_name, type_, default_=default_, module_name=module_name
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _create_model_cached(
|
||||
model_name: str,
|
||||
/,
|
||||
**field_definitions: Any,
|
||||
) -> type[BaseModel]:
|
||||
return _create_model_base(
|
||||
model_name,
|
||||
__config__=_SchemaConfig,
|
||||
**_remap_field_definitions(field_definitions),
|
||||
)
|
||||
|
||||
|
||||
# Reserved names should capture all the `public` names / methods that are
|
||||
# used by BaseModel internally. This will keep the reserved names up-to-date.
|
||||
# For reference, the reserved names are:
|
||||
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
|
||||
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
|
||||
# "model_computed_fields", "model_config", "model_construct", "model_copy",
|
||||
# "model_dump", "model_dump_json", "model_extra", "model_fields",
|
||||
# "model_fields_set", "model_json_schema", "model_parametrized_name",
|
||||
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
|
||||
# "model_validate_strings"
|
||||
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
||||
|
||||
|
||||
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
||||
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
||||
|
||||
remapped = {}
|
||||
for key, value in field_definitions.items():
|
||||
if key.startswith("_") or key in _RESERVED_NAMES:
|
||||
# Let's add a prefix to avoid colliding with internal pydantic fields
|
||||
if isinstance(value, FieldInfo):
|
||||
msg = (
|
||||
f"Remapping for fields starting with '_' or fields with a name "
|
||||
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
||||
f" the field is a pydantic Field instance. Got {key}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
type_, default_ = value
|
||||
remapped[f"private_{key}"] = (
|
||||
type_,
|
||||
Field(
|
||||
default=default_,
|
||||
alias=key,
|
||||
serialization_alias=key,
|
||||
title=key.lstrip("_").replace("_", " ").title(),
|
||||
),
|
||||
)
|
||||
else:
|
||||
remapped[key] = value
|
||||
return remapped
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
*,
|
||||
field_definitions: dict[str, Any] | None = None,
|
||||
root: Any | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
|
||||
Attention:
|
||||
Please do not use outside of langchain packages. This API
|
||||
is subject to change at any time.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model.
|
||||
module_name: The name of the module where the model is defined.
|
||||
This is used by Pydantic to resolve any forward references.
|
||||
field_definitions: The field definitions for the model.
|
||||
root: Type for a root model (RootModel)
|
||||
|
||||
Returns:
|
||||
Type[BaseModel]: The created model.
|
||||
"""
|
||||
field_definitions = field_definitions or {}
|
||||
|
||||
if root:
|
||||
if field_definitions:
|
||||
msg = (
|
||||
"When specifying __root__ no other "
|
||||
f"fields should be provided. Got {field_definitions}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if isinstance(root, tuple):
|
||||
kwargs = {"type_": root[0], "default_": root[1]}
|
||||
else:
|
||||
kwargs = {"type_": root}
|
||||
|
||||
try:
|
||||
named_root_model = _create_root_model_cached(model_name, **kwargs)
|
||||
except TypeError:
|
||||
# something in the arguments into _create_root_model_cached is not hashable
|
||||
named_root_model = _create_root_model(
|
||||
model_name,
|
||||
**kwargs,
|
||||
)
|
||||
return named_root_model
|
||||
|
||||
# No root, just field definitions
|
||||
names = set(field_definitions.keys())
|
||||
|
||||
capture_warnings = False
|
||||
|
||||
for name in names:
|
||||
# Also if any non-reserved name is used (e.g., model_id or model_name)
|
||||
if name.startswith("model"):
|
||||
capture_warnings = True
|
||||
|
||||
with warnings.catch_warnings() if capture_warnings else nullcontext():
|
||||
if capture_warnings:
|
||||
warnings.filterwarnings(action="ignore")
|
||||
try:
|
||||
return _create_model_cached(model_name, **field_definitions)
|
||||
except TypeError:
|
||||
# something in field definitions is not hashable
|
||||
return _create_model_base(
|
||||
model_name,
|
||||
__config__=_SchemaConfig,
|
||||
**_remap_field_definitions(field_definitions),
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_pydantic(type_: Any) -> bool:
|
||||
"""Check if a given "complex" type is supported by pydantic.
|
||||
|
||||
This will return False for primitive types like int, str, etc.
|
||||
|
||||
The check is meant for container types like dataclasses, TypedDicts, etc.
|
||||
"""
|
||||
if is_dataclass(type_):
|
||||
return True
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, BaseModel):
|
||||
return True
|
||||
|
||||
if hasattr(type_, "__orig_bases__"):
|
||||
for base in type_.__orig_bases__:
|
||||
if base is TypedDict:
|
||||
return True
|
||||
elif base is typing.TypedDict: # noqa: TID251
|
||||
# ignoring TID251 since it's OK to use typing.TypedDict in this case.
|
||||
# Pydantic supports typing.TypedDict from Python 3.12
|
||||
# For older versions, only typing_extensions.TypedDict is supported.
|
||||
if sys.version_info >= (3, 12):
|
||||
return True
|
||||
return False
|
||||
124
venv/Lib/site-packages/langgraph/_internal/_queue.py
Normal file
124
venv/Lib/site-packages/langgraph/_internal/_queue.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections import deque
|
||||
from time import monotonic
|
||||
|
||||
|
||||
class AsyncQueue(asyncio.Queue):
|
||||
"""Async unbounded FIFO queue with a wait() method.
|
||||
|
||||
Subclassed from asyncio.Queue, adding a wait() method."""
|
||||
|
||||
async def wait(self) -> None:
|
||||
"""If queue is empty, wait until an item is available.
|
||||
|
||||
Copied from Queue.get(), removing the call to .get_nowait(),
|
||||
ie. this doesn't consume the item, just waits for it.
|
||||
"""
|
||||
while self.empty():
|
||||
getter = self._get_loop().create_future()
|
||||
self._getters.append(getter)
|
||||
try:
|
||||
await getter
|
||||
except BaseException:
|
||||
getter.cancel() # Just in case getter is not done yet.
|
||||
try:
|
||||
# Clean self._getters from canceled getters.
|
||||
self._getters.remove(getter)
|
||||
except ValueError:
|
||||
# The getter could be removed from self._getters by a
|
||||
# previous put_nowait call.
|
||||
pass
|
||||
if not self.empty() and not getter.cancelled():
|
||||
# We were woken up by put_nowait(), but can't take
|
||||
# the call. Wake up the next in line.
|
||||
self._wakeup_next(self._getters)
|
||||
raise
|
||||
|
||||
|
||||
class Semaphore(threading.Semaphore):
|
||||
"""Semaphore subclass with a wait() method."""
|
||||
|
||||
def wait(self, blocking: bool = True, timeout: float | None = None):
|
||||
"""Block until the semaphore can be acquired, but don't acquire it."""
|
||||
if not blocking and timeout is not None:
|
||||
raise ValueError("can't specify timeout for non-blocking acquire")
|
||||
rc = False
|
||||
endtime = None
|
||||
with self._cond:
|
||||
while self._value == 0:
|
||||
if not blocking:
|
||||
break
|
||||
if timeout is not None:
|
||||
if endtime is None:
|
||||
endtime = monotonic() + timeout
|
||||
else:
|
||||
timeout = endtime - monotonic()
|
||||
if timeout <= 0:
|
||||
break
|
||||
self._cond.wait(timeout)
|
||||
else:
|
||||
rc = True
|
||||
return rc
|
||||
|
||||
|
||||
class SyncQueue:
|
||||
"""Unbounded FIFO queue with a wait() method.
|
||||
Adapted from pure Python implementation of queue.SimpleQueue.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue = deque()
|
||||
self._count = Semaphore(0)
|
||||
|
||||
def put(self, item, block=True, timeout=None):
|
||||
"""Put the item on the queue.
|
||||
|
||||
The optional 'block' and 'timeout' arguments are ignored, as this method
|
||||
never blocks. They are provided for compatibility with the Queue class.
|
||||
"""
|
||||
self._queue.append(item)
|
||||
self._count.release()
|
||||
|
||||
def get(self, block=False, timeout=None):
|
||||
"""Remove and return an item from the queue.
|
||||
|
||||
If optional args 'block' is true and 'timeout' is None (the default),
|
||||
block if necessary until an item is available. If 'timeout' is
|
||||
a non-negative number, it blocks at most 'timeout' seconds and raises
|
||||
the Empty exception if no item was available within that time.
|
||||
Otherwise ('block' is false), return an item if one is immediately
|
||||
available, else raise the Empty exception ('timeout' is ignored
|
||||
in that case).
|
||||
"""
|
||||
if timeout is not None and timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
if not self._count.acquire(block, timeout):
|
||||
raise queue.Empty
|
||||
try:
|
||||
return self._queue.popleft()
|
||||
except IndexError:
|
||||
raise queue.Empty
|
||||
|
||||
def wait(self, block=True, timeout=None):
|
||||
"""If queue is empty, wait until an item maybe is available,
|
||||
but don't consume it.
|
||||
"""
|
||||
if timeout is not None and timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
self._count.wait(block, timeout)
|
||||
|
||||
def empty(self):
|
||||
"""Return True if the queue is empty, False otherwise (not reliable!)."""
|
||||
return len(self._queue) == 0
|
||||
|
||||
def qsize(self):
|
||||
"""Return the approximate size of the queue (not reliable!)."""
|
||||
return len(self._queue)
|
||||
|
||||
__class_getitem__ = classmethod(types.GenericAlias)
|
||||
29
venv/Lib/site-packages/langgraph/_internal/_retry.py
Normal file
29
venv/Lib/site-packages/langgraph/_internal/_retry.py
Normal file
@@ -0,0 +1,29 @@
|
||||
def default_retry_on(exc: Exception) -> bool:
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
if isinstance(exc, ConnectionError):
|
||||
return True
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return 500 <= exc.response.status_code < 600
|
||||
if isinstance(exc, requests.HTTPError):
|
||||
return 500 <= exc.response.status_code < 600 if exc.response else True
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
ValueError,
|
||||
TypeError,
|
||||
ArithmeticError,
|
||||
ImportError,
|
||||
LookupError,
|
||||
NameError,
|
||||
SyntaxError,
|
||||
RuntimeError,
|
||||
ReferenceError,
|
||||
StopIteration,
|
||||
StopAsyncIteration,
|
||||
OSError,
|
||||
),
|
||||
):
|
||||
return False
|
||||
return True
|
||||
914
venv/Lib/site-packages/langgraph/_internal/_runnable.py
Normal file
914
venv/Lib/site-packages/langgraph/_internal/_runnable.py
Normal file
@@ -0,0 +1,914 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from contextvars import Context, Token, copy_context
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeGuard,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.runnables.base import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnableSequence,
|
||||
)
|
||||
from langchain_core.runnables.base import (
|
||||
RunnableLike as LCRunnableLike,
|
||||
)
|
||||
from langchain_core.runnables.config import (
|
||||
run_in_executor,
|
||||
var_child_runnable_config,
|
||||
)
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from langgraph._internal._config import (
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
patch_config,
|
||||
)
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_RUNTIME,
|
||||
)
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
try:
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
except ImportError:
|
||||
_StreamingCallbackHandler = None # type: ignore
|
||||
|
||||
|
||||
def _set_config_context(
|
||||
config: RunnableConfig, run: Any = None
|
||||
) -> Token[RunnableConfig | None]:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
config: The config to set.
|
||||
"""
|
||||
config_token = var_child_runnable_config.set(config)
|
||||
if run is not None:
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
|
||||
_set_tracing_context({"parent": run})
|
||||
return config_token
|
||||
|
||||
|
||||
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
token: The config token to reset.
|
||||
"""
|
||||
var_child_runnable_config.reset(token)
|
||||
if run is not None:
|
||||
from langsmith.run_helpers import _set_tracing_context
|
||||
|
||||
_set_tracing_context(
|
||||
{
|
||||
"parent": None,
|
||||
"project_name": None,
|
||||
"tags": None,
|
||||
"metadata": None,
|
||||
"enabled": None,
|
||||
"client": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_config_context(
|
||||
config: RunnableConfig, run: Any = None
|
||||
) -> Generator[Context, None, None]:
|
||||
"""Set the child Runnable config + tracing context.
|
||||
|
||||
Args:
|
||||
config: The config to set.
|
||||
"""
|
||||
ctx = copy_context()
|
||||
config_token = ctx.run(_set_config_context, config, run)
|
||||
try:
|
||||
yield ctx
|
||||
finally:
|
||||
ctx.run(_unset_config_context, config_token, run)
|
||||
|
||||
|
||||
# Before Python 3.11 native StrEnum is not available
|
||||
class StrEnum(str, enum.Enum):
|
||||
"""A string enum."""
|
||||
|
||||
|
||||
# Special type to denote any type is accepted
|
||||
ANY_TYPE = object()
|
||||
|
||||
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
|
||||
|
||||
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
|
||||
# A named argument may appear multiple times if it appears with distinct types.
|
||||
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
|
||||
(
|
||||
"config",
|
||||
(
|
||||
RunnableConfig,
|
||||
"RunnableConfig",
|
||||
Optional[RunnableConfig], # noqa: UP045
|
||||
"Optional[RunnableConfig]",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
# for now, use config directly, eventually, will pop off of Runtime
|
||||
"N/A",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"writer",
|
||||
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
|
||||
"stream_writer",
|
||||
lambda _: None,
|
||||
),
|
||||
(
|
||||
"store",
|
||||
(
|
||||
BaseStore,
|
||||
"BaseStore",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
"store",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"store",
|
||||
(
|
||||
Optional[BaseStore], # noqa: UP045
|
||||
"Optional[BaseStore]",
|
||||
),
|
||||
"store",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"previous",
|
||||
(ANY_TYPE,),
|
||||
"previous",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
(
|
||||
"runtime",
|
||||
(ANY_TYPE,),
|
||||
# we never hit this block, we just inject runtime directly
|
||||
"N/A",
|
||||
inspect.Parameter.empty,
|
||||
),
|
||||
)
|
||||
"""List of kwargs that can be passed to functions, and their corresponding
|
||||
config keys, default values and type annotations.
|
||||
|
||||
Used to configure keyword arguments that can be injected at runtime
|
||||
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
|
||||
|
||||
For a keyword to be injected from the config object, the function signature
|
||||
must contain a kwarg with the same name and a matching type annotation.
|
||||
|
||||
Each tuple contains:
|
||||
- the name of the kwarg in the function signature
|
||||
- the type annotation(s) for the kwarg
|
||||
- the `Runtime` attribute for fetching the value (N/A if not applicable)
|
||||
|
||||
This is fully internal and should be further refactored to use `get_type_hints`
|
||||
to resolve forward references and optional types formatted like BaseStore | None.
|
||||
"""
|
||||
|
||||
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
|
||||
|
||||
|
||||
class _RunnableWithWriter(Protocol[Input, Output]):
|
||||
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithStore(Protocol[Input, Output]):
|
||||
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithWriterStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, writer: StreamWriter, store: BaseStore
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigWriter(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self, state: Input, *, config: RunnableConfig, store: BaseStore
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
|
||||
def __call__(
|
||||
self,
|
||||
state: Input,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter,
|
||||
store: BaseStore,
|
||||
) -> Output: ...
|
||||
|
||||
|
||||
RunnableLike = (
|
||||
LCRunnableLike
|
||||
| _RunnableWithWriter[Input, Output]
|
||||
| _RunnableWithStore[Input, Output]
|
||||
| _RunnableWithWriterStore[Input, Output]
|
||||
| _RunnableWithConfigWriter[Input, Output]
|
||||
| _RunnableWithConfigStore[Input, Output]
|
||||
| _RunnableWithConfigWriterStore[Input, Output]
|
||||
)
|
||||
|
||||
|
||||
class RunnableCallable(Runnable):
|
||||
"""A much simpler version of RunnableLambda that requires sync and async functions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any | Runnable] | None,
|
||||
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
trace: bool = True,
|
||||
recurse: bool = True,
|
||||
explode_args: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.name = name
|
||||
if self.name is None:
|
||||
if func:
|
||||
try:
|
||||
if func.__name__ != "<lambda>":
|
||||
self.name = func.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
elif afunc:
|
||||
try:
|
||||
self.name = afunc.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
self.func = func
|
||||
self.afunc = afunc
|
||||
self.tags = tags
|
||||
self.kwargs = kwargs
|
||||
self.trace = trace
|
||||
self.recurse = recurse
|
||||
self.explode_args = explode_args
|
||||
# check signature
|
||||
if func is None and afunc is None:
|
||||
raise ValueError("At least one of func or afunc must be provided.")
|
||||
|
||||
self.func_accepts: dict[str, tuple[str, Any]] = {}
|
||||
params = inspect.signature(cast(Callable, func or afunc)).parameters
|
||||
|
||||
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
|
||||
p = params.get(kw)
|
||||
|
||||
if p is None or p.kind not in VALID_KINDS:
|
||||
# If parameter is not found or is not a valid kind, skip
|
||||
continue
|
||||
|
||||
if typ != (ANY_TYPE,) and p.annotation not in typ:
|
||||
# A specific type is required, but the function annotation does
|
||||
# not match the expected type.
|
||||
|
||||
# If this is a config parameter with incorrect typing, emit a warning
|
||||
# because we used to support any type but are moving towards more correct typing
|
||||
if kw == "config" and p.annotation != inspect.Parameter.empty:
|
||||
warnings.warn(
|
||||
f"The 'config' parameter should be typed as 'RunnableConfig' or "
|
||||
f"'RunnableConfig | None', not '{p.annotation}'. ",
|
||||
UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
continue
|
||||
|
||||
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
|
||||
self.func_accepts[kw] = (runtime_key, default)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_args = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
|
||||
}
|
||||
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
|
||||
|
||||
def invoke(
|
||||
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if self.func is None:
|
||||
raise TypeError(
|
||||
f'No synchronous function provided to "{self.name}".'
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
if self.explode_args:
|
||||
args, _kwargs = input
|
||||
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||||
else:
|
||||
args = (input,)
|
||||
kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||||
|
||||
for kw, (runtime_key, default) in self.func_accepts.items():
|
||||
# If the kwarg is already set, use the set value
|
||||
if kw in kwargs:
|
||||
continue
|
||||
|
||||
kw_value: Any = MISSING
|
||||
if kw == "config":
|
||||
kw_value = config
|
||||
elif runtime:
|
||||
if kw == "runtime":
|
||||
kw_value = runtime
|
||||
else:
|
||||
try:
|
||||
kw_value = getattr(runtime, runtime_key)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if kw_value is MISSING:
|
||||
if default is inspect.Parameter.empty:
|
||||
raise ValueError(
|
||||
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||||
)
|
||||
kw_value = default
|
||||
kwargs[kw] = kw_value
|
||||
|
||||
if self.trace:
|
||||
callback_manager = get_callback_manager_for_config(config, self.tags)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
# get the run
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(child_config, run) as context:
|
||||
ret = context.run(self.func, *args, **kwargs)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(ret)
|
||||
else:
|
||||
ret = self.func(*args, **kwargs)
|
||||
if self.recurse and isinstance(ret, Runnable):
|
||||
return ret.invoke(input, config)
|
||||
return ret
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if not self.afunc:
|
||||
return self.invoke(input, config)
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
if self.explode_args:
|
||||
args, _kwargs = input
|
||||
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||||
else:
|
||||
args = (input,)
|
||||
kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||||
|
||||
for kw, (runtime_key, default) in self.func_accepts.items():
|
||||
# If the kwarg has already been set, use the set value
|
||||
if kw in kwargs:
|
||||
continue
|
||||
|
||||
kw_value: Any = MISSING
|
||||
if kw == "config":
|
||||
kw_value = config
|
||||
elif runtime:
|
||||
if kw == "runtime":
|
||||
kw_value = runtime
|
||||
else:
|
||||
try:
|
||||
kw_value = getattr(runtime, runtime_key)
|
||||
except AttributeError:
|
||||
pass
|
||||
if kw_value is MISSING:
|
||||
if default is inspect.Parameter.empty:
|
||||
raise ValueError(
|
||||
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||||
)
|
||||
kw_value = default
|
||||
kwargs[kw] = kw_value
|
||||
|
||||
if self.trace:
|
||||
callback_manager = get_async_callback_manager_for_config(config, self.tags)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.name,
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
try:
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
with set_config_context(child_config, run) as context:
|
||||
ret = await asyncio.create_task(coro, context=context)
|
||||
else:
|
||||
ret = await coro
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(ret)
|
||||
else:
|
||||
ret = await self.afunc(*args, **kwargs)
|
||||
if self.recurse and isinstance(ret, Runnable):
|
||||
return await ret.ainvoke(input, config)
|
||||
return ret
|
||||
|
||||
|
||||
def is_async_callable(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., Awaitable]]:
|
||||
"""Check if a function is async."""
|
||||
return (
|
||||
inspect.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and inspect.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
def is_async_generator(
|
||||
func: Any,
|
||||
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||||
"""Check if a function is an async generator."""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__")
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
def coerce_to_runnable(
|
||||
thing: RunnableLike, *, name: str | None, trace: bool
|
||||
) -> Runnable:
|
||||
"""Coerce a runnable-like object into a Runnable.
|
||||
|
||||
Args:
|
||||
thing: A runnable-like object.
|
||||
|
||||
Returns:
|
||||
A Runnable.
|
||||
"""
|
||||
if isinstance(thing, Runnable):
|
||||
return thing
|
||||
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||||
return RunnableLambda(thing, name=name)
|
||||
elif callable(thing):
|
||||
if is_async_callable(thing):
|
||||
return RunnableCallable(None, thing, name=name, trace=trace)
|
||||
else:
|
||||
return RunnableCallable(
|
||||
thing,
|
||||
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
|
||||
name=name,
|
||||
trace=trace,
|
||||
)
|
||||
elif isinstance(thing, dict):
|
||||
return RunnableParallel(thing)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected a Runnable, callable or dict."
|
||||
f"Instead got an unsupported type: {type(thing)}"
|
||||
)
|
||||
|
||||
|
||||
class RunnableSeq(Runnable):
|
||||
"""Sequence of `Runnable`, where the output of each is the input of the next.
|
||||
|
||||
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
|
||||
LangGraph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*steps: RunnableLike,
|
||||
name: str | None = None,
|
||||
trace_inputs: Callable[[Any], Any] | None = None,
|
||||
) -> None:
|
||||
"""Create a new RunnableSeq.
|
||||
|
||||
Args:
|
||||
steps: The steps to include in the sequence.
|
||||
name: The name of the `Runnable`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the sequence has less than 2 steps.
|
||||
"""
|
||||
steps_flat: list[Runnable] = []
|
||||
for step in steps:
|
||||
if isinstance(step, RunnableSequence):
|
||||
steps_flat.extend(step.steps)
|
||||
elif isinstance(step, RunnableSeq):
|
||||
steps_flat.extend(step.steps)
|
||||
else:
|
||||
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
|
||||
if len(steps_flat) < 2:
|
||||
raise ValueError(
|
||||
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
|
||||
)
|
||||
self.steps = steps_flat
|
||||
self.name = name
|
||||
self.trace_inputs = trace_inputs
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Any,
|
||||
) -> Runnable:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
name=self.name or other.name,
|
||||
)
|
||||
elif isinstance(other, RunnableSeq):
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
*other.steps,
|
||||
name=self.name or other.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSeq(
|
||||
*self.steps,
|
||||
coerce_to_runnable(other, name=None, trace=True),
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Any,
|
||||
) -> Runnable:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
*self.steps,
|
||||
name=other.name or self.name,
|
||||
)
|
||||
elif isinstance(other, RunnableSeq):
|
||||
return RunnableSeq(
|
||||
*other.steps,
|
||||
*self.steps,
|
||||
name=other.name or self.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
coerce_to_runnable(other, name=None, trace=True),
|
||||
*self.steps,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks and context
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||
)
|
||||
# 1st step is the actual node,
|
||||
# others are writers which don't need to be run in context
|
||||
if i == 0:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(config, run) as context:
|
||||
input = context.run(step.invoke, input, config, **kwargs)
|
||||
else:
|
||||
input = step.invoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(input)
|
||||
return input
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for i, step in enumerate(self.steps):
|
||||
# mark each step as a child run
|
||||
config = patch_config(
|
||||
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||||
)
|
||||
# 1st step is the actual node,
|
||||
# others are writers which don't need to be run in context
|
||||
if i == 0:
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# run in context
|
||||
with set_config_context(config, run) as context:
|
||||
input = await asyncio.create_task(
|
||||
step.ainvoke(input, config, **kwargs), context=context
|
||||
)
|
||||
else:
|
||||
input = await step.ainvoke(input, config, **kwargs)
|
||||
else:
|
||||
input = await step.ainvoke(input, config)
|
||||
# finish the root run
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(input)
|
||||
return input
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# create first step config
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||||
)
|
||||
# run all in context
|
||||
with set_config_context(config, run) as context:
|
||||
try:
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
for idx, step in enumerate(self.steps):
|
||||
if idx == 0:
|
||||
iterator = step.stream(input, config, **kwargs)
|
||||
else:
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||||
)
|
||||
iterator = step.transform(iterator, config)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
iterator = h.tap_output_iter(run_manager.run_id, iterator)
|
||||
# consume into final output
|
||||
output = context.run(_consume_iter, iterator)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_chain_end(output)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
if config is None:
|
||||
config = ensure_config()
|
||||
# setup callbacks
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
None,
|
||||
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
run_id=config.pop("run_id", None),
|
||||
)
|
||||
# stream the last steps
|
||||
# transform the input stream of each step with the next
|
||||
# steps that don't natively support transforming an input stream will
|
||||
# buffer input in memory until all available, and then start emitting output
|
||||
if ASYNCIO_ACCEPTS_CONTEXT:
|
||||
# get the run object
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, LangChainTracer):
|
||||
run = h.run_map.get(str(run_manager.run_id))
|
||||
break
|
||||
else:
|
||||
run = None
|
||||
# create first step config
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||||
)
|
||||
# run all in context
|
||||
with set_config_context(config, run) as context:
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
for idx, step in enumerate(self.steps):
|
||||
if idx == 0:
|
||||
aiterator = step.astream(input, config, **kwargs)
|
||||
else:
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
f"seq:step:{idx + 1}"
|
||||
),
|
||||
)
|
||||
aiterator = step.atransform(aiterator, config)
|
||||
if hasattr(aiterator, "aclose"):
|
||||
stack.push_async_callback(aiterator.aclose)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
aiterator = h.tap_output_aiter(
|
||||
run_manager.run_id, aiterator
|
||||
)
|
||||
# consume into final output
|
||||
output = await asyncio.create_task(
|
||||
_consume_aiter(aiterator), context=context
|
||||
)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
else:
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
for idx, step in enumerate(self.steps):
|
||||
config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||||
)
|
||||
if idx == 0:
|
||||
aiterator = step.astream(input, config, **kwargs)
|
||||
else:
|
||||
aiterator = step.atransform(aiterator, config)
|
||||
if hasattr(aiterator, "aclose"):
|
||||
stack.push_async_callback(aiterator.aclose)
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
if _StreamingCallbackHandler is not None:
|
||||
for h in run_manager.handlers:
|
||||
if isinstance(h, _StreamingCallbackHandler):
|
||||
aiterator = h.tap_output_aiter(
|
||||
run_manager.run_id, aiterator
|
||||
)
|
||||
# consume into final output
|
||||
output = await _consume_aiter(aiterator)
|
||||
# sequence doesn't emit output, yield to mark as generator
|
||||
yield
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(output)
|
||||
|
||||
|
||||
def _consume_iter(it: Iterator[Any]) -> Any:
|
||||
"""Consume an iterator."""
|
||||
output: Any = None
|
||||
add_supported = False
|
||||
for chunk in it:
|
||||
# collect final output
|
||||
if output is None:
|
||||
output = chunk
|
||||
elif add_supported:
|
||||
try:
|
||||
output = output + chunk
|
||||
except TypeError:
|
||||
output = chunk
|
||||
add_supported = False
|
||||
else:
|
||||
output = chunk
|
||||
return output
|
||||
|
||||
|
||||
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
|
||||
"""Consume an async iterator."""
|
||||
output: Any = None
|
||||
add_supported = False
|
||||
async for chunk in it:
|
||||
# collect final output
|
||||
if add_supported:
|
||||
try:
|
||||
output = output + chunk
|
||||
except TypeError:
|
||||
output = chunk
|
||||
add_supported = False
|
||||
else:
|
||||
output = chunk
|
||||
return output
|
||||
19
venv/Lib/site-packages/langgraph/_internal/_scratchpad.py
Normal file
19
venv/Lib/site-packages/langgraph/_internal/_scratchpad.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langgraph.types import _DC_KWARGS
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_DC_KWARGS)
|
||||
class PregelScratchpad:
|
||||
step: int
|
||||
stop: int
|
||||
# call
|
||||
call_counter: Callable[[], int]
|
||||
# interrupt
|
||||
interrupt_counter: Callable[[], int]
|
||||
get_null_resume: Callable[[bool], Any]
|
||||
resume: list[Any]
|
||||
# subgraph
|
||||
subgraph_counter: Callable[[], int]
|
||||
253
venv/Lib/site-packages/langgraph/_internal/_serde.py
Normal file
253
venv/Lib/site-packages/langgraph/_internal/_serde.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import sys
|
||||
import types
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from langchain_core import messages as lc_messages
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired, Required, is_typeddict
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.serde._msgpack import ( # noqa: F401
|
||||
STRICT_MSGPACK_ENABLED,
|
||||
)
|
||||
except ImportError:
|
||||
STRICT_MSGPACK_ENABLED = False
|
||||
|
||||
_warned_allowlist_unsupported = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _supports_checkpointer_allowlist() -> bool:
|
||||
return hasattr(BaseCheckpointSaver, "with_allowlist")
|
||||
|
||||
|
||||
_SUPPORTS_ALLOWLIST = _supports_checkpointer_allowlist()
|
||||
|
||||
|
||||
def apply_checkpointer_allowlist(
|
||||
checkpointer: Any, allowlist: set[tuple[str, ...]] | None
|
||||
) -> Any:
|
||||
if not checkpointer or allowlist is None or checkpointer in (True, False):
|
||||
return checkpointer
|
||||
if not _SUPPORTS_ALLOWLIST:
|
||||
global _warned_allowlist_unsupported
|
||||
if not _warned_allowlist_unsupported:
|
||||
logger.warning(
|
||||
"Checkpointer does not support with_allowlist; strict msgpack "
|
||||
"allowlist will be skipped."
|
||||
)
|
||||
_warned_allowlist_unsupported = True
|
||||
return checkpointer
|
||||
return checkpointer.with_allowlist(allowlist)
|
||||
|
||||
|
||||
def curated_core_allowlist() -> set[tuple[str, ...]]:
|
||||
allowlist: set[tuple[str, ...]] = set()
|
||||
for name in (
|
||||
"BaseMessage",
|
||||
"BaseMessageChunk",
|
||||
"HumanMessage",
|
||||
"HumanMessageChunk",
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"RemoveMessage",
|
||||
):
|
||||
cls = getattr(lc_messages, name, None)
|
||||
if cls is None:
|
||||
continue
|
||||
allowlist.add((cls.__module__, cls.__name__))
|
||||
|
||||
return allowlist
|
||||
|
||||
|
||||
def build_serde_allowlist(
|
||||
*,
|
||||
schemas: list[type[Any]] | None = None,
|
||||
channels: dict[str, Any] | None = None,
|
||||
) -> set[tuple[str, ...]]:
|
||||
allowlist = curated_core_allowlist()
|
||||
if schemas:
|
||||
schemas = [schema for schema in schemas if schema is not None]
|
||||
return allowlist | collect_allowlist_from_schemas(
|
||||
schemas=schemas,
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
|
||||
def collect_allowlist_from_schemas(
|
||||
*,
|
||||
schemas: list[type[Any]] | None = None,
|
||||
channels: dict[str, Any] | None = None,
|
||||
) -> set[tuple[str, ...]]:
|
||||
allowlist: set[tuple[str, ...]] = set()
|
||||
seen: set[Any] = set()
|
||||
seen_ids: set[int] = set()
|
||||
|
||||
if schemas:
|
||||
for schema in schemas:
|
||||
_collect_from_type(schema, allowlist, seen, seen_ids)
|
||||
|
||||
if channels:
|
||||
for channel in channels.values():
|
||||
value_type = getattr(channel, "ValueType", None)
|
||||
if value_type is not None:
|
||||
_collect_from_type(value_type, allowlist, seen, seen_ids)
|
||||
update_type = getattr(channel, "UpdateType", None)
|
||||
if update_type is not None:
|
||||
_collect_from_type(update_type, allowlist, seen, seen_ids)
|
||||
|
||||
return allowlist
|
||||
|
||||
|
||||
def _collect_from_type(
|
||||
typ: Any,
|
||||
allowlist: set[tuple[str, ...]],
|
||||
seen: set[Any],
|
||||
seen_ids: set[int],
|
||||
) -> None:
|
||||
if _already_seen(typ, seen, seen_ids):
|
||||
return
|
||||
|
||||
if typ is Any or typ is None:
|
||||
return
|
||||
|
||||
if typ is Literal:
|
||||
return
|
||||
|
||||
if isinstance(typ, types.UnionType):
|
||||
for arg in typ.__args__:
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
origin = get_origin(typ)
|
||||
if origin is Union:
|
||||
for arg in get_args(typ):
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
if origin is Annotated or origin in (Required, NotRequired):
|
||||
args = get_args(typ)
|
||||
if args:
|
||||
_collect_from_type(args[0], allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if origin is Literal:
|
||||
return
|
||||
|
||||
if origin in (list, set, tuple, dict, deque, frozenset):
|
||||
for arg in get_args(typ):
|
||||
_collect_from_type(arg, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if hasattr(typ, "__supertype__"):
|
||||
_collect_from_type(typ.__supertype__, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if is_typeddict(typ):
|
||||
for field_type in _safe_get_type_hints(typ).values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if _is_pydantic_model(typ):
|
||||
allowlist.add((typ.__module__, typ.__name__))
|
||||
field_types = _safe_get_type_hints(typ)
|
||||
if field_types:
|
||||
for field_type in field_types.values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
else:
|
||||
for field_type in _pydantic_field_types(typ):
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if dataclasses.is_dataclass(typ):
|
||||
if typ_name := getattr(typ, "__name__", None):
|
||||
allowlist.add((typ.__module__, typ_name))
|
||||
field_types = _safe_get_type_hints(typ)
|
||||
if field_types:
|
||||
for field_type in field_types.values():
|
||||
_collect_from_type(field_type, allowlist, seen, seen_ids)
|
||||
else:
|
||||
for field in dataclasses.fields(typ):
|
||||
_collect_from_type(field.type, allowlist, seen, seen_ids)
|
||||
return
|
||||
|
||||
if isinstance(typ, type) and issubclass(typ, Enum):
|
||||
allowlist.add((typ.__module__, typ.__name__))
|
||||
return
|
||||
|
||||
|
||||
def _already_seen(typ: Any, seen: set[Any], seen_ids: set[int]) -> bool:
|
||||
try:
|
||||
if typ in seen:
|
||||
return True
|
||||
seen.add(typ)
|
||||
return False
|
||||
except TypeError:
|
||||
typ_id = id(typ)
|
||||
if typ_id in seen_ids:
|
||||
return True
|
||||
seen_ids.add(typ_id)
|
||||
return False
|
||||
|
||||
|
||||
def _safe_get_type_hints(typ: Any) -> dict[str, Any]:
|
||||
try:
|
||||
module = sys.modules.get(getattr(typ, "__module__", ""))
|
||||
globalns = module.__dict__ if module else None
|
||||
localns = dict(vars(typ)) if hasattr(typ, "__dict__") else None
|
||||
return get_type_hints(
|
||||
typ, globalns=globalns, localns=localns, include_extras=True
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _is_pydantic_model(typ: Any) -> bool:
|
||||
if not isinstance(typ, type):
|
||||
return False
|
||||
if issubclass(typ, BaseModel):
|
||||
return True
|
||||
try:
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
except Exception:
|
||||
return False
|
||||
return issubclass(typ, BaseModelV1)
|
||||
|
||||
|
||||
def _pydantic_field_types(typ: type[Any]) -> list[Any]:
|
||||
if hasattr(typ, "model_fields"):
|
||||
return [
|
||||
field.annotation
|
||||
for field in typ.model_fields.values()
|
||||
if getattr(field, "annotation", None) is not None
|
||||
]
|
||||
if hasattr(typ, "__fields__"):
|
||||
return [
|
||||
field.outer_type_
|
||||
for field in typ.__fields__.values()
|
||||
if getattr(field, "outer_type_", None) is not None
|
||||
]
|
||||
return []
|
||||
54
venv/Lib/site-packages/langgraph/_internal/_typing.py
Normal file
54
venv/Lib/site-packages/langgraph/_internal/_typing.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Private typing utilities for LangGraph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import Field
|
||||
from typing import Any, ClassVar, Protocol, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class TypedDictLikeV1(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts
|
||||
|
||||
Version 1: using `ClassVar` for keys."""
|
||||
|
||||
__required_keys__: ClassVar[frozenset[str]]
|
||||
__optional_keys__: ClassVar[frozenset[str]]
|
||||
|
||||
|
||||
class TypedDictLikeV2(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts
|
||||
|
||||
Version 2: not using `ClassVar` for keys."""
|
||||
|
||||
__required_keys__: frozenset[str]
|
||||
__optional_keys__: frozenset[str]
|
||||
|
||||
|
||||
class DataclassLike(Protocol):
|
||||
"""Protocol to represent types that behave like dataclasses.
|
||||
|
||||
Inspired by the private _DataclassT from dataclasses that uses a similar protocol as a bound."""
|
||||
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
|
||||
StateLike: TypeAlias = TypedDictLikeV1 | TypedDictLikeV2 | DataclassLike | BaseModel
|
||||
"""Type alias for state-like types.
|
||||
|
||||
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
|
||||
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in type checking.
|
||||
"""
|
||||
|
||||
MISSING = object()
|
||||
"""Unset sentinel value."""
|
||||
|
||||
|
||||
class DeprecatedKwargs(TypedDict):
|
||||
"""TypedDict to use for extra keyword arguments, enabling type checking warnings for deprecated arguments."""
|
||||
|
||||
|
||||
EMPTY_SEQ: tuple[str, ...] = tuple()
|
||||
"""An empty sequence of strings."""
|
||||
Reference in New Issue
Block a user