initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,365 @@
"""Utilities for batching operations in a background task."""
from __future__ import annotations
import asyncio
import functools
import weakref
from collections.abc import Callable, Iterable
from typing import Any, Literal, TypeVar
from langgraph.store.base import (
NOT_PROVIDED,
BaseStore,
GetOp,
Item,
ListNamespacesOp,
MatchCondition,
NamespacePath,
NotProvided,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
_ensure_refresh,
_ensure_ttl,
_validate_namespace,
)
F = TypeVar("F", bound=Callable)
def _check_loop(func: F) -> F:
@functools.wraps(func)
def wrapper(store: AsyncBatchedBaseStore, *args: Any, **kwargs: Any) -> Any:
method_name: str = func.__name__
try:
current_loop = asyncio.get_running_loop()
if current_loop is store._loop:
replacement_str = (
f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)"
if method_name
else "For example, replace `store.get(...)` with `await store.aget(...)`"
)
raise asyncio.InvalidStateError(
f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. "
"This can lead to deadlocks or performance issues. "
"Please use the asynchronous interface for main thread operations. "
f"{replacement_str} "
)
except RuntimeError:
pass
return func(store, *args, **kwargs)
return wrapper
class AsyncBatchedBaseStore(BaseStore):
"""Efficiently batch operations in a background task."""
__slots__ = ("_loop", "_aqueue", "_task")
def __init__(self) -> None:
super().__init__()
self._loop = asyncio.get_running_loop()
self._aqueue: asyncio.Queue[tuple[asyncio.Future, Op]] = asyncio.Queue()
self._task: asyncio.Task | None = None
self._ensure_task()
def __del__(self) -> None:
try:
if self._task:
self._task.cancel()
except RuntimeError:
pass
def _ensure_task(self) -> None:
"""Ensure the background processing loop is running."""
if self._task is None or self._task.done():
self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self)))
async def aget(
self,
namespace: tuple[str, ...],
key: str,
*,
refresh_ttl: bool | None = None,
) -> Item | None:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
GetOp(
namespace,
key,
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
),
)
)
return await fut
async def asearch(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: str | None = None,
filter: dict[str, Any] | None = None,
limit: int = 10,
offset: int = 0,
refresh_ttl: bool | None = None,
) -> list[SearchItem]:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
SearchOp(
namespace_prefix,
filter,
limit,
offset,
query,
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
),
)
)
return await fut
async def aput(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Literal[False] | list[str] | None = None,
*,
ttl: float | None | NotProvided = NOT_PROVIDED,
) -> None:
self._ensure_task()
_validate_namespace(namespace)
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
PutOp(
namespace, key, value, index, ttl=_ensure_ttl(self.ttl_config, ttl)
),
)
)
return await fut
async def adelete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait((fut, PutOp(namespace, key, None)))
return await fut
async def alist_namespaces(
self,
*,
prefix: NamespacePath | None = None,
suffix: NamespacePath | None = None,
max_depth: int | None = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
self._ensure_task()
fut = self._loop.create_future()
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
self._aqueue.put_nowait((fut, op))
return await fut
@_check_loop
def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result()
@_check_loop
def get(
self,
namespace: tuple[str, ...],
key: str,
*,
refresh_ttl: bool | None = None,
) -> Item | None:
return asyncio.run_coroutine_threadsafe(
self.aget(namespace, key=key, refresh_ttl=refresh_ttl), self._loop
).result()
@_check_loop
def search(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: str | None = None,
filter: dict[str, Any] | None = None,
limit: int = 10,
offset: int = 0,
refresh_ttl: bool | None = None,
) -> list[SearchItem]:
return asyncio.run_coroutine_threadsafe(
self.asearch(
namespace_prefix,
query=query,
filter=filter,
limit=limit,
offset=offset,
refresh_ttl=refresh_ttl,
),
self._loop,
).result()
@_check_loop
def put(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Literal[False] | list[str] | None = None,
*,
ttl: float | None | NotProvided = NOT_PROVIDED,
) -> None:
_validate_namespace(namespace)
asyncio.run_coroutine_threadsafe(
self.aput(
namespace,
key=key,
value=value,
index=index,
ttl=_ensure_ttl(self.ttl_config, ttl),
),
self._loop,
).result()
@_check_loop
def delete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
asyncio.run_coroutine_threadsafe(
self.adelete(namespace, key=key), self._loop
).result()
@_check_loop
def list_namespaces(
self,
*,
prefix: NamespacePath | None = None,
suffix: NamespacePath | None = None,
max_depth: int | None = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
return asyncio.run_coroutine_threadsafe(
self.alist_namespaces(
prefix=prefix,
suffix=suffix,
max_depth=max_depth,
limit=limit,
offset=offset,
),
self._loop,
).result()
def _dedupe_ops(values: list[Op]) -> tuple[list[int] | None, list[Op]]:
"""Dedupe operations while preserving order for results.
Args:
values: List of operations to dedupe
Returns:
Tuple of (listen indices, deduped operations)
where listen indices map deduped operation results back to original positions
"""
if len(values) <= 1:
return None, list(values)
dedupped: list[Op] = []
listen: list[int] = []
puts: dict[tuple[tuple[str, ...], str], int] = {}
for op in values:
if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)):
try:
listen.append(dedupped.index(op))
except ValueError:
listen.append(len(dedupped))
dedupped.append(op)
elif isinstance(op, PutOp):
putkey = (op.namespace, op.key)
if putkey in puts:
# Overwrite previous put
ix = puts[putkey]
dedupped[ix] = op
listen.append(ix)
else:
puts[putkey] = len(dedupped)
listen.append(len(dedupped))
dedupped.append(op)
else: # Any new ops will be treated regularly
listen.append(len(dedupped))
dedupped.append(op)
return listen, dedupped
async def _run(
aqueue: asyncio.Queue[tuple[asyncio.Future, Op]],
store: weakref.ReferenceType[BaseStore],
) -> None:
while item := await aqueue.get():
# check if store is still alive
if s := store():
try:
# accumulate operations scheduled in same tick
items = [item]
try:
while item := aqueue.get_nowait():
items.append(item)
except asyncio.QueueEmpty:
pass
# get the operations to run
futs = [item[0] for item in items]
values = [item[1] for item in items]
# action each operation
try:
listen, dedupped = _dedupe_ops(values)
results = await s.abatch(dedupped)
if listen is not None:
results = [results[ix] for ix in listen]
# set the results of each operation
for fut, result in zip(futs, results, strict=False):
# guard against future being done (e.g. cancelled)
if not fut.done():
fut.set_result(result)
except Exception as e:
for fut in futs:
# guard against future being done (e.g. cancelled)
if not fut.done():
fut.set_exception(e)
finally:
# remove strong ref to store
del s
else:
break

