initial commit
This commit is contained in:
12
venv/Lib/site-packages/langgraph/graph/__init__.py
Normal file
12
venv/Lib/site-packages/langgraph/graph/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
|
||||
from langgraph.graph.state import StateGraph
|
||||
|
||||
__all__ = (
|
||||
"END",
|
||||
"START",
|
||||
"StateGraph",
|
||||
"add_messages",
|
||||
"MessagesState",
|
||||
"MessageGraph",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
225
venv/Lib/site-packages/langgraph/graph/_branch.py
Normal file
225
venv/Lib/site-packages/langgraph/graph/_branch.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Hashable, Sequence
|
||||
from inspect import (
|
||||
isfunction,
|
||||
ismethod,
|
||||
signature,
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
)
|
||||
|
||||
from langgraph._internal._runnable import (
|
||||
RunnableCallable,
|
||||
)
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry
|
||||
from langgraph.types import Send
|
||||
|
||||
_Writer = Callable[
|
||||
[Sequence[str | Send], bool],
|
||||
Sequence[ChannelWriteEntry | Send],
|
||||
]
|
||||
|
||||
|
||||
def _get_branch_path_input_schema(
|
||||
path: Callable[..., Hashable | Sequence[Hashable]]
|
||||
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
|
||||
| Runnable[Any, Hashable | Sequence[Hashable]],
|
||||
) -> type[Any] | None:
|
||||
input = None
|
||||
# detect input schema annotation in the branch callable
|
||||
try:
|
||||
callable_: (
|
||||
Callable[..., Hashable | Sequence[Hashable]]
|
||||
| Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
|
||||
| None
|
||||
) = None
|
||||
if isinstance(path, (RunnableCallable, RunnableLambda)):
|
||||
if isfunction(path.func) or ismethod(path.func):
|
||||
callable_ = path.func
|
||||
elif (callable_method := getattr(path.func, "__call__", None)) and ismethod(
|
||||
callable_method
|
||||
):
|
||||
callable_ = callable_method
|
||||
elif isfunction(path.afunc) or ismethod(path.afunc):
|
||||
callable_ = path.afunc
|
||||
elif (
|
||||
callable_method := getattr(path.afunc, "__call__", None)
|
||||
) and ismethod(callable_method):
|
||||
callable_ = callable_method
|
||||
elif callable(path):
|
||||
callable_ = path
|
||||
|
||||
if callable_ is not None and (hints := get_type_hints(callable_)):
|
||||
first_parameter_name = next(
|
||||
iter(signature(cast(FunctionType, callable_)).parameters.keys())
|
||||
)
|
||||
if input_hint := hints.get(first_parameter_name):
|
||||
if isinstance(input_hint, type) and get_type_hints(input_hint):
|
||||
input = input_hint
|
||||
except (TypeError, StopIteration):
|
||||
pass
|
||||
|
||||
return input
|
||||
|
||||
|
||||
class BranchSpec(NamedTuple):
|
||||
path: Runnable[Any, Hashable | list[Hashable]]
|
||||
ends: dict[Hashable, str] | None
|
||||
input_schema: type[Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_path(
|
||||
cls,
|
||||
path: Runnable[Any, Hashable | list[Hashable]],
|
||||
path_map: dict[Hashable, str] | list[str] | None,
|
||||
infer_schema: bool = False,
|
||||
) -> BranchSpec:
|
||||
# coerce path_map to a dictionary
|
||||
path_map_: dict[Hashable, str] | None = None
|
||||
try:
|
||||
if isinstance(path_map, dict):
|
||||
path_map_ = path_map.copy()
|
||||
elif isinstance(path_map, list):
|
||||
path_map_ = {name: name for name in path_map}
|
||||
else:
|
||||
# find func
|
||||
func: Callable | None = None
|
||||
if isinstance(path, (RunnableCallable, RunnableLambda)):
|
||||
func = path.func or path.afunc
|
||||
if func is not None:
|
||||
# find callable method
|
||||
if (cal := getattr(path, "__call__", None)) and ismethod(cal):
|
||||
func = cal
|
||||
# get the return type
|
||||
if rtn_type := get_type_hints(func).get("return"):
|
||||
if get_origin(rtn_type) is Literal:
|
||||
path_map_ = {name: name for name in get_args(rtn_type)}
|
||||
except Exception:
|
||||
pass
|
||||
# infer input schema
|
||||
input_schema = _get_branch_path_input_schema(path) if infer_schema else None
|
||||
# create branch
|
||||
return cls(path=path, ends=path_map_, input_schema=input_schema)
|
||||
|
||||
def run(
|
||||
self,
|
||||
writer: _Writer,
|
||||
reader: Callable[[RunnableConfig], Any] | None = None,
|
||||
) -> RunnableCallable:
|
||||
return ChannelWrite.register_writer(
|
||||
RunnableCallable(
|
||||
func=self._route,
|
||||
afunc=self._aroute,
|
||||
writer=writer,
|
||||
reader=reader,
|
||||
name=None,
|
||||
trace=False,
|
||||
),
|
||||
list(
|
||||
zip_longest(
|
||||
writer([e for e in self.ends.values()], True),
|
||||
[str(la) for la, e in self.ends.items()],
|
||||
)
|
||||
)
|
||||
if self.ends
|
||||
else None,
|
||||
)
|
||||
|
||||
def _route(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
reader: Callable[[RunnableConfig], Any] | None,
|
||||
writer: _Writer,
|
||||
) -> Runnable:
|
||||
if reader:
|
||||
value = reader(config)
|
||||
# passthrough additional keys from node to branch
|
||||
# only doable when using dict states
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and isinstance(input, dict)
|
||||
and self.input_schema is None
|
||||
):
|
||||
value = {**input, **value}
|
||||
else:
|
||||
value = input
|
||||
result = self.path.invoke(value, config)
|
||||
return self._finish(writer, input, result, config)
|
||||
|
||||
async def _aroute(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
reader: Callable[[RunnableConfig], Any] | None,
|
||||
writer: _Writer,
|
||||
) -> Runnable:
|
||||
if reader:
|
||||
value = reader(config)
|
||||
# passthrough additional keys from node to branch
|
||||
# only doable when using dict states
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and isinstance(input, dict)
|
||||
and self.input_schema is None
|
||||
):
|
||||
value = {**input, **value}
|
||||
else:
|
||||
value = input
|
||||
result = await self.path.ainvoke(value, config)
|
||||
return self._finish(writer, input, result, config)
|
||||
|
||||
def _finish(
|
||||
self,
|
||||
writer: _Writer,
|
||||
input: Any,
|
||||
result: Any,
|
||||
config: RunnableConfig,
|
||||
) -> Runnable | Any:
|
||||
if not isinstance(result, (list, tuple)):
|
||||
result = [result]
|
||||
if self.ends:
|
||||
destinations: Sequence[Send | str] = [
|
||||
r if isinstance(r, Send) else self.ends[r] for r in result
|
||||
]
|
||||
else:
|
||||
destinations = cast(Sequence[Send | str], result)
|
||||
if any(dest is None or dest == START for dest in destinations):
|
||||
raise ValueError("Branch did not return a valid destination")
|
||||
if any(p.node == END for p in destinations if isinstance(p, Send)):
|
||||
raise InvalidUpdateError("Cannot send a packet to the END node")
|
||||
entries = writer(destinations, False)
|
||||
if not entries:
|
||||
return input
|
||||
else:
|
||||
need_passthrough = False
|
||||
for e in entries:
|
||||
if isinstance(e, ChannelWriteEntry):
|
||||
if e.value is PASSTHROUGH:
|
||||
need_passthrough = True
|
||||
break
|
||||
if need_passthrough:
|
||||
return ChannelWrite(entries)
|
||||
else:
|
||||
ChannelWrite.do_write(config, entries)
|
||||
return input
|
||||
92
venv/Lib/site-packages/langgraph/graph/_node.py
Normal file
92
venv/Lib/site-packages/langgraph/graph/_node.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Protocol, TypeAlias
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from langgraph._internal._typing import EMPTY_SEQ
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import CachePolicy, RetryPolicy, StreamWriter
|
||||
from langgraph.typing import ContextT, NodeInputT, NodeInputT_contra
|
||||
|
||||
|
||||
class _Node(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfig(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, config: RunnableConfig) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithWriter(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, *, writer: StreamWriter) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(self, state: NodeInputT_contra, *, store: BaseStore) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithWriterStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, writer: StreamWriter, store: BaseStore
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigWriter(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, config: RunnableConfig, writer: StreamWriter
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, config: RunnableConfig, store: BaseStore
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithConfigWriterStore(Protocol[NodeInputT_contra]):
|
||||
def __call__(
|
||||
self,
|
||||
state: NodeInputT_contra,
|
||||
*,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter,
|
||||
store: BaseStore,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class _NodeWithRuntime(Protocol[NodeInputT_contra, ContextT]):
|
||||
def __call__(
|
||||
self, state: NodeInputT_contra, *, runtime: Runtime[ContextT]
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
# TODO: we probably don't want to explicitly support the config / store signatures once
|
||||
# we move to adding a context arg. Maybe what we do is we add support for kwargs with param spec
|
||||
# this is purely for typing purposes though, so can easily change in the coming weeks.
|
||||
StateNode: TypeAlias = (
|
||||
_Node[NodeInputT]
|
||||
| _NodeWithConfig[NodeInputT]
|
||||
| _NodeWithWriter[NodeInputT]
|
||||
| _NodeWithStore[NodeInputT]
|
||||
| _NodeWithWriterStore[NodeInputT]
|
||||
| _NodeWithConfigWriter[NodeInputT]
|
||||
| _NodeWithConfigStore[NodeInputT]
|
||||
| _NodeWithConfigWriterStore[NodeInputT]
|
||||
| _NodeWithRuntime[NodeInputT, ContextT]
|
||||
| Runnable[NodeInputT, Any]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class StateNodeSpec(Generic[NodeInputT, ContextT]):
|
||||
runnable: StateNode[NodeInputT, ContextT]
|
||||
metadata: dict[str, Any] | None
|
||||
input_schema: type[NodeInputT]
|
||||
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
|
||||
cache_policy: CachePolicy | None
|
||||
ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ
|
||||
defer: bool = False
|
||||
372
venv/Lib/site-packages/langgraph/graph/message.py
Normal file
372
venv/Lib/site-packages/langgraph/graph/message.py
Normal file
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
convert_to_messages,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from typing_extensions import TypedDict, deprecated
|
||||
|
||||
from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, NS_SEP
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.warnings import LangGraphDeprecatedSinceV10
|
||||
|
||||
__all__ = (
|
||||
"add_messages",
|
||||
"MessagesState",
|
||||
"MessageGraph",
|
||||
"REMOVE_ALL_MESSAGES",
|
||||
)
|
||||
|
||||
Messages = list[MessageLikeRepresentation] | MessageLikeRepresentation
|
||||
|
||||
REMOVE_ALL_MESSAGES = "__remove_all__"
|
||||
|
||||
|
||||
def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
|
||||
def _add_messages(
|
||||
left: Messages | None = None, right: Messages | None = None, **kwargs: Any
|
||||
) -> Messages | Callable[[Messages, Messages], Messages]:
|
||||
if left is not None and right is not None:
|
||||
return func(left, right, **kwargs)
|
||||
elif left is not None or right is not None:
|
||||
msg = (
|
||||
f"Must specify non-null arguments for both 'left' and 'right'. Only "
|
||||
f"received: '{'left' if left else 'right'}'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
return partial(func, **kwargs)
|
||||
|
||||
_add_messages.__doc__ = func.__doc__
|
||||
return cast(Callable[[Messages, Messages], Messages], _add_messages)
|
||||
|
||||
|
||||
@_add_messages_wrapper
|
||||
def add_messages(
|
||||
left: Messages,
|
||||
right: Messages,
|
||||
*,
|
||||
format: Literal["langchain-openai"] | None = None,
|
||||
) -> Messages:
|
||||
"""Merges two lists of messages, updating existing messages by ID.
|
||||
|
||||
By default, this ensures the state is "append-only", unless the
|
||||
new message has the same ID as an existing message.
|
||||
|
||||
Args:
|
||||
left: The base list of `Messages`.
|
||||
right: The list of `Messages` (or single `Message`) to merge
|
||||
into the base list.
|
||||
format: The format to return messages in. If `None` then `Messages` will be
|
||||
returned as is. If `langchain-openai` then `Messages` will be returned as
|
||||
`BaseMessage` objects with their contents formatted to match OpenAI message
|
||||
format, meaning contents can be string, `'text'` blocks, or `'image_url'` blocks
|
||||
and tool responses are returned as their own `ToolMessage` objects.
|
||||
|
||||
!!! important "Requirement"
|
||||
|
||||
Must have `langchain-core>=0.3.11` installed to use this feature.
|
||||
|
||||
Returns:
|
||||
A new list of messages with the messages from `right` merged into `left`.
|
||||
If a message in `right` has the same ID as a message in `left`, the
|
||||
message from `right` will replace the message from `left`.
|
||||
|
||||
Example: Basic usage
|
||||
```python
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
msgs1 = [HumanMessage(content="Hello", id="1")]
|
||||
msgs2 = [AIMessage(content="Hi there!", id="2")]
|
||||
add_messages(msgs1, msgs2)
|
||||
# [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
|
||||
```
|
||||
|
||||
Example: Overwrite existing message
|
||||
```python
|
||||
msgs1 = [HumanMessage(content="Hello", id="1")]
|
||||
msgs2 = [HumanMessage(content="Hello again", id="1")]
|
||||
add_messages(msgs1, msgs2)
|
||||
# [HumanMessage(content='Hello again', id='1')]
|
||||
```
|
||||
|
||||
Example: Use in a StateGraph
|
||||
```python
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
|
||||
builder.set_entry_point("chatbot")
|
||||
builder.set_finish_point("chatbot")
|
||||
graph = builder.compile()
|
||||
graph.invoke({})
|
||||
# {'messages': [AIMessage(content='Hello', id=...)]}
|
||||
```
|
||||
|
||||
Example: Use OpenAI message format
|
||||
```python
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, add_messages
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
messages: Annotated[list, add_messages(format="langchain-openai")]
|
||||
|
||||
|
||||
def chatbot_node(state: State) -> list:
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here's an image:",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "1234",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("chatbot", chatbot_node)
|
||||
builder.set_entry_point("chatbot")
|
||||
builder.set_finish_point("chatbot")
|
||||
graph = builder.compile()
|
||||
graph.invoke({"messages": []})
|
||||
# {
|
||||
# 'messages': [
|
||||
# HumanMessage(
|
||||
# content=[
|
||||
# {"type": "text", "text": "Here's an image:"},
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {"url": "data:image/jpeg;base64,1234"},
|
||||
# },
|
||||
# ],
|
||||
# ),
|
||||
# ]
|
||||
# }
|
||||
```
|
||||
|
||||
"""
|
||||
remove_all_idx = None
|
||||
# coerce to list
|
||||
if not isinstance(left, list):
|
||||
left = [left] # type: ignore[assignment]
|
||||
if not isinstance(right, list):
|
||||
right = [right] # type: ignore[assignment]
|
||||
# coerce to message
|
||||
left = [
|
||||
message_chunk_to_message(cast(BaseMessageChunk, m))
|
||||
for m in convert_to_messages(left)
|
||||
]
|
||||
right = [
|
||||
message_chunk_to_message(cast(BaseMessageChunk, m))
|
||||
for m in convert_to_messages(right)
|
||||
]
|
||||
# assign missing ids
|
||||
for m in left:
|
||||
if m.id is None:
|
||||
m.id = str(uuid.uuid4())
|
||||
for idx, m in enumerate(right):
|
||||
if m.id is None:
|
||||
m.id = str(uuid.uuid4())
|
||||
if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
|
||||
remove_all_idx = idx
|
||||
|
||||
if remove_all_idx is not None:
|
||||
return right[remove_all_idx + 1 :]
|
||||
|
||||
# merge
|
||||
merged = left.copy()
|
||||
merged_by_id = {m.id: i for i, m in enumerate(merged)}
|
||||
ids_to_remove = set()
|
||||
for m in right:
|
||||
if (existing_idx := merged_by_id.get(m.id)) is not None:
|
||||
if isinstance(m, RemoveMessage):
|
||||
ids_to_remove.add(m.id)
|
||||
else:
|
||||
ids_to_remove.discard(m.id)
|
||||
merged[existing_idx] = m
|
||||
else:
|
||||
if isinstance(m, RemoveMessage):
|
||||
raise ValueError(
|
||||
f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
|
||||
)
|
||||
|
||||
merged_by_id[m.id] = len(merged)
|
||||
merged.append(m)
|
||||
merged = [m for m in merged if m.id not in ids_to_remove]
|
||||
|
||||
if format == "langchain-openai":
|
||||
merged = _format_messages(merged)
|
||||
elif format:
|
||||
msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
pass
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
@deprecated(
|
||||
"MessageGraph is deprecated in langgraph 1.0.0, to be removed in 2.0.0. Please use StateGraph with a `messages` key instead.",
|
||||
category=None,
|
||||
)
|
||||
class MessageGraph(StateGraph):
|
||||
"""A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
|
||||
|
||||
MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
|
||||
Each node in a MessageGraph takes a list of messages as input and returns zero or more
|
||||
messages as output. The `add_messages` function is used to merge the output messages from each node
|
||||
into the existing list of messages in the graph's state.
|
||||
|
||||
Examples:
|
||||
```pycon
|
||||
>>> from langgraph.graph.message import MessageGraph
|
||||
...
|
||||
>>> builder = MessageGraph()
|
||||
>>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
|
||||
>>> builder.set_entry_point("chatbot")
|
||||
>>> builder.set_finish_point("chatbot")
|
||||
>>> builder.compile().invoke([("user", "Hi there.")])
|
||||
[HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
|
||||
```
|
||||
|
||||
```pycon
|
||||
>>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
>>> from langgraph.graph.message import MessageGraph
|
||||
...
|
||||
>>> builder = MessageGraph()
|
||||
>>> builder.add_node(
|
||||
... "chatbot",
|
||||
... lambda state: [
|
||||
... AIMessage(
|
||||
... content="Hello!",
|
||||
... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
|
||||
... )
|
||||
... ],
|
||||
... )
|
||||
>>> builder.add_node(
|
||||
... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
|
||||
... )
|
||||
>>> builder.set_entry_point("chatbot")
|
||||
>>> builder.add_edge("chatbot", "search")
|
||||
>>> builder.set_finish_point("search")
|
||||
>>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
|
||||
{'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
|
||||
AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
|
||||
ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
warnings.warn(
|
||||
"MessageGraph is deprecated in LangGraph v1.0.0, to be removed in v2.0.0. Please use StateGraph with a `messages` key instead.",
|
||||
category=LangGraphDeprecatedSinceV10,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
|
||||
|
||||
def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
try:
|
||||
from langchain_core.messages import convert_to_openai_messages
|
||||
except ImportError:
|
||||
msg = (
|
||||
"Must have langchain-core>=0.3.11 installed to use automatic message "
|
||||
"formatting (format='langchain-openai'). Please update your langchain-core "
|
||||
"version or remove the 'format' flag. Returning un-formatted "
|
||||
"messages."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return list(messages)
|
||||
else:
|
||||
return convert_to_messages(convert_to_openai_messages(messages))
|
||||
|
||||
|
||||
def push_message(
|
||||
message: MessageLikeRepresentation | BaseMessageChunk,
|
||||
*,
|
||||
state_key: str | None = "messages",
|
||||
) -> AnyMessage:
|
||||
"""Write a message manually to the `messages` / `messages-tuple` stream mode.
|
||||
|
||||
Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
|
||||
"""
|
||||
|
||||
from langchain_core.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
)
|
||||
|
||||
from langgraph.config import get_config
|
||||
from langgraph.pregel._messages import StreamMessagesHandler
|
||||
|
||||
config = get_config()
|
||||
message = next(x for x in convert_to_messages([message]))
|
||||
|
||||
if message.id is None:
|
||||
raise ValueError("Message ID is required")
|
||||
|
||||
if isinstance(config["callbacks"], BaseCallbackManager):
|
||||
manager = config["callbacks"]
|
||||
handlers = manager.handlers
|
||||
elif isinstance(config["callbacks"], list) and all(
|
||||
isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
|
||||
):
|
||||
handlers = config["callbacks"]
|
||||
|
||||
if stream_handler := next(
|
||||
(x for x in handlers if isinstance(x, StreamMessagesHandler)), None
|
||||
):
|
||||
metadata = config["metadata"]
|
||||
message_meta = (
|
||||
tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
|
||||
metadata,
|
||||
)
|
||||
stream_handler._emit(message_meta, message, dedupe=False)
|
||||
|
||||
if state_key:
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, message)])
|
||||
|
||||
return message
|
||||
1731
venv/Lib/site-packages/langgraph/graph/state.py
Normal file
1731
venv/Lib/site-packages/langgraph/graph/state.py
Normal file
File diff suppressed because it is too large
Load Diff
227
venv/Lib/site-packages/langgraph/graph/ui.py
Normal file
227
venv/Lib/site-packages/langgraph/graph/ui.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langgraph.config import get_config, get_stream_writer
|
||||
from langgraph.constants import CONF
|
||||
|
||||
__all__ = (
|
||||
"UIMessage",
|
||||
"RemoveUIMessage",
|
||||
"AnyUIMessage",
|
||||
"push_ui_message",
|
||||
"delete_ui_message",
|
||||
"ui_message_reducer",
|
||||
)
|
||||
|
||||
|
||||
class UIMessage(TypedDict):
|
||||
"""A message type for UI updates in LangGraph.
|
||||
|
||||
This TypedDict represents a UI message that can be sent to update the UI state.
|
||||
It contains information about the UI component to render and its properties.
|
||||
|
||||
Attributes:
|
||||
type: Literal type indicating this is a UI message.
|
||||
id: Unique identifier for the UI message.
|
||||
name: Name of the UI component to render.
|
||||
props: Properties to pass to the UI component.
|
||||
metadata: Additional metadata about the UI message.
|
||||
"""
|
||||
|
||||
type: Literal["ui"]
|
||||
id: str
|
||||
name: str
|
||||
props: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class RemoveUIMessage(TypedDict):
|
||||
"""A message type for removing UI components in LangGraph.
|
||||
|
||||
This TypedDict represents a message that can be sent to remove a UI component
|
||||
from the current state.
|
||||
|
||||
Attributes:
|
||||
type: Literal type indicating this is a remove-ui message.
|
||||
id: Unique identifier of the UI message to remove.
|
||||
"""
|
||||
|
||||
type: Literal["remove-ui"]
|
||||
id: str
|
||||
|
||||
|
||||
AnyUIMessage = UIMessage | RemoveUIMessage
|
||||
|
||||
|
||||
def push_ui_message(
|
||||
name: str,
|
||||
props: dict[str, Any],
|
||||
*,
|
||||
id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
message: AnyMessage | None = None,
|
||||
state_key: str | None = "ui",
|
||||
merge: bool = False,
|
||||
) -> UIMessage:
|
||||
"""Push a new UI message to update the UI state.
|
||||
|
||||
This function creates and sends a UI message that will be rendered in the UI.
|
||||
It also updates the graph state with the new UI message.
|
||||
|
||||
Args:
|
||||
name: Name of the UI component to render.
|
||||
props: Properties to pass to the UI component.
|
||||
id: Optional unique identifier for the UI message.
|
||||
If not provided, a random UUID will be generated.
|
||||
metadata: Optional additional metadata about the UI message.
|
||||
message: Optional message object to associate with the UI message.
|
||||
state_key: Key in the graph state where the UI messages are stored.
|
||||
merge: Whether to merge props with existing UI message (True) or replace
|
||||
them (False).
|
||||
|
||||
Returns:
|
||||
The created UI message.
|
||||
|
||||
Example:
|
||||
```python
|
||||
push_ui_message(
|
||||
name="component-name",
|
||||
props={"content": "Hello world"},
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
from langgraph._internal._constants import CONFIG_KEY_SEND
|
||||
|
||||
writer = get_stream_writer()
|
||||
config = get_config()
|
||||
|
||||
message_id = None
|
||||
if message:
|
||||
if isinstance(message, dict) and "id" in message:
|
||||
message_id = message.get("id")
|
||||
elif hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
|
||||
evt: UIMessage = {
|
||||
"type": "ui",
|
||||
"id": id or str(uuid4()),
|
||||
"name": name,
|
||||
"props": props,
|
||||
"metadata": {
|
||||
"merge": merge,
|
||||
"run_id": config.get("run_id", None),
|
||||
"tags": config.get("tags", None),
|
||||
"name": config.get("run_name", None),
|
||||
**(metadata or {}),
|
||||
**({"message_id": message_id} if message_id else {}),
|
||||
},
|
||||
}
|
||||
|
||||
writer(evt)
|
||||
if state_key:
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
|
||||
|
||||
return evt
|
||||
|
||||
|
||||
def delete_ui_message(id: str, *, state_key: str = "ui") -> RemoveUIMessage:
|
||||
"""Delete a UI message by ID from the UI state.
|
||||
|
||||
This function creates and sends a message to remove a UI component from the current state.
|
||||
It also updates the graph state to remove the UI message.
|
||||
|
||||
Args:
|
||||
id: Unique identifier of the UI component to remove.
|
||||
state_key: Key in the graph state where the UI messages are stored. Defaults to "ui".
|
||||
|
||||
Returns:
|
||||
The remove UI message.
|
||||
|
||||
Example:
|
||||
```python
|
||||
delete_ui_message("message-123")
|
||||
```
|
||||
|
||||
"""
|
||||
from langgraph._internal._constants import CONFIG_KEY_SEND
|
||||
|
||||
writer = get_stream_writer()
|
||||
config = get_config()
|
||||
|
||||
evt: RemoveUIMessage = {"type": "remove-ui", "id": id}
|
||||
|
||||
writer(evt)
|
||||
config[CONF][CONFIG_KEY_SEND]([(state_key, evt)])
|
||||
|
||||
return evt
|
||||
|
||||
|
||||
def ui_message_reducer(
|
||||
left: list[AnyUIMessage] | AnyUIMessage,
|
||||
right: list[AnyUIMessage] | AnyUIMessage,
|
||||
) -> list[AnyUIMessage]:
|
||||
"""Merge two lists of UI messages, supporting removing UI messages.
|
||||
|
||||
This function combines two lists of UI messages, handling both regular UI messages
|
||||
and `remove-ui` messages. When a `remove-ui` message is encountered, it removes any
|
||||
UI message with the matching ID from the current state.
|
||||
|
||||
Args:
|
||||
left: First list of UI messages or single UI message.
|
||||
right: Second list of UI messages or single UI message.
|
||||
|
||||
Returns:
|
||||
Combined list of UI messages with removals applied.
|
||||
|
||||
Example:
|
||||
```python
|
||||
messages = ui_message_reducer(
|
||||
[{"type": "ui", "id": "1", "name": "Chat", "props": {}}],
|
||||
{"type": "remove-ui", "id": "1"},
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
if not isinstance(left, list):
|
||||
left = [left]
|
||||
|
||||
if not isinstance(right, list):
|
||||
right = [right]
|
||||
|
||||
# merge messages
|
||||
merged = left.copy()
|
||||
merged_by_id = {m.get("id"): i for i, m in enumerate(merged)}
|
||||
ids_to_remove = set()
|
||||
|
||||
for msg in right:
|
||||
msg_id = msg.get("id")
|
||||
|
||||
if (existing_idx := merged_by_id.get(msg_id)) is not None:
|
||||
if msg.get("type") == "remove-ui":
|
||||
ids_to_remove.add(msg_id)
|
||||
else:
|
||||
ids_to_remove.discard(msg_id)
|
||||
|
||||
if cast(UIMessage, msg).get("metadata", {}).get("merge", False):
|
||||
prev_msg = merged[existing_idx]
|
||||
msg = msg.copy()
|
||||
msg["props"] = {**prev_msg["props"], **msg["props"]}
|
||||
|
||||
merged[existing_idx] = msg
|
||||
else:
|
||||
if msg.get("type") == "remove-ui":
|
||||
raise ValueError(
|
||||
f"Attempting to delete an UI message with an ID that doesn't exist ('{msg_id}')"
|
||||
)
|
||||
|
||||
merged_by_id[msg_id] = len(merged)
|
||||
merged.append(msg)
|
||||
|
||||
merged = [m for m in merged if m.get("id") not in ids_to_remove]
|
||||
return merged
|
||||
Reference in New Issue
Block a user