initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
"""**Indexes**.
**Index** is used to avoid writing duplicated content
into the vectostore and to avoid over-writing content if it's unchanged.
Indexes also :
* Create knowledge graphs from data.
* Support indexing workflows from LangChain data loaders to vectorstores.
Importantly, Index keeps on working even if the content being written is derived
via a set of transformations from some source content (e.g., indexing children
documents that were derived from parent documents by chunking.)
"""
from typing import TYPE_CHECKING, Any
from langchain_core.indexing.api import IndexingResult, aindex, index
from langchain_classic._api import create_importer
from langchain_classic.indexes._sql_record_manager import SQLRecordManager
from langchain_classic.indexes.vectorstore import VectorstoreIndexCreator
if TYPE_CHECKING:
from langchain_community.graphs.index_creator import GraphIndexCreator
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"GraphIndexCreator": "langchain_community.graphs.index_creator",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"GraphIndexCreator",
"IndexingResult",
"SQLRecordManager",
"VectorstoreIndexCreator",
# Keep sorted
"aindex",
"index",
]

View File

@@ -0,0 +1,5 @@
from langchain_core.indexing.api import _abatch, _batch, _HashedDocument
# Please do not use these in your application. These are private APIs.
# Here to avoid changing unit tests during a migration.
__all__ = ["_HashedDocument", "_abatch", "_batch"]

View File

