initial commit
This commit is contained in:
157
venv/Lib/site-packages/langgraph_sdk/sse.py
Normal file
157
venv/Lib/site-packages/langgraph_sdk/sse.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Adapted from httpx_sse to split lines on \n, \r, \r\n per the SSE spec."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import orjson
|
||||
|
||||
from langgraph_sdk.schema import StreamPart
|
||||
|
||||
BytesLike = bytes | bytearray | memoryview
|
||||
|
||||
|
||||
class BytesLineDecoder:
|
||||
"""
|
||||
Handles incrementally reading lines from text.
|
||||
|
||||
Has the same behaviour as the stdllib bytes splitlines,
|
||||
but handling the input iteratively.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = bytearray()
|
||||
self.trailing_cr: bool = False
|
||||
|
||||
def decode(self, text: bytes) -> list[BytesLike]:
|
||||
# See https://docs.python.org/3/glossary.html#term-universal-newlines
|
||||
NEWLINE_CHARS = b"\n\r"
|
||||
|
||||
# We always push a trailing `\r` into the next decode iteration.
|
||||
if self.trailing_cr:
|
||||
text = b"\r" + text
|
||||
self.trailing_cr = False
|
||||
if text.endswith(b"\r"):
|
||||
self.trailing_cr = True
|
||||
text = text[:-1]
|
||||
|
||||
if not text:
|
||||
# NOTE: the edge case input of empty text doesn't occur in practice,
|
||||
# because other httpx internals filter out this value
|
||||
return [] # pragma: no cover
|
||||
|
||||
trailing_newline = text[-1] in NEWLINE_CHARS
|
||||
lines = cast(list[BytesLike], text.splitlines())
|
||||
|
||||
if len(lines) == 1 and not trailing_newline:
|
||||
# No new lines, buffer the input and continue.
|
||||
self.buffer.extend(lines[0])
|
||||
return []
|
||||
|
||||
if self.buffer:
|
||||
# Include any existing buffer in the first portion of the
|
||||
# splitlines result.
|
||||
self.buffer.extend(lines[0])
|
||||
lines = cast(list[BytesLike], [self.buffer, *lines[1:]])
|
||||
self.buffer = bytearray()
|
||||
|
||||
if not trailing_newline:
|
||||
# If the last segment of splitlines is not newline terminated,
|
||||
# then drop it from our output and start a new buffer.
|
||||
self.buffer.extend(lines.pop())
|
||||
|
||||
return lines
|
||||
|
||||
def flush(self) -> list[BytesLike]:
|
||||
if not self.buffer and not self.trailing_cr:
|
||||
return []
|
||||
|
||||
lines = [self.buffer]
|
||||
self.buffer = bytearray()
|
||||
self.trailing_cr = False
|
||||
return lines
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
def __init__(self) -> None:
|
||||
self._event = ""
|
||||
self._data = bytearray()
|
||||
self._last_event_id = ""
|
||||
self._retry: int | None = None
|
||||
|
||||
@property
|
||||
def last_event_id(self) -> str | None:
|
||||
"""Return the last event identifier that was seen."""
|
||||
|
||||
return self._last_event_id or None
|
||||
|
||||
def decode(self, line: bytes) -> StreamPart | None:
|
||||
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
|
||||
|
||||
if not line:
|
||||
if (
|
||||
not self._event
|
||||
and not self._data
|
||||
and not self._last_event_id
|
||||
and self._retry is None
|
||||
):
|
||||
return None
|
||||
|
||||
sse = StreamPart(
|
||||
event=self._event,
|
||||
data=orjson.loads(self._data) if self._data else None, # type: ignore[invalid-argument-type]
|
||||
id=self.last_event_id,
|
||||
)
|
||||
|
||||
# NOTE: as per the SSE spec, do not reset last_event_id.
|
||||
self._event = ""
|
||||
self._data = bytearray()
|
||||
self._retry = None
|
||||
|
||||
return sse
|
||||
|
||||
if line.startswith(b":"):
|
||||
return None
|
||||
|
||||
fieldname, _, value = line.partition(b":")
|
||||
|
||||
if value.startswith(b" "):
|
||||
value = value[1:]
|
||||
|
||||
if fieldname == b"event":
|
||||
self._event = value.decode()
|
||||
elif fieldname == b"data":
|
||||
self._data.extend(value)
|
||||
elif fieldname == b"id":
|
||||
if b"\0" in value:
|
||||
pass
|
||||
else:
|
||||
self._last_event_id = value.decode()
|
||||
elif fieldname == b"retry":
|
||||
with contextlib.suppress(TypeError, ValueError):
|
||||
self._retry = int(value)
|
||||
else:
|
||||
pass # Field is ignored.
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def aiter_lines_raw(response: httpx.Response) -> AsyncIterator[BytesLike]:
|
||||
decoder = BytesLineDecoder()
|
||||
async for chunk in response.aiter_bytes():
|
||||
for line in decoder.decode(chunk):
|
||||
yield line
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
|
||||
|
||||
def iter_lines_raw(response: httpx.Response) -> Iterator[BytesLike]:
|
||||
decoder = BytesLineDecoder()
|
||||
for chunk in response.iter_bytes():
|
||||
for line in decoder.decode(chunk):
|
||||
yield line
|
||||
for line in decoder.flush():
|
||||
yield line
|
||||
Reference in New Issue
Block a user