initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,365 @@
"""Utilities for batching operations in a background task."""
from __future__ import annotations
import asyncio
import functools
import weakref
from collections.abc import Callable, Iterable
from typing import Any, Literal, TypeVar
from langgraph.store.base import (
NOT_PROVIDED,
BaseStore,
GetOp,
Item,
ListNamespacesOp,
MatchCondition,
NamespacePath,
NotProvided,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
_ensure_refresh,
_ensure_ttl,
_validate_namespace,
)
F = TypeVar("F", bound=Callable)
def _check_loop(func: F) -> F:
@functools.wraps(func)
def wrapper(store: AsyncBatchedBaseStore, *args: Any, **kwargs: Any) -> Any:
method_name: str = func.__name__
try:
current_loop = asyncio.get_running_loop()
if current_loop is store._loop:
replacement_str = (
f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)"
if method_name
else "For example, replace `store.get(...)` with `await store.aget(...)`"
)
raise asyncio.InvalidStateError(
f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. "
"This can lead to deadlocks or performance issues. "
"Please use the asynchronous interface for main thread operations. "
f"{replacement_str} "
)
except RuntimeError:
pass
return func(store, *args, **kwargs)
return wrapper
class AsyncBatchedBaseStore(BaseStore):
"""Efficiently batch operations in a background task."""
__slots__ = ("_loop", "_aqueue", "_task")
def __init__(self) -> None:
super().__init__()
self._loop = asyncio.get_running_loop()
self._aqueue: asyncio.Queue[tuple[asyncio.Future, Op]] = asyncio.Queue()
self._task: asyncio.Task | None = None
self._ensure_task()
def __del__(self) -> None:
try:
if self._task:
self._task.cancel()
except RuntimeError:
pass
def _ensure_task(self) -> None:
"""Ensure the background processing loop is running."""
if self._task is None or self._task.done():
self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self)))
async def aget(
self,
namespace: tuple[str, ...],
key: str,
*,
refresh_ttl: bool | None = None,
) -> Item | None:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
GetOp(
namespace,
key,
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
),
)
)
return await fut
async def asearch(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: str | None = None,
filter: dict[str, Any] | None = None,
limit: int = 10,
offset: int = 0,
refresh_ttl: bool | None = None,
) -> list[SearchItem]:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
SearchOp(
namespace_prefix,
filter,
limit,
offset,
query,
refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl),
),
)
)
return await fut
async def aput(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Literal[False] | list[str] | None = None,
*,
ttl: float | None | NotProvided = NOT_PROVIDED,
) -> None:
self._ensure_task()
_validate_namespace(namespace)
fut = self._loop.create_future()
self._aqueue.put_nowait(
(
fut,
PutOp(
namespace, key, value, index, ttl=_ensure_ttl(self.ttl_config, ttl)
),
)
)
return await fut
async def adelete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
self._ensure_task()
fut = self._loop.create_future()
self._aqueue.put_nowait((fut, PutOp(namespace, key, None)))
return await fut
async def alist_namespaces(
self,
*,
prefix: NamespacePath | None = None,
suffix: NamespacePath | None = None,
max_depth: int | None = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
self._ensure_task()
fut = self._loop.create_future()
match_conditions = []
if prefix:
match_conditions.append(MatchCondition(match_type="prefix", path=prefix))
if suffix:
match_conditions.append(MatchCondition(match_type="suffix", path=suffix))
op = ListNamespacesOp(
match_conditions=tuple(match_conditions),
max_depth=max_depth,
limit=limit,
offset=offset,
)
self._aqueue.put_nowait((fut, op))
return await fut
@_check_loop
def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result()
@_check_loop
def get(
self,
namespace: tuple[str, ...],
key: str,
*,
refresh_ttl: bool | None = None,
) -> Item | None:
return asyncio.run_coroutine_threadsafe(
self.aget(namespace, key=key, refresh_ttl=refresh_ttl), self._loop
).result()
@_check_loop
def search(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: str | None = None,
filter: dict[str, Any] | None = None,
limit: int = 10,
offset: int = 0,
refresh_ttl: bool | None = None,
) -> list[SearchItem]:
return asyncio.run_coroutine_threadsafe(
self.asearch(
namespace_prefix,
query=query,
filter=filter,
limit=limit,
offset=offset,
refresh_ttl=refresh_ttl,
),
self._loop,
).result()
@_check_loop
def put(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Literal[False] | list[str] | None = None,
*,
ttl: float | None | NotProvided = NOT_PROVIDED,
) -> None:
_validate_namespace(namespace)
asyncio.run_coroutine_threadsafe(
self.aput(
namespace,
key=key,
value=value,
index=index,
ttl=_ensure_ttl(self.ttl_config, ttl),
),
self._loop,
).result()
@_check_loop
def delete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
asyncio.run_coroutine_threadsafe(
self.adelete(namespace, key=key), self._loop
).result()
@_check_loop
def list_namespaces(
self,
*,
prefix: NamespacePath | None = None,
suffix: NamespacePath | None = None,
max_depth: int | None = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
return asyncio.run_coroutine_threadsafe(
self.alist_namespaces(
prefix=prefix,
suffix=suffix,
max_depth=max_depth,
limit=limit,
offset=offset,
),
self._loop,
).result()
def _dedupe_ops(values: list[Op]) -> tuple[list[int] | None, list[Op]]:
"""Dedupe operations while preserving order for results.
Args:
values: List of operations to dedupe
Returns:
Tuple of (listen indices, deduped operations)
where listen indices map deduped operation results back to original positions
"""
if len(values) <= 1:
return None, list(values)
dedupped: list[Op] = []
listen: list[int] = []
puts: dict[tuple[tuple[str, ...], str], int] = {}
for op in values:
if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)):
try:
listen.append(dedupped.index(op))
except ValueError:
listen.append(len(dedupped))
dedupped.append(op)
elif isinstance(op, PutOp):
putkey = (op.namespace, op.key)
if putkey in puts:
# Overwrite previous put
ix = puts[putkey]
dedupped[ix] = op
listen.append(ix)
else:
puts[putkey] = len(dedupped)
listen.append(len(dedupped))
dedupped.append(op)
else: # Any new ops will be treated regularly
listen.append(len(dedupped))
dedupped.append(op)
return listen, dedupped
async def _run(
aqueue: asyncio.Queue[tuple[asyncio.Future, Op]],
store: weakref.ReferenceType[BaseStore],
) -> None:
while item := await aqueue.get():
# check if store is still alive
if s := store():
try:
# accumulate operations scheduled in same tick
items = [item]
try:
while item := aqueue.get_nowait():
items.append(item)
except asyncio.QueueEmpty:
pass
# get the operations to run
futs = [item[0] for item in items]
values = [item[1] for item in items]
# action each operation
try:
listen, dedupped = _dedupe_ops(values)
results = await s.abatch(dedupped)
if listen is not None:
results = [results[ix] for ix in listen]
# set the results of each operation
for fut, result in zip(futs, results, strict=False):
# guard against future being done (e.g. cancelled)
if not fut.done():
fut.set_result(result)
except Exception as e:
for fut in futs:
# guard against future being done (e.g. cancelled)
if not fut.done():
fut.set_exception(e)
finally:
# remove strong ref to store
del s
else:
break

