initial commit
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
"""**Storage** is an implementation of key-value store.
|
||||
|
||||
Storage module provides implementations of various key-value stores that conform
|
||||
to a simple key-value interface.
|
||||
|
||||
The primary goal of these storages is to support caching.
|
||||
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseStore --> <name>Store # Examples: MongoDBStore, RedisStore
|
||||
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.storage.astradb import (
|
||||
AstraDBByteStore,
|
||||
AstraDBStore,
|
||||
)
|
||||
from langchain_community.storage.cassandra import (
|
||||
CassandraByteStore,
|
||||
)
|
||||
from langchain_community.storage.mongodb import MongoDBByteStore, MongoDBStore
|
||||
from langchain_community.storage.redis import (
|
||||
RedisStore,
|
||||
)
|
||||
from langchain_community.storage.sql import (
|
||||
SQLStore,
|
||||
)
|
||||
from langchain_community.storage.upstash_redis import (
|
||||
UpstashRedisByteStore,
|
||||
UpstashRedisStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AstraDBByteStore",
|
||||
"AstraDBStore",
|
||||
"CassandraByteStore",
|
||||
"MongoDBStore",
|
||||
"MongoDBByteStore",
|
||||
"RedisStore",
|
||||
"SQLStore",
|
||||
"UpstashRedisByteStore",
|
||||
"UpstashRedisStore",
|
||||
]
|
||||
|
||||
_module_lookup = {
|
||||
"AstraDBByteStore": "langchain_community.storage.astradb",
|
||||
"AstraDBStore": "langchain_community.storage.astradb",
|
||||
"CassandraByteStore": "langchain_community.storage.cassandra",
|
||||
"MongoDBStore": "langchain_community.storage.mongodb",
|
||||
"MongoDBByteStore": "langchain_community.storage.mongodb",
|
||||
"RedisStore": "langchain_community.storage.redis",
|
||||
"SQLStore": "langchain_community.storage.sql",
|
||||
"UpstashRedisByteStore": "langchain_community.storage.upstash_redis",
|
||||
"UpstashRedisStore": "langchain_community.storage.upstash_redis",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
238
venv/Lib/site-packages/langchain_community/storage/astradb.py
Normal file
238
venv/Lib/site-packages/langchain_community/storage/astradb.py
Normal file
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
from langchain_community.utilities.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
||||
"""Base class for the DataStax AstraDB data store."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
@abstractmethod
|
||||
def decode_value(self, value: Any) -> Optional[V]:
|
||||
"""Decodes value from Astra DB"""
|
||||
|
||||
@abstractmethod
|
||||
def encode_value(self, value: Optional[V]) -> Any:
|
||||
"""Encodes value for Astra DB"""
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs_dict = {}
|
||||
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
|
||||
docs_dict[doc["_id"]] = doc.get("value")
|
||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs_dict = {}
|
||||
async for doc in self.async_collection.paginated_find(
|
||||
filter={"_id": {"$in": list(keys)}}
|
||||
):
|
||||
docs_dict[doc["_id"]] = doc.get("value")
|
||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
for k, v in key_value_pairs:
|
||||
self.collection.upsert({"_id": k, "value": self.encode_value(v)})
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
for k, v in key_value_pairs:
|
||||
await self.async_collection.upsert(
|
||||
{
|
||||
"_id": k,
|
||||
"value": self.encode_value(v),
|
||||
}
|
||||
)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs = self.collection.paginated_find()
|
||||
for doc in docs:
|
||||
key = doc["_id"]
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
async for doc in self.async_collection.paginated_find():
|
||||
key = doc["_id"]
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.22",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_astradb.AstraDBStore",
|
||||
)
|
||||
class AstraDBStore(AstraDBBaseStore[Any]):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
*,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
"""BaseStore implementation using DataStax AstraDB as the underlying store.
|
||||
|
||||
The value type can be any type serializable by json.dumps.
|
||||
Can be used to store embeddings with the CacheBackedEmbeddings.
|
||||
|
||||
Documents in the AstraDB collection will have the format
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"_id": "<key>",
|
||||
"value": <value>
|
||||
}
|
||||
|
||||
Args:
|
||||
collection_name: name of the Astra DB collection to create/use.
|
||||
token: API token for Astra DB usage.
|
||||
api_endpoint: full URL to the API endpoint,
|
||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
||||
astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace: namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
||||
OFF).
|
||||
pre_delete_collection: whether to delete the collection
|
||||
before creating it. If False and the collection already exists,
|
||||
the collection will be used as is.
|
||||
"""
|
||||
# Constructor doc is not inherited so we have to override it.
|
||||
super().__init__(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
def decode_value(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
def encode_value(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.22",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_astradb.AstraDBByteStore",
|
||||
)
|
||||
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
*,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
"""ByteStore implementation using DataStax AstraDB as the underlying store.
|
||||
|
||||
The bytes values are converted to base64 encoded strings
|
||||
Documents in the AstraDB collection will have the format
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"_id": "<key>",
|
||||
"value": "<byte64 string value>"
|
||||
}
|
||||
|
||||
Args:
|
||||
collection_name: name of the Astra DB collection to create/use.
|
||||
token: API token for Astra DB usage.
|
||||
api_endpoint: full URL to the API endpoint,
|
||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
||||
astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace: namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
||||
OFF).
|
||||
pre_delete_collection: whether to delete the collection
|
||||
before creating it. If False and the collection already exists,
|
||||
the collection will be used as is.
|
||||
"""
|
||||
# Constructor doc is not inherited so we have to override it.
|
||||
super().__init__(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
def decode_value(self, value: Any) -> Optional[bytes]:
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64decode(value)
|
||||
|
||||
def encode_value(self, value: Optional[bytes]) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(value).decode("ascii")
|
||||
220
venv/Lib/site-packages/langchain_community/storage/cassandra.py
Normal file
220
venv/Lib/site-packages/langchain_community/storage/cassandra.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from asyncio import InvalidStateError, Task
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core.stores import ByteStore
|
||||
|
||||
from langchain_community.utilities.cassandra import SetupMode, aexecute_cql
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cassandra.cluster import Session
|
||||
from cassandra.query import PreparedStatement
|
||||
|
||||
CREATE_TABLE_CQL_TEMPLATE = """
|
||||
CREATE TABLE IF NOT EXISTS {keyspace}.{table}
|
||||
(row_id TEXT, body_blob BLOB, PRIMARY KEY (row_id));
|
||||
"""
|
||||
SELECT_TABLE_CQL_TEMPLATE = (
|
||||
"""SELECT row_id, body_blob FROM {keyspace}.{table} WHERE row_id IN ?;"""
|
||||
)
|
||||
SELECT_ALL_TABLE_CQL_TEMPLATE = """SELECT row_id, body_blob FROM {keyspace}.{table};"""
|
||||
INSERT_TABLE_CQL_TEMPLATE = (
|
||||
"""INSERT INTO {keyspace}.{table} (row_id, body_blob) VALUES (?, ?);"""
|
||||
)
|
||||
DELETE_TABLE_CQL_TEMPLATE = """DELETE FROM {keyspace}.{table} WHERE row_id IN ?;"""
|
||||
|
||||
|
||||
class CassandraByteStore(ByteStore):
|
||||
"""A ByteStore implementation using Cassandra as the backend.
|
||||
|
||||
Parameters:
|
||||
table: The name of the table to use.
|
||||
session: A Cassandra session object. If not provided, it will be resolved
|
||||
from the cassio config.
|
||||
keyspace: The keyspace to use. If not provided, it will be resolved
|
||||
from the cassio config.
|
||||
setup_mode: The setup mode to use. Default is SYNC (SetupMode.SYNC).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: str,
|
||||
*,
|
||||
session: Optional[Session] = None,
|
||||
keyspace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
if not session or not keyspace:
|
||||
try:
|
||||
from cassio.config import check_resolve_keyspace, check_resolve_session
|
||||
|
||||
self.keyspace = keyspace or check_resolve_keyspace(keyspace)
|
||||
self.session = session or check_resolve_session()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent cassio package."
|
||||
"Please install it with `pip install --upgrade cassio`."
|
||||
)
|
||||
else:
|
||||
self.keyspace = keyspace
|
||||
self.session = session
|
||||
self.table = table
|
||||
self.select_statement = None
|
||||
self.insert_statement = None
|
||||
self.delete_statement = None
|
||||
|
||||
create_cql = CREATE_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace,
|
||||
table=self.table,
|
||||
)
|
||||
self.db_setup_task: Optional[Task[None]] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
self.db_setup_task = asyncio.create_task(
|
||||
aexecute_cql(self.session, create_cql)
|
||||
)
|
||||
else:
|
||||
self.session.execute(create_cql)
|
||||
|
||||
def ensure_db_setup(self) -> None:
|
||||
"""Ensure that the DB setup is finished. If not, raise a ValueError."""
|
||||
if self.db_setup_task:
|
||||
try:
|
||||
self.db_setup_task.result()
|
||||
except InvalidStateError:
|
||||
raise ValueError(
|
||||
"Asynchronous setup of the DB not finished. "
|
||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||
"event loop. Consider using their async equivalents."
|
||||
)
|
||||
|
||||
async def aensure_db_setup(self) -> None:
|
||||
"""Ensure that the DB setup is finished. If not, wait for it."""
|
||||
if self.db_setup_task:
|
||||
await self.db_setup_task
|
||||
|
||||
def get_select_statement(self) -> PreparedStatement:
|
||||
"""Get the prepared select statement for the table.
|
||||
If not available, prepare it.
|
||||
|
||||
Returns:
|
||||
PreparedStatement: The prepared statement.
|
||||
"""
|
||||
if not self.select_statement:
|
||||
self.select_statement = self.session.prepare(
|
||||
SELECT_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace, table=self.table
|
||||
)
|
||||
)
|
||||
return self.select_statement
|
||||
|
||||
def get_insert_statement(self) -> PreparedStatement:
|
||||
"""Get the prepared insert statement for the table.
|
||||
If not available, prepare it.
|
||||
|
||||
Returns:
|
||||
PreparedStatement: The prepared statement.
|
||||
"""
|
||||
if not self.insert_statement:
|
||||
self.insert_statement = self.session.prepare(
|
||||
INSERT_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace, table=self.table
|
||||
)
|
||||
)
|
||||
return self.insert_statement
|
||||
|
||||
def get_delete_statement(self) -> PreparedStatement:
|
||||
"""Get the prepared delete statement for the table.
|
||||
If not available, prepare it.
|
||||
|
||||
Returns:
|
||||
PreparedStatement: The prepared statement.
|
||||
"""
|
||||
|
||||
if not self.delete_statement:
|
||||
self.delete_statement = self.session.prepare(
|
||||
DELETE_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace, table=self.table
|
||||
)
|
||||
)
|
||||
return self.delete_statement
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
from cassandra.query import ValueSequence
|
||||
|
||||
self.ensure_db_setup()
|
||||
docs_dict = {}
|
||||
for row in self.session.execute(
|
||||
self.get_select_statement(), [ValueSequence(keys)]
|
||||
):
|
||||
docs_dict[row.row_id] = row.body_blob
|
||||
return [docs_dict.get(key) for key in keys]
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
from cassandra.query import ValueSequence
|
||||
|
||||
await self.aensure_db_setup()
|
||||
docs_dict = {}
|
||||
for row in await aexecute_cql(
|
||||
self.session, self.get_select_statement(), parameters=[ValueSequence(keys)]
|
||||
):
|
||||
docs_dict[row.row_id] = row.body_blob
|
||||
return [docs_dict.get(key) for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
self.ensure_db_setup()
|
||||
insert_statement = self.get_insert_statement()
|
||||
for k, v in key_value_pairs:
|
||||
self.session.execute(insert_statement, (k, v))
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
await self.aensure_db_setup()
|
||||
insert_statement = self.get_insert_statement()
|
||||
for k, v in key_value_pairs:
|
||||
await aexecute_cql(self.session, insert_statement, parameters=(k, v))
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
from cassandra.query import ValueSequence
|
||||
|
||||
self.ensure_db_setup()
|
||||
self.session.execute(self.get_delete_statement(), [ValueSequence(keys)])
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
from cassandra.query import ValueSequence
|
||||
|
||||
await self.aensure_db_setup()
|
||||
await aexecute_cql(
|
||||
self.session, self.get_delete_statement(), parameters=[ValueSequence(keys)]
|
||||
)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
self.ensure_db_setup()
|
||||
for row in self.session.execute(
|
||||
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace, table=self.table
|
||||
)
|
||||
):
|
||||
key = row.row_id
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
await self.aensure_db_setup()
|
||||
for row in await aexecute_cql(
|
||||
self.session,
|
||||
SELECT_ALL_TABLE_CQL_TEMPLATE.format(
|
||||
keyspace=self.keyspace, table=self.table
|
||||
),
|
||||
):
|
||||
key = row.row_id
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain_core.stores import InvalidKeyException
|
||||
|
||||
__all__ = ["InvalidKeyException"]
|
||||
248
venv/Lib/site-packages/langchain_community/storage/mongodb.py
Normal file
248
venv/Lib/site-packages/langchain_community/storage/mongodb.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from typing import Iterator, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
|
||||
class MongoDBByteStore(BaseStore[str, bytes]):
|
||||
"""BaseStore implementation using MongoDB as the underlying store.
|
||||
|
||||
Examples:
|
||||
Create a MongoDBByteStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Instantiate the MongoDBByteStore with a MongoDB connection
|
||||
from langchain_classic.storage import MongoDBByteStore
|
||||
|
||||
mongo_conn_str = "mongodb://localhost:27017/"
|
||||
mongodb_store = MongoDBBytesStore(mongo_conn_str, db_name="test-db",
|
||||
collection_name="test-collection")
|
||||
|
||||
# Set values for keys
|
||||
mongodb_store.mset([("key1", "hello"), ("key2", "workd")])
|
||||
|
||||
# Get values for keys
|
||||
values = mongodb_store.mget(["key1", "key2"])
|
||||
# [bytes1, bytes1]
|
||||
|
||||
# Iterate over keys
|
||||
for key in mongodb_store.yield_keys():
|
||||
print(key)
|
||||
|
||||
# Delete keys
|
||||
mongodb_store.mdelete(["key1", "key2"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
*,
|
||||
client_kwargs: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Initialize the MongoDBStore with a MongoDB connection string.
|
||||
|
||||
Args:
|
||||
connection_string (str): MongoDB connection string
|
||||
db_name (str): name to use
|
||||
collection_name (str): collection name to use
|
||||
client_kwargs (dict): Keyword arguments to pass to the Mongo client
|
||||
"""
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The MongoDBStore requires the pymongo library to be "
|
||||
"installed. "
|
||||
"pip install pymongo"
|
||||
) from e
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("connection_string must be provided.")
|
||||
if not db_name:
|
||||
raise ValueError("db_name must be provided.")
|
||||
if not collection_name:
|
||||
raise ValueError("collection_name must be provided.")
|
||||
|
||||
self.client: MongoClient = MongoClient(
|
||||
connection_string, **(client_kwargs or {})
|
||||
)
|
||||
self.collection = self.client[db_name][collection_name]
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
"""Get the list of documents associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (list[str]): A list of keys representing Document IDs..
|
||||
|
||||
Returns:
|
||||
list[Document]: A list of Documents corresponding to the provided
|
||||
keys, where each Document is either retrieved successfully or
|
||||
represented as None if not found.
|
||||
"""
|
||||
result = self.collection.find({"_id": {"$in": keys}})
|
||||
result_dict = {doc["_id"]: doc["value"] for doc in result}
|
||||
return [result_dict.get(key) for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
"""Set the given key-value pairs.
|
||||
|
||||
Args:
|
||||
key_value_pairs (list[tuple[str, Document]]): A list of id-document
|
||||
pairs.
|
||||
"""
|
||||
from pymongo import UpdateOne
|
||||
|
||||
updates = [{"_id": k, "value": v} for k, v in key_value_pairs]
|
||||
self.collection.bulk_write(
|
||||
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates]
|
||||
)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given ids.
|
||||
|
||||
Args:
|
||||
keys (list[str]): A list of keys representing Document IDs..
|
||||
"""
|
||||
self.collection.delete_many({"_id": {"$in": keys}})
|
||||
|
||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix of keys to retrieve.
|
||||
"""
|
||||
if prefix is None:
|
||||
for doc in self.collection.find(projection=["_id"]):
|
||||
yield doc["_id"]
|
||||
else:
|
||||
for doc in self.collection.find(
|
||||
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"]
|
||||
):
|
||||
yield doc["_id"]
|
||||
|
||||
|
||||
class MongoDBStore(BaseStore[str, Document]):
|
||||
"""BaseStore implementation using MongoDB as the underlying store.
|
||||
|
||||
Examples:
|
||||
Create a MongoDBStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Instantiate the MongoDBStore with a MongoDB connection
|
||||
from langchain_classic.storage import MongoDBStore
|
||||
|
||||
mongo_conn_str = "mongodb://localhost:27017/"
|
||||
mongodb_store = MongoDBStore(mongo_conn_str, db_name="test-db",
|
||||
collection_name="test-collection")
|
||||
|
||||
# Set values for keys
|
||||
doc1 = Document(...)
|
||||
doc2 = Document(...)
|
||||
mongodb_store.mset([("key1", doc1), ("key2", doc2)])
|
||||
|
||||
# Get values for keys
|
||||
values = mongodb_store.mget(["key1", "key2"])
|
||||
# [doc1, doc2]
|
||||
|
||||
# Iterate over keys
|
||||
for key in mongodb_store.yield_keys():
|
||||
print(key)
|
||||
|
||||
# Delete keys
|
||||
mongodb_store.mdelete(["key1", "key2"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
db_name: str,
|
||||
collection_name: str,
|
||||
*,
|
||||
client_kwargs: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Initialize the MongoDBStore with a MongoDB connection string.
|
||||
|
||||
Args:
|
||||
connection_string (str): MongoDB connection string
|
||||
db_name (str): name to use
|
||||
collection_name (str): collection name to use
|
||||
client_kwargs (dict): Keyword arguments to pass to the Mongo client
|
||||
"""
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The MongoDBStore requires the pymongo library to be "
|
||||
"installed. "
|
||||
"pip install pymongo"
|
||||
) from e
|
||||
|
||||
if not connection_string:
|
||||
raise ValueError("connection_string must be provided.")
|
||||
if not db_name:
|
||||
raise ValueError("db_name must be provided.")
|
||||
if not collection_name:
|
||||
raise ValueError("collection_name must be provided.")
|
||||
|
||||
self.client: MongoClient = MongoClient(
|
||||
connection_string, **(client_kwargs or {})
|
||||
)
|
||||
self.collection = self.client[db_name][collection_name]
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
|
||||
"""Get the list of documents associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (list[str]): A list of keys representing Document IDs..
|
||||
|
||||
Returns:
|
||||
list[Document]: A list of Documents corresponding to the provided
|
||||
keys, where each Document is either retrieved successfully or
|
||||
represented as None if not found.
|
||||
"""
|
||||
result = self.collection.find({"_id": {"$in": keys}})
|
||||
result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result}
|
||||
return [result_dict.get(key) for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
|
||||
"""Set the given key-value pairs.
|
||||
|
||||
Args:
|
||||
key_value_pairs (list[tuple[str, Document]]): A list of id-document
|
||||
pairs.
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from pymongo import UpdateOne
|
||||
|
||||
updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs]
|
||||
self.collection.bulk_write(
|
||||
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates]
|
||||
)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given ids.
|
||||
|
||||
Args:
|
||||
keys (list[str]): A list of keys representing Document IDs..
|
||||
"""
|
||||
self.collection.delete_many({"_id": {"$in": keys}})
|
||||
|
||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix of keys to retrieve.
|
||||
"""
|
||||
if prefix is None:
|
||||
for doc in self.collection.find(projection=["_id"]):
|
||||
yield doc["_id"]
|
||||
else:
|
||||
for doc in self.collection.find(
|
||||
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"]
|
||||
):
|
||||
yield doc["_id"]
|
||||
144
venv/Lib/site-packages/langchain_community/storage/redis.py
Normal file
144
venv/Lib/site-packages/langchain_community/storage/redis.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from langchain_core.stores import ByteStore
|
||||
|
||||
from langchain_community.utilities.redis import get_client
|
||||
|
||||
|
||||
class RedisStore(ByteStore):
|
||||
"""BaseStore implementation using Redis as the underlying store.
|
||||
|
||||
Examples:
|
||||
Create a RedisStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Instantiate the RedisStore with a Redis connection
|
||||
from langchain_community.storage import RedisStore
|
||||
from langchain_community.utilities.redis import get_client
|
||||
|
||||
client = get_client('redis://localhost:6379')
|
||||
redis_store = RedisStore(client=client)
|
||||
|
||||
# Set values for keys
|
||||
redis_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = redis_store.mget(["key1", "key2"])
|
||||
# [b"value1", b"value2"]
|
||||
|
||||
# Delete keys
|
||||
redis_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
for key in redis_store.yield_keys():
|
||||
print(key) # noqa: T201
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: Any = None,
|
||||
redis_url: Optional[str] = None,
|
||||
client_kwargs: Optional[dict] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the RedisStore with a Redis connection.
|
||||
|
||||
Must provide either a Redis client or a redis_url with optional client_kwargs.
|
||||
|
||||
Args:
|
||||
client: A Redis connection instance
|
||||
redis_url: redis url
|
||||
client_kwargs: Keyword arguments to pass to the Redis client
|
||||
ttl: time to expire keys in seconds if provided,
|
||||
if None keys will never expire
|
||||
namespace: if provided, all keys will be prefixed with this namespace
|
||||
"""
|
||||
try:
|
||||
from redis import Redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The RedisStore requires the redis library to be installed. "
|
||||
"pip install redis"
|
||||
) from e
|
||||
|
||||
if client and (redis_url or client_kwargs):
|
||||
raise ValueError(
|
||||
"Either a Redis client or a redis_url with optional client_kwargs "
|
||||
"must be provided, but not both."
|
||||
)
|
||||
|
||||
if not client and not redis_url:
|
||||
raise ValueError("Either a Redis client or a redis_url must be provided.")
|
||||
|
||||
if client:
|
||||
if not isinstance(client, Redis):
|
||||
raise TypeError(
|
||||
f"Expected Redis client, got {type(client).__name__} instead."
|
||||
)
|
||||
_client = client
|
||||
else:
|
||||
if not redis_url:
|
||||
raise ValueError(
|
||||
"Either a Redis client or a redis_url must be provided."
|
||||
)
|
||||
_client = get_client(redis_url, **(client_kwargs or {}))
|
||||
|
||||
self.client = _client
|
||||
|
||||
if not isinstance(ttl, int) and ttl is not None:
|
||||
raise TypeError(f"Expected int or None, got {type(ttl)=} instead.")
|
||||
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
|
||||
def _get_prefixed_key(self, key: str) -> str:
|
||||
"""Get the key with the namespace prefix.
|
||||
|
||||
Args:
|
||||
key (str): The original key.
|
||||
|
||||
Returns:
|
||||
str: The key with the namespace prefix.
|
||||
"""
|
||||
delimiter = "/"
|
||||
if self.namespace:
|
||||
return f"{self.namespace}{delimiter}{key}"
|
||||
return key
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
return cast(
|
||||
List[Optional[bytes]],
|
||||
self.client.mget([self._get_prefixed_key(key) for key in keys]),
|
||||
)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
"""Set the given key-value pairs."""
|
||||
pipe = self.client.pipeline()
|
||||
|
||||
for key, value in key_value_pairs:
|
||||
pipe.set(self._get_prefixed_key(key), value, ex=self.ttl)
|
||||
pipe.execute()
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys."""
|
||||
_keys = [self._get_prefixed_key(key) for key in keys]
|
||||
self.client.delete(*_keys)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store."""
|
||||
if prefix:
|
||||
pattern = self._get_prefixed_key(prefix)
|
||||
else:
|
||||
pattern = self._get_prefixed_key("*")
|
||||
scan_iter = cast(Iterator[bytes], self.client.scan_iter(match=pattern))
|
||||
for key in scan_iter:
|
||||
decoded_key = key.decode("utf-8")
|
||||
if self.namespace:
|
||||
relative_key = decoded_key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield decoded_key
|
||||
295
venv/Lib/site-packages/langchain_community/storage/sql.py
Normal file
295
venv/Lib/site-packages/langchain_community/storage/sql.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from sqlalchemy import (
|
||||
LargeBinary,
|
||||
Text,
|
||||
and_,
|
||||
create_engine,
|
||||
delete,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
Mapped,
|
||||
Session,
|
||||
declarative_base,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import mapped_column
|
||||
|
||||
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
|
||||
"""Table used to save values."""
|
||||
|
||||
# ATTENTION:
|
||||
# Prior to modifying this table, please determine whether
|
||||
# we should create migrations for this table to make sure
|
||||
# users do not experience data loss.
|
||||
__tablename__ = "langchain_key_value_stores"
|
||||
|
||||
namespace: Mapped[str] = mapped_column(
|
||||
primary_key=True, index=True, nullable=False
|
||||
)
|
||||
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
|
||||
value = mapped_column(LargeBinary, index=False, nullable=False)
|
||||
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
from sqlalchemy import Column
|
||||
|
||||
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc,no-redef]
|
||||
"""Table used to save values."""
|
||||
|
||||
# ATTENTION:
|
||||
# Prior to modifying this table, please determine whether
|
||||
# we should create migrations for this table to make sure
|
||||
# users do not experience data loss.
|
||||
__tablename__ = "langchain_key_value_stores"
|
||||
|
||||
namespace = Column(Text(), primary_key=True, index=True, nullable=False)
|
||||
key = Column(Text(), primary_key=True, index=True, nullable=False)
|
||||
value = Column(LargeBinary, index=False, nullable=False)
|
||||
|
||||
|
||||
def items_equal(x: Any, y: Any) -> bool:
|
||||
return x == y
|
||||
|
||||
|
||||
# This is a fix of original SQLStore.
|
||||
# This can will be removed when a PR will be merged.
|
||||
class SQLStore(BaseStore[str, bytes]):
|
||||
"""BaseStore interface that works on an SQL database.
|
||||
|
||||
Examples:
|
||||
Create a SQLStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.storage import SQLStore
|
||||
|
||||
# Instantiate the SQLStore with the root path
|
||||
sql_store = SQLStore(namespace="test", db_url="sqlite://:memory:")
|
||||
|
||||
# Set values for keys
|
||||
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||
|
||||
# Delete keys
|
||||
sql_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
for key in sql_store.yield_keys():
|
||||
print(key)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
namespace: str,
|
||||
db_url: Optional[Union[str, Path]] = None,
|
||||
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
async_mode: Optional[bool] = None,
|
||||
):
|
||||
if db_url is None and engine is None:
|
||||
raise ValueError("Must specify either db_url or engine")
|
||||
|
||||
if db_url is not None and engine is not None:
|
||||
raise ValueError("Must specify either db_url or engine, not both")
|
||||
|
||||
_engine: Union[Engine, AsyncEngine]
|
||||
if db_url:
|
||||
if async_mode is None:
|
||||
async_mode = False
|
||||
if async_mode:
|
||||
_engine = create_async_engine(
|
||||
url=str(db_url),
|
||||
**(engine_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
|
||||
elif engine:
|
||||
_engine = engine
|
||||
|
||||
else:
|
||||
raise AssertionError("Something went wrong with configuration of engine.")
|
||||
|
||||
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
||||
if isinstance(_engine, AsyncEngine):
|
||||
self.async_mode = True
|
||||
_session_maker = async_sessionmaker(bind=_engine)
|
||||
else:
|
||||
self.async_mode = False
|
||||
_session_maker = sessionmaker(bind=_engine)
|
||||
|
||||
self.engine = _engine
|
||||
self.dialect = _engine.dialect.name
|
||||
self.session_maker = _session_maker
|
||||
self.namespace = namespace
|
||||
|
||||
def create_schema(self) -> None:
|
||||
Base.metadata.create_all(self.engine) # problem in sqlalchemy v1
|
||||
# sqlalchemy.exc.CompileError: (in table 'langchain_key_value_stores',
|
||||
# column 'namespace'): Can't generate DDL for NullType(); did you forget
|
||||
# to specify a type on this Column?
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
assert isinstance(self.engine, AsyncEngine)
|
||||
async with self.engine.begin() as session:
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
|
||||
def drop(self) -> None:
|
||||
Base.metadata.drop_all(bind=self.engine.connect())
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
assert isinstance(self.engine, AsyncEngine)
|
||||
result: Dict[str, bytes] = {}
|
||||
async with self._make_async_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
for v in await session.scalars(stmt):
|
||||
result[v.key] = v.value
|
||||
return [result.get(key) for key in keys]
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
result = {}
|
||||
|
||||
with self._make_sync_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
for v in session.scalars(stmt):
|
||||
result[v.key] = v.value
|
||||
return [result.get(key) for key in keys]
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
async with self._make_async_session() as session:
|
||||
await self._amdelete([key for key, _ in key_value_pairs], session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
values: Dict[str, bytes] = dict(key_value_pairs)
|
||||
with self._make_sync_session() as session:
|
||||
self._mdelete(list(values.keys()), session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
||||
for k, v in values.items()
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
||||
stmt = delete(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
|
||||
stmt = delete(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
LangchainKeyValueStores.key.in_(keys),
|
||||
LangchainKeyValueStores.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
with self._make_sync_session() as session:
|
||||
self._mdelete(keys, session)
|
||||
session.commit()
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
async with self._make_async_session() as session:
|
||||
await self._amdelete(keys, session)
|
||||
await session.commit()
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
with self._make_sync_session() as session:
|
||||
for v in session.query(LangchainKeyValueStores).filter(
|
||||
LangchainKeyValueStores.namespace == self.namespace
|
||||
):
|
||||
if str(v.key).startswith(prefix or ""):
|
||||
yield str(v.key)
|
||||
session.close()
|
||||
|
||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
async with self._make_async_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
LangchainKeyValueStores.namespace == self.namespace
|
||||
)
|
||||
for v in await session.scalars(stmt):
|
||||
if str(v.key).startswith(prefix or ""):
|
||||
yield str(v.key)
|
||||
await session.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_sync_session(self) -> Generator[Session, None, None]:
|
||||
"""Make an async session."""
|
||||
if self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use a sync method in when async mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
with cast(Session, self.session_maker()) as session:
|
||||
yield cast(Session, session)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Make an async session."""
|
||||
if not self.async_mode:
|
||||
raise ValueError(
|
||||
"Attempting to use an async method in when sync mode is turned on. "
|
||||
"Please use the corresponding async method instead."
|
||||
)
|
||||
async with cast(AsyncSession, self.session_maker()) as session:
|
||||
yield cast(AsyncSession, session)
|
||||
@@ -0,0 +1,174 @@
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
|
||||
class _UpstashRedisStore(BaseStore[str, str]):
|
||||
"""BaseStore implementation using Upstash Redis as the underlying store."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: Any = None,
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the UpstashRedisStore with HTTP API.
|
||||
|
||||
Must provide either an Upstash Redis client or a url.
|
||||
|
||||
Args:
|
||||
client: An Upstash Redis instance
|
||||
url: UPSTASH_REDIS_REST_URL
|
||||
token: UPSTASH_REDIS_REST_TOKEN
|
||||
ttl: time to expire keys in seconds if provided,
|
||||
if None keys will never expire
|
||||
namespace: if provided, all keys will be prefixed with this namespace
|
||||
"""
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"UpstashRedisStore requires the upstash_redis library to be installed. "
|
||||
"pip install upstash_redis"
|
||||
) from e
|
||||
|
||||
if client and url:
|
||||
raise ValueError(
|
||||
"Either an Upstash Redis client or a url must be provided, not both."
|
||||
)
|
||||
|
||||
if client:
|
||||
if not isinstance(client, Redis):
|
||||
raise TypeError(
|
||||
f"Expected Upstash Redis client, got {type(client).__name__}."
|
||||
)
|
||||
_client = client
|
||||
else:
|
||||
if not url or not token:
|
||||
raise ValueError(
|
||||
"Either an Upstash Redis client or url and token must be provided."
|
||||
)
|
||||
_client = Redis(url=url, token=token)
|
||||
|
||||
self.client = _client
|
||||
|
||||
if not isinstance(ttl, int) and ttl is not None:
|
||||
raise TypeError(f"Expected int or None, got {type(ttl)} instead.")
|
||||
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
|
||||
def _get_prefixed_key(self, key: str) -> str:
|
||||
"""Get the key with the namespace prefix.
|
||||
|
||||
Args:
|
||||
key (str): The original key.
|
||||
|
||||
Returns:
|
||||
str: The key with the namespace prefix.
|
||||
"""
|
||||
delimiter = "/"
|
||||
if self.namespace:
|
||||
return f"{self.namespace}{delimiter}{key}"
|
||||
return key
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[str]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
|
||||
keys = [self._get_prefixed_key(key) for key in keys]
|
||||
return cast(
|
||||
List[Optional[str]],
|
||||
self.client.mget(*keys),
|
||||
)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, str]]) -> None:
|
||||
"""Set the given key-value pairs."""
|
||||
for key, value in key_value_pairs:
|
||||
self.client.set(self._get_prefixed_key(key), value, ex=self.ttl)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys."""
|
||||
_keys = [self._get_prefixed_key(key) for key in keys]
|
||||
self.client.delete(*_keys)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store."""
|
||||
if prefix:
|
||||
pattern = self._get_prefixed_key(prefix)
|
||||
else:
|
||||
pattern = self._get_prefixed_key("*")
|
||||
|
||||
cursor, keys = self.client.scan(0, match=pattern)
|
||||
for key in keys:
|
||||
if self.namespace:
|
||||
relative_key = key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield key
|
||||
|
||||
while cursor != 0:
|
||||
cursor, keys = self.client.scan(cursor, match=pattern)
|
||||
for key in keys:
|
||||
if self.namespace:
|
||||
relative_key = key[len(self.namespace) + 1 :]
|
||||
yield relative_key
|
||||
else:
|
||||
yield key
|
||||
|
||||
|
||||
@deprecated("0.0.1", alternative="UpstashRedisByteStore")
|
||||
class UpstashRedisStore(_UpstashRedisStore):
|
||||
"""
|
||||
BaseStore implementation using Upstash Redis
|
||||
as the underlying store to store strings.
|
||||
|
||||
Deprecated in favor of the more generic UpstashRedisByteStore.
|
||||
"""
|
||||
|
||||
|
||||
class UpstashRedisByteStore(ByteStore):
|
||||
"""
|
||||
BaseStore implementation using Upstash Redis
|
||||
as the underlying store to store raw bytes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client: Any = None,
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
self.underlying_store = _UpstashRedisStore(
|
||||
client=client, url=url, token=token, ttl=ttl, namespace=namespace
|
||||
)
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
return [
|
||||
value.encode("utf-8") if value is not None else None
|
||||
for value in self.underlying_store.mget(keys)
|
||||
]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
"""Set the given key-value pairs."""
|
||||
self.underlying_store.mset(
|
||||
[
|
||||
(k, v.decode("utf-8")) if v is not None else None
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys."""
|
||||
self.underlying_store.mdelete(keys)
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Yield keys in the store."""
|
||||
yield from self.underlying_store.yield_keys(prefix=prefix)
|
||||
Reference in New Issue
Block a user