initial commit

This commit is contained in:
Gokul
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
"""**Retriever** class returns Documents given a text **query**.
It is more general than a vector store. A retriever does not need to be able to
store documents, only to return (or retrieve) it. Vector stores can be used as
the backbone of a retriever, but there are other types of retrievers as well.
"""
from typing import TYPE_CHECKING, Any
from langchain_classic._api.module_import import create_importer
from langchain_classic.retrievers.contextual_compression import (
ContextualCompressionRetriever,
)
from langchain_classic.retrievers.ensemble import EnsembleRetriever
from langchain_classic.retrievers.merger_retriever import MergerRetriever
from langchain_classic.retrievers.multi_query import MultiQueryRetriever
from langchain_classic.retrievers.multi_vector import MultiVectorRetriever
from langchain_classic.retrievers.parent_document_retriever import (
ParentDocumentRetriever,
)
from langchain_classic.retrievers.re_phraser import RePhraseQueryRetriever
from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
from langchain_classic.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
)
if TYPE_CHECKING:
from langchain_community.retrievers import (
AmazonKendraRetriever,
AmazonKnowledgeBasesRetriever,
ArceeRetriever,
ArxivRetriever,
AzureAISearchRetriever,
AzureCognitiveSearchRetriever,
BM25Retriever,
ChaindeskRetriever,
ChatGPTPluginRetriever,
CohereRagRetriever,
DocArrayRetriever,
DriaRetriever,
ElasticSearchBM25Retriever,
EmbedchainRetriever,
GoogleCloudEnterpriseSearchRetriever,
GoogleDocumentAIWarehouseRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
KayAiRetriever,
KNNRetriever,
LlamaIndexGraphRetriever,
LlamaIndexRetriever,
MetalRetriever,
MilvusRetriever,
NeuralDBRetriever,
OutlineRetriever,
PineconeHybridSearchRetriever,
PubMedRetriever,
RemoteLangChainRetriever,
SVMRetriever,
TavilySearchAPIRetriever,
TFIDFRetriever,
VespaRetriever,
WeaviateHybridSearchRetriever,
WebResearchRetriever,
WikipediaRetriever,
ZepRetriever,
ZillizRetriever,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"AmazonKendraRetriever": "langchain_community.retrievers",
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers",
"ArceeRetriever": "langchain_community.retrievers",
"ArxivRetriever": "langchain_community.retrievers",
"AzureAISearchRetriever": "langchain_community.retrievers",
"AzureCognitiveSearchRetriever": "langchain_community.retrievers",
"ChatGPTPluginRetriever": "langchain_community.retrievers",
"ChaindeskRetriever": "langchain_community.retrievers",
"CohereRagRetriever": "langchain_community.retrievers",
"ElasticSearchBM25Retriever": "langchain_community.retrievers",
"EmbedchainRetriever": "langchain_community.retrievers",
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers",
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers",
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers",
"GoogleVertexAISearchRetriever": "langchain_community.retrievers",
"KayAiRetriever": "langchain_community.retrievers",
"KNNRetriever": "langchain_community.retrievers",
"LlamaIndexGraphRetriever": "langchain_community.retrievers",
"LlamaIndexRetriever": "langchain_community.retrievers",
"MetalRetriever": "langchain_community.retrievers",
"MilvusRetriever": "langchain_community.retrievers",
"OutlineRetriever": "langchain_community.retrievers",
"PineconeHybridSearchRetriever": "langchain_community.retrievers",
"PubMedRetriever": "langchain_community.retrievers",
"RemoteLangChainRetriever": "langchain_community.retrievers",
"SVMRetriever": "langchain_community.retrievers",
"TavilySearchAPIRetriever": "langchain_community.retrievers",
"BM25Retriever": "langchain_community.retrievers",
"DriaRetriever": "langchain_community.retrievers",
"NeuralDBRetriever": "langchain_community.retrievers",
"TFIDFRetriever": "langchain_community.retrievers",
"VespaRetriever": "langchain_community.retrievers",
"WeaviateHybridSearchRetriever": "langchain_community.retrievers",
"WebResearchRetriever": "langchain_community.retrievers",
"WikipediaRetriever": "langchain_community.retrievers",
"ZepRetriever": "langchain_community.retrievers",
"ZillizRetriever": "langchain_community.retrievers",
"DocArrayRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"BM25Retriever",
"ChaindeskRetriever",
"ChatGPTPluginRetriever",
"CohereRagRetriever",
"ContextualCompressionRetriever",
"DocArrayRetriever",
"DriaRetriever",
"ElasticSearchBM25Retriever",
"EmbedchainRetriever",
"EnsembleRetriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever",
"KNNRetriever",
"KayAiRetriever",
"LlamaIndexGraphRetriever",
"LlamaIndexRetriever",
"MergerRetriever",
"MetalRetriever",
"MilvusRetriever",
"MultiQueryRetriever",
"MultiVectorRetriever",
"NeuralDBRetriever",
"OutlineRetriever",
"ParentDocumentRetriever",
"PineconeHybridSearchRetriever",
"PubMedRetriever",
"RePhraseQueryRetriever",
"RemoteLangChainRetriever",
"SVMRetriever",
"SelfQueryRetriever",
"TFIDFRetriever",
"TavilySearchAPIRetriever",
"TimeWeightedVectorStoreRetriever",
"VespaRetriever",
"WeaviateHybridSearchRetriever",
"WebResearchRetriever",
"WikipediaRetriever",
"ZepRetriever",
"ZillizRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import ArceeRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ArceeRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ArceeRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import ArxivRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ArxivRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ArxivRetriever",
]

View File

@@ -0,0 +1,30 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import (
AzureAISearchRetriever,
AzureCognitiveSearchRetriever,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"AzureAISearchRetriever": "langchain_community.retrievers",
"AzureCognitiveSearchRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
]

