initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
"""**Storage** is an implementation of key-value store.
Storage module provides implementations of various key-value stores that conform
to a simple key-value interface.
The primary goal of these storages is to support caching.
**Class hierarchy:**
.. code-block::
BaseStore --> <name>Store # Examples: MongoDBStore, RedisStore
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.storage.astradb import (
AstraDBByteStore,
AstraDBStore,
)
from langchain_community.storage.cassandra import (
CassandraByteStore,
)
from langchain_community.storage.mongodb import MongoDBByteStore, MongoDBStore
from langchain_community.storage.redis import (
RedisStore,
)
from langchain_community.storage.sql import (
SQLStore,
)
from langchain_community.storage.upstash_redis import (
UpstashRedisByteStore,
UpstashRedisStore,
)
__all__ = [
"AstraDBByteStore",
"AstraDBStore",
"CassandraByteStore",
"MongoDBStore",
"MongoDBByteStore",
"RedisStore",
"SQLStore",
"UpstashRedisByteStore",
"UpstashRedisStore",
]
_module_lookup = {
"AstraDBByteStore": "langchain_community.storage.astradb",
"AstraDBStore": "langchain_community.storage.astradb",
"CassandraByteStore": "langchain_community.storage.cassandra",
"MongoDBStore": "langchain_community.storage.mongodb",
"MongoDBByteStore": "langchain_community.storage.mongodb",
"RedisStore": "langchain_community.storage.redis",
"SQLStore": "langchain_community.storage.sql",
"UpstashRedisByteStore": "langchain_community.storage.upstash_redis",
"UpstashRedisStore": "langchain_community.storage.upstash_redis",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,238 @@
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB
V = TypeVar("V")
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
@abstractmethod
def decode_value(self, value: Any) -> Optional[V]:
"""Decodes value from Astra DB"""
@abstractmethod
def encode_value(self, value: Optional[V]) -> Any:
"""Encodes value for Astra DB"""
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert({"_id": k, "value": self.encode_value(v)})
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert(
{
"_id": k,
"value": self.encode_value(v),
}
)
def mdelete(self, keys: Sequence[str]) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
async def amdelete(self, keys: Sequence[str]) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
@deprecated(
since="0.0.22",
removal="1.0",
alternative_import="langchain_astradb.AstraDBStore",
)
class AstraDBStore(AstraDBBaseStore[Any]):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": <value>
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Any:
return value
def encode_value(self, value: Any) -> Any:
return value
@deprecated(
since="0.0.22",
removal="1.0",
alternative_import="langchain_astradb.AstraDBByteStore",
)
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
*,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": "<byte64 string value>"
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
# Constructor doc is not inherited so we have to override it.
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value)
def encode_value(self, value: Optional[bytes]) -> Any:
if value is None:
return None
return base64.b64encode(value).decode("ascii")

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import asyncio
from asyncio import InvalidStateError, Task
from typing import (
TYPE_CHECKING,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
Tuple,
)
from langchain_core.stores import ByteStore
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql
if TYPE_CHECKING:
from cassandra.cluster import Session
from cassandra.query import PreparedStatement
CREATE_TABLE_CQL_TEMPLATE = """
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
"""
SELECT_TABLE_CQL_TEMPLATE = (
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
)
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
INSERT_TABLE_CQL_TEMPLATE = (
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
)
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""
class CassandraByteStore(ByteStore):
"""A ByteStore implementation using Cassandra as the backend.
Parameters:
table: The name of the table to use.
session: A Cassandra session object. If not provided, it will be resolved
from the cassio config.
keyspace: The keyspace to use. If not provided, it will be resolved
from the cassio config.
setup_mode: The setup mode to use. Default is SYNC (SetupMode.SYNC).
"""
def __init__(
self,
table: str,
*,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
if not session or not keyspace:
try:
from cassio.config import check_resolve_keyspace, check_resolve_session
self.keyspace = keyspace or check_resolve_keyspace(keyspace)
self.session = session or check_resolve_session()
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent cassio package."
"Please install it with `pip install --upgrade cassio`."
)
else:
self.keyspace = keyspace
self.session = session
self.table = table
self.select_statement = None
self.insert_statement = None
self.delete_statement = None
create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace,
table=self.table,
)
self.db_setup_task: Optional[Task[None]] = None
if setup_mode == SetupMode.ASYNC:
self.db_setup_task = asyncio.create_task(
aexecute_cql(self.session, create_cql)
)
else:
self.session.execute(create_cql)
def ensure_db_setup(self) -> None:
"""Ensure that the DB setup is finished. If not, raise a ValueError."""
if self.db_setup_task:
try:
self.db_setup_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
"""Ensure that the DB setup is finished. If not, wait for it."""
if self.db_setup_task:
await self.db_setup_task
def get_select_statement(self) -> PreparedStatement:
"""Get the prepared select statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.select_statement:
self.select_statement = self.session.prepare(
SELECT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.select_statement
def get_insert_statement(self) -> PreparedStatement:
"""Get the prepared insert statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.insert_statement:
self.insert_statement = self.session.prepare(
INSERT_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.insert_statement
def get_delete_statement(self) -> PreparedStatement:
"""Get the prepared delete statement for the table.
If not available, prepare it.
Returns:
PreparedStatement: The prepared statement.
"""
if not self.delete_statement:
self.delete_statement = self.session.prepare(
DELETE_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
)
return self.delete_statement
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
self.ensure_db_setup()
docs_dict = {}
for row in self.session.execute(
self.get_select_statement(), [ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
docs_dict = {}
for row in await aexecute_cql(
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
):
docs_dict[row.row_id] = row.body_blob
return [docs_dict.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
self.ensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
self.session.execute(insert_statement, (k, v))
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
await self.aensure_db_setup()
insert_statement = self.get_insert_statement()
for k, v in key_value_pairs:
await aexecute_cql(self.session, insert_statement, parameters=(k, v))
def mdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
self.ensure_db_setup()
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
async def amdelete(self, keys: Sequence[str]) -> None:
from cassandra.query import ValueSequence
await self.aensure_db_setup()
await aexecute_cql(
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.ensure_db_setup()
for row in self.session.execute(
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
)
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.aensure_db_setup()
for row in await aexecute_cql(
self.session,
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
keyspace=self.keyspace, table=self.table
),
):
key = row.row_id
if not prefix or key.startswith(prefix):
yield key

View File

@@ -0,0 +1,3 @@
from langchain_core.stores import InvalidKeyException
__all__ = ["InvalidKeyException"]

View File

@@ -0,0 +1,248 @@
from typing import Iterator, List, Optional, Sequence, Tuple
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
class MongoDBByteStore(BaseStore[str, bytes]):
"""BaseStore implementation using MongoDB as the underlying store.
Examples:
Create a MongoDBByteStore instance and perform operations on it:
.. code-block:: python
# Instantiate the MongoDBByteStore with a MongoDB connection
from langchain_classic.storage import MongoDBByteStore
mongo_conn_str = "mongodb://localhost:27017/"
mongodb_store = MongoDBBytesStore(mongo_conn_str, db_name="test-db",
collection_name="test-collection")
# Set values for keys
mongodb_store.mset([("key1", "hello"), ("key2", "workd")])
# Get values for keys
values = mongodb_store.mget(["key1", "key2"])
# [bytes1, bytes1]
# Iterate over keys
for key in mongodb_store.yield_keys():
print(key)
# Delete keys
mongodb_store.mdelete(["key1", "key2"])
"""
def __init__(
self,
connection_string: str,
db_name: str,
collection_name: str,
*,
client_kwargs: Optional[dict] = None,
) -> None:
"""Initialize the MongoDBStore with a MongoDB connection string.
Args:
connection_string (str): MongoDB connection string
db_name (str): name to use
collection_name (str): collection name to use
client_kwargs (dict): Keyword arguments to pass to the Mongo client
"""
try:
from pymongo import MongoClient
except ImportError as e:
raise ImportError(
"The MongoDBStore requires the pymongo library to be "
"installed. "
"pip install pymongo"
) from e
if not connection_string:
raise ValueError("connection_string must be provided.")
if not db_name:
raise ValueError("db_name must be provided.")
if not collection_name:
raise ValueError("collection_name must be provided.")
self.client: MongoClient = MongoClient(
connection_string, **(client_kwargs or {})
)
self.collection = self.client[db_name][collection_name]
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the list of documents associated with the given keys.
Args:
keys (list[str]): A list of keys representing Document IDs..
Returns:
list[Document]: A list of Documents corresponding to the provided
keys, where each Document is either retrieved successfully or
represented as None if not found.
"""
result = self.collection.find({"_id": {"$in": keys}})
result_dict = {doc["_id"]: doc["value"] for doc in result}
return [result_dict.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
"""Set the given key-value pairs.
Args:
key_value_pairs (list[tuple[str, Document]]): A list of id-document
pairs.
"""
from pymongo import UpdateOne
updates = [{"_id": k, "value": v} for k, v in key_value_pairs]
self.collection.bulk_write(
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates]
)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given ids.
Args:
keys (list[str]): A list of keys representing Document IDs..
"""
self.collection.delete_many({"_id": {"$in": keys}})
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store.
Args:
prefix (str): prefix of keys to retrieve.
"""
if prefix is None:
for doc in self.collection.find(projection=["_id"]):
yield doc["_id"]
else:
for doc in self.collection.find(
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"]
):
yield doc["_id"]
class MongoDBStore(BaseStore[str, Document]):
"""BaseStore implementation using MongoDB as the underlying store.
Examples:
Create a MongoDBStore instance and perform operations on it:
.. code-block:: python
# Instantiate the MongoDBStore with a MongoDB connection
from langchain_classic.storage import MongoDBStore
mongo_conn_str = "mongodb://localhost:27017/"
mongodb_store = MongoDBStore(mongo_conn_str, db_name="test-db",
collection_name="test-collection")
# Set values for keys
doc1 = Document(...)
doc2 = Document(...)
mongodb_store.mset([("key1", doc1), ("key2", doc2)])
# Get values for keys
values = mongodb_store.mget(["key1", "key2"])
# [doc1, doc2]
# Iterate over keys
for key in mongodb_store.yield_keys():
print(key)
# Delete keys
mongodb_store.mdelete(["key1", "key2"])
"""
def __init__(
self,
connection_string: str,
db_name: str,
collection_name: str,
*,
client_kwargs: Optional[dict] = None,
) -> None:
"""Initialize the MongoDBStore with a MongoDB connection string.
Args:
connection_string (str): MongoDB connection string
db_name (str): name to use
collection_name (str): collection name to use
client_kwargs (dict): Keyword arguments to pass to the Mongo client
"""
try:
from pymongo import MongoClient
except ImportError as e:
raise ImportError(
"The MongoDBStore requires the pymongo library to be "
"installed. "
"pip install pymongo"
) from e
if not connection_string:
raise ValueError("connection_string must be provided.")
if not db_name:
raise ValueError("db_name must be provided.")
if not collection_name:
raise ValueError("collection_name must be provided.")
self.client: MongoClient = MongoClient(
connection_string, **(client_kwargs or {})
)
self.collection = self.client[db_name][collection_name]
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Get the list of documents associated with the given keys.
Args:
keys (list[str]): A list of keys representing Document IDs..
Returns:
list[Document]: A list of Documents corresponding to the provided
keys, where each Document is either retrieved successfully or
represented as None if not found.
"""
result = self.collection.find({"_id": {"$in": keys}})
result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result}
return [result_dict.get(key) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Set the given key-value pairs.
Args:
key_value_pairs (list[tuple[str, Document]]): A list of id-document
pairs.
Returns:
None
"""
from pymongo import UpdateOne
updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs]
self.collection.bulk_write(
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates]
)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given ids.
Args:
keys (list[str]): A list of keys representing Document IDs..
"""
self.collection.delete_many({"_id": {"$in": keys}})
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store.
Args:
prefix (str): prefix of keys to retrieve.
"""
if prefix is None:
for doc in self.collection.find(projection=["_id"]):
yield doc["_id"]
else:
for doc in self.collection.find(
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"]
):
yield doc["_id"]