View File

@@ -0,0 +1,433 @@
"""Utilities for working with embedding functions and LangChain's Embeddings interface.
This module provides tools to wrap arbitrary embedding functions (both sync and async)
into LangChain's Embeddings interface. This enables using custom embedding functions
with LangChain-compatible tools while maintaining support for both synchronous and
asynchronous operations.
"""
from __future__ import annotations
import asyncio
import functools
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from langchain_core.embeddings import Embeddings
EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]]
"""Type for synchronous embedding functions.
The function should take a sequence of strings and return a list of embeddings,
where each embedding is a list of floats. The dimensionality of the embeddings
should be consistent for all inputs.
"""
AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]]
"""Type for asynchronous embedding functions.
Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings.
"""
def ensure_embeddings(
embed: Embeddings | EmbeddingsFunc | AEmbeddingsFunc | str | None,
) -> Embeddings:
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
This function wraps arbitrary embedding functions to make them compatible with
LangChain's Embeddings interface. It handles both synchronous and asynchronous
functions.
Args:
embed: Either an existing Embeddings instance, or a function that converts
text to embeddings. If the function is async, it will be used for both
sync and async operations.
Returns:
An Embeddings instance that wraps the provided function(s).
??? example "Examples"
Wrap a synchronous embedding function:
```python
def my_embed_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
```
Wrap an asynchronous embedding function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_async_fn)
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
```
Initialize embeddings using a provider string:
```python
# Requires langchain>=0.3.9 and langgraph-checkpoint>=2.0.11
embeddings = ensure_embeddings("openai:text-embedding-3-small")
result = embeddings.embed_query("hello")
```
"""
if embed is None:
raise ValueError("embed must be provided")
if isinstance(embed, str):
init_embeddings = _get_init_embeddings()
if init_embeddings is None:
from importlib.metadata import PackageNotFoundError, version
try:
lc_version = version("langchain")
version_info = f"Found langchain version {lc_version}, but"
except PackageNotFoundError:
version_info = "langchain is not installed;"
raise ValueError(
f"Could not load embeddings from string '{embed}'. {version_info} "
"loading embeddings by provider:identifier string requires langchain>=0.3.9 "
"as well as the provider-specific package. "
"Install LangChain with: pip install 'langchain>=0.3.9' "
"and the provider-specific package (e.g., 'langchain-openai>=0.3.0'). "
"Alternatively, specify 'embed' as a compatible Embeddings object or python function."
)
return init_embeddings(embed)
if isinstance(embed, Embeddings):
return embed
return EmbeddingsLambda(embed)
class EmbeddingsLambda(Embeddings):
"""Wrapper to convert embedding functions into LangChain's Embeddings interface.
This class allows arbitrary embedding functions to be used with LangChain-compatible
tools. It supports both synchronous and asynchronous operations, and can handle:
1. A synchronous function for sync operations (async operations will use sync function)
2. An async function for both sync/async operations (sync operations will raise an error)
The embedding functions should convert text into fixed-dimensional vectors that
capture the semantic meaning of the text.
Args:
func: Function that converts text to embeddings. Can be sync or async.
If async, it will be used for async operations, but sync operations
will raise an error. If sync, it will be used for both sync and async operations.
??? example "Examples"
With a sync function:
```python
def my_embed_fn(texts):
# Return 2D embeddings for each text
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
await embeddings.aembed_query("hello") # Also returns [0.1, 0.2]
```
With an async function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_async_fn)
await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
# Note: embed_query() would raise an error
```
"""
def __init__(
self,
func: EmbeddingsFunc | AEmbeddingsFunc,
) -> None:
if func is None:
raise ValueError("func must be provided")
if _is_async_callable(func):
self.afunc = func
else:
self.func = func
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Raises:
ValueError: If the instance was initialized with only an async function.
"""
func = getattr(self, "func", None)
if func is None:
raise ValueError(
"EmbeddingsLambda was initialized with an async function but no sync function. "
"Use aembed_documents for async operation or provide a sync function."
)
return func(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling embed_documents with a single text
and taking the first result.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronously embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Note:
If no async function was provided, this falls back to the sync implementation.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_documents(texts)
return await afunc(texts)
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronously embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling aembed_documents with a single text
and taking the first result.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_query(text)
return (await afunc([text]))[0]
def get_text_at_path(obj: Any, path: str | list[str]) -> list[str]:
"""Extract text from an object using a path expression or pre-tokenized path.
Args:
obj: The object to extract text from
path: Either a path string or pre-tokenized path list.
!!! info "Path types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
- Nested paths in multi-field: "{field1,nested.field2}"
"""
if not path or path == "$":
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
tokens = tokenize_path(path) if isinstance(path, str) else path
def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]:
if pos >= len(tokens):
if isinstance(obj, (str, int, float, bool)):
return [str(obj)]
elif obj is None:
return []
elif isinstance(obj, (list, dict)):
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
return []
token = tokens[pos]
results = []
if token.startswith("[") and token.endswith("]"):
if not isinstance(obj, list):
return []
index = token[1:-1]
if index == "*":
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
else:
try:
idx = int(index)
if idx < 0:
idx = len(obj) + idx
if 0 <= idx < len(obj):
results.extend(_extract_from_obj(obj[idx], tokens, pos + 1))
except (ValueError, IndexError):
return []
elif token.startswith("{") and token.endswith("}"):
if not isinstance(obj, dict):
return []
fields = [f.strip() for f in token[1:-1].split(",")]
for field in fields:
nested_tokens = tokenize_path(field)
if nested_tokens:
current_obj: dict | None = obj
for nested_token in nested_tokens:
if (
isinstance(current_obj, dict)
and nested_token in current_obj
):
current_obj = current_obj[nested_token]
else:
current_obj = None
break
if current_obj is not None:
if isinstance(current_obj, (str, int, float, bool)):
results.append(str(current_obj))
elif isinstance(current_obj, (list, dict)):
results.append(
json.dumps(
current_obj, sort_keys=True, ensure_ascii=False
)
)
# Handle wildcard
elif token == "*":
if isinstance(obj, dict):
for value in obj.values():
results.extend(_extract_from_obj(value, tokens, pos + 1))
elif isinstance(obj, list):
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
# Handle regular field
else:
if isinstance(obj, dict) and token in obj:
results.extend(_extract_from_obj(obj[token], tokens, pos + 1))
return results
return _extract_from_obj(obj, tokens, 0)
# Private utility functions
def tokenize_path(path: str) -> list[str]:
"""Tokenize a path into components.
!!! info "Types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
"""
if not path:
return []
tokens = []
current: list[str] = []
i = 0
while i < len(path):
char = path[i]
if char == "[": # Handle array index
if current:
tokens.append("".join(current))
current = []
bracket_count = 1
index_chars = ["["]
i += 1
while i < len(path) and bracket_count > 0:
if path[i] == "[":
bracket_count += 1
elif path[i] == "]":
bracket_count -= 1
index_chars.append(path[i])
i += 1
tokens.append("".join(index_chars))
continue
elif char == "{": # Handle multi-field selection
if current:
tokens.append("".join(current))
current = []
brace_count = 1
field_chars = ["{"]
i += 1
while i < len(path) and brace_count > 0:
if path[i] == "{":
brace_count += 1
elif path[i] == "}":
brace_count -= 1
field_chars.append(path[i])
i += 1
tokens.append("".join(field_chars))
continue
elif char == ".": # Handle regular field
if current:
tokens.append("".join(current))
current = []
else:
current.append(char)
i += 1
if current:
tokens.append("".join(current))
return tokens
def _is_async_callable(
func: Any,
) -> bool:
"""Check if a function is async.
This includes both async def functions and classes with async __call__ methods.
Args:
func: Function or callable object to check.
Returns:
True if the function is async, False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)
@functools.lru_cache
def _get_init_embeddings() -> Callable[[str], Embeddings] | None:
try:
from langchain.embeddings import init_embeddings # type: ignore
return init_embeddings
except ImportError:
return None
__all__ = [
"ensure_embeddings",
"EmbeddingsFunc",
"AEmbeddingsFunc",
]