initial commit
This commit is contained in:
27
venv/Lib/site-packages/langgraph/channels/__init__.py
Normal file
27
venv/Lib/site-packages/langgraph/channels/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from langgraph.channels.any_value import AnyValue
|
||||
from langgraph.channels.base import BaseChannel
|
||||
from langgraph.channels.binop import BinaryOperatorAggregate
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.last_value import LastValue, LastValueAfterFinish
|
||||
from langgraph.channels.named_barrier_value import (
|
||||
NamedBarrierValue,
|
||||
NamedBarrierValueAfterFinish,
|
||||
)
|
||||
from langgraph.channels.topic import Topic
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
|
||||
__all__ = (
|
||||
# base
|
||||
"BaseChannel",
|
||||
# value types
|
||||
"AnyValue",
|
||||
"LastValue",
|
||||
"LastValueAfterFinish",
|
||||
"UntrackedValue",
|
||||
"EphemeralValue",
|
||||
"BinaryOperatorAggregate",
|
||||
"NamedBarrierValue",
|
||||
"NamedBarrierValueAfterFinish",
|
||||
# topics
|
||||
"Topic",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
72
venv/Lib/site-packages/langgraph/channels/any_value.py
Normal file
72
venv/Lib/site-packages/langgraph/channels/any_value.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
__all__ = ("AnyValue",)
|
||||
|
||||
|
||||
class AnyValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, assumes that if multiple values are
|
||||
received, they are all equal."""
|
||||
|
||||
__slots__ = ("typ", "value")
|
||||
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, AnyValue)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
if self.value is MISSING:
|
||||
return False
|
||||
else:
|
||||
self.value = MISSING
|
||||
return True
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
121
venv/Lib/site-packages/langgraph/channels/base.py
Normal file
121
venv/Lib/site-packages/langgraph/channels/base.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
Value = TypeVar("Value")
|
||||
Update = TypeVar("Update")
|
||||
Checkpoint = TypeVar("Checkpoint")
|
||||
|
||||
__all__ = ("BaseChannel",)
|
||||
|
||||
|
||||
class BaseChannel(Generic[Value, Update, Checkpoint], ABC):
|
||||
"""Base class for all channels."""
|
||||
|
||||
__slots__ = ("key", "typ")
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
self.typ = typ
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ValueType(self) -> Any:
|
||||
"""The type of the value stored in the channel."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def UpdateType(self) -> Any:
|
||||
"""The type of the update received by the channel."""
|
||||
|
||||
# serialize/deserialize methods
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel.
|
||||
|
||||
By default, delegates to `checkpoint()` and `from_checkpoint()`.
|
||||
|
||||
Subclasses can override this method with a more efficient implementation.
|
||||
"""
|
||||
return self.from_checkpoint(self.checkpoint())
|
||||
|
||||
def checkpoint(self) -> Checkpoint | Any:
|
||||
"""Return a serializable representation of the channel's current state.
|
||||
|
||||
Raises `EmptyChannelError` if the channel is empty (never updated yet),
|
||||
or doesn't support checkpoints.
|
||||
"""
|
||||
try:
|
||||
return self.get()
|
||||
except EmptyChannelError:
|
||||
return MISSING
|
||||
|
||||
@abstractmethod
|
||||
def from_checkpoint(self, checkpoint: Checkpoint | Any) -> Self:
|
||||
"""Return a new identical channel, optionally initialized from a checkpoint.
|
||||
|
||||
If the checkpoint contains complex data structures, they should be copied.
|
||||
"""
|
||||
|
||||
# read methods
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> Value:
|
||||
"""Return the current value of the channel.
|
||||
|
||||
Raises `EmptyChannelError` if the channel is empty (never updated yet)."""
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Return `True` if the channel is available (not empty), `False` otherwise.
|
||||
|
||||
Subclasses should override this method to provide a more efficient
|
||||
implementation than calling `get()` and catching `EmptyChannelError`.
|
||||
"""
|
||||
try:
|
||||
self.get()
|
||||
return True
|
||||
except EmptyChannelError:
|
||||
return False
|
||||
|
||||
# write methods
|
||||
|
||||
@abstractmethod
|
||||
def update(self, values: Sequence[Update]) -> bool:
|
||||
"""Update the channel's value with the given sequence of updates.
|
||||
The order of the updates in the sequence is arbitrary.
|
||||
This method is called by Pregel for all channels at the end of each step.
|
||||
|
||||
If there are no updates, it is called with an empty sequence.
|
||||
|
||||
Raises `InvalidUpdateError` if the sequence of updates is invalid.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise."""
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Notify the channel that a subscribed task ran.
|
||||
|
||||
By default, no-op.
|
||||
|
||||
A channel can use this method to modify its state, preventing the value from being consumed again.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
"""Notify the channel that the Pregel run is finishing.
|
||||
|
||||
By default, no-op.
|
||||
|
||||
A channel can use this method to modify its state, preventing finish.
|
||||
|
||||
Returns `True` if the channel was updated, `False` otherwise.
|
||||
"""
|
||||
return False
|
||||
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import NotRequired, Required, Self
|
||||
|
||||
from langgraph._internal._constants import OVERWRITE
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import (
|
||||
EmptyChannelError,
|
||||
ErrorCode,
|
||||
InvalidUpdateError,
|
||||
create_error_message,
|
||||
)
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
__all__ = ("BinaryOperatorAggregate",)
|
||||
|
||||
|
||||
# Adapted from typing_extensions
|
||||
def _strip_extras(t): # type: ignore[no-untyped-def]
|
||||
"""Strips Annotated, Required and NotRequired from a given type."""
|
||||
if hasattr(t, "__origin__"):
|
||||
return _strip_extras(t.__origin__)
|
||||
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
|
||||
return _strip_extras(t.__args__[0])
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def _get_overwrite(value: Any) -> tuple[bool, Any]:
|
||||
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
|
||||
if isinstance(value, Overwrite):
|
||||
return True, value.value
|
||||
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
|
||||
return True, value[OVERWRITE]
|
||||
return False, None
|
||||
|
||||
|
||||
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the result of applying a binary operator to the current value and each new value.
|
||||
|
||||
```python
|
||||
import operator
|
||||
|
||||
total = Channels.BinaryOperatorAggregate(int, operator.add)
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("value", "operator")
|
||||
|
||||
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
|
||||
super().__init__(typ)
|
||||
self.operator = operator
|
||||
# special forms from typing or collections.abc are not instantiable
|
||||
# so we need to replace them with their concrete counterparts
|
||||
typ = _strip_extras(typ)
|
||||
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
|
||||
typ = list
|
||||
if typ in (collections.abc.Set, collections.abc.MutableSet):
|
||||
typ = set
|
||||
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
|
||||
typ = dict
|
||||
try:
|
||||
self.value = typ()
|
||||
except Exception:
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, BinaryOperatorAggregate) and (
|
||||
value.operator is self.operator
|
||||
if value.operator.__name__ != "<lambda>"
|
||||
and self.operator.__name__ != "<lambda>"
|
||||
else True
|
||||
)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if not values:
|
||||
return False
|
||||
if self.value is MISSING:
|
||||
self.value = values[0]
|
||||
values = values[1:]
|
||||
seen_overwrite: bool = False
|
||||
for value in values:
|
||||
is_overwrite, overwrite_value = _get_overwrite(value)
|
||||
if is_overwrite:
|
||||
if seen_overwrite:
|
||||
msg = create_error_message(
|
||||
message="Can receive only one Overwrite value per super-step.",
|
||||
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||||
)
|
||||
raise InvalidUpdateError(msg)
|
||||
self.value = overwrite_value
|
||||
seen_overwrite = True
|
||||
continue
|
||||
if not seen_overwrite:
|
||||
self.value = self.operator(self.value, value)
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
79
venv/Lib/site-packages/langgraph/channels/ephemeral_value.py
Normal file
79
venv/Lib/site-packages/langgraph/channels/ephemeral_value.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("EphemeralValue",)
|
||||
|
||||
|
||||
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the value received in the step immediately preceding, clears after."""
|
||||
|
||||
__slots__ = ("value", "guard")
|
||||
|
||||
value: Value | Any
|
||||
guard: bool
|
||||
|
||||
def __init__(self, typ: Any, guard: bool = True) -> None:
|
||||
super().__init__(typ)
|
||||
self.guard = guard
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, EphemeralValue) and value.guard == self.guard
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
if self.value is not MISSING:
|
||||
self.value = MISSING
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if len(values) != 1 and self.guard:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
|
||||
)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
151
venv/Lib/site-packages/langgraph/channels/last_value.py
Normal file
151
venv/Lib/site-packages/langgraph/channels/last_value.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import (
|
||||
EmptyChannelError,
|
||||
ErrorCode,
|
||||
InvalidUpdateError,
|
||||
create_error_message,
|
||||
)
|
||||
|
||||
__all__ = ("LastValue", "LastValueAfterFinish")
|
||||
|
||||
|
||||
class LastValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, can receive at most one value per step."""
|
||||
|
||||
__slots__ = ("value",)
|
||||
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, LastValue)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.key)
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
if len(values) != 1:
|
||||
msg = create_error_message(
|
||||
message=f"At key '{self.key}': Can receive only one value per step. Use an Annotated key to handle multiple values.",
|
||||
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||||
)
|
||||
raise InvalidUpdateError(msg)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
|
||||
|
||||
class LastValueAfterFinish(
|
||||
Generic[Value], BaseChannel[Value, Value, tuple[Value, bool]]
|
||||
):
|
||||
"""Stores the last value received, but only made available after finish().
|
||||
Once made available, clears the value."""
|
||||
|
||||
__slots__ = ("value", "finished")
|
||||
|
||||
value: Value | Any
|
||||
finished: bool
|
||||
|
||||
def __init__(self, typ: Any, key: str = "") -> None:
|
||||
super().__init__(typ, key)
|
||||
self.value = MISSING
|
||||
self.finished = False
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, LastValueAfterFinish)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def checkpoint(self) -> tuple[Value | Any, bool] | Any:
|
||||
if self.value is MISSING:
|
||||
return MISSING
|
||||
return (self.value, self.finished)
|
||||
|
||||
def from_checkpoint(self, checkpoint: tuple[Value | Any, bool] | Any) -> Self:
|
||||
empty = self.__class__(self.typ)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value, empty.finished = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value | Any]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
|
||||
self.finished = False
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.finished:
|
||||
self.finished = False
|
||||
self.value = MISSING
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
if not self.finished and self.value is not MISSING:
|
||||
self.finished = True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING or not self.finished:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING and self.finished
|
||||
167
venv/Lib/site-packages/langgraph/channels/named_barrier_value.py
Normal file
167
venv/Lib/site-packages/langgraph/channels/named_barrier_value.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("NamedBarrierValue", "NamedBarrierValueAfterFinish")
|
||||
|
||||
|
||||
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
|
||||
"""A channel that waits until all named values are received before making the value available."""
|
||||
|
||||
__slots__ = ("names", "seen")
|
||||
|
||||
names: set[Value]
|
||||
seen: set[Value]
|
||||
|
||||
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
||||
super().__init__(typ)
|
||||
self.names = names
|
||||
self.seen: set[str] = set()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, NamedBarrierValue) and value.names == self.names
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
empty.seen = self.seen.copy()
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> set[Value]:
|
||||
return self.seen
|
||||
|
||||
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.seen = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
updated = False
|
||||
for value in values:
|
||||
if value in self.names:
|
||||
if value not in self.seen:
|
||||
self.seen.add(value)
|
||||
updated = True
|
||||
else:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': Value {value} not in {self.names}"
|
||||
)
|
||||
return updated
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.seen != self.names:
|
||||
raise EmptyChannelError()
|
||||
return None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.seen == self.names
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.seen == self.names:
|
||||
self.seen = set()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class NamedBarrierValueAfterFinish(
|
||||
Generic[Value], BaseChannel[Value, Value, set[Value]]
|
||||
):
|
||||
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
|
||||
|
||||
__slots__ = ("names", "seen", "finished")
|
||||
|
||||
names: set[Value]
|
||||
seen: set[Value]
|
||||
|
||||
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
||||
super().__init__(typ)
|
||||
self.names = names
|
||||
self.seen: set[str] = set()
|
||||
self.finished = False
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return (
|
||||
isinstance(value, NamedBarrierValueAfterFinish)
|
||||
and value.names == self.names
|
||||
)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
empty.seen = self.seen.copy()
|
||||
empty.finished = self.finished
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> tuple[set[Value], bool]:
|
||||
return (self.seen, self.finished)
|
||||
|
||||
def from_checkpoint(self, checkpoint: tuple[set[Value], bool]) -> Self:
|
||||
empty = self.__class__(self.typ, self.names)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.seen, empty.finished = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
updated = False
|
||||
for value in values:
|
||||
if value in self.names:
|
||||
if value not in self.seen:
|
||||
self.seen.add(value)
|
||||
updated = True
|
||||
else:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': Value {value} not in {self.names}"
|
||||
)
|
||||
return updated
|
||||
|
||||
def get(self) -> Value:
|
||||
if not self.finished or self.seen != self.names:
|
||||
raise EmptyChannelError()
|
||||
return None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.finished and self.seen == self.names
|
||||
|
||||
def consume(self) -> bool:
|
||||
if self.finished and self.seen == self.names:
|
||||
self.finished = False
|
||||
self.seen = set()
|
||||
return True
|
||||
return False
|
||||
|
||||
def finish(self) -> bool:
|
||||
if not self.finished and self.seen == self.names:
|
||||
self.finished = True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
94
venv/Lib/site-packages/langgraph/channels/topic.py
Normal file
94
venv/Lib/site-packages/langgraph/channels/topic.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError
|
||||
|
||||
__all__ = ("Topic",)
|
||||
|
||||
|
||||
def _flatten(values: Sequence[Value | list[Value]]) -> Iterator[Value]:
|
||||
for value in values:
|
||||
if isinstance(value, list):
|
||||
yield from value
|
||||
else:
|
||||
yield value
|
||||
|
||||
|
||||
class Topic(
|
||||
Generic[Value],
|
||||
BaseChannel[Sequence[Value], Value | list[Value], list[Value]],
|
||||
):
|
||||
"""A configurable PubSub Topic.
|
||||
|
||||
Args:
|
||||
typ: The type of the value stored in the channel.
|
||||
accumulate: Whether to accumulate values across steps. If `False`, the channel will be emptied after each step.
|
||||
"""
|
||||
|
||||
__slots__ = ("values", "accumulate")
|
||||
|
||||
def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
|
||||
super().__init__(typ)
|
||||
# attrs
|
||||
self.accumulate = accumulate
|
||||
# state
|
||||
self.values = list[Value]()
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, Topic) and value.accumulate == self.accumulate
|
||||
|
||||
@property
|
||||
def ValueType(self) -> Any:
|
||||
"""The type of the value stored in the channel."""
|
||||
return Sequence[self.typ] # type: ignore[name-defined]
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> Any:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ | list[self.typ] # type: ignore[name-defined]
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.accumulate)
|
||||
empty.key = self.key
|
||||
empty.values = self.values.copy()
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> list[Value]:
|
||||
return self.values
|
||||
|
||||
def from_checkpoint(self, checkpoint: list[Value]) -> Self:
|
||||
empty = self.__class__(self.typ, self.accumulate)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
if isinstance(checkpoint, tuple):
|
||||
# backwards compatibility
|
||||
empty.values = checkpoint[1]
|
||||
else:
|
||||
empty.values = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value | list[Value]]) -> bool:
|
||||
updated = False
|
||||
if not self.accumulate:
|
||||
updated = bool(self.values)
|
||||
self.values = list[Value]()
|
||||
if flat_values := tuple(_flatten(values)):
|
||||
updated = True
|
||||
self.values.extend(flat_values)
|
||||
return updated
|
||||
|
||||
def get(self) -> Sequence[Value]:
|
||||
if self.values:
|
||||
return list(self.values)
|
||||
else:
|
||||
raise EmptyChannelError
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return bool(self.values)
|
||||
73
venv/Lib/site-packages/langgraph/channels/untracked_value.py
Normal file
73
venv/Lib/site-packages/langgraph/channels/untracked_value.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||||
|
||||
__all__ = ("UntrackedValue",)
|
||||
|
||||
|
||||
class UntrackedValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the last value received, never checkpointed."""
|
||||
|
||||
__slots__ = ("value", "guard")
|
||||
|
||||
guard: bool
|
||||
value: Value | Any
|
||||
|
||||
def __init__(self, typ: type[Value], guard: bool = True) -> None:
|
||||
super().__init__(typ)
|
||||
self.guard = guard
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, UntrackedValue) and value.guard == self.guard
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def checkpoint(self) -> Value | Any:
|
||||
return MISSING
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.guard)
|
||||
empty.key = self.key
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if len(values) == 0:
|
||||
return False
|
||||
if len(values) != 1 and self.guard:
|
||||
raise InvalidUpdateError(
|
||||
f"At key '{self.key}': UntrackedValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
|
||||
)
|
||||
|
||||
self.value = values[-1]
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
Reference in New Issue
Block a user