initial commit
This commit is contained in:
1314
venv/Lib/site-packages/langgraph/store/base/__init__.py
Normal file
1314
venv/Lib/site-packages/langgraph/store/base/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
365
venv/Lib/site-packages/langgraph/store/base/batch.py
Normal file
365
venv/Lib/site-packages/langgraph/store/base/batch.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Utilities for batching operations in a background task."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import weakref
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Literal, TypeVar
|
||||
|
||||
from langgraph.store.base import (
|
||||
NOT_PROVIDED,
|
||||
BaseStore,
|
||||
GetOp,
|
||||
Item,
|
||||
ListNamespacesOp,
|
||||
MatchCondition,
|
||||
NamespacePath,
|
||||
NotProvided,
|
||||
Op,
|
||||
PutOp,
|
||||
Result,
|
||||
SearchItem,
|
||||
SearchOp,
|
||||
_ensure_refresh,
|
||||
_ensure_ttl,
|
||||
_validate_namespace,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def _check_loop(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
def wrapper(store: AsyncBatchedBaseStore, *args: Any, **kwargs: Any) -> Any:
|
||||
method_name: str = func.__name__
|
||||
try:
|
||||
current_loop = asyncio.get_running_loop()
|
||||
if current_loop is store._loop:
|
||||
replacement_str = (
|
||||
f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)"
|
||||
if method_name
|
||||
else "For example, replace `store.get(...)` with `await store.aget(...)`"
|
||||
)
|
||||
raise asyncio.InvalidStateError(
|
||||
f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. "
|
||||
"This can lead to deadlocks or performance issues. "
|
||||
"Please use the asynchronous interface for main thread operations. "
|
||||
f"{replacement_str} "
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return func(store, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncBatchedBaseStore(BaseStore):
|
||||
"""Efficiently batch operations in a background task."""
|
||||
|
||||
__slots__ = ("_loop", "_aqueue", "_task")
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._aqueue: asyncio.Queue[tuple[asyncio.Future, Op]] = asyncio.Queue()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._ensure_task()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _ensure_task(self) -> None:
|
||||
"""Ensure the background processing loop is running."""
|
||||
if self._task is None or self._task.done():
|
||||
self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self)))
|
||||
|
||||
async def aget(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
*,
|
||||
refresh_ttl: bool | None = None,
|
||||
) -> Item | None:
|
||||
self._ensure_task()
|
||||
fut = self._loop.create_future()
|
||||
self._aqueue.put_nowait(
|
||||
(
|
||||
fut,
|
||||
GetOp(
|
||||
namespace,
|
||||
key,
|
||||
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
|
||||
),
|
||||
)
|
||||
)
|
||||
return await fut
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
namespace_prefix: tuple[str, ...],
|
||||
/,
|
||||
*,
|
||||
query: str | None = None,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
refresh_ttl: bool | None = None,
|
||||
) -> list[SearchItem]:
|
||||
self._ensure_task()
|
||||
fut = self._loop.create_future()
|
||||
self._aqueue.put_nowait(
|
||||
(
|
||||
fut,
|
||||
SearchOp(
|
||||
namespace_prefix,
|
||||
filter,
|
||||
limit,
|
||||
offset,
|
||||
query,
|
||||
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
|
||||
),
|
||||
)
|
||||
)
|
||||
return await fut
|
||||
|
||||
async def aput(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
value: dict[str, Any],
|
||||
index: Literal[False] | list[str] | None = None,
|
||||
*,
|
||||
ttl: float | None | NotProvided = NOT_PROVIDED,
|
||||
) -> None:
|
||||
self._ensure_task()
|
||||
_validate_namespace(namespace)
|
||||
fut = self._loop.create_future()
|
||||
self._aqueue.put_nowait(
|
||||
(
|
||||
fut,
|
||||
PutOp(
|
||||
namespace, key, value, index, ttl=_ensure_ttl(self.ttl_config, ttl)
|
||||
),
|
||||
)
|
||||
)
|
||||
return await fut
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
) -> None:
|
||||
self._ensure_task()
|
||||
fut = self._loop.create_future()
|
||||
self._aqueue.put_nowait((fut, PutOp(namespace, key, None)))
|
||||
return await fut
|
||||
|
||||
async def alist_namespaces(
|
||||
self,
|
||||
*,
|
||||
prefix: NamespacePath | None = None,
|
||||
suffix: NamespacePath | None = None,
|
||||
max_depth: int | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[tuple[str, ...]]:
|
||||
self._ensure_task()
|
||||
fut = self._loop.create_future()
|
||||
match_conditions = []
|
||||
if prefix:
|
||||
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
|
||||
if suffix:
|
||||
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
|
||||
|
||||
op = ListNamespacesOp(
|
||||
match_conditions=tuple(match_conditions),
|
||||
max_depth=max_depth,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
self._aqueue.put_nowait((fut, op))
|
||||
return await fut
|
||||
|
||||
@_check_loop
|
||||
def batch(self, ops: Iterable[Op]) -> list[Result]:
|
||||
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result()
|
||||
|
||||
@_check_loop
|
||||
def get(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
*,
|
||||
refresh_ttl: bool | None = None,
|
||||
) -> Item | None:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.aget(namespace, key=key, refresh_ttl=refresh_ttl), self._loop
|
||||
).result()
|
||||
|
||||
@_check_loop
|
||||
def search(
|
||||
self,
|
||||
namespace_prefix: tuple[str, ...],
|
||||
/,
|
||||
*,
|
||||
query: str | None = None,
|
||||
filter: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
refresh_ttl: bool | None = None,
|
||||
) -> list[SearchItem]:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.asearch(
|
||||
namespace_prefix,
|
||||
query=query,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
refresh_ttl=refresh_ttl,
|
||||
),
|
||||
self._loop,
|
||||
).result()
|
||||
|
||||
@_check_loop
|
||||
def put(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
value: dict[str, Any],
|
||||
index: Literal[False] | list[str] | None = None,
|
||||
*,
|
||||
ttl: float | None | NotProvided = NOT_PROVIDED,
|
||||
) -> None:
|
||||
_validate_namespace(namespace)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.aput(
|
||||
namespace,
|
||||
key=key,
|
||||
value=value,
|
||||
index=index,
|
||||
ttl=_ensure_ttl(self.ttl_config, ttl),
|
||||
),
|
||||
self._loop,
|
||||
).result()
|
||||
|
||||
@_check_loop
|
||||
def delete(
|
||||
self,
|
||||
namespace: tuple[str, ...],
|
||||
key: str,
|
||||
) -> None:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.adelete(namespace, key=key), self._loop
|
||||
).result()
|
||||
|
||||
@_check_loop
|
||||
def list_namespaces(
|
||||
self,
|
||||
*,
|
||||
prefix: NamespacePath | None = None,
|
||||
suffix: NamespacePath | None = None,
|
||||
max_depth: int | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[tuple[str, ...]]:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self.alist_namespaces(
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
max_depth=max_depth,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
),
|
||||
self._loop,
|
||||
).result()
|
||||
|
||||
|
||||
def _dedupe_ops(values: list[Op]) -> tuple[list[int] | None, list[Op]]:
|
||||
"""Dedupe operations while preserving order for results.
|
||||
|
||||
Args:
|
||||
values: List of operations to dedupe
|
||||
|
||||
Returns:
|
||||
Tuple of (listen indices, deduped operations)
|
||||
where listen indices map deduped operation results back to original positions
|
||||
"""
|
||||
if len(values) <= 1:
|
||||
return None, list(values)
|
||||
|
||||
dedupped: list[Op] = []
|
||||
listen: list[int] = []
|
||||
puts: dict[tuple[tuple[str, ...], str], int] = {}
|
||||
|
||||
for op in values:
|
||||
if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)):
|
||||
try:
|
||||
listen.append(dedupped.index(op))
|
||||
except ValueError:
|
||||
listen.append(len(dedupped))
|
||||
dedupped.append(op)
|
||||
elif isinstance(op, PutOp):
|
||||
putkey = (op.namespace, op.key)
|
||||
if putkey in puts:
|
||||
# Overwrite previous put
|
||||
ix = puts[putkey]
|
||||
dedupped[ix] = op
|
||||
listen.append(ix)
|
||||
else:
|
||||
puts[putkey] = len(dedupped)
|
||||
listen.append(len(dedupped))
|
||||
dedupped.append(op)
|
||||
|
||||
else: # Any new ops will be treated regularly
|
||||
listen.append(len(dedupped))
|
||||
dedupped.append(op)
|
||||
|
||||
return listen, dedupped
|
||||
|
||||
|
||||
async def _run(
|
||||
aqueue: asyncio.Queue[tuple[asyncio.Future, Op]],
|
||||
store: weakref.ReferenceType[BaseStore],
|
||||
) -> None:
|
||||
while item := await aqueue.get():
|
||||
# check if store is still alive
|
||||
if s := store():
|
||||
try:
|
||||
# accumulate operations scheduled in same tick
|
||||
items = [item]
|
||||
try:
|
||||
while item := aqueue.get_nowait():
|
||||
items.append(item)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
# get the operations to run
|
||||
futs = [item[0] for item in items]
|
||||
values = [item[1] for item in items]
|
||||
# action each operation
|
||||
try:
|
||||
listen, dedupped = _dedupe_ops(values)
|
||||
results = await s.abatch(dedupped)
|
||||
if listen is not None:
|
||||
results = [results[ix] for ix in listen]
|
||||
|
||||
# set the results of each operation
|
||||
for fut, result in zip(futs, results, strict=False):
|
||||
# guard against future being done (e.g. cancelled)
|
||||
if not fut.done():
|
||||
fut.set_result(result)
|
||||
except Exception as e:
|
||||
for fut in futs:
|
||||
# guard against future being done (e.g. cancelled)
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
finally:
|
||||
# remove strong ref to store
|
||||
del s
|
||||
else:
|
||||
break
|
||||
433
venv/Lib/site-packages/langgraph/store/base/embed.py
Normal file
433
venv/Lib/site-packages/langgraph/store/base/embed.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Utilities for working with embedding functions and LangChain's Embeddings interface.
|
||||
|
||||
This module provides tools to wrap arbitrary embedding functions (both sync and async)
|
||||
into LangChain's Embeddings interface. This enables using custom embedding functions
|
||||
with LangChain-compatible tools while maintaining support for both synchronous and
|
||||
asynchronous operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]]
|
||||
"""Type for synchronous embedding functions.
|
||||
|
||||
The function should take a sequence of strings and return a list of embeddings,
|
||||
where each embedding is a list of floats. The dimensionality of the embeddings
|
||||
should be consistent for all inputs.
|
||||
"""
|
||||
|
||||
AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]]
|
||||
"""Type for asynchronous embedding functions.
|
||||
|
||||
Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings.
|
||||
"""
|
||||
|
||||
|
||||
def ensure_embeddings(
|
||||
embed: Embeddings | EmbeddingsFunc | AEmbeddingsFunc | str | None,
|
||||
) -> Embeddings:
|
||||
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
|
||||
|
||||
This function wraps arbitrary embedding functions to make them compatible with
|
||||
LangChain's Embeddings interface. It handles both synchronous and asynchronous
|
||||
functions.
|
||||
|
||||
Args:
|
||||
embed: Either an existing Embeddings instance, or a function that converts
|
||||
text to embeddings. If the function is async, it will be used for both
|
||||
sync and async operations.
|
||||
|
||||
Returns:
|
||||
An Embeddings instance that wraps the provided function(s).
|
||||
|
||||
??? example "Examples"
|
||||
|
||||
Wrap a synchronous embedding function:
|
||||
|
||||
```python
|
||||
def my_embed_fn(texts):
|
||||
return [[0.1, 0.2] for _ in texts]
|
||||
|
||||
embeddings = ensure_embeddings(my_embed_fn)
|
||||
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
|
||||
```
|
||||
|
||||
Wrap an asynchronous embedding function:
|
||||
|
||||
```python
|
||||
async def my_async_fn(texts):
|
||||
return [[0.1, 0.2] for _ in texts]
|
||||
|
||||
embeddings = ensure_embeddings(my_async_fn)
|
||||
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
|
||||
```
|
||||
|
||||
Initialize embeddings using a provider string:
|
||||
|
||||
```python
|
||||
# Requires langchain>=0.3.9 and langgraph-checkpoint>=2.0.11
|
||||
embeddings = ensure_embeddings("openai:text-embedding-3-small")
|
||||
result = embeddings.embed_query("hello")
|
||||
```
|
||||
"""
|
||||
if embed is None:
|
||||
raise ValueError("embed must be provided")
|
||||
if isinstance(embed, str):
|
||||
init_embeddings = _get_init_embeddings()
|
||||
if init_embeddings is None:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
lc_version = version("langchain")
|
||||
version_info = f"Found langchain version {lc_version}, but"
|
||||
except PackageNotFoundError:
|
||||
version_info = "langchain is not installed;"
|
||||
|
||||
raise ValueError(
|
||||
f"Could not load embeddings from string '{embed}'. {version_info} "
|
||||
"loading embeddings by provider:identifier string requires langchain>=0.3.9 "
|
||||
"as well as the provider-specific package. "
|
||||
"Install LangChain with: pip install 'langchain>=0.3.9' "
|
||||
"and the provider-specific package (e.g., 'langchain-openai>=0.3.0'). "
|
||||
"Alternatively, specify 'embed' as a compatible Embeddings object or python function."
|
||||
)
|
||||
return init_embeddings(embed)
|
||||
|
||||
if isinstance(embed, Embeddings):
|
||||
return embed
|
||||
return EmbeddingsLambda(embed)
|
||||
|
||||
|
||||
class EmbeddingsLambda(Embeddings):
|
||||
"""Wrapper to convert embedding functions into LangChain's Embeddings interface.
|
||||
|
||||
This class allows arbitrary embedding functions to be used with LangChain-compatible
|
||||
tools. It supports both synchronous and asynchronous operations, and can handle:
|
||||
1. A synchronous function for sync operations (async operations will use sync function)
|
||||
2. An async function for both sync/async operations (sync operations will raise an error)
|
||||
|
||||
The embedding functions should convert text into fixed-dimensional vectors that
|
||||
capture the semantic meaning of the text.
|
||||
|
||||
Args:
|
||||
func: Function that converts text to embeddings. Can be sync or async.
|
||||
If async, it will be used for async operations, but sync operations
|
||||
will raise an error. If sync, it will be used for both sync and async operations.
|
||||
|
||||
??? example "Examples"
|
||||
|
||||
With a sync function:
|
||||
|
||||
```python
|
||||
def my_embed_fn(texts):
|
||||
# Return 2D embeddings for each text
|
||||
return [[0.1, 0.2] for _ in texts]
|
||||
|
||||
embeddings = EmbeddingsLambda(my_embed_fn)
|
||||
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
|
||||
await embeddings.aembed_query("hello") # Also returns [0.1, 0.2]
|
||||
```
|
||||
|
||||
With an async function:
|
||||
|
||||
```python
|
||||
async def my_async_fn(texts):
|
||||
return [[0.1, 0.2] for _ in texts]
|
||||
|
||||
embeddings = EmbeddingsLambda(my_async_fn)
|
||||
await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
|
||||
# Note: embed_query() would raise an error
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: EmbeddingsFunc | AEmbeddingsFunc,
|
||||
) -> None:
|
||||
if func is None:
|
||||
raise ValueError("func must be provided")
|
||||
if _is_async_callable(func):
|
||||
self.afunc = func
|
||||
else:
|
||||
self.func = func
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts into vectors.
|
||||
|
||||
Args:
|
||||
texts: list of texts to convert to embeddings.
|
||||
|
||||
Returns:
|
||||
list of embeddings, one per input text. Each embedding is a list of floats.
|
||||
|
||||
Raises:
|
||||
ValueError: If the instance was initialized with only an async function.
|
||||
"""
|
||||
func = getattr(self, "func", None)
|
||||
if func is None:
|
||||
raise ValueError(
|
||||
"EmbeddingsLambda was initialized with an async function but no sync function. "
|
||||
"Use aembed_documents for async operation or provide a sync function."
|
||||
)
|
||||
return func(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single piece of text.
|
||||
|
||||
Args:
|
||||
text: Text to convert to an embedding.
|
||||
|
||||
Returns:
|
||||
Embedding vector as a list of floats.
|
||||
|
||||
Note:
|
||||
This is equivalent to calling embed_documents with a single text
|
||||
and taking the first result.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Asynchronously embed a list of texts into vectors.
|
||||
|
||||
Args:
|
||||
texts: list of texts to convert to embeddings.
|
||||
|
||||
Returns:
|
||||
list of embeddings, one per input text. Each embedding is a list of floats.
|
||||
|
||||
Note:
|
||||
If no async function was provided, this falls back to the sync implementation.
|
||||
"""
|
||||
afunc = getattr(self, "afunc", None)
|
||||
if afunc is None:
|
||||
return await super().aembed_documents(texts)
|
||||
return await afunc(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Asynchronously embed a single piece of text.
|
||||
|
||||
Args:
|
||||
text: Text to convert to an embedding.
|
||||
|
||||
Returns:
|
||||
Embedding vector as a list of floats.
|
||||
|
||||
Note:
|
||||
This is equivalent to calling aembed_documents with a single text
|
||||
and taking the first result.
|
||||
"""
|
||||
afunc = getattr(self, "afunc", None)
|
||||
if afunc is None:
|
||||
return await super().aembed_query(text)
|
||||
return (await afunc([text]))[0]
|
||||
|
||||
|
||||
def get_text_at_path(obj: Any, path: str | list[str]) -> list[str]:
|
||||
"""Extract text from an object using a path expression or pre-tokenized path.
|
||||
|
||||
Args:
|
||||
obj: The object to extract text from
|
||||
path: Either a path string or pre-tokenized path list.
|
||||
|
||||
!!! info "Path types handled"
|
||||
- Simple paths: "field1.field2"
|
||||
- Array indexing: "[0]", "[*]", "[-1]"
|
||||
- Wildcards: "*"
|
||||
- Multi-field selection: "{field1,field2}"
|
||||
- Nested paths in multi-field: "{field1,nested.field2}"
|
||||
"""
|
||||
if not path or path == "$":
|
||||
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
|
||||
|
||||
tokens = tokenize_path(path) if isinstance(path, str) else path
|
||||
|
||||
def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]:
|
||||
if pos >= len(tokens):
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return [str(obj)]
|
||||
elif obj is None:
|
||||
return []
|
||||
elif isinstance(obj, (list, dict)):
|
||||
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
|
||||
return []
|
||||
|
||||
token = tokens[pos]
|
||||
results = []
|
||||
|
||||
if token.startswith("[") and token.endswith("]"):
|
||||
if not isinstance(obj, list):
|
||||
return []
|
||||
|
||||
index = token[1:-1]
|
||||
if index == "*":
|
||||
for item in obj:
|
||||
results.extend(_extract_from_obj(item, tokens, pos + 1))
|
||||
else:
|
||||
try:
|
||||
idx = int(index)
|
||||
if idx < 0:
|
||||
idx = len(obj) + idx
|
||||
if 0 <= idx < len(obj):
|
||||
results.extend(_extract_from_obj(obj[idx], tokens, pos + 1))
|
||||
except (ValueError, IndexError):
|
||||
return []
|
||||
|
||||
elif token.startswith("{") and token.endswith("}"):
|
||||
if not isinstance(obj, dict):
|
||||
return []
|
||||
|
||||
fields = [f.strip() for f in token[1:-1].split(",")]
|
||||
for field in fields:
|
||||
nested_tokens = tokenize_path(field)
|
||||
if nested_tokens:
|
||||
current_obj: dict | None = obj
|
||||
for nested_token in nested_tokens:
|
||||
if (
|
||||
isinstance(current_obj, dict)
|
||||
and nested_token in current_obj
|
||||
):
|
||||
current_obj = current_obj[nested_token]
|
||||
else:
|
||||
current_obj = None
|
||||
break
|
||||
if current_obj is not None:
|
||||
if isinstance(current_obj, (str, int, float, bool)):
|
||||
results.append(str(current_obj))
|
||||
elif isinstance(current_obj, (list, dict)):
|
||||
results.append(
|
||||
json.dumps(
|
||||
current_obj, sort_keys=True, ensure_ascii=False
|
||||
)
|
||||
)
|
||||
|
||||
# Handle wildcard
|
||||
elif token == "*":
|
||||
if isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
results.extend(_extract_from_obj(value, tokens, pos + 1))
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
results.extend(_extract_from_obj(item, tokens, pos + 1))
|
||||
|
||||
# Handle regular field
|
||||
else:
|
||||
if isinstance(obj, dict) and token in obj:
|
||||
results.extend(_extract_from_obj(obj[token], tokens, pos + 1))
|
||||
|
||||
return results
|
||||
|
||||
return _extract_from_obj(obj, tokens, 0)
|
||||
|
||||
|
||||
# Private utility functions
|
||||
|
||||
|
||||
def tokenize_path(path: str) -> list[str]:
|
||||
"""Tokenize a path into components.
|
||||
|
||||
!!! info "Types handled"
|
||||
- Simple paths: "field1.field2"
|
||||
- Array indexing: "[0]", "[*]", "[-1]"
|
||||
- Wildcards: "*"
|
||||
- Multi-field selection: "{field1,field2}"
|
||||
"""
|
||||
if not path:
|
||||
return []
|
||||
|
||||
tokens = []
|
||||
current: list[str] = []
|
||||
i = 0
|
||||
while i < len(path):
|
||||
char = path[i]
|
||||
|
||||
if char == "[": # Handle array index
|
||||
if current:
|
||||
tokens.append("".join(current))
|
||||
current = []
|
||||
bracket_count = 1
|
||||
index_chars = ["["]
|
||||
i += 1
|
||||
while i < len(path) and bracket_count > 0:
|
||||
if path[i] == "[":
|
||||
bracket_count += 1
|
||||
elif path[i] == "]":
|
||||
bracket_count -= 1
|
||||
index_chars.append(path[i])
|
||||
i += 1
|
||||
tokens.append("".join(index_chars))
|
||||
continue
|
||||
|
||||
elif char == "{": # Handle multi-field selection
|
||||
if current:
|
||||
tokens.append("".join(current))
|
||||
current = []
|
||||
brace_count = 1
|
||||
field_chars = ["{"]
|
||||
i += 1
|
||||
while i < len(path) and brace_count > 0:
|
||||
if path[i] == "{":
|
||||
brace_count += 1
|
||||
elif path[i] == "}":
|
||||
brace_count -= 1
|
||||
field_chars.append(path[i])
|
||||
i += 1
|
||||
tokens.append("".join(field_chars))
|
||||
continue
|
||||
|
||||
elif char == ".": # Handle regular field
|
||||
if current:
|
||||
tokens.append("".join(current))
|
||||
current = []
|
||||
else:
|
||||
current.append(char)
|
||||
i += 1
|
||||
|
||||
if current:
|
||||
tokens.append("".join(current))
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def _is_async_callable(
|
||||
func: Any,
|
||||
) -> bool:
|
||||
"""Check if a function is async.
|
||||
|
||||
This includes both async def functions and classes with async __call__ methods.
|
||||
|
||||
Args:
|
||||
func: Function or callable object to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
and asyncio.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_init_embeddings() -> Callable[[str], Embeddings] | None:
|
||||
try:
|
||||
from langchain.embeddings import init_embeddings # type: ignore
|
||||
|
||||
return init_embeddings
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ensure_embeddings",
|
||||
"EmbeddingsFunc",
|
||||
"AEmbeddingsFunc",
|
||||
]
|
||||
592
venv/Lib/site-packages/langgraph/store/memory/__init__.py
Normal file
592
venv/Lib/site-packages/langgraph/store/memory/__init__.py
Normal file
@@ -0,0 +1,592 @@
|
||||
"""In-memory dictionary-backed store with optional vector search.
|
||||
|
||||
!!! example "Examples"
|
||||
Basic key-value storage:
|
||||
```python
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
store = InMemoryStore()
|
||||
store.put(("users", "123"), "prefs", {"theme": "dark"})
|
||||
item = store.get(("users", "123"), "prefs")
|
||||
```
|
||||
|
||||
Vector search using LangChain embeddings:
|
||||
```python
|
||||
from langchain.embeddings import init_embeddings
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
store = InMemoryStore(
|
||||
index={
|
||||
"dims": 1536,
|
||||
"embed": init_embeddings("openai:text-embedding-3-small")
|
||||
}
|
||||
)
|
||||
|
||||
# Store documents
|
||||
store.put(("docs",), "doc1", {"text": "Python tutorial"})
|
||||
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
|
||||
|
||||
# Search by similarity
|
||||
results = store.search(("docs",), query="python programming")
|
||||
```
|
||||
|
||||
Vector search using OpenAI SDK directly:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
def embed_texts(texts: list[str]) -> list[list[float]]:
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=texts
|
||||
)
|
||||
return [e.embedding for e in response.data]
|
||||
|
||||
store = InMemoryStore(
|
||||
index={
|
||||
"dims": 1536,
|
||||
"embed": embed_texts
|
||||
}
|
||||
)
|
||||
|
||||
# Store documents
|
||||
store.put(("docs",), "doc1", {"text": "Python tutorial"})
|
||||
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
|
||||
|
||||
# Search by similarity
|
||||
results = store.search(("docs",), query="python programming")
|
||||
```
|
||||
|
||||
Async vector search using OpenAI SDK:
|
||||
```python
|
||||
from openai import AsyncOpenAI
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
client = AsyncOpenAI()
|
||||
|
||||
async def aembed_texts(texts: list[str]) -> list[list[float]]:
|
||||
response = await client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=texts
|
||||
)
|
||||
return [e.embedding for e in response.data]
|
||||
|
||||
store = InMemoryStore(
|
||||
index={
|
||||
"dims": 1536,
|
||||
"embed": aembed_texts
|
||||
}
|
||||
)
|
||||
|
||||
# Store documents
|
||||
await store.aput(("docs",), "doc1", {"text": "Python tutorial"})
|
||||
await store.aput(("docs",), "doc2", {"text": "TypeScript guide"})
|
||||
|
||||
# Search by similarity
|
||||
results = await store.asearch(("docs",), query="python programming")
|
||||
```
|
||||
|
||||
Warning:
|
||||
This store keeps all data in memory. Data is lost when the process exits.
|
||||
For persistence, use a database-backed store like PostgresStore.
|
||||
|
||||
Tip:
|
||||
For vector search, install numpy for better performance:
|
||||
```bash
|
||||
pip install numpy
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures as cf
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timezone
|
||||
from importlib import util
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langgraph.store.base import (
|
||||
BaseStore,
|
||||
GetOp,
|
||||
IndexConfig,
|
||||
Item,
|
||||
ListNamespacesOp,
|
||||
MatchCondition,
|
||||
Op,
|
||||
PutOp,
|
||||
Result,
|
||||
SearchItem,
|
||||
SearchOp,
|
||||
ensure_embeddings,
|
||||
get_text_at_path,
|
||||
tokenize_path,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryStore(BaseStore):
|
||||
"""In-memory dictionary-backed store with optional vector search.
|
||||
|
||||
!!! example "Examples"
|
||||
Basic key-value storage:
|
||||
store = InMemoryStore()
|
||||
store.put(("users", "123"), "prefs", {"theme": "dark"})
|
||||
item = store.get(("users", "123"), "prefs")
|
||||
|
||||
Vector search with embeddings:
|
||||
from langchain.embeddings import init_embeddings
|
||||
store = InMemoryStore(index={
|
||||
"dims": 1536,
|
||||
"embed": init_embeddings("openai:text-embedding-3-small"),
|
||||
"fields": ["text"],
|
||||
})
|
||||
|
||||
# Store documents
|
||||
store.put(("docs",), "doc1", {"text": "Python tutorial"})
|
||||
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
|
||||
|
||||
# Search by similarity
|
||||
results = store.search(("docs",), query="python programming")
|
||||
|
||||
Note:
|
||||
Semantic search is disabled by default. You can enable it by providing an `index` configuration
|
||||
when creating the store. Without this configuration, all `index` arguments passed to
|
||||
`put` or `aput`will have no effect.
|
||||
|
||||
Warning:
|
||||
This store keeps all data in memory. Data is lost when the process exits.
|
||||
For persistence, use a database-backed store like PostgresStore.
|
||||
|
||||
Tip:
|
||||
For vector search, install numpy for better performance:
|
||||
```bash
|
||||
pip install numpy
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_data",
|
||||
"_vectors",
|
||||
"index_config",
|
||||
"embeddings",
|
||||
)
|
||||
|
||||
def __init__(self, *, index: IndexConfig | None = None) -> None:
|
||||
# Both _data and _vectors are wrapped in the In-memory API
|
||||
# Do not change their names
|
||||
self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict)
|
||||
# [ns][key][path]
|
||||
self._vectors: dict[tuple[str, ...], dict[str, dict[str, list[float]]]] = (
|
||||
defaultdict(lambda: defaultdict(dict))
|
||||
)
|
||||
self.index_config = index
|
||||
if self.index_config:
|
||||
self.index_config = self.index_config.copy()
|
||||
self.embeddings: Embeddings | None = ensure_embeddings(
|
||||
self.index_config.get("embed"),
|
||||
)
|
||||
self.index_config["__tokenized_fields"] = [
|
||||
(p, tokenize_path(p)) if p != "$" else (p, p)
|
||||
for p in (self.index_config.get("fields") or ["$"])
|
||||
]
|
||||
|
||||
else:
|
||||
self.index_config = None
|
||||
self.embeddings = None
|
||||
|
||||
def batch(self, ops: Iterable[Op]) -> list[Result]:
|
||||
# The batch/abatch methods are treated as internal.
|
||||
# Users should access via put/search/get/list_namespaces/etc.
|
||||
results, put_ops, search_ops = self._prepare_ops(ops)
|
||||
if search_ops:
|
||||
queryinmem_store = self._embed_search_queries(search_ops)
|
||||
self._batch_search(search_ops, queryinmem_store, results)
|
||||
|
||||
to_embed = self._extract_texts(put_ops)
|
||||
if to_embed and self.index_config and self.embeddings:
|
||||
embeddings = self.embeddings.embed_documents(list(to_embed))
|
||||
self._insertinmem_store(to_embed, embeddings)
|
||||
self._apply_put_ops(put_ops)
|
||||
return results
|
||||
|
||||
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
|
||||
# The batch/abatch methods are treated as internal.
|
||||
# Users should access via put/search/get/list_namespaces/etc.
|
||||
results, put_ops, search_ops = self._prepare_ops(ops)
|
||||
if search_ops:
|
||||
queryinmem_store = await self._aembed_search_queries(search_ops)
|
||||
self._batch_search(search_ops, queryinmem_store, results)
|
||||
|
||||
to_embed = self._extract_texts(put_ops)
|
||||
if to_embed and self.index_config and self.embeddings:
|
||||
embeddings = await self.embeddings.aembed_documents(list(to_embed))
|
||||
self._insertinmem_store(to_embed, embeddings)
|
||||
self._apply_put_ops(put_ops)
|
||||
return results
|
||||
|
||||
# Helpers
|
||||
|
||||
def _filter_items(self, op: SearchOp) -> list[tuple[Item, list[list[float]]]]:
|
||||
"""Filter items by namespace and filter function, return items with their embeddings."""
|
||||
namespace_prefix = op.namespace_prefix
|
||||
|
||||
def filter_func(item: Item) -> bool:
|
||||
if not op.filter:
|
||||
return True
|
||||
|
||||
return all(
|
||||
_compare_values(item.value.get(key), filter_value)
|
||||
for key, filter_value in op.filter.items()
|
||||
)
|
||||
|
||||
filtered = []
|
||||
for namespace in self._data:
|
||||
if not (
|
||||
namespace[: len(namespace_prefix)] == namespace_prefix
|
||||
if len(namespace) >= len(namespace_prefix)
|
||||
else False
|
||||
):
|
||||
continue
|
||||
|
||||
for key, item in self._data[namespace].items():
|
||||
if filter_func(item):
|
||||
if op.query and (embeddings := self._vectors[namespace].get(key)):
|
||||
filtered.append((item, list(embeddings.values())))
|
||||
else:
|
||||
filtered.append((item, []))
|
||||
return filtered
|
||||
|
||||
def _embed_search_queries(
|
||||
self,
|
||||
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
|
||||
) -> dict[str, list[float]]:
|
||||
queryinmem_store = {}
|
||||
if self.index_config and self.embeddings and search_ops:
|
||||
queries = {op.query for (op, _) in search_ops.values() if op.query}
|
||||
|
||||
if queries:
|
||||
with cf.ThreadPoolExecutor() as executor:
|
||||
futures = {
|
||||
q: executor.submit(self.embeddings.embed_query, q)
|
||||
for q in list(queries)
|
||||
}
|
||||
for query, future in futures.items():
|
||||
queryinmem_store[query] = future.result()
|
||||
|
||||
return queryinmem_store
|
||||
|
||||
async def _aembed_search_queries(
|
||||
self,
|
||||
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
|
||||
) -> dict[str, list[float]]:
|
||||
queryinmem_store = {}
|
||||
if self.index_config and self.embeddings and search_ops:
|
||||
queries = {op.query for (op, _) in search_ops.values() if op.query}
|
||||
|
||||
if queries:
|
||||
coros = [self.embeddings.aembed_query(q) for q in list(queries)]
|
||||
results = await asyncio.gather(*coros)
|
||||
queryinmem_store = dict(zip(queries, results, strict=False))
|
||||
|
||||
return queryinmem_store
|
||||
|
||||
def _batch_search(
|
||||
self,
|
||||
ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
|
||||
queryinmem_store: dict[str, list[float]],
|
||||
results: list[Result],
|
||||
) -> None:
|
||||
"""Perform batch similarity search for multiple queries."""
|
||||
for i, (op, candidates) in ops.items():
|
||||
if not candidates:
|
||||
results[i] = []
|
||||
continue
|
||||
if op.query and queryinmem_store:
|
||||
query_embedding = queryinmem_store[op.query]
|
||||
flat_items, flat_vectors = [], []
|
||||
scoreless = []
|
||||
for item, vectors in candidates:
|
||||
for vector in vectors:
|
||||
flat_items.append(item)
|
||||
flat_vectors.append(vector)
|
||||
if not vectors:
|
||||
scoreless.append(item)
|
||||
|
||||
scores = _cosine_similarity(query_embedding, flat_vectors)
|
||||
sorted_results = sorted(
|
||||
zip(scores, flat_items, strict=False),
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
# max pooling
|
||||
seen: set[tuple[tuple[str, ...], str]] = set()
|
||||
kept: list[tuple[float | None, Item]] = []
|
||||
for score, item in sorted_results:
|
||||
key = (item.namespace, item.key)
|
||||
if key in seen:
|
||||
continue
|
||||
ix = len(seen)
|
||||
seen.add(key)
|
||||
if ix >= op.offset + op.limit:
|
||||
break
|
||||
if ix < op.offset:
|
||||
continue
|
||||
|
||||
kept.append((score, item))
|
||||
if scoreless and len(kept) < op.limit:
|
||||
# Corner case: if we request more items than what we have embedded,
|
||||
# fill the rest with non-scored items
|
||||
kept.extend(
|
||||
(None, item) for item in scoreless[: op.limit - len(kept)]
|
||||
)
|
||||
|
||||
results[i] = [
|
||||
SearchItem(
|
||||
namespace=item.namespace,
|
||||
key=item.key,
|
||||
value=item.value,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
score=float(score) if score is not None else None,
|
||||
)
|
||||
for score, item in kept
|
||||
]
|
||||
else:
|
||||
results[i] = [
|
||||
SearchItem(
|
||||
namespace=item.namespace,
|
||||
key=item.key,
|
||||
value=item.value,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
)
|
||||
for (item, _) in candidates[op.offset : op.offset + op.limit]
|
||||
]
|
||||
|
||||
def _prepare_ops(
|
||||
self, ops: Iterable[Op]
|
||||
) -> tuple[
|
||||
list[Result],
|
||||
dict[tuple[tuple[str, ...], str], PutOp],
|
||||
dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
|
||||
]:
|
||||
results: list[Result] = []
|
||||
put_ops: dict[tuple[tuple[str, ...], str], PutOp] = {}
|
||||
search_ops: dict[
|
||||
int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]
|
||||
] = {}
|
||||
for i, op in enumerate(ops):
|
||||
if isinstance(op, GetOp):
|
||||
item = self._data[op.namespace].get(op.key)
|
||||
results.append(item)
|
||||
elif isinstance(op, SearchOp):
|
||||
search_ops[i] = (op, self._filter_items(op))
|
||||
results.append(None)
|
||||
elif isinstance(op, ListNamespacesOp):
|
||||
results.append(self._handle_list_namespaces(op))
|
||||
elif isinstance(op, PutOp):
|
||||
put_ops[(op.namespace, op.key)] = op
|
||||
results.append(None)
|
||||
else:
|
||||
raise ValueError(f"Unknown operation type: {type(op)}")
|
||||
|
||||
return results, put_ops, search_ops
|
||||
|
||||
def _apply_put_ops(self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]) -> None:
|
||||
for (namespace, key), op in put_ops.items():
|
||||
if op.value is None:
|
||||
self._data[namespace].pop(key, None)
|
||||
self._vectors[namespace].pop(key, None)
|
||||
else:
|
||||
self._data[namespace][key] = Item(
|
||||
value=op.value,
|
||||
key=key,
|
||||
namespace=namespace,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def _extract_texts(
|
||||
self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]
|
||||
) -> dict[str, list[tuple[tuple[str, ...], str, str]]]:
|
||||
if put_ops and self.index_config and self.embeddings:
|
||||
to_embed = defaultdict(list)
|
||||
|
||||
for op in put_ops.values():
|
||||
if op.value is not None and op.index is not False:
|
||||
if op.index is None:
|
||||
paths = self.index_config["__tokenized_fields"]
|
||||
else:
|
||||
paths = [(ix, tokenize_path(ix)) for ix in op.index]
|
||||
for path, field in paths:
|
||||
texts = get_text_at_path(op.value, field)
|
||||
if texts:
|
||||
if len(texts) > 1:
|
||||
for i, text in enumerate(texts):
|
||||
to_embed[text].append(
|
||||
(op.namespace, op.key, f"{path}.{i}")
|
||||
)
|
||||
|
||||
else:
|
||||
to_embed[texts[0]].append((op.namespace, op.key, path))
|
||||
|
||||
return to_embed
|
||||
|
||||
return {}
|
||||
|
||||
def _insertinmem_store(
|
||||
self,
|
||||
to_embed: dict[str, list[tuple[tuple[str, ...], str, str]]],
|
||||
embeddings: list[list[float]],
|
||||
) -> None:
|
||||
indices = [index for indices in to_embed.values() for index in indices]
|
||||
if len(indices) != len(embeddings):
|
||||
raise ValueError(
|
||||
f"Number of embeddings ({len(embeddings)}) does not"
|
||||
f" match number of indices ({len(indices)})"
|
||||
)
|
||||
for embedding, (ns, key, path) in zip(embeddings, indices, strict=False):
|
||||
self._vectors[ns][key][path] = embedding
|
||||
|
||||
def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]:
|
||||
all_namespaces = list(
|
||||
self._data.keys()
|
||||
) # Avoid collection size changing while iterating
|
||||
namespaces = all_namespaces
|
||||
if op.match_conditions:
|
||||
namespaces = [
|
||||
ns
|
||||
for ns in namespaces
|
||||
if all(_does_match(condition, ns) for condition in op.match_conditions)
|
||||
]
|
||||
|
||||
if op.max_depth is not None:
|
||||
namespaces = sorted({ns[: op.max_depth] for ns in namespaces})
|
||||
else:
|
||||
namespaces = sorted(namespaces)
|
||||
return namespaces[op.offset : op.offset + op.limit]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _check_numpy() -> bool:
|
||||
if bool(util.find_spec("numpy")):
|
||||
return True
|
||||
logger.warning(
|
||||
"NumPy not found in the current Python environment. "
|
||||
"The InMemoryStore will use a pure Python implementation for vector operations, "
|
||||
"which may significantly impact performance, especially for large datasets or frequent searches. "
|
||||
"For optimal speed and efficiency, consider installing NumPy: "
|
||||
"pip install numpy"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
|
||||
"""
|
||||
Compute cosine similarity between a vector X and a matrix Y.
|
||||
Lazy import numpy for efficiency.
|
||||
"""
|
||||
if not Y:
|
||||
return []
|
||||
if _check_numpy():
|
||||
import numpy as np
|
||||
|
||||
X_arr = np.array(X) if not isinstance(X, np.ndarray) else X
|
||||
Y_arr = np.array(Y) if not isinstance(Y, np.ndarray) else Y
|
||||
X_norm = np.linalg.norm(X_arr)
|
||||
Y_norm = np.linalg.norm(Y_arr, axis=1)
|
||||
|
||||
# Avoid division by zero
|
||||
mask = Y_norm != 0
|
||||
similarities = np.zeros_like(Y_norm)
|
||||
similarities[mask] = np.dot(Y_arr[mask], X_arr) / (Y_norm[mask] * X_norm)
|
||||
return similarities.tolist()
|
||||
|
||||
similarities = []
|
||||
for y in Y:
|
||||
dot_product = sum(a * b for a, b in zip(X, y, strict=False))
|
||||
norm1 = sum(a * a for a in X) ** 0.5
|
||||
norm2 = sum(a * a for a in y) ** 0.5
|
||||
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
|
||||
similarities.append(similarity)
|
||||
|
||||
return similarities
|
||||
|
||||
|
||||
def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool:
|
||||
"""Whether a namespace key matches a match condition."""
|
||||
match_type = match_condition.match_type
|
||||
path = match_condition.path
|
||||
|
||||
if len(key) < len(path):
|
||||
return False
|
||||
|
||||
if match_type == "prefix":
|
||||
for k_elem, p_elem in zip(key, path, strict=False):
|
||||
if p_elem == "*":
|
||||
continue # Wildcard matches any element
|
||||
if k_elem != p_elem:
|
||||
return False
|
||||
return True
|
||||
elif match_type == "suffix":
|
||||
for k_elem, p_elem in zip(reversed(key), reversed(path), strict=False):
|
||||
if p_elem == "*":
|
||||
continue # Wildcard matches any element
|
||||
if k_elem != p_elem:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unsupported match type: {match_type}")
|
||||
|
||||
|
||||
def _compare_values(item_value: Any, filter_value: Any) -> bool:
|
||||
"""Compare values in a JSONB-like way, handling nested objects."""
|
||||
if isinstance(filter_value, dict):
|
||||
if any(k.startswith("$") for k in filter_value):
|
||||
return all(
|
||||
_apply_operator(item_value, op_key, op_value)
|
||||
for op_key, op_value in filter_value.items()
|
||||
)
|
||||
if not isinstance(item_value, dict):
|
||||
return False
|
||||
return all(
|
||||
_compare_values(item_value.get(k), v) for k, v in filter_value.items()
|
||||
)
|
||||
elif isinstance(filter_value, (list, tuple)):
|
||||
return (
|
||||
isinstance(item_value, (list, tuple))
|
||||
and len(item_value) == len(filter_value)
|
||||
and all(
|
||||
_compare_values(iv, fv)
|
||||
for iv, fv in zip(item_value, filter_value, strict=False)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return item_value == filter_value
|
||||
|
||||
|
||||
def _apply_operator(value: Any, operator: str, op_value: Any) -> bool:
|
||||
"""Apply a comparison operator, matching PostgreSQL's JSONB behavior."""
|
||||
if operator == "$eq":
|
||||
return value == op_value
|
||||
elif operator == "$gt":
|
||||
return float(value) > float(op_value)
|
||||
elif operator == "$gte":
|
||||
return float(value) >= float(op_value)
|
||||
elif operator == "$lt":
|
||||
return float(value) < float(op_value)
|
||||
elif operator == "$lte":
|
||||
return float(value) <= float(op_value)
|
||||
elif operator == "$ne":
|
||||
return value != op_value
|
||||
else:
|
||||
raise ValueError(f"Unsupported operator: {operator}")
|
||||
Binary file not shown.
Reference in New Issue
Block a user