View File

@@ -0,0 +1,433 @@
"""Utilities for working with embedding functions and LangChain's Embeddings interface.
This module provides tools to wrap arbitrary embedding functions (both sync and async)
into LangChain's Embeddings interface. This enables using custom embedding functions
with LangChain-compatible tools while maintaining support for both synchronous and
asynchronous operations.
"""
from __future__ import annotations
import asyncio
import functools
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from langchain_core.embeddings import Embeddings
EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]]
"""Type for synchronous embedding functions.
The function should take a sequence of strings and return a list of embeddings,
where each embedding is a list of floats. The dimensionality of the embeddings
should be consistent for all inputs.
"""
AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]]
"""Type for asynchronous embedding functions.
Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings.
"""
def ensure_embeddings(
embed: Embeddings | EmbeddingsFunc | AEmbeddingsFunc | str | None,
) -> Embeddings:
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
This function wraps arbitrary embedding functions to make them compatible with
LangChain's Embeddings interface. It handles both synchronous and asynchronous
functions.
Args:
embed: Either an existing Embeddings instance, or a function that converts
text to embeddings. If the function is async, it will be used for both
sync and async operations.
Returns:
An Embeddings instance that wraps the provided function(s).
??? example "Examples"
Wrap a synchronous embedding function:
```python
def my_embed_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
```
Wrap an asynchronous embedding function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_async_fn)
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
```
Initialize embeddings using a provider string:
```python
# Requires langchain>=0.3.9 and langgraph-checkpoint>=2.0.11
embeddings = ensure_embeddings("openai:text-embedding-3-small")
result = embeddings.embed_query("hello")
```
"""
if embed is None:
raise ValueError("embed must be provided")
if isinstance(embed, str):
init_embeddings = _get_init_embeddings()
if init_embeddings is None:
from importlib.metadata import PackageNotFoundError, version
try:
lc_version = version("langchain")
version_info = f"Found langchain version {lc_version}, but"
except PackageNotFoundError:
version_info = "langchain is not installed;"
raise ValueError(
f"Could not load embeddings from string '{embed}'. {version_info} "
"loading embeddings by provider:identifier string requires langchain>=0.3.9 "
"as well as the provider-specific package. "
"Install LangChain with: pip install 'langchain>=0.3.9' "
"and the provider-specific package (e.g., 'langchain-openai>=0.3.0'). "
"Alternatively, specify 'embed' as a compatible Embeddings object or python function."
)
return init_embeddings(embed)
if isinstance(embed, Embeddings):
return embed
return EmbeddingsLambda(embed)
class EmbeddingsLambda(Embeddings):
"""Wrapper to convert embedding functions into LangChain's Embeddings interface.
This class allows arbitrary embedding functions to be used with LangChain-compatible
tools. It supports both synchronous and asynchronous operations, and can handle:
1. A synchronous function for sync operations (async operations will use sync function)
2. An async function for both sync/async operations (sync operations will raise an error)
The embedding functions should convert text into fixed-dimensional vectors that
capture the semantic meaning of the text.
Args:
func: Function that converts text to embeddings. Can be sync or async.
If async, it will be used for async operations, but sync operations
will raise an error. If sync, it will be used for both sync and async operations.
??? example "Examples"
With a sync function:
```python
def my_embed_fn(texts):
# Return 2D embeddings for each text
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
await embeddings.aembed_query("hello") # Also returns [0.1, 0.2]
```
With an async function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_async_fn)
await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
# Note: embed_query() would raise an error
```
"""
def __init__(
self,
func: EmbeddingsFunc | AEmbeddingsFunc,
) -> None:
if func is None:
raise ValueError("func must be provided")
if _is_async_callable(func):
self.afunc = func
else:
self.func = func
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Raises:
ValueError: If the instance was initialized with only an async function.
"""
func = getattr(self, "func", None)
if func is None:
raise ValueError(
"EmbeddingsLambda was initialized with an async function but no sync function. "
"Use aembed_documents for async operation or provide a sync function."
)
return func(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling embed_documents with a single text
and taking the first result.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronously embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Note:
If no async function was provided, this falls back to the sync implementation.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_documents(texts)
return await afunc(texts)
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronously embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling aembed_documents with a single text
and taking the first result.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_query(text)
return (await afunc([text]))[0]
def get_text_at_path(obj: Any, path: str | list[str]) -> list[str]:
"""Extract text from an object using a path expression or pre-tokenized path.
Args:
obj: The object to extract text from
path: Either a path string or pre-tokenized path list.
!!! info "Path types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
- Nested paths in multi-field: "{field1,nested.field2}"
"""
if not path or path == "$":
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
tokens = tokenize_path(path) if isinstance(path, str) else path
def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]:
if pos >= len(tokens):
if isinstance(obj, (str, int, float, bool)):
return [str(obj)]
elif obj is None:
return []
elif isinstance(obj, (list, dict)):
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
return []
token = tokens[pos]
results = []
if token.startswith("[") and token.endswith("]"):
if not isinstance(obj, list):
return []
index = token[1:-1]
if index == "*":
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
else:
try:
idx = int(index)
if idx < 0:
idx = len(obj) + idx
if 0 <= idx < len(obj):
results.extend(_extract_from_obj(obj[idx], tokens, pos + 1))
except (ValueError, IndexError):
return []
elif token.startswith("{") and token.endswith("}"):
if not isinstance(obj, dict):
return []
fields = [f.strip() for f in token[1:-1].split(",")]
for field in fields:
nested_tokens = tokenize_path(field)
if nested_tokens:
current_obj: dict | None = obj
for nested_token in nested_tokens:
if (
isinstance(current_obj, dict)
and nested_token in current_obj
):
current_obj = current_obj[nested_token]
else:
current_obj = None
break
if current_obj is not None:
if isinstance(current_obj, (str, int, float, bool)):
results.append(str(current_obj))
elif isinstance(current_obj, (list, dict)):
results.append(
json.dumps(
current_obj, sort_keys=True, ensure_ascii=False
)
)
# Handle wildcard
elif token == "*":
if isinstance(obj, dict):
for value in obj.values():
results.extend(_extract_from_obj(value, tokens, pos + 1))
elif isinstance(obj, list):
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
# Handle regular field
else:
if isinstance(obj, dict) and token in obj:
results.extend(_extract_from_obj(obj[token], tokens, pos + 1))
return results
return _extract_from_obj(obj, tokens, 0)
# Private utility functions
def tokenize_path(path: str) -> list[str]:
"""Tokenize a path into components.
!!! info "Types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
"""
if not path:
return []
tokens = []
current: list[str] = []
i = 0
while i < len(path):
char = path[i]
if char == "[": # Handle array index
if current:
tokens.append("".join(current))
current = []
bracket_count = 1
index_chars = ["["]
i += 1
while i < len(path) and bracket_count > 0:
if path[i] == "[":
bracket_count += 1
elif path[i] == "]":
bracket_count -= 1
index_chars.append(path[i])
i += 1
tokens.append("".join(index_chars))
continue
elif char == "{": # Handle multi-field selection
if current:
tokens.append("".join(current))
current = []
brace_count = 1
field_chars = ["{"]
i += 1
while i < len(path) and brace_count > 0:
if path[i] == "{":
brace_count += 1
elif path[i] == "}":
brace_count -= 1
field_chars.append(path[i])
i += 1
tokens.append("".join(field_chars))
continue
elif char == ".": # Handle regular field
if current:
tokens.append("".join(current))
current = []
else:
current.append(char)
i += 1
if current:
tokens.append("".join(current))
return tokens
def _is_async_callable(
func: Any,
) -> bool:
"""Check if a function is async.
This includes both async def functions and classes with async __call__ methods.
Args:
func: Function or callable object to check.
Returns:
True if the function is async, False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)
@functools.lru_cache
def _get_init_embeddings() -> Callable[[str], Embeddings] | None:
try:
from langchain.embeddings import init_embeddings # type: ignore
return init_embeddings
except ImportError:
return None
__all__ = [
"ensure_embeddings",
"EmbeddingsFunc",
"AEmbeddingsFunc",
]

View File

@@ -0,0 +1,592 @@
"""In-memory dictionary-backed store with optional vector search.
!!! example "Examples"
Basic key-value storage:
```python
from langgraph.store.memory import InMemoryStore
store = InMemoryStore()
store.put(("users", "123"), "prefs", {"theme": "dark"})
item = store.get(("users", "123"), "prefs")
```
Vector search using LangChain embeddings:
```python
from langchain.embeddings import init_embeddings
from langgraph.store.memory import InMemoryStore
store = InMemoryStore(
index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small")
}
)
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
```
Vector search using OpenAI SDK directly:
```python
from openai import OpenAI
from langgraph.store.memory import InMemoryStore
client = OpenAI()
def embed_texts(texts: list[str]) -> list[list[float]]:
response = client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": embed_texts
}
)
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
```
Async vector search using OpenAI SDK:
```python
from openai import AsyncOpenAI
from langgraph.store.memory import InMemoryStore
client = AsyncOpenAI()
async def aembed_texts(texts: list[str]) -> list[list[float]]:
response = await client.embeddings.create(
model="text-embedding-3-small",
input=texts
)
return [e.embedding for e in response.data]
store = InMemoryStore(
index={
"dims": 1536,
"embed": aembed_texts
}
)
# Store documents
await store.aput(("docs",), "doc1", {"text": "Python tutorial"})
await store.aput(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = await store.asearch(("docs",), query="python programming")
```
Warning:
This store keeps all data in memory. Data is lost when the process exits.
For persistence, use a database-backed store like PostgresStore.
Tip:
For vector search, install numpy for better performance:
```bash
pip install numpy
```
"""
from __future__ import annotations
import asyncio
import concurrent.futures as cf
import functools
import logging
from collections import defaultdict
from collections.abc import Iterable
from datetime import datetime, timezone
from importlib import util
from typing import Any
from langchain_core.embeddings import Embeddings
from langgraph.store.base import (
BaseStore,
GetOp,
IndexConfig,
Item,
ListNamespacesOp,
MatchCondition,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
ensure_embeddings,
get_text_at_path,
tokenize_path,
)
logger = logging.getLogger(__name__)
class InMemoryStore(BaseStore):
"""In-memory dictionary-backed store with optional vector search.
!!! example "Examples"
Basic key-value storage:
store = InMemoryStore()
store.put(("users", "123"), "prefs", {"theme": "dark"})
item = store.get(("users", "123"), "prefs")
Vector search with embeddings:
from langchain.embeddings import init_embeddings
store = InMemoryStore(index={
"dims": 1536,
"embed": init_embeddings("openai:text-embedding-3-small"),
"fields": ["text"],
})
# Store documents
store.put(("docs",), "doc1", {"text": "Python tutorial"})
store.put(("docs",), "doc2", {"text": "TypeScript guide"})
# Search by similarity
results = store.search(("docs",), query="python programming")
Note:
Semantic search is disabled by default. You can enable it by providing an `index` configuration
when creating the store. Without this configuration, all `index` arguments passed to
`put` or `aput`will have no effect.
Warning:
This store keeps all data in memory. Data is lost when the process exits.
For persistence, use a database-backed store like PostgresStore.
Tip:
For vector search, install numpy for better performance:
```bash
pip install numpy
```
"""
__slots__ = (
"_data",
"_vectors",
"index_config",
"embeddings",
)
def __init__(self, *, index: IndexConfig | None = None) -> None:
# Both _data and _vectors are wrapped in the In-memory API
# Do not change their names
self._data: dict[tuple[str, ...], dict[str, Item]] = defaultdict(dict)
# [ns][key][path]
self._vectors: dict[tuple[str, ...], dict[str, dict[str, list[float]]]] = (
defaultdict(lambda: defaultdict(dict))
)
self.index_config = index
if self.index_config:
self.index_config = self.index_config.copy()
self.embeddings: Embeddings | None = ensure_embeddings(
self.index_config.get("embed"),
)
self.index_config["__tokenized_fields"] = [
(p, tokenize_path(p)) if p != "$" else (p, p)
for p in (self.index_config.get("fields") or ["$"])
]
else:
self.index_config = None
self.embeddings = None
def batch(self, ops: Iterable[Op]) -> list[Result]:
# The batch/abatch methods are treated as internal.
# Users should access via put/search/get/list_namespaces/etc.
results, put_ops, search_ops = self._prepare_ops(ops)
if search_ops:
queryinmem_store = self._embed_search_queries(search_ops)
self._batch_search(search_ops, queryinmem_store, results)
to_embed = self._extract_texts(put_ops)
if to_embed and self.index_config and self.embeddings:
embeddings = self.embeddings.embed_documents(list(to_embed))
self._insertinmem_store(to_embed, embeddings)
self._apply_put_ops(put_ops)
return results
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
# The batch/abatch methods are treated as internal.
# Users should access via put/search/get/list_namespaces/etc.
results, put_ops, search_ops = self._prepare_ops(ops)
if search_ops:
queryinmem_store = await self._aembed_search_queries(search_ops)
self._batch_search(search_ops, queryinmem_store, results)
to_embed = self._extract_texts(put_ops)
if to_embed and self.index_config and self.embeddings:
embeddings = await self.embeddings.aembed_documents(list(to_embed))
self._insertinmem_store(to_embed, embeddings)
self._apply_put_ops(put_ops)
return results
# Helpers
def _filter_items(self, op: SearchOp) -> list[tuple[Item, list[list[float]]]]:
"""Filter items by namespace and filter function, return items with their embeddings."""
namespace_prefix = op.namespace_prefix
def filter_func(item: Item) -> bool:
if not op.filter:
return True
return all(
_compare_values(item.value.get(key), filter_value)
for key, filter_value in op.filter.items()
)
filtered = []
for namespace in self._data:
if not (
namespace[: len(namespace_prefix)] == namespace_prefix
if len(namespace) >= len(namespace_prefix)
else False
):
continue
for key, item in self._data[namespace].items():
if filter_func(item):
if op.query and (embeddings := self._vectors[namespace].get(key)):
filtered.append((item, list(embeddings.values())))
else:
filtered.append((item, []))
return filtered
def _embed_search_queries(
self,
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
) -> dict[str, list[float]]:
queryinmem_store = {}
if self.index_config and self.embeddings and search_ops:
queries = {op.query for (op, _) in search_ops.values() if op.query}
if queries:
with cf.ThreadPoolExecutor() as executor:
futures = {
q: executor.submit(self.embeddings.embed_query, q)
for q in list(queries)
}
for query, future in futures.items():
queryinmem_store[query] = future.result()
return queryinmem_store
async def _aembed_search_queries(
self,
search_ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
) -> dict[str, list[float]]:
queryinmem_store = {}
if self.index_config and self.embeddings and search_ops:
queries = {op.query for (op, _) in search_ops.values() if op.query}
if queries:
coros = [self.embeddings.aembed_query(q) for q in list(queries)]
results = await asyncio.gather(*coros)
queryinmem_store = dict(zip(queries, results, strict=False))
return queryinmem_store
def _batch_search(
self,
ops: dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
queryinmem_store: dict[str, list[float]],
results: list[Result],
) -> None:
"""Perform batch similarity search for multiple queries."""
for i, (op, candidates) in ops.items():
if not candidates:
results[i] = []
continue
if op.query and queryinmem_store:
query_embedding = queryinmem_store[op.query]
flat_items, flat_vectors = [], []
scoreless = []
for item, vectors in candidates:
for vector in vectors:
flat_items.append(item)
flat_vectors.append(vector)
if not vectors:
scoreless.append(item)
scores = _cosine_similarity(query_embedding, flat_vectors)
sorted_results = sorted(
zip(scores, flat_items, strict=False),
key=lambda x: x[0],
reverse=True,
)
# max pooling
seen: set[tuple[tuple[str, ...], str]] = set()
kept: list[tuple[float | None, Item]] = []
for score, item in sorted_results:
key = (item.namespace, item.key)
if key in seen:
continue
ix = len(seen)
seen.add(key)
if ix >= op.offset + op.limit:
break
if ix < op.offset:
continue
kept.append((score, item))
if scoreless and len(kept) < op.limit:
# Corner case: if we request more items than what we have embedded,
# fill the rest with non-scored items
kept.extend(
(None, item) for item in scoreless[: op.limit - len(kept)]
)
results[i] = [
SearchItem(
namespace=item.namespace,
key=item.key,
value=item.value,
created_at=item.created_at,
updated_at=item.updated_at,
score=float(score) if score is not None else None,
)
for score, item in kept
]
else:
results[i] = [
SearchItem(
namespace=item.namespace,
key=item.key,
value=item.value,
created_at=item.created_at,
updated_at=item.updated_at,
)
for (item, _) in candidates[op.offset : op.offset + op.limit]
]
def _prepare_ops(
self, ops: Iterable[Op]
) -> tuple[
list[Result],
dict[tuple[tuple[str, ...], str], PutOp],
dict[int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]],
]:
results: list[Result] = []
put_ops: dict[tuple[tuple[str, ...], str], PutOp] = {}
search_ops: dict[
int, tuple[SearchOp, list[tuple[Item, list[list[float]]]]]
] = {}
for i, op in enumerate(ops):
if isinstance(op, GetOp):
item = self._data[op.namespace].get(op.key)
results.append(item)
elif isinstance(op, SearchOp):
search_ops[i] = (op, self._filter_items(op))
results.append(None)
elif isinstance(op, ListNamespacesOp):
results.append(self._handle_list_namespaces(op))
elif isinstance(op, PutOp):
put_ops[(op.namespace, op.key)] = op
results.append(None)
else:
raise ValueError(f"Unknown operation type: {type(op)}")
return results, put_ops, search_ops
def _apply_put_ops(self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]) -> None:
for (namespace, key), op in put_ops.items():
if op.value is None:
self._data[namespace].pop(key, None)
self._vectors[namespace].pop(key, None)
else:
self._data[namespace][key] = Item(
value=op.value,
key=key,
namespace=namespace,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
def _extract_texts(
self, put_ops: dict[tuple[tuple[str, ...], str], PutOp]
) -> dict[str, list[tuple[tuple[str, ...], str, str]]]:
if put_ops and self.index_config and self.embeddings:
to_embed = defaultdict(list)
for op in put_ops.values():
if op.value is not None and op.index is not False:
if op.index is None:
paths = self.index_config["__tokenized_fields"]
else:
paths = [(ix, tokenize_path(ix)) for ix in op.index]
for path, field in paths:
texts = get_text_at_path(op.value, field)
if texts:
if len(texts) > 1:
for i, text in enumerate(texts):
to_embed[text].append(
(op.namespace, op.key, f"{path}.{i}")
)
else:
to_embed[texts[0]].append((op.namespace, op.key, path))
return to_embed
return {}
def _insertinmem_store(
self,
to_embed: dict[str, list[tuple[tuple[str, ...], str, str]]],
embeddings: list[list[float]],
) -> None:
indices = [index for indices in to_embed.values() for index in indices]
if len(indices) != len(embeddings):
raise ValueError(
f"Number of embeddings ({len(embeddings)}) does not"
f" match number of indices ({len(indices)})"
)
for embedding, (ns, key, path) in zip(embeddings, indices, strict=False):
self._vectors[ns][key][path] = embedding
def _handle_list_namespaces(self, op: ListNamespacesOp) -> list[tuple[str, ...]]:
all_namespaces = list(
self._data.keys()
) # Avoid collection size changing while iterating
namespaces = all_namespaces
if op.match_conditions:
namespaces = [
ns
for ns in namespaces
if all(_does_match(condition, ns) for condition in op.match_conditions)
]
if op.max_depth is not None:
namespaces = sorted({ns[: op.max_depth] for ns in namespaces})
else:
namespaces = sorted(namespaces)
return namespaces[op.offset : op.offset + op.limit]
@functools.lru_cache(maxsize=1)
def _check_numpy() -> bool:
if bool(util.find_spec("numpy")):
return True
logger.warning(
"NumPy not found in the current Python environment. "
"The InMemoryStore will use a pure Python implementation for vector operations, "
"which may significantly impact performance, especially for large datasets or frequent searches. "
"For optimal speed and efficiency, consider installing NumPy: "
"pip install numpy"
)
return False
def _cosine_similarity(X: list[float], Y: list[list[float]]) -> list[float]:
"""
Compute cosine similarity between a vector X and a matrix Y.
Lazy import numpy for efficiency.
"""
if not Y:
return []
if _check_numpy():
import numpy as np
X_arr = np.array(X) if not isinstance(X, np.ndarray) else X
Y_arr = np.array(Y) if not isinstance(Y, np.ndarray) else Y
X_norm = np.linalg.norm(X_arr)
Y_norm = np.linalg.norm(Y_arr, axis=1)
# Avoid division by zero
mask = Y_norm != 0
similarities = np.zeros_like(Y_norm)
similarities[mask] = np.dot(Y_arr[mask], X_arr) / (Y_norm[mask] * X_norm)
return similarities.tolist()
similarities = []
for y in Y:
dot_product = sum(a * b for a, b in zip(X, y, strict=False))
norm1 = sum(a * a for a in X) ** 0.5
norm2 = sum(a * a for a in y) ** 0.5
similarity = dot_product / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0
similarities.append(similarity)
return similarities
def _does_match(match_condition: MatchCondition, key: tuple[str, ...]) -> bool:
"""Whether a namespace key matches a match condition."""
match_type = match_condition.match_type
path = match_condition.path
if len(key) < len(path):
return False
if match_type == "prefix":
for k_elem, p_elem in zip(key, path, strict=False):
if p_elem == "*":
continue # Wildcard matches any element
if k_elem != p_elem:
return False
return True
elif match_type == "suffix":
for k_elem, p_elem in zip(reversed(key), reversed(path), strict=False):
if p_elem == "*":
continue # Wildcard matches any element
if k_elem != p_elem:
return False
return True
else:
raise ValueError(f"Unsupported match type: {match_type}")
def _compare_values(item_value: Any, filter_value: Any) -> bool:
"""Compare values in a JSONB-like way, handling nested objects."""
if isinstance(filter_value, dict):
if any(k.startswith("$") for k in filter_value):
return all(
_apply_operator(item_value, op_key, op_value)
for op_key, op_value in filter_value.items()
)
if not isinstance(item_value, dict):
return False
return all(
_compare_values(item_value.get(k), v) for k, v in filter_value.items()
)
elif isinstance(filter_value, (list, tuple)):
return (
isinstance(item_value, (list, tuple))
and len(item_value) == len(filter_value)
and all(
_compare_values(iv, fv)
for iv, fv in zip(item_value, filter_value, strict=False)
)
)
else:
return item_value == filter_value
def _apply_operator(value: Any, operator: str, op_value: Any) -> bool:
"""Apply a comparison operator, matching PostgreSQL's JSONB behavior."""
if operator == "$eq":
return value == op_value
elif operator == "$gt":
return float(value) > float(op_value)
elif operator == "$gte":
return float(value) >= float(op_value)
elif operator == "$lt":
return float(value) < float(op_value)
elif operator == "$lte":
return float(value) <= float(op_value)
elif operator == "$ne":
return value != op_value
else:
raise ValueError(f"Unsupported operator: {operator}")