@@ -0,0 +1,532 @@
"""Implementation of a record management layer in SQLAlchemy.
The management layer uses SQLAlchemy to track upserted records.
Currently, this layer only works with SQLite; hopwever, should be adaptable
to other SQL implementations with minimal effort.
Currently, includes an implementation that uses SQLAlchemy which should
allow it to work with a variety of SQL as a backend.
* Each key is associated with an updated_at field.
* This filed is updated whenever the key is updated.
* Keys can be listed based on the updated at field.
* Keys can be deleted.
"""
import contextlib
import decimal
import uuid
from collections.abc import AsyncGenerator, Generator, Sequence
from typing import Any
from langchain_core.indexing import RecordManager
from sqlalchemy import (
Column,
Float,
Index,
String,
UniqueConstraint,
and_,
create_engine,
delete,
select,
text,
)
from sqlalchemy.engine import URL, Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.orm import Query, Session, declarative_base, sessionmaker
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore[assignment,misc]
Base = declarative_base()
class UpsertionRecord(Base): # type: ignore[valid-type,misc]
"""Table used to keep track of when a key was last updated."""
# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "upsertion_record"
uuid = Column(
String,
index=True,
default=lambda: str(uuid.uuid4()),
primary_key=True,
nullable=False,
)
key = Column(String, index=True)
# Using a non-normalized representation to handle `namespace` attribute.
# If the need arises, this attribute can be pulled into a separate Collection
# table at some time later.
namespace = Column(String, index=True, nullable=False)
group_id = Column(String, index=True, nullable=True)
# The timestamp associated with the last record upsertion.
updated_at = Column(Float, index=True)
__table_args__ = (
UniqueConstraint("key", "namespace", name="uix_key_namespace"),
Index("ix_key_namespace", "key", "namespace"),
)
class SQLRecordManager(RecordManager):
"""A SQL Alchemy based implementation of the record manager."""
def __init__(
self,
namespace: str,
*,
engine: Engine | AsyncEngine | None = None,
db_url: None | str | URL = None,
engine_kwargs: dict[str, Any] | None = None,
async_mode: bool = False,
) -> None:
"""Initialize the SQLRecordManager.
This class serves as a manager persistence layer that uses an SQL
backend to track upserted records. You should specify either a `db_url`
to create an engine or provide an existing engine.
Args:
namespace: The namespace associated with this record manager.
engine: An already existing SQL Alchemy engine.
db_url: A database connection string used to create an SQL Alchemy engine.
engine_kwargs: Additional keyword arguments to be passed when creating the
engine.
async_mode: Whether to create an async engine. Driver should support async
operations. It only applies if `db_url` is provided.
Raises:
ValueError: If both db_url and engine are provided or neither.
AssertionError: If something unexpected happens during engine configuration.
"""
super().__init__(namespace=namespace)
if db_url is None and engine is None:
msg = "Must specify either db_url or engine"
raise ValueError(msg)
if db_url is not None and engine is not None:
msg = "Must specify either db_url or engine, not both"
raise ValueError(msg)
_engine: Engine | AsyncEngine
if db_url:
if async_mode:
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
else:
_engine = create_engine(db_url, **(engine_kwargs or {}))
elif engine:
_engine = engine
else:
msg = "Something went wrong with configuration of engine."
raise AssertionError(msg)
_session_factory: sessionmaker[Session] | async_sessionmaker[AsyncSession]
if isinstance(_engine, AsyncEngine):
_session_factory = async_sessionmaker(bind=_engine)
else:
_session_factory = sessionmaker(bind=_engine)
self.engine = _engine
self.dialect = _engine.dialect.name
self.session_factory = _session_factory
def create_schema(self) -> None:
"""Create the database schema."""
if isinstance(self.engine, AsyncEngine):
msg = "This method is not supported for async engines."
raise AssertionError(msg) # noqa: TRY004
Base.metadata.create_all(self.engine)
async def acreate_schema(self) -> None:
"""Create the database schema."""
if not isinstance(self.engine, AsyncEngine):
msg = "This method is not supported for sync engines."
raise AssertionError(msg) # noqa: TRY004
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a session and close it after use."""
if isinstance(self.session_factory, async_sessionmaker):
msg = "This method is not supported for async engines."
raise AssertionError(msg) # noqa: TRY004
session = self.session_factory()
try:
yield session
finally:
session.close()
@contextlib.asynccontextmanager
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Create a session and close it after use."""
if not isinstance(self.session_factory, async_sessionmaker):
msg = "This method is not supported for sync engines."
raise AssertionError(msg) # noqa: TRY004
async with self.session_factory() as session:
yield session
def get_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
with self._make_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
msg = f"Not implemented for dialect {self.dialect}"
raise NotImplementedError(msg)
dt = session.execute(query).scalar()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
msg = f"Unexpected type for datetime: {type(dt)}"
raise AssertionError(msg) # noqa: TRY004
return dt
async def aget_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
async with self._amake_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
if self.dialect == "sqlite":
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
elif self.dialect == "postgresql":
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
else:
msg = f"Not implemented for dialect {self.dialect}"
raise NotImplementedError(msg)
dt = (await session.execute(query)).scalar_one_or_none()
if isinstance(dt, decimal.Decimal):
dt = float(dt)
if not isinstance(dt, float):
msg = f"Unexpected type for datetime: {type(dt)}"
raise AssertionError(msg) # noqa: TRY004
return dt
def update(
self,
keys: Sequence[str],
*,
group_ids: Sequence[str | None] | None = None,
time_at_least: float | None = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
msg = (
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
raise ValueError(msg)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = self.get_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
msg = f"Time sync issue: {update_time} < {time_at_least}"
raise AssertionError(msg)
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids, strict=False)
]
with self._make_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord,
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_={
"updated_at": sqlite_insert_stmt.excluded.updated_at,
"group_id": sqlite_insert_stmt.excluded.group_id,
},
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses postgresql insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert,
)
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
constraint="uix_key_namespace", # Name of constraint
set_={
"updated_at": pg_insert_stmt.excluded.updated_at,
"group_id": pg_insert_stmt.excluded.group_id,
},
)
else:
msg = f"Unsupported dialect {self.dialect}"
raise NotImplementedError(msg)
session.execute(stmt)
session.commit()
async def aupdate(
self,
keys: Sequence[str],
*,
group_ids: Sequence[str | None] | None = None,
time_at_least: float | None = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
msg = (
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
raise ValueError(msg)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = await self.aget_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
msg = f"Time sync issue: {update_time} < {time_at_least}"
raise AssertionError(msg)
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids, strict=False)
]
async with self._amake_session() as session:
if self.dialect == "sqlite":
from sqlalchemy.dialects.sqlite import Insert as SqliteInsertType
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
sqlite_insert_stmt: SqliteInsertType = sqlite_insert(
UpsertionRecord,
).values(records_to_upsert)
stmt = sqlite_insert_stmt.on_conflict_do_update(
[UpsertionRecord.key, UpsertionRecord.namespace],
set_={
"updated_at": sqlite_insert_stmt.excluded.updated_at,
"group_id": sqlite_insert_stmt.excluded.group_id,
},
)
elif self.dialect == "postgresql":
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
from sqlalchemy.dialects.postgresql import insert as pg_insert
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
pg_insert_stmt: PgInsertType = pg_insert(UpsertionRecord).values(
records_to_upsert,
)
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
constraint="uix_key_namespace", # Name of constraint
set_={
"updated_at": pg_insert_stmt.excluded.updated_at,
"group_id": pg_insert_stmt.excluded.group_id,
},
)
else:
msg = f"Unsupported dialect {self.dialect}"
raise NotImplementedError(msg)
await session.execute(stmt)
await session.commit()
def exists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the given keys exist in the SQLite database."""
session: Session
with self._make_session() as session:
filtered_query: Query = session.query(UpsertionRecord.key).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
),
)
records = filtered_query.all()
found_keys = {r.key for r in records}
return [k in found_keys for k in keys]
async def aexists(self, keys: Sequence[str]) -> list[bool]:
"""Check if the given keys exist in the SQLite database."""
async with self._amake_session() as session:
records = (
(
await session.execute(
select(UpsertionRecord.key).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
),
),
)
)
.scalars()
.all()
)
found_keys = set(records)
return [k in found_keys for k in keys]
def list_keys(
self,
*,
before: float | None = None,
after: float | None = None,
group_ids: Sequence[str] | None = None,
limit: int | None = None,
) -> list[str]:
"""List records in the SQLite database based on the provided date range."""
session: Session
with self._make_session() as session:
query: Query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace,
)
if after:
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
if limit:
query = query.limit(limit)
records = query.all()
return [r.key for r in records]
async def alist_keys(
self,
*,
before: float | None = None,
after: float | None = None,
group_ids: Sequence[str] | None = None,
limit: int | None = None,
) -> list[str]:
"""List records in the SQLite database based on the provided date range."""
session: AsyncSession
async with self._amake_session() as session:
query: Query = select(UpsertionRecord.key).filter( # type: ignore[assignment]
UpsertionRecord.namespace == self.namespace,
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter(UpsertionRecord.updated_at > after)
if before:
query = query.filter(UpsertionRecord.updated_at < before)
if group_ids:
query = query.filter(UpsertionRecord.group_id.in_(group_ids))
if limit:
query = query.limit(limit)
records = (await session.execute(query)).scalars().all()
return list(records)
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
session: Session
with self._make_session() as session:
filtered_query: Query = session.query(UpsertionRecord).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
),
)
filtered_query.delete()
session.commit()
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
async with self._amake_session() as session:
await session.execute(
delete(UpsertionRecord).where(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
),
),
)
await session.commit()

View File

@@ -0,0 +1,28 @@
"""**Graphs** provide a natural language interface to graph databases."""
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.graphs.index_creator import GraphIndexCreator
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"GraphIndexCreator": "langchain_community.graphs.index_creator",
"NetworkxEntityGraph": "langchain_community.graphs.networkx_graph",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = ["GraphIndexCreator", "NetworkxEntityGraph"]

View File

@@ -0,0 +1,13 @@
"""Relevant prompts for constructing indexes."""
from langchain_core._api import warn_deprecated
warn_deprecated(
since="0.1.47",
message=(
"langchain.indexes.prompts will be removed in the future."
"If you're relying on these prompts, please open an issue on "
"GitHub to explain your use case."
),
pending=True,
)

View File

@@ -0,0 +1,39 @@
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
Output: Langchain
END OF EXAMPLE
EXAMPLE
Conversation history:
Person #1: how's it going today?
AI: "It's going great! How about you?"
Person #1: good! busy working on Langchain. lots to do.
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
Last line:
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
Output: Langchain, Person #2
END OF EXAMPLE
Conversation history (for reference only):
{history}
Last line of conversation (for extraction):
Human: {input}
Output:""" # noqa: E501
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
)

