initial commit
This commit is contained in:
250
venv/Lib/site-packages/langgraph/pregel/_messages.py
Normal file
250
venv/Lib/site-packages/langgraph/pregel/_messages.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langgraph._internal._constants import NS_SEP
|
||||
from langgraph.constants import TAG_HIDDEN, TAG_NOSTREAM
|
||||
from langgraph.pregel.protocol import StreamChunk
|
||||
from langgraph.types import Command
|
||||
|
||||
try:
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
except ImportError:
|
||||
_StreamingCallbackHandler = object # type: ignore
|
||||
|
||||
T = TypeVar("T")
|
||||
Meta = tuple[tuple[str, ...], dict[str, Any]]
|
||||
|
||||
|
||||
def _state_values(obj: Any) -> Sequence[Any]:
|
||||
"""Extract top-level field values from a state object (dict, BaseModel, or dataclass)."""
|
||||
if isinstance(obj, dict):
|
||||
return list(obj.values())
|
||||
elif isinstance(obj, BaseModel):
|
||||
return [getattr(obj, k) for k in type(obj).model_fields]
|
||||
elif is_dataclass(obj) and not isinstance(obj, type):
|
||||
return [getattr(obj, f.name) for f in fields(obj)]
|
||||
return ()
|
||||
|
||||
|
||||
class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
|
||||
"""A callback handler that implements stream_mode=messages.
|
||||
|
||||
Collects messages from:
|
||||
(1) chat model stream events; and
|
||||
(2) node outputs.
|
||||
"""
|
||||
|
||||
run_inline = True
|
||||
"""We want this callback to run in the main thread to avoid order/locking issues."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: Callable[[StreamChunk], None],
|
||||
subgraphs: bool,
|
||||
*,
|
||||
parent_ns: tuple[str, ...] | None = None,
|
||||
) -> None:
|
||||
"""Configure the handler to stream messages from LLMs and nodes.
|
||||
|
||||
Args:
|
||||
stream: A callable that takes a StreamChunk and emits it.
|
||||
subgraphs: Whether to emit messages from subgraphs.
|
||||
parent_ns: The namespace where the handler was created.
|
||||
We keep track of this namespace to allow calls to subgraphs that
|
||||
were explicitly requested as a stream with `messages` mode
|
||||
configured.
|
||||
|
||||
Example:
|
||||
parent_ns is used to handle scenarios where the subgraph is explicitly
|
||||
streamed with `stream_mode="messages"`.
|
||||
|
||||
```python
|
||||
def parent_graph_node():
|
||||
# This node is in the parent graph.
|
||||
async for event in some_subgraph(..., stream_mode="messages"):
|
||||
do something with event # <-- these events will be emitted
|
||||
return ...
|
||||
|
||||
parent_graph.invoke(subgraphs=False)
|
||||
```
|
||||
"""
|
||||
self.stream = stream
|
||||
self.subgraphs = subgraphs
|
||||
self.metadata: dict[UUID, Meta] = {}
|
||||
self.seen: set[int | str] = set()
|
||||
self.parent_ns = parent_ns
|
||||
|
||||
def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
|
||||
if dedupe and message.id in self.seen:
|
||||
return
|
||||
else:
|
||||
if message.id is None:
|
||||
message.id = str(uuid4())
|
||||
self.seen.add(message.id)
|
||||
self.stream((meta[0], "messages", (message, meta[1])))
|
||||
|
||||
def _find_and_emit_messages(self, meta: Meta, response: Any) -> None:
|
||||
if isinstance(response, BaseMessage):
|
||||
self._emit(meta, response, dedupe=True)
|
||||
elif isinstance(response, Sequence):
|
||||
for value in response:
|
||||
if isinstance(value, BaseMessage):
|
||||
self._emit(meta, value, dedupe=True)
|
||||
else:
|
||||
for value in _state_values(response):
|
||||
if isinstance(value, BaseMessage):
|
||||
self._emit(meta, value, dedupe=True)
|
||||
elif isinstance(value, Sequence):
|
||||
for item in value:
|
||||
if isinstance(item, BaseMessage):
|
||||
self._emit(meta, item, dedupe=True)
|
||||
|
||||
def tap_output_aiter(
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
return output
|
||||
|
||||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
|
||||
return output
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if metadata and (not tags or (TAG_NOSTREAM not in tags)):
|
||||
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
|
||||
:-1
|
||||
]
|
||||
if not self.subgraphs and len(ns) > 0 and ns != self.parent_ns:
|
||||
return
|
||||
if tags:
|
||||
if filtered_tags := [t for t in tags if not t.startswith("seq:step")]:
|
||||
metadata["tags"] = filtered_tags
|
||||
self.metadata[run_id] = (ns, metadata)
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if not isinstance(chunk, ChatGenerationChunk):
|
||||
return
|
||||
if meta := self.metadata.get(run_id):
|
||||
self._emit(meta, chunk.message)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if meta := self.metadata.get(run_id):
|
||||
if response.generations and response.generations[0]:
|
||||
gen = response.generations[0][0]
|
||||
if isinstance(gen, ChatGeneration):
|
||||
self._emit(meta, gen.message, dedupe=True)
|
||||
self.metadata.pop(run_id, None)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.metadata.pop(run_id, None)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if (
|
||||
metadata
|
||||
and kwargs.get("name") == metadata.get("langgraph_node")
|
||||
and (not tags or TAG_HIDDEN not in tags)
|
||||
):
|
||||
ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
|
||||
:-1
|
||||
]
|
||||
if not self.subgraphs and len(ns) > 0:
|
||||
return
|
||||
self.metadata[run_id] = (ns, metadata)
|
||||
for value in _state_values(inputs):
|
||||
if isinstance(value, BaseMessage):
|
||||
if value.id is not None:
|
||||
self.seen.add(value.id)
|
||||
elif isinstance(value, Sequence) and not isinstance(value, str):
|
||||
for item in value:
|
||||
if isinstance(item, BaseMessage):
|
||||
if item.id is not None:
|
||||
self.seen.add(item.id)
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
response: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if meta := self.metadata.pop(run_id, None):
|
||||
# Handle Command node updates
|
||||
if isinstance(response, Command):
|
||||
self._find_and_emit_messages(meta, response.update)
|
||||
# Handle list of Command updates
|
||||
elif isinstance(response, Sequence) and any(
|
||||
isinstance(value, Command) for value in response
|
||||
):
|
||||
for value in response:
|
||||
if isinstance(value, Command):
|
||||
self._find_and_emit_messages(meta, value.update)
|
||||
else:
|
||||
self._find_and_emit_messages(meta, value)
|
||||
# Handle basic updates / streaming
|
||||
else:
|
||||
self._find_and_emit_messages(meta, response)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.metadata.pop(run_id, None)
|
||||
Reference in New Issue
Block a user