View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
from langchain_community.retrievers.bedrock import (
RetrievalConfig,
VectorSearchConfig,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"VectorSearchConfig": "langchain_community.retrievers.bedrock",
"RetrievalConfig": "langchain_community.retrievers.bedrock",
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"AmazonKnowledgeBasesRetriever",
"RetrievalConfig",
"VectorSearchConfig",
]

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import BM25Retriever
from langchain_community.retrievers.bm25 import default_preprocessing_func
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"default_preprocessing_func": "langchain_community.retrievers.bm25",
"BM25Retriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"BM25Retriever",
"default_preprocessing_func",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import ChaindeskRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ChaindeskRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ChaindeskRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import ChatGPTPluginRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ChatGPTPluginRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ChatGPTPluginRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import CohereRagRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"CohereRagRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"CohereRagRetriever",
]

View File

@@ -0,0 +1,68 @@
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from pydantic import ConfigDict
from typing_extensions import override
class ContextualCompressionRetriever(BaseRetriever):
"""Retriever that wraps a base retriever and compresses the results."""
base_compressor: BaseDocumentCompressor
"""Compressor for compressing retrieved documents."""
base_retriever: RetrieverLike
"""Base Retriever to use for getting relevant documents."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@override
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
docs = self.base_retriever.invoke(
query,
config={"callbacks": run_manager.get_child()},
**kwargs,
)
if docs:
compressed_docs = self.base_compressor.compress_documents(
docs,
query,
callbacks=run_manager.get_child(),
)
return list(compressed_docs)
return []
@override
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> list[Document]:
docs = await self.base_retriever.ainvoke(
query,
config={"callbacks": run_manager.get_child()},
**kwargs,
)
if docs:
compressed_docs = await self.base_compressor.acompress_documents(
docs,
query,
callbacks=run_manager.get_child(),
)
return list(compressed_docs)
return []

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers.databerry import DataberryRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"DataberryRetriever": "langchain_community.retrievers.databerry"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"DataberryRetriever",
]

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import DocArrayRetriever
from langchain_community.retrievers.docarray import SearchType
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"SearchType": "langchain_community.retrievers.docarray",
"DocArrayRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"DocArrayRetriever",
"SearchType",
]

View File

@@ -0,0 +1,46 @@
import importlib
from typing import Any
from langchain_classic.retrievers.document_compressors.base import (
DocumentCompressorPipeline,
)
from langchain_classic.retrievers.document_compressors.chain_extract import (
LLMChainExtractor,
)
from langchain_classic.retrievers.document_compressors.chain_filter import (
LLMChainFilter,
)
from langchain_classic.retrievers.document_compressors.cohere_rerank import CohereRerank
from langchain_classic.retrievers.document_compressors.cross_encoder_rerank import (
CrossEncoderReranker,
)
from langchain_classic.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
from langchain_classic.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)
_module_lookup = {
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
msg = f"module {__name__} has no attribute {name}"
raise AttributeError(msg)
__all__ = [
"CohereRerank",
"CrossEncoderReranker",
"DocumentCompressorPipeline",
"EmbeddingsFilter",
"FlashrankRerank",
"LLMChainExtractor",
"LLMChainFilter",
"LLMListwiseRerank",
]

View File

@@ -0,0 +1,81 @@
from collections.abc import Sequence
from inspect import signature
from langchain_core.callbacks import Callbacks
from langchain_core.documents import (
BaseDocumentCompressor,
BaseDocumentTransformer,
Document,
)
from pydantic import ConfigDict
class DocumentCompressorPipeline(BaseDocumentCompressor):
"""Document compressor that uses a pipeline of Transformers."""
transformers: list[BaseDocumentTransformer | BaseDocumentCompressor]
"""List of document filters that are chained together and run in sequence."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Transform a list of documents."""
for _transformer in self.transformers:
if isinstance(_transformer, BaseDocumentCompressor):
accepts_callbacks = (
signature(_transformer.compress_documents).parameters.get(
"callbacks",
)
is not None
)
if accepts_callbacks:
documents = _transformer.compress_documents(
documents,
query,
callbacks=callbacks,
)
else:
documents = _transformer.compress_documents(documents, query)
elif isinstance(_transformer, BaseDocumentTransformer):
documents = _transformer.transform_documents(documents)
else:
msg = f"Got unexpected transformer type: {_transformer}" # type: ignore[unreachable]
raise ValueError(msg) # noqa: TRY004
return documents
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
for _transformer in self.transformers:
if isinstance(_transformer, BaseDocumentCompressor):
accepts_callbacks = (
signature(_transformer.acompress_documents).parameters.get(
"callbacks",
)
is not None
)
if accepts_callbacks:
documents = await _transformer.acompress_documents(
documents,
query,
callbacks=callbacks,
)
else:
documents = await _transformer.acompress_documents(documents, query)
elif isinstance(_transformer, BaseDocumentTransformer):
documents = await _transformer.atransform_documents(documents)
else:
msg = f"Got unexpected transformer type: {_transformer}" # type: ignore[unreachable]
raise ValueError(msg) # noqa: TRY004
return documents

View File

