initial commit
This commit is contained in:
277
venv/Lib/site-packages/langgraph/pregel/_read.py
Normal file
277
venv/Lib/site-packages/langgraph/pregel/_read.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
|
||||
from langgraph._internal._config import merge_configs
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_READ
|
||||
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
|
||||
from langgraph.pregel._utils import find_subgraph_pregel
|
||||
from langgraph.pregel._write import ChannelWrite
|
||||
from langgraph.pregel.protocol import PregelProtocol
|
||||
from langgraph.types import CachePolicy, RetryPolicy
|
||||
|
||||
READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
|
||||
INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
|
||||
|
||||
|
||||
class ChannelRead(RunnableCallable):
|
||||
"""Implements the logic for reading state from CONFIG_KEY_READ.
|
||||
Usable both as a runnable as well as a static method to call imperatively."""
|
||||
|
||||
channel: str | list[str]
|
||||
|
||||
fresh: bool = False
|
||||
|
||||
mapper: Callable[[Any], Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel: str | list[str],
|
||||
*,
|
||||
fresh: bool = False,
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
func=self._read,
|
||||
afunc=self._aread,
|
||||
tags=tags,
|
||||
name=None,
|
||||
trace=False,
|
||||
)
|
||||
self.fresh = fresh
|
||||
self.mapper = mapper
|
||||
self.channel = channel
|
||||
|
||||
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
|
||||
if name:
|
||||
pass
|
||||
elif isinstance(self.channel, str):
|
||||
name = f"ChannelRead<{self.channel}>"
|
||||
else:
|
||||
name = f"ChannelRead<{','.join(self.channel)}>"
|
||||
return super().get_name(suffix, name=name)
|
||||
|
||||
def _read(self, _: Any, config: RunnableConfig) -> Any:
|
||||
return self.do_read(
|
||||
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
||||
)
|
||||
|
||||
async def _aread(self, _: Any, config: RunnableConfig) -> Any:
|
||||
return self.do_read(
|
||||
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def do_read(
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
select: str | list[str],
|
||||
fresh: bool = False,
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Not configured with a read function"
|
||||
"Make sure to call in the context of a Pregel process"
|
||||
)
|
||||
if mapper:
|
||||
return mapper(read(select, fresh))
|
||||
else:
|
||||
return read(select, fresh)
|
||||
|
||||
|
||||
DEFAULT_BOUND = RunnableCallable(lambda input: input)
|
||||
|
||||
|
||||
class PregelNode:
|
||||
"""A node in a Pregel graph. This won't be invoked as a runnable by the graph
|
||||
itself, but instead acts as a container for the components necessary to make
|
||||
a PregelExecutableTask for a node."""
|
||||
|
||||
channels: str | list[str]
|
||||
"""The channels that will be passed as input to `bound`.
|
||||
If a str, the node will be invoked with its value if it isn't empty.
|
||||
If a list, the node will be invoked with a dict of those channels' values."""
|
||||
|
||||
triggers: list[str]
|
||||
"""If any of these channels is written to, this node will be triggered in
|
||||
the next step."""
|
||||
|
||||
mapper: Callable[[Any], Any] | None
|
||||
"""A function to transform the input before passing it to `bound`."""
|
||||
|
||||
writers: list[Runnable]
|
||||
"""A list of writers that will be executed after `bound`, responsible for
|
||||
taking the output of `bound` and writing it to the appropriate channels."""
|
||||
|
||||
bound: Runnable[Any, Any]
|
||||
"""The main logic of the node. This will be invoked with the input from
|
||||
`channels`."""
|
||||
|
||||
retry_policy: Sequence[RetryPolicy] | None
|
||||
"""The retry policies to use when invoking the node."""
|
||||
|
||||
cache_policy: CachePolicy | None
|
||||
"""The cache policy to use when invoking the node."""
|
||||
|
||||
tags: Sequence[str] | None
|
||||
"""Tags to attach to the node for tracing."""
|
||||
|
||||
metadata: Mapping[str, Any] | None
|
||||
"""Metadata to attach to the node for tracing."""
|
||||
|
||||
subgraphs: Sequence[PregelProtocol]
|
||||
"""Subgraphs used by the node."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels: str | list[str],
|
||||
triggers: Sequence[str],
|
||||
mapper: Callable[[Any], Any] | None = None,
|
||||
writers: list[Runnable] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
bound: Runnable[Any, Any] | None = None,
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
subgraphs: Sequence[PregelProtocol] | None = None,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.triggers = list(triggers)
|
||||
self.mapper = mapper
|
||||
self.writers = writers or []
|
||||
self.bound = bound if bound is not None else DEFAULT_BOUND
|
||||
self.cache_policy = cache_policy
|
||||
if isinstance(retry_policy, RetryPolicy):
|
||||
self.retry_policy = (retry_policy,)
|
||||
else:
|
||||
self.retry_policy = retry_policy
|
||||
self.tags = tags
|
||||
self.metadata = metadata
|
||||
if subgraphs is not None:
|
||||
self.subgraphs = subgraphs
|
||||
elif self.bound is not DEFAULT_BOUND:
|
||||
try:
|
||||
subgraph = find_subgraph_pregel(self.bound)
|
||||
except Exception:
|
||||
subgraph = None
|
||||
if subgraph:
|
||||
self.subgraphs = [subgraph]
|
||||
else:
|
||||
self.subgraphs = []
|
||||
else:
|
||||
self.subgraphs = []
|
||||
|
||||
def copy(self, update: dict[str, Any]) -> PregelNode:
|
||||
attrs = {**self.__dict__, **update}
|
||||
# Drop the cached properties
|
||||
attrs.pop("flat_writers", None)
|
||||
attrs.pop("node", None)
|
||||
attrs.pop("input_cache_key", None)
|
||||
return PregelNode(**attrs)
|
||||
|
||||
@cached_property
|
||||
def flat_writers(self) -> list[Runnable]:
|
||||
"""Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
|
||||
writers = self.writers.copy()
|
||||
while (
|
||||
len(writers) > 1
|
||||
and isinstance(writers[-1], ChannelWrite)
|
||||
and isinstance(writers[-2], ChannelWrite)
|
||||
):
|
||||
# we can combine writes if they are consecutive
|
||||
# careful to not modify the original writers list or ChannelWrite
|
||||
writers[-2] = ChannelWrite(
|
||||
writes=writers[-2].writes + writers[-1].writes,
|
||||
)
|
||||
writers.pop()
|
||||
return writers
|
||||
|
||||
@cached_property
|
||||
def node(self) -> Runnable[Any, Any] | None:
|
||||
"""Get a runnable that combines `bound` and `writers`."""
|
||||
writers = self.flat_writers
|
||||
if self.bound is DEFAULT_BOUND and not writers:
|
||||
return None
|
||||
elif self.bound is DEFAULT_BOUND and len(writers) == 1:
|
||||
return writers[0]
|
||||
elif self.bound is DEFAULT_BOUND:
|
||||
return RunnableSeq(*writers)
|
||||
elif writers:
|
||||
return RunnableSeq(self.bound, *writers)
|
||||
else:
|
||||
return self.bound
|
||||
|
||||
@cached_property
|
||||
def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
|
||||
"""Get a cache key for the input to the node.
|
||||
This is used to avoid calculating the same input multiple times."""
|
||||
return (
|
||||
self.mapper,
|
||||
tuple(self.channels)
|
||||
if isinstance(self.channels, list)
|
||||
else (self.channels,),
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
return self.bound.invoke(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Any:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
return await self.bound.ainvoke(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> Iterator[Any]:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
yield from self.bound.stream(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any | None,
|
||||
) -> AsyncIterator[Any]:
|
||||
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
||||
async for item in self.bound.astream(
|
||||
input,
|
||||
merge_configs(self_config, config),
|
||||
**kwargs,
|
||||
):
|
||||
yield item
|
||||
Reference in New Issue
Block a user