initial commit
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
"""**Index** is used to avoid writing duplicated content
|
||||
into the vectostore and to avoid over-writing content if it's unchanged.
|
||||
|
||||
Indexes also :
|
||||
|
||||
* Create knowledge graphs from data.
|
||||
|
||||
* Support indexing workflows from LangChain data loaders to vectorstores.
|
||||
|
||||
Importantly, Index keeps on working even if the content being written is derived
|
||||
via a set of transformations from some source content (e.g., indexing children
|
||||
documents that were derived from parent documents by chunking.)
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,237 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_community.indexes.base import RecordManager
|
||||
|
||||
IMPORT_PYMONGO_ERROR = (
|
||||
"Could not import MongoClient. Please install it with `pip install pymongo`."
|
||||
)
|
||||
IMPORT_MOTOR_ASYNCIO_ERROR = (
|
||||
"Could not import AsyncIOMotorClient. Please install it with `pip install motor`."
|
||||
)
|
||||
|
||||
|
||||
def _import_pymongo() -> Any:
|
||||
"""Import PyMongo if available, otherwise raise error."""
|
||||
try:
|
||||
from pymongo import MongoClient
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_PYMONGO_ERROR)
|
||||
return MongoClient
|
||||
|
||||
|
||||
def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any:
|
||||
"""Get MongoClient for sync operations from the mongodb_url,
|
||||
otherwise raise error."""
|
||||
try:
|
||||
pymongo = _import_pymongo()
|
||||
client = pymongo(mongodb_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"MongoClient string provided is not in proper format. Got error: {e} "
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def _import_motor_asyncio() -> Any:
|
||||
"""Import Motor if available, otherwise raise error."""
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_MOTOR_ASYNCIO_ERROR)
|
||||
return AsyncIOMotorClient
|
||||
|
||||
|
||||
def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any:
|
||||
"""Get AsyncIOMotorClient for async operations from the mongodb_url,
|
||||
otherwise raise error."""
|
||||
try:
|
||||
motor = _import_motor_asyncio()
|
||||
client = motor(mongodb_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"AsyncIOMotorClient string provided is not in proper format. "
|
||||
f"Got error: {e} "
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class MongoDocumentManager(RecordManager):
|
||||
"""A MongoDB based implementation of the document manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
*,
|
||||
mongodb_url: str,
|
||||
db_name: str,
|
||||
collection_name: str = "documentMetadata",
|
||||
) -> None:
|
||||
"""Initialize the MongoDocumentManager.
|
||||
|
||||
Args:
|
||||
namespace: The namespace associated with this document manager.
|
||||
db_name: The name of the database to use.
|
||||
collection_name: The name of the collection to use.
|
||||
Default is 'documentMetadata'.
|
||||
"""
|
||||
super().__init__(namespace=namespace)
|
||||
self.sync_client = _get_pymongo_client(mongodb_url)
|
||||
self.sync_db = self.sync_client[db_name]
|
||||
self.sync_collection = self.sync_db[collection_name]
|
||||
self.async_client = _get_motor_client(mongodb_url)
|
||||
self.async_db = self.async_client[db_name]
|
||||
self.async_collection = self.async_db[collection_name]
|
||||
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema for the document manager."""
|
||||
pass
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema for the document manager."""
|
||||
pass
|
||||
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert documents into the MongoDB collection."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError("Number of keys does not match number of group_ids")
|
||||
|
||||
for key, group_id in zip(keys, group_ids):
|
||||
self.sync_collection.find_one_and_update(
|
||||
{"namespace": self.namespace, "key": key},
|
||||
{"$set": {"group_id": group_id, "updated_at": self.get_time()}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Asynchronously upsert documents into the MongoDB collection."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError("Number of keys does not match number of group_ids")
|
||||
|
||||
update_time = await self.aget_time()
|
||||
if time_at_least and update_time < time_at_least:
|
||||
raise ValueError("Server time is behind the expected time_at_least")
|
||||
|
||||
for key, group_id in zip(keys, group_ids):
|
||||
await self.async_collection.find_one_and_update(
|
||||
{"namespace": self.namespace, "key": key},
|
||||
{"$set": {"group_id": group_id, "updated_at": update_time}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a timestamp."""
|
||||
server_info = self.sync_db.command("hostInfo")
|
||||
local_time = server_info["system"]["currentTime"]
|
||||
timestamp = local_time.timestamp()
|
||||
return timestamp
|
||||
|
||||
async def aget_time(self) -> float:
|
||||
"""Asynchronously get the current server time as a timestamp."""
|
||||
host_info = await self.async_collection.database.command("hostInfo")
|
||||
local_time = host_info["system"]["currentTime"]
|
||||
return local_time.timestamp()
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the MongoDB collection."""
|
||||
existing_keys = {
|
||||
doc["key"]
|
||||
for doc in self.sync_collection.find(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
|
||||
)
|
||||
}
|
||||
return [key in existing_keys for key in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Asynchronously check if the given keys exist in the MongoDB collection."""
|
||||
cursor = self.async_collection.find(
|
||||
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
|
||||
)
|
||||
existing_keys = {doc["key"] async for doc in cursor}
|
||||
return [key in existing_keys for key in keys]
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List documents in the MongoDB collection based on the provided date range."""
|
||||
query: Dict[str, Any] = {"namespace": self.namespace}
|
||||
if before:
|
||||
query["updated_at"] = {"$lt": before}
|
||||
if after:
|
||||
query["updated_at"] = {"$gt": after}
|
||||
if group_ids:
|
||||
query["group_id"] = {"$in": group_ids}
|
||||
|
||||
cursor = (
|
||||
self.sync_collection.find(query, {"key": 1}).limit(limit)
|
||||
if limit
|
||||
else self.sync_collection.find(query, {"key": 1})
|
||||
)
|
||||
return [doc["key"] for doc in cursor]
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Asynchronously list documents in the MongoDB collection
|
||||
based on the provided date range.
|
||||
"""
|
||||
query: Dict[str, Any] = {"namespace": self.namespace}
|
||||
if before:
|
||||
query["updated_at"] = {"$lt": before}
|
||||
if after:
|
||||
query["updated_at"] = {"$gt": after}
|
||||
if group_ids:
|
||||
query["group_id"] = {"$in": group_ids}
|
||||
|
||||
cursor = (
|
||||
self.async_collection.find(query, {"key": 1}).limit(limit)
|
||||
if limit
|
||||
else self.async_collection.find(query, {"key": 1})
|
||||
)
|
||||
return [doc["key"] async for doc in cursor]
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete documents from the MongoDB collection."""
|
||||
self.sync_collection.delete_many(
|
||||
{
|
||||
"namespace": self.namespace,
|
||||
"key": {"$in": keys},
|
||||
}
|
||||
)
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Asynchronously delete documents from the MongoDB collection."""
|
||||
await self.async_collection.delete_many(
|
||||
{
|
||||
"namespace": self.namespace,
|
||||
"key": {"$in": keys},
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,525 @@
|
||||
"""Implementation of a record management layer in SQLAlchemy.
|
||||
|
||||
The management layer uses SQLAlchemy to track upserted records.
|
||||
|
||||
Currently, this layer only works with SQLite; hopwever, should be adaptable
|
||||
to other SQL implementations with minimal effort.
|
||||
|
||||
Currently, includes an implementation that uses SQLAlchemy which should
|
||||
allow it to work with a variety of SQL as a backend.
|
||||
|
||||
* Each key is associated with an updated_at field.
|
||||
* This filed is updated whenever the key is updated.
|
||||
* Keys can be listed based on the updated at field.
|
||||
* Keys can be deleted.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import decimal
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Float,
|
||||
Index,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
create_engine,
|
||||
delete,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import URL, Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
except ImportError:
|
||||
# dummy for sqlalchemy < 2
|
||||
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
|
||||
|
||||
from langchain_community.indexes.base import RecordManager
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class UpsertionRecord(Base): # type: ignore[valid-type,misc]
|
||||
"""Table used to keep track of when a key was last updated."""
|
||||
|
||||
# ATTENTION:
|
||||
# Prior to modifying this table, please determine whether
|
||||
# we should create migrations for this table to make sure
|
||||
# users do not experience data loss.
|
||||
__tablename__ = "upsertion_record"
|
||||
|
||||
uuid = Column(
|
||||
String,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
)
|
||||
key = Column(String, index=True)
|
||||
# Using a non-normalized representation to handle `namespace` attribute.
|
||||
# If the need arises, this attribute can be pulled into a separate Collection
|
||||
# table at some time later.
|
||||
namespace = Column(String, index=True, nullable=False)
|
||||
group_id = Column(String, index=True, nullable=True)
|
||||
|
||||
# The timestamp associated with the last record upsertion.
|
||||
updated_at = Column(Float, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("key", "namespace", name="uix_key_namespace"),
|
||||
Index("ix_key_namespace", "key", "namespace"),
|
||||
)
|
||||
|
||||
|
||||
class SQLRecordManager(RecordManager):
|
||||
"""A SQL Alchemy based implementation of the record manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
*,
|
||||
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
||||
db_url: Union[None, str, URL] = None,
|
||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
async_mode: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the SQLRecordManager.
|
||||
|
||||
This class serves as a manager persistence layer that uses an SQL
|
||||
backend to track upserted records. You should specify either a db_url
|
||||
to create an engine or provide an existing engine.
|
||||
|
||||
Args:
|
||||
namespace: The namespace associated with this record manager.
|
||||
engine: An already existing SQL Alchemy engine.
|
||||
Default is None.
|
||||
db_url: A database connection string used to create
|
||||
an SQL Alchemy engine. Default is None.
|
||||
engine_kwargs: Additional keyword arguments
|
||||
to be passed when creating the engine. Default is an empty dictionary.
|
||||
async_mode: Whether to create an async engine.
|
||||
Driver should support async operations.
|
||||
It only applies if db_url is provided.
|
||||
Default is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If both db_url and engine are provided or neither.
|
||||
AssertionError: If something unexpected happens during engine configuration.
|
||||
"""
|
||||
super().__init__(namespace=namespace)
|
||||
if db_url is None and engine is None:
|
||||
raise ValueError("Must specify either db_url or engine")
|
||||
|
||||
if db_url is not None and engine is not None:
|
||||
raise ValueError("Must specify either db_url or engine, not both")
|
||||
|
||||
_engine: Union[Engine, AsyncEngine]
|
||||
if db_url:
|
||||
if async_mode:
|
||||
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
|
||||
else:
|
||||
_engine = create_engine(db_url, **(engine_kwargs or {}))
|
||||
elif engine:
|
||||
_engine = engine
|
||||
|
||||
else:
|
||||
raise AssertionError("Something went wrong with configuration of engine.")
|
||||
|
||||
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
||||
if isinstance(_engine, AsyncEngine):
|
||||
_session_factory = async_sessionmaker(bind=_engine)
|
||||
else:
|
||||
_session_factory = sessionmaker(bind=_engine)
|
||||
|
||||
self.engine = _engine
|
||||
self.dialect = _engine.dialect.name
|
||||
self.session_factory = _session_factory
|
||||
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema."""
|
||||
if isinstance(self.engine, AsyncEngine):
|
||||
raise AssertionError("This method is not supported for async engines.")
|
||||
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema."""
|
||||
|
||||
if not isinstance(self.engine, AsyncEngine):
|
||||
raise AssertionError("This method is not supported for sync engines.")
|
||||
|
||||
async with self.engine.begin() as session:
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_session(self) -> Generator[Session, None, None]:
|
||||
"""Create a session and close it after use."""
|
||||
|
||||
if isinstance(self.session_factory, async_sessionmaker):
|
||||
raise AssertionError("This method is not supported for async engines.")
|
||||
|
||||
session = self.session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a session and close it after use."""
|
||||
|
||||
if not isinstance(self.engine, AsyncEngine):
|
||||
raise AssertionError("This method is not supported for sync engines.")
|
||||
|
||||
async with cast(AsyncSession, self.session_factory()) as session:
|
||||
yield session
|
||||
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a timestamp.
|
||||
|
||||
Please note it's critical that time is obtained from the server since
|
||||
we want a monotonic clock.
|
||||
"""
|
||||
with self._make_session() as session:
|
||||
# * SQLite specific implementation, can be changed based on dialect.
|
||||
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
|
||||
# ----
|
||||
# julianday('now'): Julian day number for the current date and time.
|
||||
# The Julian day is a continuous count of days, starting from a
|
||||
# reference date (Julian day number 0).
|
||||
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
||||
# 86400.0 - constant represents the number of seconds
|
||||
# in a day (24 hours * 60 minutes * 60 seconds)
|
||||
if self.dialect == "sqlite":
|
||||
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
||||
elif self.dialect == "postgresql":
|
||||
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
||||
|
||||
dt = session.execute(query).scalar()
|
||||
if isinstance(dt, decimal.Decimal):
|
||||
dt = float(dt)
|
||||
if not isinstance(dt, float):
|
||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||
return dt
|
||||
|
||||
async def aget_time(self) -> float:
|
||||
"""Get the current server time as a timestamp.
|
||||
|
||||
Please note it's critical that time is obtained from the server since
|
||||
we want a monotonic clock.
|
||||
"""
|
||||
async with self._amake_session() as session:
|
||||
# * SQLite specific implementation, can be changed based on dialect.
|
||||
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
|
||||
# ----
|
||||
# julianday('now'): Julian day number for the current date and time.
|
||||
# The Julian day is a continuous count of days, starting from a
|
||||
# reference date (Julian day number 0).
|
||||
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
||||
# 86400.0 - constant represents the number of seconds
|
||||
# in a day (24 hours * 60 minutes * 60 seconds)
|
||||
if self.dialect == "sqlite":
|
||||
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
||||
elif self.dialect == "postgresql":
|
||||
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
||||
|
||||
dt = (await session.execute(query)).scalar_one_or_none()
|
||||
|
||||
if isinstance(dt, decimal.Decimal):
|
||||
dt = float(dt)
|
||||
if not isinstance(dt, float):
|
||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||
return dt
|
||||
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the SQLite database."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError(
|
||||
f"Number of keys ({len(keys)}) does not match number of "
|
||||
f"group_ids ({len(group_ids)})"
|
||||
)
|
||||
|
||||
# Get the current time from the server.
|
||||
# This makes an extra round trip to the server, should not be a big deal
|
||||
# if the batch size is large enough.
|
||||
# Getting the time here helps us compare it against the time_at_least
|
||||
# and raise an error if there is a time sync issue.
|
||||
# Here, we're just being extra careful to minimize the chance of
|
||||
# data loss due to incorrectly deleting records.
|
||||
update_time = self.get_time()
|
||||
|
||||
if time_at_least and update_time < time_at_least:
|
||||
# Safeguard against time sync issues
|
||||
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
|
||||
|
||||
records_to_upsert = [
|
||||
{
|
||||
"key": key,
|
||||
"namespace": self.namespace,
|
||||
"updated_at": update_time,
|
||||
"group_id": group_id,
|
||||
}
|
||||
for key, group_id in zip(keys, group_ids)
|
||||
]
|
||||
|
||||
with self._make_session() as session:
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||
stmt = insert_stmt.on_conflict_do_update(
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the SQLite database."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError(
|
||||
f"Number of keys ({len(keys)}) does not match number of "
|
||||
f"group_ids ({len(group_ids)})"
|
||||
)
|
||||
|
||||
# Get the current time from the server.
|
||||
# This makes an extra round trip to the server, should not be a big deal
|
||||
# if the batch size is large enough.
|
||||
# Getting the time here helps us compare it against the time_at_least
|
||||
# and raise an error if there is a time sync issue.
|
||||
# Here, we're just being extra careful to minimize the chance of
|
||||
# data loss due to incorrectly deleting records.
|
||||
update_time = await self.aget_time()
|
||||
|
||||
if time_at_least and update_time < time_at_least:
|
||||
# Safeguard against time sync issues
|
||||
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
|
||||
|
||||
records_to_upsert = [
|
||||
{
|
||||
"key": key,
|
||||
"namespace": self.namespace,
|
||||
"updated_at": update_time,
|
||||
"group_id": group_id,
|
||||
}
|
||||
for key, group_id in zip(keys, group_ids)
|
||||
]
|
||||
|
||||
async with self._amake_session() as session:
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update(
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) # type: ignore[assignment]
|
||||
stmt = insert_stmt.on_conflict_do_update(
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at,
|
||||
group_id=insert_stmt.excluded.group_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the SQLite database."""
|
||||
with self._make_session() as session:
|
||||
records = (
|
||||
# mypy does not recognize .all()
|
||||
session.query(UpsertionRecord.key)
|
||||
.filter(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
found_keys = set(r.key for r in records)
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the SQLite database."""
|
||||
async with self._amake_session() as session:
|
||||
records = (
|
||||
(
|
||||
await session.execute(
|
||||
select(UpsertionRecord.key).where(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
found_keys = set(records)
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the SQLite database based on the provided date range."""
|
||||
with self._make_session() as session:
|
||||
query = session.query(UpsertionRecord).filter(
|
||||
UpsertionRecord.namespace == self.namespace
|
||||
)
|
||||
|
||||
# mypy does not recognize .all() or .filter()
|
||||
if after:
|
||||
query = query.filter(UpsertionRecord.updated_at > after)
|
||||
if before:
|
||||
query = query.filter(UpsertionRecord.updated_at < before)
|
||||
if group_ids:
|
||||
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
records = query.all()
|
||||
return [r.key for r in records] # type: ignore[misc]
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the SQLite database based on the provided date range."""
|
||||
async with self._amake_session() as session:
|
||||
query = select(UpsertionRecord.key).filter(
|
||||
UpsertionRecord.namespace == self.namespace
|
||||
)
|
||||
|
||||
# mypy does not recognize .all() or .filter()
|
||||
if after:
|
||||
query = query.filter(UpsertionRecord.updated_at > after)
|
||||
if before:
|
||||
query = query.filter(UpsertionRecord.updated_at < before)
|
||||
if group_ids:
|
||||
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
records = (await session.execute(query)).scalars().all()
|
||||
return list(records)
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete records from the SQLite database."""
|
||||
with self._make_session() as session:
|
||||
# mypy does not recognize .delete()
|
||||
session.query(UpsertionRecord).filter(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete records from the SQLite database."""
|
||||
async with self._amake_session() as session:
|
||||
await session.execute(
|
||||
delete(UpsertionRecord).where(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
172
venv/Lib/site-packages/langchain_community/indexes/base.py
Normal file
172
venv/Lib/site-packages/langchain_community/indexes/base.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
NAMESPACE_UUID = uuid.UUID(int=1984)
|
||||
|
||||
|
||||
class RecordManager(ABC):
|
||||
"""Abstract base class for a record manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
) -> None:
|
||||
"""Initialize the record manager.
|
||||
|
||||
Args:
|
||||
namespace (str): The namespace for the record manager.
|
||||
"""
|
||||
self.namespace = namespace
|
||||
|
||||
@abstractmethod
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema for the record manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema for the record manager."""
|
||||
|
||||
@abstractmethod
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!
|
||||
|
||||
It's important to get this from the server to ensure a monotonic clock,
|
||||
otherwise there may be data loss when cleaning up old documents!
|
||||
|
||||
Returns:
|
||||
The current server time as a float timestamp.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!
|
||||
|
||||
It's important to get this from the server to ensure a monotonic clock,
|
||||
otherwise there may be data loss when cleaning up old documents!
|
||||
|
||||
Returns:
|
||||
The current server time as a float timestamp.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the database.
|
||||
|
||||
Args:
|
||||
keys: A list of record keys to upsert.
|
||||
group_ids: A list of group IDs corresponding to the keys.
|
||||
time_at_least: if provided, updates should only happen if the
|
||||
updated_at field is at least this time.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of keys doesn't match the length of group_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the database.
|
||||
|
||||
Args:
|
||||
keys: A list of record keys to upsert.
|
||||
group_ids: A list of group IDs corresponding to the keys.
|
||||
time_at_least: if provided, updates should only happen if the
|
||||
updated_at field is at least this time.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of keys doesn't match the length of group_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to check.
|
||||
|
||||
Returns:
|
||||
A list of boolean values indicating the existence of each key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to check.
|
||||
|
||||
Returns:
|
||||
A list of boolean values indicating the existence of each key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
before: Filter to list records updated before this time.
|
||||
after: Filter to list records updated after this time.
|
||||
group_ids: Filter to list records with specific group IDs.
|
||||
limit: optional limit on the number of records to return.
|
||||
|
||||
Returns:
|
||||
A list of keys for the matching records.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
before: Filter to list records updated before this time.
|
||||
after: Filter to list records updated after this time.
|
||||
group_ids: Filter to list records with specific group IDs.
|
||||
limit: optional limit on the number of records to return.
|
||||
|
||||
Returns:
|
||||
A list of keys for the matching records.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete specified records from the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to delete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete specified records from the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to delete.
|
||||
"""
|
||||
Reference in New Issue
Block a user