initial commit
This commit is contained in:
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
134
venv/Lib/site-packages/langgraph/channels/binop.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import NotRequired, Required, Self
|
||||
|
||||
from langgraph._internal._constants import OVERWRITE
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.channels.base import BaseChannel, Value
|
||||
from langgraph.errors import (
|
||||
EmptyChannelError,
|
||||
ErrorCode,
|
||||
InvalidUpdateError,
|
||||
create_error_message,
|
||||
)
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
__all__ = ("BinaryOperatorAggregate",)
|
||||
|
||||
|
||||
# Adapted from typing_extensions
|
||||
def _strip_extras(t): # type: ignore[no-untyped-def]
|
||||
"""Strips Annotated, Required and NotRequired from a given type."""
|
||||
if hasattr(t, "__origin__"):
|
||||
return _strip_extras(t.__origin__)
|
||||
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
|
||||
return _strip_extras(t.__args__[0])
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def _get_overwrite(value: Any) -> tuple[bool, Any]:
|
||||
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
|
||||
if isinstance(value, Overwrite):
|
||||
return True, value.value
|
||||
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
|
||||
return True, value[OVERWRITE]
|
||||
return False, None
|
||||
|
||||
|
||||
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
|
||||
"""Stores the result of applying a binary operator to the current value and each new value.
|
||||
|
||||
```python
|
||||
import operator
|
||||
|
||||
total = Channels.BinaryOperatorAggregate(int, operator.add)
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("value", "operator")
|
||||
|
||||
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
|
||||
super().__init__(typ)
|
||||
self.operator = operator
|
||||
# special forms from typing or collections.abc are not instantiable
|
||||
# so we need to replace them with their concrete counterparts
|
||||
typ = _strip_extras(typ)
|
||||
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
|
||||
typ = list
|
||||
if typ in (collections.abc.Set, collections.abc.MutableSet):
|
||||
typ = set
|
||||
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
|
||||
typ = dict
|
||||
try:
|
||||
self.value = typ()
|
||||
except Exception:
|
||||
self.value = MISSING
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(value, BinaryOperatorAggregate) and (
|
||||
value.operator is self.operator
|
||||
if value.operator.__name__ != "<lambda>"
|
||||
and self.operator.__name__ != "<lambda>"
|
||||
else True
|
||||
)
|
||||
|
||||
@property
|
||||
def ValueType(self) -> type[Value]:
|
||||
"""The type of the value stored in the channel."""
|
||||
return self.typ
|
||||
|
||||
@property
|
||||
def UpdateType(self) -> type[Value]:
|
||||
"""The type of the update received by the channel."""
|
||||
return self.typ
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the channel."""
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
empty.value = self.value
|
||||
return empty
|
||||
|
||||
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||||
empty = self.__class__(self.typ, self.operator)
|
||||
empty.key = self.key
|
||||
if checkpoint is not MISSING:
|
||||
empty.value = checkpoint
|
||||
return empty
|
||||
|
||||
def update(self, values: Sequence[Value]) -> bool:
|
||||
if not values:
|
||||
return False
|
||||
if self.value is MISSING:
|
||||
self.value = values[0]
|
||||
values = values[1:]
|
||||
seen_overwrite: bool = False
|
||||
for value in values:
|
||||
is_overwrite, overwrite_value = _get_overwrite(value)
|
||||
if is_overwrite:
|
||||
if seen_overwrite:
|
||||
msg = create_error_message(
|
||||
message="Can receive only one Overwrite value per super-step.",
|
||||
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||||
)
|
||||
raise InvalidUpdateError(msg)
|
||||
self.value = overwrite_value
|
||||
seen_overwrite = True
|
||||
continue
|
||||
if not seen_overwrite:
|
||||
self.value = self.operator(self.value, value)
|
||||
return True
|
||||
|
||||
def get(self) -> Value:
|
||||
if self.value is MISSING:
|
||||
raise EmptyChannelError()
|
||||
return self.value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self.value is not MISSING
|
||||
|
||||
def checkpoint(self) -> Value:
|
||||
return self.value
|
||||
Reference in New Issue
Block a user