@@ -0,0 +1,126 @@
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Any, cast
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from pydantic import ConfigDict
from typing_extensions import override
from langchain_classic.chains.llm import LLMChain
from langchain_classic.retrievers.document_compressors.chain_extract_prompt import (
prompt_template,
)
def default_get_input(query: str, doc: Document) -> dict[str, Any]:
"""Return the compression chain input."""
return {"question": query, "context": doc.page_content}
class NoOutputParser(BaseOutputParser[str]):
"""Parse outputs that could return a null string of some sort."""
no_output_str: str = "NO_OUTPUT"
@override
def parse(self, text: str) -> str:
cleaned_text = text.strip()
if cleaned_text == self.no_output_str:
return ""
return cleaned_text
def _get_default_chain_prompt() -> PromptTemplate:
output_parser = NoOutputParser()
template = prompt_template.format(no_output_str=output_parser.no_output_str)
return PromptTemplate(
template=template,
input_variables=["question", "context"],
output_parser=output_parser,
)
class LLMChainExtractor(BaseDocumentCompressor):
"""LLM Chain Extractor.
Document compressor that uses an LLM chain to extract
the relevant parts of documents.
"""
llm_chain: Runnable
"""LLM wrapper to use for compressing documents."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Compress page content of raw documents."""
compressed_docs = []
for doc in documents:
_input = self.get_input(query, doc)
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
output = self.llm_chain.prompt.output_parser.parse(output)
else:
output = output_
if len(output) == 0:
continue
compressed_docs.append(
Document(page_content=cast("str", output), metadata=doc.metadata),
)
return compressed_docs
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Compress page content of raw documents asynchronously."""
inputs = [self.get_input(query, doc) for doc in documents]
outputs = await self.llm_chain.abatch(inputs, {"callbacks": callbacks})
compressed_docs = []
for i, doc in enumerate(documents):
if len(outputs[i]) == 0:
continue
compressed_docs.append(
Document(page_content=outputs[i], metadata=doc.metadata),
)
return compressed_docs
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: PromptTemplate | None = None,
get_input: Callable[[str, Document], str] | None = None,
llm_chain_kwargs: dict | None = None, # noqa: ARG003
) -> LLMChainExtractor:
"""Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
_get_input = get_input if get_input is not None else default_get_input
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, get_input=_get_input)

View File

@@ -0,0 +1,10 @@
prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}.
Remember, *DO NOT* edit the extracted parts of the context.
> Question: {{question}}
> Context:
>>>
{{context}}
>>>
Extracted relevant parts:""" # noqa: E501

View File

@@ -0,0 +1,135 @@
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
from collections.abc import Callable, Sequence
from typing import Any
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from pydantic import ConfigDict
from langchain_classic.chains import LLMChain
from langchain_classic.output_parsers.boolean import BooleanOutputParser
from langchain_classic.retrievers.document_compressors.chain_filter_prompt import (
prompt_template,
)
def _get_default_chain_prompt() -> PromptTemplate:
return PromptTemplate(
template=prompt_template,
input_variables=["question", "context"],
output_parser=BooleanOutputParser(),
)
def default_get_input(query: str, doc: Document) -> dict[str, Any]:
"""Return the compression chain input."""
return {"question": query, "context": doc.page_content}
class LLMChainFilter(BaseDocumentCompressor):
"""Filter that drops documents that aren't relevant to the query."""
llm_chain: Runnable
"""LLM wrapper to use for filtering documents.
The chain prompt is expected to have a BooleanOutputParser."""
get_input: Callable[[str, Document], dict] = default_get_input
"""Callable for constructing the chain input from the query and a Document."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
config = RunnableConfig(callbacks=callbacks)
outputs = zip(
self.llm_chain.batch(
[self.get_input(query, doc) for doc in documents],
config=config,
),
documents,
strict=False,
)
for output_, doc in outputs:
include_doc = None
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
elif isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
return filtered_docs
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
config = RunnableConfig(callbacks=callbacks)
outputs = zip(
await self.llm_chain.abatch(
[self.get_input(query, doc) for doc in documents],
config=config,
),
documents,
strict=False,
)
for output_, doc in outputs:
include_doc = None
if isinstance(self.llm_chain, LLMChain):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
elif isinstance(output_, bool):
include_doc = output_
if include_doc:
filtered_docs.append(doc)
return filtered_docs
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate | None = None,
**kwargs: Any,
) -> "LLMChainFilter":
"""Create a LLMChainFilter from a language model.
Args:
llm: The language model to use for filtering.
prompt: The prompt to use for the filter.
kwargs: Additional arguments to pass to the constructor.
Returns:
A LLMChainFilter that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
if _prompt.output_parser is not None:
parser = _prompt.output_parser
else:
parser = StrOutputParser()
llm_chain = _prompt | llm | parser
return cls(llm_chain=llm_chain, **kwargs)

View File

@@ -0,0 +1,8 @@
prompt_template = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't.
> Question: {question}
> Context:
>>>
{context}
>>>
> Relevant (YES / NO):""" # noqa: E501

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
from collections.abc import Sequence
from copy import deepcopy
from typing import Any
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, model_validator
from typing_extensions import override
@deprecated(
since="0.0.30",
removal="1.0",
alternative_import="langchain_cohere.CohereRerank",
)
class CohereRerank(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
client: Any = None
"""Cohere client to use for compressing documents."""
top_n: int | None = 3
"""Number of documents to return."""
model: str = "rerank-english-v2.0"
"""Model to use for reranking."""
cohere_api_key: str | None = None
"""Cohere API key. Must be specified directly or via environment variable
COHERE_API_KEY."""
user_agent: str = "langchain"
"""Identifier for the application making the request."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: dict) -> Any:
"""Validate that api key and python package exists in environment."""
if not values.get("client"):
try:
import cohere
except ImportError as e:
msg = (
"Could not import cohere python package. "
"Please install it with `pip install cohere`."
)
raise ImportError(msg) from e
cohere_api_key = get_from_dict_or_env(
values,
"cohere_api_key",
"COHERE_API_KEY",
)
client_name = values.get("user_agent", "langchain")
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
return values
def rerank(
self,
documents: Sequence[str | Document | dict],
query: str,
*,
model: str | None = None,
top_n: int | None = -1,
max_chunks_per_doc: int | None = None,
) -> list[dict[str, Any]]:
"""Returns an ordered list of documents ordered by their relevance to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
model: The model to use for re-ranking. Default to self.model.
top_n : The number of results to return. If `None` returns all results.
max_chunks_per_doc : The maximum number of chunks derived from a document.
""" # noqa: E501
if len(documents) == 0: # to avoid empty api call
return []
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
model = model or self.model
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
results = self.client.rerank(
query=query,
documents=docs,
model=model,
top_n=top_n,
max_chunks_per_doc=max_chunks_per_doc,
)
if hasattr(results, "results"):
results = results.results
return [
{"index": res.index, "relevance_score": res.relevance_score}
for res in results
]
@override
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Compress documents using Cohere's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed

View File

@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
class BaseCrossEncoder(ABC):
"""Interface for cross encoder models."""
@abstractmethod
def score(self, text_pairs: list[tuple[str, str]]) -> list[float]:
"""Score pairs' similarity.
Args:
text_pairs: List of pairs of texts.
Returns:
List of scores.
"""

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import operator
from collections.abc import Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict
from typing_extensions import override
from langchain_classic.retrievers.document_compressors.cross_encoder import (
BaseCrossEncoder,
)
class CrossEncoderReranker(BaseDocumentCompressor):
"""Document compressor that uses CrossEncoder for reranking."""
model: BaseCrossEncoder
"""CrossEncoder model to use for scoring similarity
between the query and documents."""
top_n: int = 3
"""Number of documents to return."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@override
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Rerank documents using CrossEncoder.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
scores = self.model.score([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores, strict=False))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return [doc for doc, _ in result[: self.top_n]]

View File

@@ -0,0 +1,141 @@
from collections.abc import Callable, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init
from pydantic import ConfigDict, Field
from typing_extensions import override
def _get_similarity_function() -> Callable:
try:
from langchain_community.utils.math import cosine_similarity
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg) from e
return cosine_similarity
class EmbeddingsFilter(BaseDocumentCompressor):
"""Embeddings Filter.
Document compressor that uses embeddings to drop documents unrelated to the query.
"""
embeddings: Embeddings
"""Embeddings to use for embedding document contents and queries."""
similarity_fn: Callable = Field(default_factory=_get_similarity_function)
"""Similarity function for comparing documents. Function expected to take as input
two matrices (List[List[float]]) and return a matrix of scores where higher values
indicate greater similarity."""
k: int | None = 20
"""The number of relevant documents to return. Can be set to `None`, in which case
`similarity_threshold` must be specified."""
similarity_threshold: float | None = None
"""Threshold for determining when two documents are similar enough
to be considered redundant. Defaults to `None`, must be specified if `k` is set
to None."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@pre_init
def validate_params(cls, values: dict) -> dict:
"""Validate similarity parameters."""
if values["k"] is None and values["similarity_threshold"] is None:
msg = "Must specify one of `k` or `similarity_threshold`."
raise ValueError(msg)
return values
@override
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Filter documents based on similarity of their embeddings to the query."""
try:
from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501
_get_embeddings_from_stateful_docs,
get_stateful_documents,
)
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg) from e
try:
import numpy as np
except ImportError as e:
msg = "Could not import numpy, please install with `pip install numpy`."
raise ImportError(msg) from e
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings,
stateful_documents,
)
embedded_query = self.embeddings.embed_query(query)
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
included_idxs: np.ndarray = np.arange(len(embedded_documents))
if self.k is not None:
included_idxs = np.argsort(similarity)[::-1][: self.k]
if self.similarity_threshold is not None:
similar_enough = np.where(
similarity[included_idxs] > self.similarity_threshold,
)
included_idxs = included_idxs[similar_enough]
for i in included_idxs:
stateful_documents[i].state["query_similarity_score"] = similarity[i]
return [stateful_documents[i] for i in included_idxs]
@override
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Filter documents based on similarity of their embeddings to the query."""
try:
from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501
_aget_embeddings_from_stateful_docs,
get_stateful_documents,
)
except ImportError as e:
msg = (
"To use please install langchain-community "
"with `pip install langchain-community`."
)
raise ImportError(msg) from e
try:
import numpy as np
except ImportError as e:
msg = "Could not import numpy, please install with `pip install numpy`."
raise ImportError(msg) from e
stateful_documents = get_stateful_documents(documents)
embedded_documents = await _aget_embeddings_from_stateful_docs(
self.embeddings,
stateful_documents,
)
embedded_query = await self.embeddings.aembed_query(query)
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
included_idxs: np.ndarray = np.arange(len(embedded_documents))
if self.k is not None:
included_idxs = np.argsort(similarity)[::-1][: self.k]
if self.similarity_threshold is not None:
similar_enough = np.where(
similarity[included_idxs] > self.similarity_threshold,
)
included_idxs = included_idxs[similar_enough]
for i in included_idxs:
stateful_documents[i].state["query_similarity_score"] = similarity[i]
return [stateful_documents[i] for i in included_idxs]

View File

@@ -0,0 +1,27 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.document_compressors.flashrank_rerank import (
FlashrankRerank,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"FlashrankRerank",
]

View File