View File

@@ -0,0 +1,144 @@
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
from langchain_core.stores import ByteStore
from langchain_community.utilities.redis import get_client
class RedisStore(ByteStore):
"""BaseStore implementation using Redis as the underlying store.
Examples:
Create a RedisStore instance and perform operations on it:
.. code-block:: python
# Instantiate the RedisStore with a Redis connection
from langchain_community.storage import RedisStore
from langchain_community.utilities.redis import get_client
client = get_client('redis://localhost:6379')
redis_store = RedisStore(client=client)
# Set values for keys
redis_store.mset([("key1", b"value1"), ("key2", b"value2")])
# Get values for keys
values = redis_store.mget(["key1", "key2"])
# [b"value1", b"value2"]
# Delete keys
redis_store.mdelete(["key1"])
# Iterate over keys
for key in redis_store.yield_keys():
print(key) # noqa: T201
"""
def __init__(
self,
*,
client: Any = None,
redis_url: Optional[str] = None,
client_kwargs: Optional[dict] = None,
ttl: Optional[int] = None,
namespace: Optional[str] = None,
) -> None:
"""Initialize the RedisStore with a Redis connection.
Must provide either a Redis client or a redis_url with optional client_kwargs.
Args:
client: A Redis connection instance
redis_url: redis url
client_kwargs: Keyword arguments to pass to the Redis client
ttl: time to expire keys in seconds if provided,
if None keys will never expire
namespace: if provided, all keys will be prefixed with this namespace
"""
try:
from redis import Redis
except ImportError as e:
raise ImportError(
"The RedisStore requires the redis library to be installed. "
"pip install redis"
) from e
if client and (redis_url or client_kwargs):
raise ValueError(
"Either a Redis client or a redis_url with optional client_kwargs "
"must be provided, but not both."
)
if not client and not redis_url:
raise ValueError("Either a Redis client or a redis_url must be provided.")
if client:
if not isinstance(client, Redis):
raise TypeError(
f"Expected Redis client, got {type(client).__name__} instead."
)
_client = client
else:
if not redis_url:
raise ValueError(
"Either a Redis client or a redis_url must be provided."
)
_client = get_client(redis_url, **(client_kwargs or {}))
self.client = _client
if not isinstance(ttl, int) and ttl is not None:
raise TypeError(f"Expected int or None, got {type(ttl)=} instead.")
self.ttl = ttl
self.namespace = namespace
def _get_prefixed_key(self, key: str) -> str:
"""Get the key with the namespace prefix.
Args:
key (str): The original key.
Returns:
str: The key with the namespace prefix.
"""
delimiter = "/"
if self.namespace:
return f"{self.namespace}{delimiter}{key}"
return key
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the values associated with the given keys."""
return cast(
List[Optional[bytes]],
self.client.mget([self._get_prefixed_key(key) for key in keys]),
)
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
"""Set the given key-value pairs."""
pipe = self.client.pipeline()
for key, value in key_value_pairs:
pipe.set(self._get_prefixed_key(key), value, ex=self.ttl)
pipe.execute()
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys."""
_keys = [self._get_prefixed_key(key) for key in keys]
self.client.delete(*_keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store."""
if prefix:
pattern = self._get_prefixed_key(prefix)
else:
pattern = self._get_prefixed_key("*")
scan_iter = cast(Iterator[bytes], self.client.scan_iter(match=pattern))
for key in scan_iter:
decoded_key = key.decode("utf-8")
if self.namespace:
relative_key = decoded_key[len(self.namespace) + 1 :]
yield relative_key
else:
yield decoded_key

View File

@@ -0,0 +1,295 @@
import contextlib
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from langchain_core.stores import BaseStore
from sqlalchemy import (
LargeBinary,
Text,
and_,
create_engine,
delete,
select,
)
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.orm import (
Mapped,
Session,
declarative_base,
sessionmaker,
)
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
Base = declarative_base()
try:
from sqlalchemy.orm import mapped_column
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
"""Table used to save values."""
# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"
namespace: Mapped[str] = mapped_column(
primary_key=True, index=True, nullable=False
)
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
value = mapped_column(LargeBinary, index=False, nullable=False)
except ImportError:
# dummy for sqlalchemy < 2
from sqlalchemy import Column
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc,no-redef]
"""Table used to save values."""
# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"
namespace = Column(Text(), primary_key=True, index=True, nullable=False)
key = Column(Text(), primary_key=True, index=True, nullable=False)
value = Column(LargeBinary, index=False, nullable=False)
def items_equal(x: Any, y: Any) -> bool:
return x == y
# This is a fix of original SQLStore.
# This can will be removed when a PR will be merged.
class SQLStore(BaseStore[str, bytes]):
"""BaseStore interface that works on an SQL database.
Examples:
Create a SQLStore instance and perform operations on it:
.. code-block:: python
from langchain_community.storage import SQLStore
# Instantiate the SQLStore with the root path
sql_store = SQLStore(namespace="test", db_url="sqlite://:memory:")
# Set values for keys
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
# Get values for keys
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
# Delete keys
sql_store.mdelete(["key1"])
# Iterate over keys
for key in sql_store.yield_keys():
print(key)
"""
def __init__(
self,
*,
namespace: str,
db_url: Optional[Union[str, Path]] = None,
engine: Optional[Union[Engine, AsyncEngine]] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None,
):
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")
if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")
_engine: Union[Engine, AsyncEngine]
if db_url:
if async_mode is None:
async_mode = False
if async_mode:
_engine = create_async_engine(
url=str(db_url),
**(engine_kwargs or {}),
)
else:
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
elif engine:
_engine = engine
else:
raise AssertionError("Something went wrong with configuration of engine.")
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
if isinstance(_engine, AsyncEngine):
self.async_mode = True
_session_maker = async_sessionmaker(bind=_engine)
else:
self.async_mode = False
_session_maker = sessionmaker(bind=_engine)
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_maker = _session_maker
self.namespace = namespace
def create_schema(self) -> None:
Base.metadata.create_all(self.engine) # problem in sqlalchemy v1
# sqlalchemy.exc.CompileError: (in table 'langchain_key_value_stores',
# column 'namespace'): Can't generate DDL for NullType(); did you forget
# to specify a type on this Column?
async def acreate_schema(self) -> None:
assert isinstance(self.engine, AsyncEngine)
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)
def drop(self) -> None:
Base.metadata.drop_all(bind=self.engine.connect())
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
assert isinstance(self.engine, AsyncEngine)
result: Dict[str, bytes] = {}
async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
for v in await session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
result = {}
with self._make_sync_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
for v in session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
async with self._make_async_session() as session:
await self._amdelete([key for key, _ in key_value_pairs], session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
for k, v in key_value_pairs
]
)
await session.commit()
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
values: Dict[str, bytes] = dict(key_value_pairs)
with self._make_sync_session() as session:
self._mdelete(list(values.keys()), session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
for k, v in values.items()
]
)
session.commit()
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
stmt = delete(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
session.execute(stmt)
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
stmt = delete(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
await session.execute(stmt)
def mdelete(self, keys: Sequence[str]) -> None:
with self._make_sync_session() as session:
self._mdelete(keys, session)
session.commit()
async def amdelete(self, keys: Sequence[str]) -> None:
async with self._make_async_session() as session:
await self._amdelete(keys, session)
await session.commit()
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
with self._make_sync_session() as session:
for v in session.query(LangchainKeyValueStores).filter(
LangchainKeyValueStores.namespace == self.namespace
):
if str(v.key).startswith(prefix or ""):
yield str(v.key)
session.close()
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter(
LangchainKeyValueStores.namespace == self.namespace
)
for v in await session.scalars(stmt):
if str(v.key).startswith(prefix or ""):
yield str(v.key)
await session.close()
@contextlib.contextmanager
def _make_sync_session(self) -> Generator[Session, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with cast(Session, self.session_maker()) as session:
yield cast(Session, session)
@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with cast(AsyncSession, self.session_maker()) as session:
yield cast(AsyncSession, session)

View File

@@ -0,0 +1,174 @@
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
from langchain_core._api.deprecation import deprecated
from langchain_core.stores import BaseStore, ByteStore
class _UpstashRedisStore(BaseStore[str, str]):
"""BaseStore implementation using Upstash Redis as the underlying store."""
def __init__(
self,
*,
client: Any = None,
url: Optional[str] = None,
token: Optional[str] = None,
ttl: Optional[int] = None,
namespace: Optional[str] = None,
) -> None:
"""Initialize the UpstashRedisStore with HTTP API.
Must provide either an Upstash Redis client or a url.
Args:
client: An Upstash Redis instance
url: UPSTASH_REDIS_REST_URL
token: UPSTASH_REDIS_REST_TOKEN
ttl: time to expire keys in seconds if provided,
if None keys will never expire
namespace: if provided, all keys will be prefixed with this namespace
"""
try:
from upstash_redis import Redis
except ImportError as e:
raise ImportError(
"UpstashRedisStore requires the upstash_redis library to be installed. "
"pip install upstash_redis"
) from e
if client and url:
raise ValueError(
"Either an Upstash Redis client or a url must be provided, not both."
)
if client:
if not isinstance(client, Redis):
raise TypeError(
f"Expected Upstash Redis client, got {type(client).__name__}."
)
_client = client
else:
if not url or not token:
raise ValueError(
"Either an Upstash Redis client or url and token must be provided."
)
_client = Redis(url=url, token=token)
self.client = _client
if not isinstance(ttl, int) and ttl is not None:
raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
self.ttl = ttl
self.namespace = namespace
def _get_prefixed_key(self, key: str) -> str:
"""Get the key with the namespace prefix.
Args:
key (str): The original key.
Returns:
str: The key with the namespace prefix.
"""
delimiter = "/"
if self.namespace:
return f"{self.namespace}{delimiter}{key}"
return key
def mget(self, keys: Sequence[str]) -> List[Optional[str]]:
"""Get the values associated with the given keys."""
keys = [self._get_prefixed_key(key) for key in keys]
return cast(
List[Optional[str]],
self.client.mget(*keys),
)
def mset(self, key_value_pairs: Sequence[Tuple[str, str]]) -> None:
"""Set the given key-value pairs."""
for key, value in key_value_pairs:
self.client.set(self._get_prefixed_key(key), value, ex=self.ttl)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys."""
_keys = [self._get_prefixed_key(key) for key in keys]
self.client.delete(*_keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store."""
if prefix:
pattern = self._get_prefixed_key(prefix)
else:
pattern = self._get_prefixed_key("*")
cursor, keys = self.client.scan(0, match=pattern)
for key in keys:
if self.namespace:
relative_key = key[len(self.namespace) + 1 :]
yield relative_key
else:
yield key
while cursor != 0:
cursor, keys = self.client.scan(cursor, match=pattern)
for key in keys:
if self.namespace:
relative_key = key[len(self.namespace) + 1 :]
yield relative_key
else:
yield key
@deprecated("0.0.1", alternative="UpstashRedisByteStore")
class UpstashRedisStore(_UpstashRedisStore):
"""
BaseStore implementation using Upstash Redis
as the underlying store to store strings.
Deprecated in favor of the more generic UpstashRedisByteStore.
"""
class UpstashRedisByteStore(ByteStore):
"""
BaseStore implementation using Upstash Redis
as the underlying store to store raw bytes.
"""
def __init__(
self,
*,
client: Any = None,
url: Optional[str] = None,
token: Optional[str] = None,
ttl: Optional[int] = None,
namespace: Optional[str] = None,
) -> None:
self.underlying_store = _UpstashRedisStore(
client=client, url=url, token=token, ttl=ttl, namespace=namespace
)
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
"""Get the values associated with the given keys."""
return [
value.encode("utf-8") if value is not None else None
for value in self.underlying_store.mget(keys)
]
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
"""Set the given key-value pairs."""
self.underlying_store.mset(
[
(k, v.decode("utf-8")) if v is not None else None
for k, v in key_value_pairs
]
)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys."""
self.underlying_store.mdelete(keys)
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store."""
yield from self.underlying_store.yield_keys(prefix=prefix)