View File

@@ -0,0 +1,24 @@
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
Full conversation history (for context):
{history}
Entity to summarize:
{entity}
Existing summary of {entity}:
{summary}
Last line of conversation:
Human: {input}
Updated summary:""" # noqa: E501
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
input_variables=["entity", "summary", "history", "input"],
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
)

View File

@@ -0,0 +1,36 @@
from langchain_core.prompts.prompt import PromptTemplate
KG_TRIPLE_DELIMITER = "<|>"
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
"You are a networked intelligence helping a human track knowledge triples"
" about all relevant people, things, concepts, etc. and integrating"
" them with your knowledge stored within your weights"
" as well as that stored in a knowledge graph."
" Extract all of the knowledge triples from the text."
" A knowledge triple is a clause that contains a subject, a predicate,"
" and an object. The subject is the entity being described,"
" the predicate is the property of the subject that is being"
" described, and the object is the value of the property.\n\n"
"EXAMPLE\n"
"It's a state in the US. It's also the number 1 producer of gold in the US.\n\n"
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"I'm going to the store.\n\n"
"Output: NONE\n"
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n"
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" # noqa: E501
"END OF EXAMPLE\n\n"
"EXAMPLE\n"
"{text}"
"Output:"
)
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["text"],
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
)

View File

@@ -0,0 +1,271 @@
"""Vectorstore stubs for the indexing api."""
from typing import Any
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from pydantic import BaseModel, ConfigDict, Field
from langchain_classic.chains.qa_with_sources.retrieval import (
RetrievalQAWithSourcesChain,
)
from langchain_classic.chains.retrieval_qa.base import RetrievalQA
def _get_default_text_splitter() -> TextSplitter:
"""Return the default text splitter used for chunking documents."""
return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
class VectorStoreIndexWrapper(BaseModel):
"""Wrapper around a `VectorStore` for easy access."""
vectorstore: VectorStore
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def query(
self,
question: str,
llm: BaseLanguageModel | None = None,
retriever_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> str:
"""Query the `VectorStore` using the provided LLM.
Args:
question: The question or prompt to query.
llm: The language model to use. Must not be `None`.
retriever_kwargs: Optional keyword arguments for the retriever.
**kwargs: Additional keyword arguments forwarded to the chain.
Returns:
The result string from the RetrievalQA chain.
"""
if llm is None:
msg = (
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQA.from_chain_type(
llm,
retriever=self.vectorstore.as_retriever(**retriever_kwargs),
**kwargs,
)
return chain.invoke({chain.input_key: question})[chain.output_key]
async def aquery(
self,
question: str,
llm: BaseLanguageModel | None = None,
retriever_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> str:
"""Asynchronously query the `VectorStore` using the provided LLM.
Args:
question: The question or prompt to query.
llm: The language model to use. Must not be `None`.
retriever_kwargs: Optional keyword arguments for the retriever.
**kwargs: Additional keyword arguments forwarded to the chain.
Returns:
The asynchronous result string from the RetrievalQA chain.
"""
if llm is None:
msg = (
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQA.from_chain_type(
llm,
retriever=self.vectorstore.as_retriever(**retriever_kwargs),
**kwargs,
)
return (await chain.ainvoke({chain.input_key: question}))[chain.output_key]
def query_with_sources(
self,
question: str,
llm: BaseLanguageModel | None = None,
retriever_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict:
"""Query the `VectorStore` and retrieve the answer along with sources.
Args:
question: The question or prompt to query.
llm: The language model to use. Must not be `None`.
retriever_kwargs: Optional keyword arguments for the retriever.
**kwargs: Additional keyword arguments forwarded to the chain.
Returns:
`dict` containing the answer and source documents.
"""
if llm is None:
msg = (
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm,
retriever=self.vectorstore.as_retriever(**retriever_kwargs),
**kwargs,
)
return chain.invoke({chain.question_key: question})
async def aquery_with_sources(
self,
question: str,
llm: BaseLanguageModel | None = None,
retriever_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> dict:
"""Asynchronously query the `VectorStore` and retrieve the answer and sources.
Args:
question: The question or prompt to query.
llm: The language model to use. Must not be `None`.
retriever_kwargs: Optional keyword arguments for the retriever.
**kwargs: Additional keyword arguments forwarded to the chain.
Returns:
`dict` containing the answer and source documents.
"""
if llm is None:
msg = (
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"model = OpenAI(temperature=0)"
)
raise NotImplementedError(msg)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm,
retriever=self.vectorstore.as_retriever(**retriever_kwargs),
**kwargs,
)
return await chain.ainvoke({chain.question_key: question})
def _get_in_memory_vectorstore() -> type[VectorStore]:
"""Get the `InMemoryVectorStore`."""
import warnings
try:
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
except ImportError as e:
msg = "Please install langchain-community to use the InMemoryVectorStore."
raise ImportError(msg) from e
warnings.warn(
"Using InMemoryVectorStore as the default vectorstore."
"This memory store won't persist data. You should explicitly"
"specify a VectorStore when using VectorstoreIndexCreator",
stacklevel=3,
)
return InMemoryVectorStore
class VectorstoreIndexCreator(BaseModel):
"""Logic for creating indexes."""
vectorstore_cls: type[VectorStore] = Field(
default_factory=_get_in_memory_vectorstore,
)
embedding: Embeddings
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
vectorstore_kwargs: dict = Field(default_factory=dict)
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def from_loaders(self, loaders: list[BaseLoader]) -> VectorStoreIndexWrapper:
"""Create a `VectorStore` index from a list of loaders.
Args:
loaders: A list of `BaseLoader` instances to load documents.
Returns:
A `VectorStoreIndexWrapper` containing the constructed vectorstore.
"""
docs = []
for loader in loaders:
docs.extend(loader.load())
return self.from_documents(docs)
async def afrom_loaders(self, loaders: list[BaseLoader]) -> VectorStoreIndexWrapper:
"""Asynchronously create a `VectorStore` index from a list of loaders.
Args:
loaders: A list of `BaseLoader` instances to load documents.
Returns:
A `VectorStoreIndexWrapper` containing the constructed vectorstore.
"""
docs = []
for loader in loaders:
docs.extend([doc async for doc in loader.alazy_load()])
return await self.afrom_documents(docs)
def from_documents(self, documents: list[Document]) -> VectorStoreIndexWrapper:
"""Create a `VectorStore` index from a list of documents.
Args:
documents: A list of `Document` objects.
Returns:
A `VectorStoreIndexWrapper` containing the constructed vectorstore.
"""
sub_docs = self.text_splitter.split_documents(documents)
vectorstore = self.vectorstore_cls.from_documents(
sub_docs,
self.embedding,
**self.vectorstore_kwargs,
)
return VectorStoreIndexWrapper(vectorstore=vectorstore)
async def afrom_documents(
self,
documents: list[Document],
) -> VectorStoreIndexWrapper:
"""Asynchronously create a `VectorStore` index from a list of documents.
Args:
documents: A list of `Document` objects.
Returns:
A `VectorStoreIndexWrapper` containing the constructed vectorstore.
"""
sub_docs = self.text_splitter.split_documents(documents)
vectorstore = await self.vectorstore_cls.afrom_documents(
sub_docs,
self.embedding,
**self.vectorstore_kwargs,
)
return VectorStoreIndexWrapper(vectorstore=vectorstore)