initial commit
This commit is contained in:
174
venv/Lib/site-packages/langgraph/pregel/_io.py
Normal file
174
venv/Lib/site-packages/langgraph/pregel/_io.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
NULL_TASK_ID,
|
||||
RESUME,
|
||||
RETURN,
|
||||
TASKS,
|
||||
)
|
||||
from langgraph._internal._typing import EMPTY_SEQ, MISSING
|
||||
from langgraph.channels.base import BaseChannel, EmptyChannelError
|
||||
from langgraph.constants import START, TAG_HIDDEN
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.pregel._log import logger
|
||||
from langgraph.types import Command, PregelExecutableTask, Send
|
||||
|
||||
|
||||
def read_channel(
|
||||
channels: Mapping[str, BaseChannel],
|
||||
chan: str,
|
||||
*,
|
||||
catch: bool = True,
|
||||
) -> Any:
|
||||
try:
|
||||
return channels[chan].get()
|
||||
except EmptyChannelError:
|
||||
if catch:
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def read_channels(
|
||||
channels: Mapping[str, BaseChannel],
|
||||
select: Sequence[str] | str,
|
||||
*,
|
||||
skip_empty: bool = True,
|
||||
) -> dict[str, Any] | Any:
|
||||
if isinstance(select, str):
|
||||
return read_channel(channels, select)
|
||||
else:
|
||||
values: dict[str, Any] = {}
|
||||
for k in select:
|
||||
try:
|
||||
values[k] = read_channel(channels, k, catch=not skip_empty)
|
||||
except EmptyChannelError:
|
||||
pass
|
||||
return values
|
||||
|
||||
|
||||
def map_command(cmd: Command) -> Iterator[tuple[str, str, Any]]:
|
||||
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
|
||||
if cmd.graph == Command.PARENT:
|
||||
raise InvalidUpdateError("There is no parent graph")
|
||||
if cmd.goto:
|
||||
if isinstance(cmd.goto, (tuple, list)):
|
||||
sends = cmd.goto
|
||||
else:
|
||||
sends = [cmd.goto]
|
||||
for send in sends:
|
||||
if isinstance(send, Send):
|
||||
yield (NULL_TASK_ID, TASKS, send)
|
||||
elif isinstance(send, str):
|
||||
yield (NULL_TASK_ID, f"branch:to:{send}", START)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"In Command.goto, expected Send/str, got {type(send).__name__}"
|
||||
)
|
||||
if cmd.resume is not None:
|
||||
yield (NULL_TASK_ID, RESUME, cmd.resume)
|
||||
if cmd.update:
|
||||
for k, v in cmd._update_as_tuples():
|
||||
yield (NULL_TASK_ID, k, v)
|
||||
|
||||
|
||||
def map_input(
|
||||
input_channels: str | Sequence[str],
|
||||
chunk: dict[str, Any] | Any | None,
|
||||
) -> Iterator[tuple[str, Any]]:
|
||||
"""Map input chunk to a sequence of pending writes in the form (channel, value)."""
|
||||
if chunk is None:
|
||||
return
|
||||
elif isinstance(input_channels, str):
|
||||
yield (input_channels, chunk)
|
||||
else:
|
||||
if not isinstance(chunk, dict):
|
||||
raise TypeError(f"Expected chunk to be a dict, got {type(chunk).__name__}")
|
||||
for k in chunk:
|
||||
if k in input_channels:
|
||||
yield (k, chunk[k])
|
||||
else:
|
||||
logger.warning(f"Input channel {k} not found in {input_channels}")
|
||||
|
||||
|
||||
def map_output_values(
|
||||
output_channels: str | Sequence[str],
|
||||
pending_writes: Literal[True] | Sequence[tuple[str, Any]],
|
||||
channels: Mapping[str, BaseChannel],
|
||||
) -> Iterator[dict[str, Any] | Any]:
|
||||
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
|
||||
if isinstance(output_channels, str):
|
||||
if pending_writes is True or any(
|
||||
chan == output_channels for chan, _ in pending_writes
|
||||
):
|
||||
yield read_channel(channels, output_channels)
|
||||
else:
|
||||
if pending_writes is True or {
|
||||
c for c, _ in pending_writes if c in output_channels
|
||||
}:
|
||||
yield read_channels(channels, output_channels)
|
||||
|
||||
|
||||
def map_output_updates(
|
||||
output_channels: str | Sequence[str],
|
||||
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
|
||||
cached: bool = False,
|
||||
) -> Iterator[dict[str, Any | dict[str, Any]]]:
|
||||
"""Map pending writes (a sequence of tuples (channel, value)) to output chunk."""
|
||||
output_tasks = [
|
||||
(t, ww)
|
||||
for t, ww in tasks
|
||||
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
|
||||
and ww[0][0] != ERROR
|
||||
and ww[0][0] != INTERRUPT
|
||||
]
|
||||
if not output_tasks:
|
||||
return
|
||||
updated: list[tuple[str, Any]] = []
|
||||
for task, writes in output_tasks:
|
||||
rtn = next((value for chan, value in writes if chan == RETURN), MISSING)
|
||||
if rtn is not MISSING:
|
||||
updated.append((task.name, rtn))
|
||||
elif isinstance(output_channels, str):
|
||||
updated.extend(
|
||||
(task.name, value) for chan, value in writes if chan == output_channels
|
||||
)
|
||||
elif any(chan in output_channels for chan, _ in writes):
|
||||
counts = Counter(chan for chan, _ in writes)
|
||||
if any(counts[chan] > 1 for chan in output_channels):
|
||||
updated.extend(
|
||||
(
|
||||
task.name,
|
||||
{chan: value},
|
||||
)
|
||||
for chan, value in writes
|
||||
if chan in output_channels
|
||||
)
|
||||
else:
|
||||
updated.append(
|
||||
(
|
||||
task.name,
|
||||
{
|
||||
chan: value
|
||||
for chan, value in writes
|
||||
if chan in output_channels
|
||||
},
|
||||
)
|
||||
)
|
||||
grouped: dict[str, Any] = {t.name: [] for t, _ in output_tasks}
|
||||
for node, value in updated:
|
||||
grouped[node].append(value)
|
||||
for node, value in grouped.items():
|
||||
if len(value) == 0:
|
||||
grouped[node] = None
|
||||
if len(value) == 1:
|
||||
grouped[node] = value[0]
|
||||
if cached:
|
||||
grouped["__metadata__"] = {"cached": cached}
|
||||
yield grouped
|
||||
Reference in New Issue
Block a user