@@ -0,0 +1,146 @@
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
from collections.abc import Sequence
from typing import Any
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from pydantic import BaseModel, ConfigDict, Field
_default_system_tmpl = """{context}
Sort the Documents by their relevance to the Query."""
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
[("system", _default_system_tmpl), ("human", "{query}")],
)
def _get_prompt_input(input_: dict) -> dict[str, Any]:
"""Return the compression chain input."""
documents = input_["documents"]
context = ""
for index, doc in enumerate(documents):
context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
document_range = "empty list"
if len(documents) > 0:
document_range = f"Document ID: 0, ..., Document ID: {len(documents) - 1}"
context += f"Documents = [{document_range}]"
return {"query": input_["query"], "context": context}
def _parse_ranking(results: dict) -> list[Document]:
ranking = results["ranking"]
docs = results["documents"]
return [docs[i] for i in ranking.ranked_document_ids]
class LLMListwiseRerank(BaseDocumentCompressor):
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`.
Adapted from: https://arxiv.org/pdf/2305.02156.pdf
`LLMListwiseRerank` uses a language model to rerank a list of documents based on
their relevance to a query.
!!! note
Requires that underlying model implement `with_structured_output`.
Example usage:
```python
from langchain_classic.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
documents = [
Document("Sally is my friend from school"),
Document("Steve is my friend from home"),
Document("I didn't always like yogurt"),
Document("I wonder why it's called football"),
Document("Where's waldo"),
]
reranker = LLMListwiseRerank.from_llm(
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
)
compressed_docs = reranker.compress_documents(documents, "Who is steve")
assert len(compressed_docs) == 3
assert "Steve" in compressed_docs[0].page_content
```
"""
reranker: Runnable[dict, list[Document]]
"""LLM-based reranker to use for filtering documents. Expected to take in a dict
with 'documents: Sequence[Document]' and 'query: str' keys and output a
List[Document]."""
top_n: int = 3
"""Number of documents to return."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
results = self.reranker.invoke(
{"documents": documents, "query": query},
config={"callbacks": callbacks},
)
return results[: self.top_n]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
prompt: BasePromptTemplate | None = None,
**kwargs: Any,
) -> "LLMListwiseRerank":
"""Create a LLMListwiseRerank document compressor from a language model.
Args:
llm: The language model to use for filtering. **Must implement
BaseLanguageModel.with_structured_output().**
prompt: The prompt to use for the filter.
kwargs: Additional arguments to pass to the constructor.
Returns:
A LLMListwiseRerank document compressor that uses the given language model.
"""
if type(llm).with_structured_output == BaseLanguageModel.with_structured_output:
msg = (
f"llm of type {type(llm)} does not implement `with_structured_output`."
)
raise ValueError(msg)
class RankDocuments(BaseModel):
"""Rank the documents by their relevance to the user question.
Rank from most to least relevant.
"""
ranked_document_ids: list[int] = Field(
...,
description=(
"The integer IDs of the documents, sorted from most to least "
"relevant to the user question."
),
)
_prompt = prompt if prompt is not None else _DEFAULT_PROMPT
reranker = RunnablePassthrough.assign(
ranking=RunnableLambda(_get_prompt_input)
| _prompt
| llm.with_structured_output(RankDocuments),
) | RunnableLambda(_parse_ranking)
return cls(reranker=reranker, **kwargs)

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import ElasticSearchBM25Retriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"ElasticSearchBM25Retriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"ElasticSearchBM25Retriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import EmbedchainRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"EmbedchainRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"EmbedchainRetriever",
]

View File

@@ -0,0 +1,352 @@
"""Ensemble Retriever.
Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion.
"""
import asyncio
from collections import defaultdict
from collections.abc import Callable, Hashable, Iterable, Iterator
from itertools import chain
from typing import (
Any,
TypeVar,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import ensure_config, patch_config
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
from pydantic import model_validator
from typing_extensions import override
T = TypeVar("T")
H = TypeVar("H", bound=Hashable)
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
"""Yield unique elements of an iterable based on a key function.
Args:
iterable: The iterable to filter.
key: A function that returns a hashable key for each element.
Yields:
Unique elements of the iterable based on the key function.
"""
seen = set()
for e in iterable:
if (k := key(e)) not in seen:
seen.add(k)
yield e
class EnsembleRetriever(BaseRetriever):
"""Retriever that ensembles the multiple retrievers.
It uses a rank fusion.
Args:
retrievers: A list of retrievers to ensemble.
weights: A list of weights corresponding to the retrievers. Defaults to equal
weighting for all retrievers.
c: A constant added to the rank, controlling the balance between the importance
of high-ranked items and the consideration given to lower-ranked items.
id_key: The key in the document's metadata used to determine unique documents.
If not specified, page_content is used.
"""
retrievers: list[RetrieverLike]
weights: list[float]
c: int = 60
id_key: str | None = None
@property
def config_specs(self) -> list[ConfigurableFieldSpec]:
"""List configurable fields for this runnable."""
return get_unique_config_specs(
spec for retriever in self.retrievers for spec in retriever.config_specs
)
@model_validator(mode="before")
@classmethod
def _set_weights(cls, values: dict[str, Any]) -> Any:
weights = values.get("weights")
if not weights:
n_retrievers = len(values["retrievers"])
values["weights"] = [1 / n_retrievers] * n_retrievers
return values
retrievers = values["retrievers"]
if len(weights) != len(retrievers):
msg = (
"Length of weights must match number of retrievers "
f"(got {len(weights)} weights for {len(retrievers)} retrievers)."
)
raise ValueError(msg)
if not any(w > 0 for w in weights):
msg = "At least one ensemble weight must be greater than zero."
raise ValueError(msg)
return values
@override
def invoke(
self,
input: str,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> list[Document]:
from langchain_core.callbacks import CallbackManager
config = ensure_config(config)
callback_manager = CallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
None,
input,
name=config.get("run_name") or self.get_name(),
**kwargs,
)
try:
result = self.rank_fusion(input, run_manager=run_manager, config=config)
except Exception as e:
run_manager.on_retriever_error(e)
raise
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
@override
async def ainvoke(
self,
input: str,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> list[Document]:
from langchain_core.callbacks import AsyncCallbackManager
config = ensure_config(config)
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
None,
input,
name=config.get("run_name") or self.get_name(),
**kwargs,
)
try:
result = await self.arank_fusion(
input,
run_manager=run_manager,
config=config,
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get the relevant documents for a given query.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of reranked documents.
"""
# Get fused result of the retrievers.
return self.rank_fusion(query, run_manager)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of reranked documents.
"""
# Get fused result of the retrievers.
return await self.arank_fusion(query, run_manager)
def rank_fusion(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
*,
config: RunnableConfig | None = None,
) -> list[Document]:
"""Rank fusion.
Retrieve the results of the retrievers and use rank_fusion_func to get
the final result.
Args:
query: The query to search for.
run_manager: The callback handler to use.
config: Optional configuration for the retrievers.
Returns:
A list of reranked documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
patch_config(
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
)
for i, retriever in enumerate(self.retrievers)
]
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable]
for doc in retriever_docs[i]
]
# apply rank fusion
return self.weighted_reciprocal_rank(retriever_docs)
async def arank_fusion(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
*,
config: RunnableConfig | None = None,
) -> list[Document]:
"""Rank fusion.
Asynchronously retrieve the results of the retrievers
and use rank_fusion_func to get the final result.
Args:
query: The query to search for.
run_manager: The callback handler to use.
config: Optional configuration for the retrievers.
Returns:
A list of reranked documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*[
retriever.ainvoke(
query,
patch_config(
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
)
for i, retriever in enumerate(self.retrievers)
],
)
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=doc) if not isinstance(doc, Document) else doc
for doc in retriever_docs[i]
]
# apply rank fusion
return self.weighted_reciprocal_rank(retriever_docs)
def weighted_reciprocal_rank(
self,
doc_lists: list[list[Document]],
) -> list[Document]:
"""Perform weighted Reciprocal Rank Fusion on multiple rank lists.
You can find more details about RRF here:
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf.
Args:
doc_lists: A list of rank lists, where each rank list contains unique items.
Returns:
The final aggregated list of items sorted by their weighted RRF
scores in descending order.
"""
if len(doc_lists) != len(self.weights):
msg = "Number of rank lists must be equal to the number of weights."
raise ValueError(msg)
# Associate each doc's content with its RRF score for later sorting by it
# Duplicated contents across retrievers are collapsed & scored cumulatively
rrf_score: dict[str, float] = defaultdict(float)
for doc_list, weight in zip(doc_lists, self.weights, strict=False):
for rank, doc in enumerate(doc_list, start=1):
rrf_score[
(
doc.page_content
if self.id_key is None
else doc.metadata[self.id_key]
)
] += weight / (rank + self.c)
# Docs are deduplicated by their contents then sorted by their scores
all_docs = chain.from_iterable(doc_lists)
return sorted(
unique_by_key(
all_docs,
lambda doc: (
doc.page_content
if self.id_key is None
else doc.metadata[self.id_key]
),
),
reverse=True,
key=lambda doc: rrf_score[
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
],
)

View File

@@ -0,0 +1,25 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import GoogleDocumentAIWarehouseRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"GoogleDocumentAIWarehouseRetriever",
]

View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"GoogleVertexAISearchRetriever": "langchain_community.retrievers",
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers",
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import KayAiRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"KayAiRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"KayAiRetriever",
]

View File

@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import AmazonKendraRetriever
from langchain_community.retrievers.kendra import (
AdditionalResultAttribute,
AdditionalResultAttributeValue,
DocumentAttribute,
DocumentAttributeValue,
Highlight,
QueryResult,
QueryResultItem,
ResultItem,
RetrieveResult,
RetrieveResultItem,
TextWithHighLights,
clean_excerpt,
combined_text,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"clean_excerpt": "langchain_community.retrievers.kendra",
"combined_text": "langchain_community.retrievers.kendra",
"Highlight": "langchain_community.retrievers.kendra",
"TextWithHighLights": "langchain_community.retrievers.kendra",
"AdditionalResultAttributeValue": "langchain_community.retrievers.kendra",
"AdditionalResultAttribute": "langchain_community.retrievers.kendra",
"DocumentAttributeValue": "langchain_community.retrievers.kendra",
"DocumentAttribute": "langchain_community.retrievers.kendra",
"ResultItem": "langchain_community.retrievers.kendra",
"QueryResultItem": "langchain_community.retrievers.kendra",
"RetrieveResultItem": "langchain_community.retrievers.kendra",
"QueryResult": "langchain_community.retrievers.kendra",
"RetrieveResult": "langchain_community.retrievers.kendra",
"AmazonKendraRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"AdditionalResultAttribute",
"AdditionalResultAttributeValue",
"AmazonKendraRetriever",
"DocumentAttribute",
"DocumentAttributeValue",
"Highlight",
"QueryResult",
"QueryResultItem",
"ResultItem",
"RetrieveResult",
"RetrieveResultItem",
"TextWithHighLights",
"clean_excerpt",
"combined_text",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import KNNRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"KNNRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"KNNRetriever",
]

View File

@@ -0,0 +1,30 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import (
LlamaIndexGraphRetriever,
LlamaIndexRetriever,
)
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"LlamaIndexRetriever": "langchain_community.retrievers",
"LlamaIndexGraphRetriever": "langchain_community.retrievers",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"LlamaIndexGraphRetriever",
"LlamaIndexRetriever",
]

View File

@@ -0,0 +1,119 @@
import asyncio
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""
retrievers: list[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get the relevant documents for a given query.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
return self.merge_documents(query, run_manager)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
return await self.amerge_documents(query, run_manager)
def merge_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Merge the results of the retrievers.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for _retriever, doc in zip(self.retrievers, retriever_docs, strict=False):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
async def amerge_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
run_manager: The callback handler to use.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.ainvoke(
query,
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
)
for i, retriever in enumerate(self.retrievers)
),
)
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(map(len, retriever_docs), default=0)
for i in range(max_docs):
for _retriever, doc in zip(self.retrievers, retriever_docs, strict=False):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import MetalRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"MetalRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"MetalRetriever",
]

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import MilvusRetriever
from langchain_community.retrievers.milvus import MilvusRetreiver
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {
"MilvusRetriever": "langchain_community.retrievers",
"MilvusRetreiver": "langchain_community.retrievers.milvus",
}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"MilvusRetreiver",
"MilvusRetriever",
]

View File

@@ -0,0 +1,240 @@
import asyncio
import logging
from collections.abc import Sequence
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from typing_extensions import override
from langchain_classic.chains.llm import LLMChain
logger = logging.getLogger(__name__)
class LineListOutputParser(BaseOutputParser[list[str]]):
"""Output parser for a list of lines."""
@override
def parse(self, text: str) -> list[str]:
lines = text.strip().split("\n")
return list(filter(None, lines)) # Remove empty lines
# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is
to generate 3 different versions of the given user
question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question,
your goal is to help the user overcome some of the limitations
of distance-based similarity search. Provide these alternative
questions separated by newlines. Original question: {question}""",
)
def _unique_documents(documents: Sequence[Document]) -> list[Document]:
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
class MultiQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to write a set of queries.
Retrieve docs for each query. Return the unique union of all retrieved docs.
"""
retriever: BaseRetriever
llm_chain: Runnable
verbose: bool = True
parser_key: str = "lines"
"""DEPRECATED. parser_key is no longer used and should not be specified."""
include_original: bool = False
"""Whether to include the original query in the list of generated queries."""
@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: str | None = None, # noqa: ARG003
include_original: bool = False, # noqa: FBT001,FBT002
) -> "MultiQueryRetriever":
"""Initialize from llm using default template.
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
prompt: The prompt which aims to generate several different versions
of the given user query
parser_key: DEPRECATED. `parser_key` is no longer used and should not be
specified.
include_original: Whether to include the original query in the list of
generated queries.
Returns:
MultiQueryRetriever
"""
output_parser = LineListOutputParser()
llm_chain = prompt | llm | output_parser
return cls(
retriever=retriever,
llm_chain=llm_chain,
include_original=include_original,
)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get relevant documents given a user query.
Args:
query: user query
run_manager: the callback handler to use.
Returns:
Unique union of relevant documents from all generated queries
"""
queries = await self.agenerate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = await self.aretrieve_documents(queries, run_manager)
return self.unique_union(documents)
async def agenerate_queries(
self,
question: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[str]:
"""Generate queries based upon user input.
Args:
question: user query
run_manager: the callback handler to use.
Returns:
List of LLM generated queries that are similar to the user input
"""
response = await self.llm_chain.ainvoke(
{"question": question},
config={"callbacks": run_manager.get_child()},
)
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info("Generated queries: %s", lines)
return lines
async def aretrieve_documents(
self,
queries: list[str],
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Run all LLM generated queries.
Args:
queries: query list
run_manager: the callback handler to use
Returns:
List of retrieved Documents
"""
document_lists = await asyncio.gather(
*(
self.retriever.ainvoke(
query,
config={"callbacks": run_manager.get_child()},
)
for query in queries
),
)
return [doc for docs in document_lists for doc in docs]
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get relevant documents given a user query.
Args:
query: user query
run_manager: the callback handler to use.
Returns:
Unique union of relevant documents from all generated queries
"""
queries = self.generate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = self.retrieve_documents(queries, run_manager)
return self.unique_union(documents)
def generate_queries(
self,
question: str,
run_manager: CallbackManagerForRetrieverRun,
) -> list[str]:
"""Generate queries based upon user input.
Args:
question: user query
run_manager: run manager for callbacks
Returns:
List of LLM generated queries that are similar to the user input
"""
response = self.llm_chain.invoke(
{"question": question},
config={"callbacks": run_manager.get_child()},
)
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info("Generated queries: %s", lines)
return lines
def retrieve_documents(
self,
queries: list[str],
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Run all LLM generated queries.
Args:
queries: query list
run_manager: run manager for callbacks
Returns:
List of retrieved Documents
"""
documents = []
for query in queries:
docs = self.retriever.invoke(
query,
config={"callbacks": run_manager.get_child()},
)
documents.extend(docs)
return documents
def unique_union(self, documents: list[Document]) -> list[Document]:
"""Get unique Documents.
Args:
documents: List of retrieved Documents
Returns:
List of unique retrieved Documents
"""
return _unique_documents(documents)

View File

@@ -0,0 +1,155 @@
from enum import Enum
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator
from typing_extensions import override
from langchain_classic.storage._lc_store import create_kv_docstore
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
"""Similarity search."""
similarity_score_threshold = "similarity_score_threshold"
"""Similarity search with a score threshold."""
mmr = "mmr"
"""Maximal Marginal Relevance reranking of similarity search."""
class MultiVectorRetriever(BaseRetriever):
"""Retriever that supports multiple embeddings per parent document.
This retriever is designed for scenarios where documents are split into
smaller chunks for embedding and vector search, but retrieval returns
the original parent documents rather than individual chunks.
It works by:
- Performing similarity (or MMR) search over embedded child chunks
- Collecting unique parent document IDs from chunk metadata
- Fetching and returning the corresponding parent documents from the docstore
This pattern is commonly used in RAG pipelines to improve answer grounding
while preserving full document context.
"""
vectorstore: VectorStore
"""The underlying `VectorStore` to use to store small chunks
and their embedding vectors"""
byte_store: ByteStore | None = None
"""The lower-level backing storage layer for the parent documents"""
docstore: BaseStore[str, Document]
"""The storage interface for the parent documents"""
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""
@model_validator(mode="before")
@classmethod
def _shim_docstore(cls, values: dict) -> Any:
byte_store = values.get("byte_store")
docstore = values.get("docstore")
if byte_store is not None:
docstore = create_kv_docstore(byte_store)
elif docstore is None:
msg = "You must pass a `byte_store` parameter."
raise ValueError(msg)
values["docstore"] = docstore
return values
@override
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents.
"""
if self.search_type == SearchType.mmr:
sub_docs = self.vectorstore.max_marginal_relevance_search(
query,
**self.search_kwargs,
)
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
self.vectorstore.similarity_search_with_relevance_scores(
query,
**self.search_kwargs,
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else:
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
# We do this to maintain the order of the IDs that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
@override
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents.
"""
if self.search_type == SearchType.mmr:
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
query,
**self.search_kwargs,
)
elif self.search_type == SearchType.similarity_score_threshold:
sub_docs_and_similarities = (
await self.vectorstore.asimilarity_search_with_relevance_scores(
query,
**self.search_kwargs,
)
)
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
else:
sub_docs = await self.vectorstore.asimilarity_search(
query,
**self.search_kwargs,
)
# We do this to maintain the order of the IDs that are returned
ids = []
for d in sub_docs:
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = await self.docstore.amget(ids)
return [d for d in docs if d is not None]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import OutlineRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"OutlineRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"OutlineRetriever",
]

View File

@@ -0,0 +1,176 @@
import uuid
from collections.abc import Sequence
from typing import Any
from langchain_core.documents import Document
from langchain_text_splitters import TextSplitter
from langchain_classic.retrievers import MultiVectorRetriever
class ParentDocumentRetriever(MultiVectorRetriever):
"""Retrieve small chunks then retrieve their parent documents.
When splitting documents for retrieval, there are often conflicting desires:
1. You may want to have small documents, so that their embeddings can most
accurately reflect their meaning. If too long, then the embeddings can
lose meaning.
2. You want to have long enough documents that the context of each chunk is
retained.
The ParentDocumentRetriever strikes that balance by splitting and storing
small chunks of data. During retrieval, it first fetches the small chunks
but then looks up the parent IDs for those chunks and returns those larger
documents.
Note that "parent document" refers to the document that a small chunk
originated from. This can either be the whole raw document OR a larger
chunk.
Examples:
```python
from langchain_chroma import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_classic.storage import InMemoryStore
# This text splitter is used to create the parent documents
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, add_start_index=True
)
# This text splitter is used to create the child documents
# It should create documents smaller than the parent
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=400, add_start_index=True
)
# The VectorStore to use to index the child chunks
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
# Initialize the retriever
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
```
"""
child_splitter: TextSplitter
"""The text splitter to use to create child documents."""
"""The key to use to track the parent id. This will be stored in the
metadata of child documents."""
parent_splitter: TextSplitter | None = None
"""The text splitter to use to create parent documents.
If none, then the parent documents will be the raw documents passed in."""
child_metadata_fields: Sequence[str] | None = None
"""Metadata fields to leave in child documents. If `None`, leave all parent document
metadata.
"""
def _split_docs_for_adding(
self,
documents: list[Document],
ids: list[str] | None = None,
*,
add_to_docstore: bool = True,
) -> tuple[list[Document], list[tuple[str, Document]]]:
if self.parent_splitter is not None:
documents = self.parent_splitter.split_documents(documents)
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
if not add_to_docstore:
msg = "If IDs are not passed in, `add_to_docstore` MUST be True"
raise ValueError(msg)
else:
if len(documents) != len(ids):
msg = (
"Got uneven list of documents and ids. "
"If `ids` is provided, should be same length as `documents`."
)
raise ValueError(msg)
doc_ids = ids
docs = []
full_docs = []
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
if self.child_metadata_fields is not None:
for _doc in sub_docs:
_doc.metadata = {
k: _doc.metadata[k] for k in self.child_metadata_fields
}
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)
full_docs.append((_id, doc))
return docs, full_docs
def add_documents(
self,
documents: list[Document],
ids: list[str] | None = None,
add_to_docstore: bool = True, # noqa: FBT001,FBT002
**kwargs: Any,
) -> None:
"""Adds documents to the docstore and vectorstores.
Args:
documents: List of documents to add
ids: Optional list of IDs for documents. If provided should be the same
length as the list of documents. Can be provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
IDs.
add_to_docstore: Boolean of whether to add documents to docstore.
This can be false if and only if `ids` are provided. You may want
to set this to False if the documents are already in the docstore
and you don't want to re-add them.
**kwargs: additional keyword arguments passed to the `VectorStore`.
"""
docs, full_docs = self._split_docs_for_adding(
documents,
ids,
add_to_docstore=add_to_docstore,
)
self.vectorstore.add_documents(docs, **kwargs)
if add_to_docstore:
self.docstore.mset(full_docs)
async def aadd_documents(
self,
documents: list[Document],
ids: list[str] | None = None,
add_to_docstore: bool = True, # noqa: FBT001,FBT002
**kwargs: Any,
) -> None:
"""Adds documents to the docstore and vectorstores.
Args:
documents: List of documents to add
ids: Optional list of IDs for documents. If provided should be the same
length as the list of documents. Can be provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
idIDss.
add_to_docstore: Boolean of whether to add documents to docstore.
This can be false if and only if `ids` are provided. You may want
to set this to False if the documents are already in the docstore
and you don't want to re-add them.
**kwargs: additional keyword arguments passed to the `VectorStore`.
"""
docs, full_docs = self._split_docs_for_adding(
documents,
ids,
add_to_docstore=add_to_docstore,
)
await self.vectorstore.aadd_documents(docs, **kwargs)
if add_to_docstore:
await self.docstore.amset(full_docs)

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import PineconeHybridSearchRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"PineconeHybridSearchRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"PineconeHybridSearchRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import PubMedRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"PubMedRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"PubMedRetriever",
]

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING, Any
from langchain_classic._api import create_importer
if TYPE_CHECKING:
from langchain_community.retrievers import PubMedRetriever
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and
# handling optional imports.
DEPRECATED_LOOKUP = {"PubMedRetriever": "langchain_community.retrievers"}
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
def __getattr__(name: str) -> Any:
"""Look up attributes dynamically."""
return _import_attribute(name)
__all__ = [
"PubMedRetriever",
]

View File

@@ -0,0 +1,92 @@
import logging
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
logger = logging.getLogger(__name__)
# Default template
DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \
query from a user and converting it into a query for a vectorstore. \
In this process, you strip out information that is not relevant for \
the retrieval task. Here is the user query: {question}"""
# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE)
class RePhraseQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to re-phrase it.
Then, retrieve docs for the re-phrased query.
"""
retriever: BaseRetriever
llm_chain: Runnable
@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLLM,
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
) -> "RePhraseQueryRetriever":
"""Initialize from llm using default template.
The prompt used here expects a single input: `question`
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
prompt: prompt template for query generation
Returns:
RePhraseQueryRetriever
"""
llm_chain = prompt | llm | StrOutputParser()
return cls(
retriever=retriever,
llm_chain=llm_chain,
)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""Get relevant documents given a user question.
Args:
query: user question
run_manager: callback handler to use
Returns:
Relevant documents for re-phrased question
"""
re_phrased_question = self.llm_chain.invoke(
query,
{"callbacks": run_manager.get_child()},
)
logger.info("Re-phrased question: %s", re_phrased_question)
return self.retriever.invoke(
re_phrased_question,
config={"callbacks": run_manager.get_child()},
)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> list[Document]:
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show More