initial commit
This commit is contained in:
169
venv/Lib/site-packages/langchain_classic/retrievers/__init__.py
Normal file
169
venv/Lib/site-packages/langchain_classic/retrievers/__init__.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""**Retriever** class returns Documents given a text **query**.
|
||||
|
||||
It is more general than a vector store. A retriever does not need to be able to
|
||||
store documents, only to return (or retrieve) it. Vector stores can be used as
|
||||
the backbone of a retriever, but there are other types of retrievers as well.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api.module_import import create_importer
|
||||
from langchain_classic.retrievers.contextual_compression import (
|
||||
ContextualCompressionRetriever,
|
||||
)
|
||||
from langchain_classic.retrievers.ensemble import EnsembleRetriever
|
||||
from langchain_classic.retrievers.merger_retriever import MergerRetriever
|
||||
from langchain_classic.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain_classic.retrievers.multi_vector import MultiVectorRetriever
|
||||
from langchain_classic.retrievers.parent_document_retriever import (
|
||||
ParentDocumentRetriever,
|
||||
)
|
||||
from langchain_classic.retrievers.re_phraser import RePhraseQueryRetriever
|
||||
from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain_classic.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import (
|
||||
AmazonKendraRetriever,
|
||||
AmazonKnowledgeBasesRetriever,
|
||||
ArceeRetriever,
|
||||
ArxivRetriever,
|
||||
AzureAISearchRetriever,
|
||||
AzureCognitiveSearchRetriever,
|
||||
BM25Retriever,
|
||||
ChaindeskRetriever,
|
||||
ChatGPTPluginRetriever,
|
||||
CohereRagRetriever,
|
||||
DocArrayRetriever,
|
||||
DriaRetriever,
|
||||
ElasticSearchBM25Retriever,
|
||||
EmbedchainRetriever,
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleDocumentAIWarehouseRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
KayAiRetriever,
|
||||
KNNRetriever,
|
||||
LlamaIndexGraphRetriever,
|
||||
LlamaIndexRetriever,
|
||||
MetalRetriever,
|
||||
MilvusRetriever,
|
||||
NeuralDBRetriever,
|
||||
OutlineRetriever,
|
||||
PineconeHybridSearchRetriever,
|
||||
PubMedRetriever,
|
||||
RemoteLangChainRetriever,
|
||||
SVMRetriever,
|
||||
TavilySearchAPIRetriever,
|
||||
TFIDFRetriever,
|
||||
VespaRetriever,
|
||||
WeaviateHybridSearchRetriever,
|
||||
WebResearchRetriever,
|
||||
WikipediaRetriever,
|
||||
ZepRetriever,
|
||||
ZillizRetriever,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"AmazonKendraRetriever": "langchain_community.retrievers",
|
||||
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers",
|
||||
"ArceeRetriever": "langchain_community.retrievers",
|
||||
"ArxivRetriever": "langchain_community.retrievers",
|
||||
"AzureAISearchRetriever": "langchain_community.retrievers",
|
||||
"AzureCognitiveSearchRetriever": "langchain_community.retrievers",
|
||||
"ChatGPTPluginRetriever": "langchain_community.retrievers",
|
||||
"ChaindeskRetriever": "langchain_community.retrievers",
|
||||
"CohereRagRetriever": "langchain_community.retrievers",
|
||||
"ElasticSearchBM25Retriever": "langchain_community.retrievers",
|
||||
"EmbedchainRetriever": "langchain_community.retrievers",
|
||||
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers",
|
||||
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers",
|
||||
"GoogleVertexAISearchRetriever": "langchain_community.retrievers",
|
||||
"KayAiRetriever": "langchain_community.retrievers",
|
||||
"KNNRetriever": "langchain_community.retrievers",
|
||||
"LlamaIndexGraphRetriever": "langchain_community.retrievers",
|
||||
"LlamaIndexRetriever": "langchain_community.retrievers",
|
||||
"MetalRetriever": "langchain_community.retrievers",
|
||||
"MilvusRetriever": "langchain_community.retrievers",
|
||||
"OutlineRetriever": "langchain_community.retrievers",
|
||||
"PineconeHybridSearchRetriever": "langchain_community.retrievers",
|
||||
"PubMedRetriever": "langchain_community.retrievers",
|
||||
"RemoteLangChainRetriever": "langchain_community.retrievers",
|
||||
"SVMRetriever": "langchain_community.retrievers",
|
||||
"TavilySearchAPIRetriever": "langchain_community.retrievers",
|
||||
"BM25Retriever": "langchain_community.retrievers",
|
||||
"DriaRetriever": "langchain_community.retrievers",
|
||||
"NeuralDBRetriever": "langchain_community.retrievers",
|
||||
"TFIDFRetriever": "langchain_community.retrievers",
|
||||
"VespaRetriever": "langchain_community.retrievers",
|
||||
"WeaviateHybridSearchRetriever": "langchain_community.retrievers",
|
||||
"WebResearchRetriever": "langchain_community.retrievers",
|
||||
"WikipediaRetriever": "langchain_community.retrievers",
|
||||
"ZepRetriever": "langchain_community.retrievers",
|
||||
"ZillizRetriever": "langchain_community.retrievers",
|
||||
"DocArrayRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"BM25Retriever",
|
||||
"ChaindeskRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"CohereRagRetriever",
|
||||
"ContextualCompressionRetriever",
|
||||
"DocArrayRetriever",
|
||||
"DriaRetriever",
|
||||
"ElasticSearchBM25Retriever",
|
||||
"EmbedchainRetriever",
|
||||
"EnsembleRetriever",
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||
"GoogleVertexAISearchRetriever",
|
||||
"KNNRetriever",
|
||||
"KayAiRetriever",
|
||||
"LlamaIndexGraphRetriever",
|
||||
"LlamaIndexRetriever",
|
||||
"MergerRetriever",
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"MultiQueryRetriever",
|
||||
"MultiVectorRetriever",
|
||||
"NeuralDBRetriever",
|
||||
"OutlineRetriever",
|
||||
"ParentDocumentRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
"RePhraseQueryRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"SVMRetriever",
|
||||
"SelfQueryRetriever",
|
||||
"TFIDFRetriever",
|
||||
"TavilySearchAPIRetriever",
|
||||
"TimeWeightedVectorStoreRetriever",
|
||||
"VespaRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"WebResearchRetriever",
|
||||
"WikipediaRetriever",
|
||||
"ZepRetriever",
|
||||
"ZillizRetriever",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
venv/Lib/site-packages/langchain_classic/retrievers/arcee.py
Normal file
23
venv/Lib/site-packages/langchain_classic/retrievers/arcee.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import ArceeRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ArceeRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ArceeRetriever",
|
||||
]
|
||||
23
venv/Lib/site-packages/langchain_classic/retrievers/arxiv.py
Normal file
23
venv/Lib/site-packages/langchain_classic/retrievers/arxiv.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import ArxivRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ArxivRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ArxivRetriever",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import (
|
||||
AzureAISearchRetriever,
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"AzureAISearchRetriever": "langchain_community.retrievers",
|
||||
"AzureCognitiveSearchRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
|
||||
from langchain_community.retrievers.bedrock import (
|
||||
RetrievalConfig,
|
||||
VectorSearchConfig,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"VectorSearchConfig": "langchain_community.retrievers.bedrock",
|
||||
"RetrievalConfig": "langchain_community.retrievers.bedrock",
|
||||
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"RetrievalConfig",
|
||||
"VectorSearchConfig",
|
||||
]
|
||||
28
venv/Lib/site-packages/langchain_classic/retrievers/bm25.py
Normal file
28
venv/Lib/site-packages/langchain_classic/retrievers/bm25.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_community.retrievers.bm25 import default_preprocessing_func
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"default_preprocessing_func": "langchain_community.retrievers.bm25",
|
||||
"BM25Retriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BM25Retriever",
|
||||
"default_preprocessing_func",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import ChaindeskRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ChaindeskRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChaindeskRetriever",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import ChatGPTPluginRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ChatGPTPluginRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatGPTPluginRetriever",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import CohereRagRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"CohereRagRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CohereRagRetriever",
|
||||
]
|
||||
@@ -0,0 +1,68 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ContextualCompressionRetriever(BaseRetriever):
|
||||
"""Retriever that wraps a base retriever and compresses the results."""
|
||||
|
||||
base_compressor: BaseDocumentCompressor
|
||||
"""Compressor for compressing retrieved documents."""
|
||||
|
||||
base_retriever: RetrieverLike
|
||||
"""Base Retriever to use for getting relevant documents."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
docs = self.base_retriever.invoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
**kwargs,
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = self.base_compressor.compress_documents(
|
||||
docs,
|
||||
query,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
return list(compressed_docs)
|
||||
return []
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
docs = await self.base_retriever.ainvoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
**kwargs,
|
||||
)
|
||||
if docs:
|
||||
compressed_docs = await self.base_compressor.acompress_documents(
|
||||
docs,
|
||||
query,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
return list(compressed_docs)
|
||||
return []
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers.databerry import DataberryRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"DataberryRetriever": "langchain_community.retrievers.databerry"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DataberryRetriever",
|
||||
]
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import DocArrayRetriever
|
||||
from langchain_community.retrievers.docarray import SearchType
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"SearchType": "langchain_community.retrievers.docarray",
|
||||
"DocArrayRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DocArrayRetriever",
|
||||
"SearchType",
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
from langchain_classic.retrievers.document_compressors.base import (
|
||||
DocumentCompressorPipeline,
|
||||
)
|
||||
from langchain_classic.retrievers.document_compressors.chain_extract import (
|
||||
LLMChainExtractor,
|
||||
)
|
||||
from langchain_classic.retrievers.document_compressors.chain_filter import (
|
||||
LLMChainFilter,
|
||||
)
|
||||
from langchain_classic.retrievers.document_compressors.cohere_rerank import CohereRerank
|
||||
from langchain_classic.retrievers.document_compressors.cross_encoder_rerank import (
|
||||
CrossEncoderReranker,
|
||||
)
|
||||
from langchain_classic.retrievers.document_compressors.embeddings_filter import (
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
from langchain_classic.retrievers.document_compressors.listwise_rerank import (
|
||||
LLMListwiseRerank,
|
||||
)
|
||||
|
||||
_module_lookup = {
|
||||
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
msg = f"module {__name__} has no attribute {name}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CohereRerank",
|
||||
"CrossEncoderReranker",
|
||||
"DocumentCompressorPipeline",
|
||||
"EmbeddingsFilter",
|
||||
"FlashrankRerank",
|
||||
"LLMChainExtractor",
|
||||
"LLMChainFilter",
|
||||
"LLMListwiseRerank",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
from collections.abc import Sequence
|
||||
from inspect import signature
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import (
|
||||
BaseDocumentCompressor,
|
||||
BaseDocumentTransformer,
|
||||
Document,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class DocumentCompressorPipeline(BaseDocumentCompressor):
|
||||
"""Document compressor that uses a pipeline of Transformers."""
|
||||
|
||||
transformers: list[BaseDocumentTransformer | BaseDocumentCompressor]
|
||||
"""List of document filters that are chained together and run in sequence."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
accepts_callbacks = (
|
||||
signature(_transformer.compress_documents).parameters.get(
|
||||
"callbacks",
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if accepts_callbacks:
|
||||
documents = _transformer.compress_documents(
|
||||
documents,
|
||||
query,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
documents = _transformer.compress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = _transformer.transform_documents(documents)
|
||||
else:
|
||||
msg = f"Got unexpected transformer type: {_transformer}" # type: ignore[unreachable]
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
return documents
|
||||
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context."""
|
||||
for _transformer in self.transformers:
|
||||
if isinstance(_transformer, BaseDocumentCompressor):
|
||||
accepts_callbacks = (
|
||||
signature(_transformer.acompress_documents).parameters.get(
|
||||
"callbacks",
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if accepts_callbacks:
|
||||
documents = await _transformer.acompress_documents(
|
||||
documents,
|
||||
query,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
documents = await _transformer.acompress_documents(documents, query)
|
||||
elif isinstance(_transformer, BaseDocumentTransformer):
|
||||
documents = await _transformer.atransform_documents(documents)
|
||||
else:
|
||||
msg = f"Got unexpected transformer type: {_transformer}" # type: ignore[unreachable]
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
return documents
|
||||
@@ -0,0 +1,126 @@
|
||||
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_classic.retrievers.document_compressors.chain_extract_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def default_get_input(query: str, doc: Document) -> dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
return {"question": query, "context": doc.page_content}
|
||||
|
||||
|
||||
class NoOutputParser(BaseOutputParser[str]):
|
||||
"""Parse outputs that could return a null string of some sort."""
|
||||
|
||||
no_output_str: str = "NO_OUTPUT"
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> str:
|
||||
cleaned_text = text.strip()
|
||||
if cleaned_text == self.no_output_str:
|
||||
return ""
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
output_parser = NoOutputParser()
|
||||
template = prompt_template.format(no_output_str=output_parser.no_output_str)
|
||||
return PromptTemplate(
|
||||
template=template,
|
||||
input_variables=["question", "context"],
|
||||
output_parser=output_parser,
|
||||
)
|
||||
|
||||
|
||||
class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""LLM Chain Extractor.
|
||||
|
||||
Document compressor that uses an LLM chain to extract
|
||||
the relevant parts of documents.
|
||||
"""
|
||||
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for compressing documents."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress page content of raw documents."""
|
||||
compressed_docs = []
|
||||
for doc in documents:
|
||||
_input = self.get_input(query, doc)
|
||||
output_ = self.llm_chain.invoke(_input, config={"callbacks": callbacks})
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
output = self.llm_chain.prompt.output_parser.parse(output)
|
||||
else:
|
||||
output = output_
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(
|
||||
Document(page_content=cast("str", output), metadata=doc.metadata),
|
||||
)
|
||||
return compressed_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress page content of raw documents asynchronously."""
|
||||
inputs = [self.get_input(query, doc) for doc in documents]
|
||||
outputs = await self.llm_chain.abatch(inputs, {"callbacks": callbacks})
|
||||
compressed_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
if len(outputs[i]) == 0:
|
||||
continue
|
||||
compressed_docs.append(
|
||||
Document(page_content=outputs[i], metadata=doc.metadata),
|
||||
)
|
||||
return compressed_docs
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: PromptTemplate | None = None,
|
||||
get_input: Callable[[str, Document], str] | None = None,
|
||||
llm_chain_kwargs: dict | None = None, # noqa: ARG003
|
||||
) -> LLMChainExtractor:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
_get_input = get_input if get_input is not None else default_get_input
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, get_input=_get_input)
|
||||
@@ -0,0 +1,10 @@
|
||||
prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}.
|
||||
|
||||
Remember, *DO NOT* edit the extracted parts of the context.
|
||||
|
||||
> Question: {{question}}
|
||||
> Context:
|
||||
>>>
|
||||
{{context}}
|
||||
>>>
|
||||
Extracted relevant parts:""" # noqa: E501
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_classic.chains import LLMChain
|
||||
from langchain_classic.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain_classic.retrievers.document_compressors.chain_filter_prompt import (
|
||||
prompt_template,
|
||||
)
|
||||
|
||||
|
||||
def _get_default_chain_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["question", "context"],
|
||||
output_parser=BooleanOutputParser(),
|
||||
)
|
||||
|
||||
|
||||
def default_get_input(query: str, doc: Document) -> dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
return {"question": query, "context": doc.page_content}
|
||||
|
||||
|
||||
class LLMChainFilter(BaseDocumentCompressor):
|
||||
"""Filter that drops documents that aren't relevant to the query."""
|
||||
|
||||
llm_chain: Runnable
|
||||
"""LLM wrapper to use for filtering documents.
|
||||
The chain prompt is expected to have a BooleanOutputParser."""
|
||||
|
||||
get_input: Callable[[str, Document], dict] = default_get_input
|
||||
"""Callable for constructing the chain input from the query and a Document."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
|
||||
config = RunnableConfig(callbacks=callbacks)
|
||||
outputs = zip(
|
||||
self.llm_chain.batch(
|
||||
[self.get_input(query, doc) for doc in documents],
|
||||
config=config,
|
||||
),
|
||||
documents,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
elif isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
return filtered_docs
|
||||
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
|
||||
config = RunnableConfig(callbacks=callbacks)
|
||||
outputs = zip(
|
||||
await self.llm_chain.abatch(
|
||||
[self.get_input(query, doc) for doc in documents],
|
||||
config=config,
|
||||
),
|
||||
documents,
|
||||
strict=False,
|
||||
)
|
||||
for output_, doc in outputs:
|
||||
include_doc = None
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
output = output_[self.llm_chain.output_key]
|
||||
if self.llm_chain.prompt.output_parser is not None:
|
||||
include_doc = self.llm_chain.prompt.output_parser.parse(output)
|
||||
elif isinstance(output_, bool):
|
||||
include_doc = output_
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
return filtered_docs
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "LLMChainFilter":
|
||||
"""Create a LLMChainFilter from a language model.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for filtering.
|
||||
prompt: The prompt to use for the filter.
|
||||
kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A LLMChainFilter that uses the given language model.
|
||||
"""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
if _prompt.output_parser is not None:
|
||||
parser = _prompt.output_parser
|
||||
else:
|
||||
parser = StrOutputParser()
|
||||
llm_chain = _prompt | llm | parser
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
@@ -0,0 +1,8 @@
|
||||
prompt_template = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't.
|
||||
|
||||
> Question: {question}
|
||||
> Context:
|
||||
>>>
|
||||
{context}
|
||||
>>>
|
||||
> Relevant (YES / NO):""" # noqa: E501
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.30",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_cohere.CohereRerank",
|
||||
)
|
||||
class CohereRerank(BaseDocumentCompressor):
|
||||
"""Document compressor that uses `Cohere Rerank API`."""
|
||||
|
||||
client: Any = None
|
||||
"""Cohere client to use for compressing documents."""
|
||||
top_n: int | None = 3
|
||||
"""Number of documents to return."""
|
||||
model: str = "rerank-english-v2.0"
|
||||
"""Model to use for reranking."""
|
||||
cohere_api_key: str | None = None
|
||||
"""Cohere API key. Must be specified directly or via environment variable
|
||||
COHERE_API_KEY."""
|
||||
user_agent: str = "langchain"
|
||||
"""Identifier for the application making the request."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if not values.get("client"):
|
||||
try:
|
||||
import cohere
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values,
|
||||
"cohere_api_key",
|
||||
"COHERE_API_KEY",
|
||||
)
|
||||
client_name = values.get("user_agent", "langchain")
|
||||
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
|
||||
return values
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
documents: Sequence[str | Document | dict],
|
||||
query: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
top_n: int | None = -1,
|
||||
max_chunks_per_doc: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns an ordered list of documents ordered by their relevance to the provided query.
|
||||
|
||||
Args:
|
||||
query: The query to use for reranking.
|
||||
documents: A sequence of documents to rerank.
|
||||
model: The model to use for re-ranking. Default to self.model.
|
||||
top_n : The number of results to return. If `None` returns all results.
|
||||
max_chunks_per_doc : The maximum number of chunks derived from a document.
|
||||
""" # noqa: E501
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
docs = [
|
||||
doc.page_content if isinstance(doc, Document) else doc for doc in documents
|
||||
]
|
||||
model = model or self.model
|
||||
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
|
||||
results = self.client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
if hasattr(results, "results"):
|
||||
results = results.results
|
||||
return [
|
||||
{"index": res.index, "relevance_score": res.relevance_score}
|
||||
for res in results
|
||||
]
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress documents using Cohere's rerank API.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
compressed = []
|
||||
for res in self.rerank(documents, query):
|
||||
doc = documents[res["index"]]
|
||||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
|
||||
doc_copy.metadata["relevance_score"] = res["relevance_score"]
|
||||
compressed.append(doc_copy)
|
||||
return compressed
|
||||
@@ -0,0 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseCrossEncoder(ABC):
|
||||
"""Interface for cross encoder models."""
|
||||
|
||||
@abstractmethod
|
||||
def score(self, text_pairs: list[tuple[str, str]]) -> list[float]:
|
||||
"""Score pairs' similarity.
|
||||
|
||||
Args:
|
||||
text_pairs: List of pairs of texts.
|
||||
|
||||
Returns:
|
||||
List of scores.
|
||||
"""
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.retrievers.document_compressors.cross_encoder import (
|
||||
BaseCrossEncoder,
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderReranker(BaseDocumentCompressor):
|
||||
"""Document compressor that uses CrossEncoder for reranking."""
|
||||
|
||||
model: BaseCrossEncoder
|
||||
"""CrossEncoder model to use for scoring similarity
|
||||
between the query and documents."""
|
||||
top_n: int = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Rerank documents using CrossEncoder.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
scores = self.model.score([(query, doc.page_content) for doc in documents])
|
||||
docs_with_scores = list(zip(documents, scores, strict=False))
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
return [doc for doc, _ in result[: self.top_n]]
|
||||
@@ -0,0 +1,141 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def _get_similarity_function() -> Callable:
|
||||
try:
|
||||
from langchain_community.utils.math import cosine_similarity
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use please install langchain-community "
|
||||
"with `pip install langchain-community`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
return cosine_similarity
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
"""Embeddings Filter.
|
||||
|
||||
Document compressor that uses embeddings to drop documents unrelated to the query.
|
||||
"""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings to use for embedding document contents and queries."""
|
||||
similarity_fn: Callable = Field(default_factory=_get_similarity_function)
|
||||
"""Similarity function for comparing documents. Function expected to take as input
|
||||
two matrices (List[List[float]]) and return a matrix of scores where higher values
|
||||
indicate greater similarity."""
|
||||
k: int | None = 20
|
||||
"""The number of relevant documents to return. Can be set to `None`, in which case
|
||||
`similarity_threshold` must be specified."""
|
||||
similarity_threshold: float | None = None
|
||||
"""Threshold for determining when two documents are similar enough
|
||||
to be considered redundant. Defaults to `None`, must be specified if `k` is set
|
||||
to None."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_params(cls, values: dict) -> dict:
|
||||
"""Validate similarity parameters."""
|
||||
if values["k"] is None and values["similarity_threshold"] is None:
|
||||
msg = "Must specify one of `k` or `similarity_threshold`."
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter documents based on similarity of their embeddings to the query."""
|
||||
try:
|
||||
from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501
|
||||
_get_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use please install langchain-community "
|
||||
"with `pip install langchain-community`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
msg = "Could not import numpy, please install with `pip install numpy`."
|
||||
raise ImportError(msg) from e
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings,
|
||||
stateful_documents,
|
||||
)
|
||||
embedded_query = self.embeddings.embed_query(query)
|
||||
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
|
||||
included_idxs: np.ndarray = np.arange(len(embedded_documents))
|
||||
if self.k is not None:
|
||||
included_idxs = np.argsort(similarity)[::-1][: self.k]
|
||||
if self.similarity_threshold is not None:
|
||||
similar_enough = np.where(
|
||||
similarity[included_idxs] > self.similarity_threshold,
|
||||
)
|
||||
included_idxs = included_idxs[similar_enough]
|
||||
for i in included_idxs:
|
||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
|
||||
@override
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter documents based on similarity of their embeddings to the query."""
|
||||
try:
|
||||
from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501
|
||||
_aget_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
)
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"To use please install langchain-community "
|
||||
"with `pip install langchain-community`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
msg = "Could not import numpy, please install with `pip install numpy`."
|
||||
raise ImportError(msg) from e
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = await _aget_embeddings_from_stateful_docs(
|
||||
self.embeddings,
|
||||
stateful_documents,
|
||||
)
|
||||
embedded_query = await self.embeddings.aembed_query(query)
|
||||
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
|
||||
included_idxs: np.ndarray = np.arange(len(embedded_documents))
|
||||
if self.k is not None:
|
||||
included_idxs = np.argsort(similarity)[::-1][: self.k]
|
||||
if self.similarity_threshold is not None:
|
||||
similar_enough = np.where(
|
||||
similarity[included_idxs] > self.similarity_threshold,
|
||||
)
|
||||
included_idxs = included_idxs[similar_enough]
|
||||
for i in included_idxs:
|
||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.document_compressors.flashrank_rerank import (
|
||||
FlashrankRerank,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FlashrankRerank",
|
||||
]
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
_default_system_tmpl = """{context}
|
||||
|
||||
Sort the Documents by their relevance to the Query."""
|
||||
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[("system", _default_system_tmpl), ("human", "{query}")],
|
||||
)
|
||||
|
||||
|
||||
def _get_prompt_input(input_: dict) -> dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
documents = input_["documents"]
|
||||
context = ""
|
||||
for index, doc in enumerate(documents):
|
||||
context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
|
||||
document_range = "empty list"
|
||||
if len(documents) > 0:
|
||||
document_range = f"Document ID: 0, ..., Document ID: {len(documents) - 1}"
|
||||
context += f"Documents = [{document_range}]"
|
||||
return {"query": input_["query"], "context": context}
|
||||
|
||||
|
||||
def _parse_ranking(results: dict) -> list[Document]:
|
||||
ranking = results["ranking"]
|
||||
docs = results["documents"]
|
||||
return [docs[i] for i in ranking.ranked_document_ids]
|
||||
|
||||
|
||||
class LLMListwiseRerank(BaseDocumentCompressor):
|
||||
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`.
|
||||
|
||||
Adapted from: https://arxiv.org/pdf/2305.02156.pdf
|
||||
|
||||
`LLMListwiseRerank` uses a language model to rerank a list of documents based on
|
||||
their relevance to a query.
|
||||
|
||||
!!! note
|
||||
Requires that underlying model implement `with_structured_output`.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from langchain_classic.retrievers.document_compressors.listwise_rerank import (
|
||||
LLMListwiseRerank,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
documents = [
|
||||
Document("Sally is my friend from school"),
|
||||
Document("Steve is my friend from home"),
|
||||
Document("I didn't always like yogurt"),
|
||||
Document("I wonder why it's called football"),
|
||||
Document("Where's waldo"),
|
||||
]
|
||||
|
||||
reranker = LLMListwiseRerank.from_llm(
|
||||
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
|
||||
)
|
||||
compressed_docs = reranker.compress_documents(documents, "Who is steve")
|
||||
assert len(compressed_docs) == 3
|
||||
assert "Steve" in compressed_docs[0].page_content
|
||||
```
|
||||
"""
|
||||
|
||||
reranker: Runnable[dict, list[Document]]
|
||||
"""LLM-based reranker to use for filtering documents. Expected to take in a dict
|
||||
with 'documents: Sequence[Document]' and 'query: str' keys and output a
|
||||
List[Document]."""
|
||||
|
||||
top_n: int = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
results = self.reranker.invoke(
|
||||
{"documents": documents, "query": query},
|
||||
config={"callbacks": callbacks},
|
||||
)
|
||||
return results[: self.top_n]
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
prompt: BasePromptTemplate | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "LLMListwiseRerank":
|
||||
"""Create a LLMListwiseRerank document compressor from a language model.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for filtering. **Must implement
|
||||
BaseLanguageModel.with_structured_output().**
|
||||
prompt: The prompt to use for the filter.
|
||||
kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A LLMListwiseRerank document compressor that uses the given language model.
|
||||
"""
|
||||
if type(llm).with_structured_output == BaseLanguageModel.with_structured_output:
|
||||
msg = (
|
||||
f"llm of type {type(llm)} does not implement `with_structured_output`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
class RankDocuments(BaseModel):
|
||||
"""Rank the documents by their relevance to the user question.
|
||||
|
||||
Rank from most to least relevant.
|
||||
"""
|
||||
|
||||
ranked_document_ids: list[int] = Field(
|
||||
...,
|
||||
description=(
|
||||
"The integer IDs of the documents, sorted from most to least "
|
||||
"relevant to the user question."
|
||||
),
|
||||
)
|
||||
|
||||
_prompt = prompt if prompt is not None else _DEFAULT_PROMPT
|
||||
reranker = RunnablePassthrough.assign(
|
||||
ranking=RunnableLambda(_get_prompt_input)
|
||||
| _prompt
|
||||
| llm.with_structured_output(RankDocuments),
|
||||
) | RunnableLambda(_parse_ranking)
|
||||
return cls(reranker=reranker, **kwargs)
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import ElasticSearchBM25Retriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ElasticSearchBM25Retriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ElasticSearchBM25Retriever",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import EmbedchainRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"EmbedchainRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbedchainRetriever",
|
||||
]
|
||||
352
venv/Lib/site-packages/langchain_classic/retrievers/ensemble.py
Normal file
352
venv/Lib/site-packages/langchain_classic/retrievers/ensemble.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Ensemble Retriever.
|
||||
|
||||
Ensemble retriever that ensemble the results of
|
||||
multiple retrievers by using weighted Reciprocal Rank Fusion.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Hashable, Iterable, Iterator
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.config import ensure_config, patch_config
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
T = TypeVar("T")
|
||||
H = TypeVar("H", bound=Hashable)
|
||||
|
||||
|
||||
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
|
||||
"""Yield unique elements of an iterable based on a key function.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to filter.
|
||||
key: A function that returns a hashable key for each element.
|
||||
|
||||
Yields:
|
||||
Unique elements of the iterable based on the key function.
|
||||
"""
|
||||
seen = set()
|
||||
for e in iterable:
|
||||
if (k := key(e)) not in seen:
|
||||
seen.add(k)
|
||||
yield e
|
||||
|
||||
|
||||
class EnsembleRetriever(BaseRetriever):
|
||||
"""Retriever that ensembles the multiple retrievers.
|
||||
|
||||
It uses a rank fusion.
|
||||
|
||||
Args:
|
||||
retrievers: A list of retrievers to ensemble.
|
||||
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
||||
weighting for all retrievers.
|
||||
c: A constant added to the rank, controlling the balance between the importance
|
||||
of high-ranked items and the consideration given to lower-ranked items.
|
||||
id_key: The key in the document's metadata used to determine unique documents.
|
||||
If not specified, page_content is used.
|
||||
"""
|
||||
|
||||
retrievers: list[RetrieverLike]
|
||||
weights: list[float]
|
||||
c: int = 60
|
||||
id_key: str | None = None
|
||||
|
||||
@property
|
||||
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||||
"""List configurable fields for this runnable."""
|
||||
return get_unique_config_specs(
|
||||
spec for retriever in self.retrievers for spec in retriever.config_specs
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _set_weights(cls, values: dict[str, Any]) -> Any:
|
||||
weights = values.get("weights")
|
||||
|
||||
if not weights:
|
||||
n_retrievers = len(values["retrievers"])
|
||||
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||||
return values
|
||||
|
||||
retrievers = values["retrievers"]
|
||||
if len(weights) != len(retrievers):
|
||||
msg = (
|
||||
"Length of weights must match number of retrievers "
|
||||
f"(got {len(weights)} weights for {len(retrievers)} retrievers)."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not any(w > 0 for w in weights):
|
||||
msg = "At least one ensemble weight must be greater than zero."
|
||||
raise ValueError(msg)
|
||||
|
||||
return values
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: str,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str,
|
||||
config: RunnableConfig | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
from langchain_core.callbacks import AsyncCallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags", []),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata", {}),
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
None,
|
||||
input,
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
result = await self.arank_fusion(
|
||||
input,
|
||||
run_manager=run_manager,
|
||||
config=config,
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
# Get fused result of the retrievers.
|
||||
return self.rank_fusion(query, run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
# Get fused result of the retrievers.
|
||||
return await self.arank_fusion(query, run_manager)
|
||||
|
||||
def rank_fusion(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: RunnableConfig | None = None,
|
||||
) -> list[Document]:
|
||||
"""Rank fusion.
|
||||
|
||||
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||||
the final result.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
config: Optional configuration for the retrievers.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.invoke(
|
||||
query,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
retriever_docs[i] = [
|
||||
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable]
|
||||
for doc in retriever_docs[i]
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
return self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
async def arank_fusion(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
*,
|
||||
config: RunnableConfig | None = None,
|
||||
) -> list[Document]:
|
||||
"""Rank fusion.
|
||||
|
||||
Asynchronously retrieve the results of the retrievers
|
||||
and use rank_fusion_func to get the final result.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
config: Optional configuration for the retrievers.
|
||||
|
||||
Returns:
|
||||
A list of reranked documents.
|
||||
"""
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = await asyncio.gather(
|
||||
*[
|
||||
retriever.ainvoke(
|
||||
query,
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
|
||||
),
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
],
|
||||
)
|
||||
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
retriever_docs[i] = [
|
||||
Document(page_content=doc) if not isinstance(doc, Document) else doc
|
||||
for doc in retriever_docs[i]
|
||||
]
|
||||
|
||||
# apply rank fusion
|
||||
return self.weighted_reciprocal_rank(retriever_docs)
|
||||
|
||||
def weighted_reciprocal_rank(
|
||||
self,
|
||||
doc_lists: list[list[Document]],
|
||||
) -> list[Document]:
|
||||
"""Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
||||
|
||||
You can find more details about RRF here:
|
||||
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf.
|
||||
|
||||
Args:
|
||||
doc_lists: A list of rank lists, where each rank list contains unique items.
|
||||
|
||||
Returns:
|
||||
The final aggregated list of items sorted by their weighted RRF
|
||||
scores in descending order.
|
||||
"""
|
||||
if len(doc_lists) != len(self.weights):
|
||||
msg = "Number of rank lists must be equal to the number of weights."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Associate each doc's content with its RRF score for later sorting by it
|
||||
# Duplicated contents across retrievers are collapsed & scored cumulatively
|
||||
rrf_score: dict[str, float] = defaultdict(float)
|
||||
for doc_list, weight in zip(doc_lists, self.weights, strict=False):
|
||||
for rank, doc in enumerate(doc_list, start=1):
|
||||
rrf_score[
|
||||
(
|
||||
doc.page_content
|
||||
if self.id_key is None
|
||||
else doc.metadata[self.id_key]
|
||||
)
|
||||
] += weight / (rank + self.c)
|
||||
|
||||
# Docs are deduplicated by their contents then sorted by their scores
|
||||
all_docs = chain.from_iterable(doc_lists)
|
||||
return sorted(
|
||||
unique_by_key(
|
||||
all_docs,
|
||||
lambda doc: (
|
||||
doc.page_content
|
||||
if self.id_key is None
|
||||
else doc.metadata[self.id_key]
|
||||
),
|
||||
),
|
||||
reverse=True,
|
||||
key=lambda doc: rrf_score[
|
||||
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import GoogleDocumentAIWarehouseRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"GoogleVertexAISearchRetriever": "langchain_community.retrievers",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers",
|
||||
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||
"GoogleVertexAISearchRetriever",
|
||||
]
|
||||
23
venv/Lib/site-packages/langchain_classic/retrievers/kay.py
Normal file
23
venv/Lib/site-packages/langchain_classic/retrievers/kay.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import KayAiRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"KayAiRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KayAiRetriever",
|
||||
]
|
||||
@@ -0,0 +1,66 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import AmazonKendraRetriever
|
||||
from langchain_community.retrievers.kendra import (
|
||||
AdditionalResultAttribute,
|
||||
AdditionalResultAttributeValue,
|
||||
DocumentAttribute,
|
||||
DocumentAttributeValue,
|
||||
Highlight,
|
||||
QueryResult,
|
||||
QueryResultItem,
|
||||
ResultItem,
|
||||
RetrieveResult,
|
||||
RetrieveResultItem,
|
||||
TextWithHighLights,
|
||||
clean_excerpt,
|
||||
combined_text,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"clean_excerpt": "langchain_community.retrievers.kendra",
|
||||
"combined_text": "langchain_community.retrievers.kendra",
|
||||
"Highlight": "langchain_community.retrievers.kendra",
|
||||
"TextWithHighLights": "langchain_community.retrievers.kendra",
|
||||
"AdditionalResultAttributeValue": "langchain_community.retrievers.kendra",
|
||||
"AdditionalResultAttribute": "langchain_community.retrievers.kendra",
|
||||
"DocumentAttributeValue": "langchain_community.retrievers.kendra",
|
||||
"DocumentAttribute": "langchain_community.retrievers.kendra",
|
||||
"ResultItem": "langchain_community.retrievers.kendra",
|
||||
"QueryResultItem": "langchain_community.retrievers.kendra",
|
||||
"RetrieveResultItem": "langchain_community.retrievers.kendra",
|
||||
"QueryResult": "langchain_community.retrievers.kendra",
|
||||
"RetrieveResult": "langchain_community.retrievers.kendra",
|
||||
"AmazonKendraRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdditionalResultAttribute",
|
||||
"AdditionalResultAttributeValue",
|
||||
"AmazonKendraRetriever",
|
||||
"DocumentAttribute",
|
||||
"DocumentAttributeValue",
|
||||
"Highlight",
|
||||
"QueryResult",
|
||||
"QueryResultItem",
|
||||
"ResultItem",
|
||||
"RetrieveResult",
|
||||
"RetrieveResultItem",
|
||||
"TextWithHighLights",
|
||||
"clean_excerpt",
|
||||
"combined_text",
|
||||
]
|
||||
23
venv/Lib/site-packages/langchain_classic/retrievers/knn.py
Normal file
23
venv/Lib/site-packages/langchain_classic/retrievers/knn.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import KNNRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"KNNRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KNNRetriever",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import (
|
||||
LlamaIndexGraphRetriever,
|
||||
LlamaIndexRetriever,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"LlamaIndexRetriever": "langchain_community.retrievers",
|
||||
"LlamaIndexGraphRetriever": "langchain_community.retrievers",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LlamaIndexGraphRetriever",
|
||||
"LlamaIndexRetriever",
|
||||
]
|
||||
@@ -0,0 +1,119 @@
|
||||
import asyncio
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class MergerRetriever(BaseRetriever):
|
||||
"""Retriever that merges the results of multiple retrievers."""
|
||||
|
||||
retrievers: list[BaseRetriever]
|
||||
"""A list of retrievers to merge."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of relevant documents.
|
||||
"""
|
||||
# Merge the results of the retrievers.
|
||||
return self.merge_documents(query, run_manager)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get the relevant documents for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of relevant documents.
|
||||
"""
|
||||
# Merge the results of the retrievers.
|
||||
return await self.amerge_documents(query, run_manager)
|
||||
|
||||
def merge_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Merge the results of the retrievers.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of merged documents.
|
||||
"""
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = [
|
||||
retriever.invoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
]
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = []
|
||||
max_docs = max(map(len, retriever_docs), default=0)
|
||||
for i in range(max_docs):
|
||||
for _retriever, doc in zip(self.retrievers, retriever_docs, strict=False):
|
||||
if i < len(doc):
|
||||
merged_documents.append(doc[i])
|
||||
|
||||
return merged_documents
|
||||
|
||||
async def amerge_documents(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Asynchronously merge the results of the retrievers.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
run_manager: The callback handler to use.
|
||||
|
||||
Returns:
|
||||
A list of merged documents.
|
||||
"""
|
||||
# Get the results of all retrievers.
|
||||
retriever_docs = await asyncio.gather(
|
||||
*(
|
||||
retriever.ainvoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child(f"retriever_{i + 1}")},
|
||||
)
|
||||
for i, retriever in enumerate(self.retrievers)
|
||||
),
|
||||
)
|
||||
|
||||
# Merge the results of the retrievers.
|
||||
merged_documents = []
|
||||
max_docs = max(map(len, retriever_docs), default=0)
|
||||
for i in range(max_docs):
|
||||
for _retriever, doc in zip(self.retrievers, retriever_docs, strict=False):
|
||||
if i < len(doc):
|
||||
merged_documents.append(doc[i])
|
||||
|
||||
return merged_documents
|
||||
23
venv/Lib/site-packages/langchain_classic/retrievers/metal.py
Normal file
23
venv/Lib/site-packages/langchain_classic/retrievers/metal.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import MetalRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"MetalRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetalRetriever",
|
||||
]
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import MilvusRetriever
|
||||
from langchain_community.retrievers.milvus import MilvusRetreiver
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"MilvusRetriever": "langchain_community.retrievers",
|
||||
"MilvusRetreiver": "langchain_community.retrievers.milvus",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MilvusRetreiver",
|
||||
"MilvusRetriever",
|
||||
]
|
||||
@@ -0,0 +1,240 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LineListOutputParser(BaseOutputParser[list[str]]):
|
||||
"""Output parser for a list of lines."""
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> list[str]:
|
||||
lines = text.strip().split("\n")
|
||||
return list(filter(None, lines)) # Remove empty lines
|
||||
|
||||
|
||||
# Default prompt
|
||||
DEFAULT_QUERY_PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template="""You are an AI language model assistant. Your task is
|
||||
to generate 3 different versions of the given user
|
||||
question to retrieve relevant documents from a vector database.
|
||||
By generating multiple perspectives on the user question,
|
||||
your goal is to help the user overcome some of the limitations
|
||||
of distance-based similarity search. Provide these alternative
|
||||
questions separated by newlines. Original question: {question}""",
|
||||
)
|
||||
|
||||
|
||||
def _unique_documents(documents: Sequence[Document]) -> list[Document]:
|
||||
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
|
||||
|
||||
|
||||
class MultiQueryRetriever(BaseRetriever):
|
||||
"""Given a query, use an LLM to write a set of queries.
|
||||
|
||||
Retrieve docs for each query. Return the unique union of all retrieved docs.
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever
|
||||
llm_chain: Runnable
|
||||
verbose: bool = True
|
||||
parser_key: str = "lines"
|
||||
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
||||
include_original: bool = False
|
||||
"""Whether to include the original query in the list of generated queries."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||
parser_key: str | None = None, # noqa: ARG003
|
||||
include_original: bool = False, # noqa: FBT001,FBT002
|
||||
) -> "MultiQueryRetriever":
|
||||
"""Initialize from llm using default template.
|
||||
|
||||
Args:
|
||||
retriever: retriever to query documents from
|
||||
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
||||
prompt: The prompt which aims to generate several different versions
|
||||
of the given user query
|
||||
parser_key: DEPRECATED. `parser_key` is no longer used and should not be
|
||||
specified.
|
||||
include_original: Whether to include the original query in the list of
|
||||
generated queries.
|
||||
|
||||
Returns:
|
||||
MultiQueryRetriever
|
||||
"""
|
||||
output_parser = LineListOutputParser()
|
||||
llm_chain = prompt | llm | output_parser
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
include_original=include_original,
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get relevant documents given a user query.
|
||||
|
||||
Args:
|
||||
query: user query
|
||||
run_manager: the callback handler to use.
|
||||
|
||||
Returns:
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = await self.agenerate_queries(query, run_manager)
|
||||
if self.include_original:
|
||||
queries.append(query)
|
||||
documents = await self.aretrieve_documents(queries, run_manager)
|
||||
return self.unique_union(documents)
|
||||
|
||||
async def agenerate_queries(
|
||||
self,
|
||||
question: str,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[str]:
|
||||
"""Generate queries based upon user input.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
run_manager: the callback handler to use.
|
||||
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = await self.llm_chain.ainvoke(
|
||||
{"question": question},
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
)
|
||||
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||
if self.verbose:
|
||||
logger.info("Generated queries: %s", lines)
|
||||
return lines
|
||||
|
||||
async def aretrieve_documents(
|
||||
self,
|
||||
queries: list[str],
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Run all LLM generated queries.
|
||||
|
||||
Args:
|
||||
queries: query list
|
||||
run_manager: the callback handler to use
|
||||
|
||||
Returns:
|
||||
List of retrieved Documents
|
||||
"""
|
||||
document_lists = await asyncio.gather(
|
||||
*(
|
||||
self.retriever.ainvoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
)
|
||||
for query in queries
|
||||
),
|
||||
)
|
||||
return [doc for docs in document_lists for doc in docs]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get relevant documents given a user query.
|
||||
|
||||
Args:
|
||||
query: user query
|
||||
run_manager: the callback handler to use.
|
||||
|
||||
Returns:
|
||||
Unique union of relevant documents from all generated queries
|
||||
"""
|
||||
queries = self.generate_queries(query, run_manager)
|
||||
if self.include_original:
|
||||
queries.append(query)
|
||||
documents = self.retrieve_documents(queries, run_manager)
|
||||
return self.unique_union(documents)
|
||||
|
||||
def generate_queries(
|
||||
self,
|
||||
question: str,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[str]:
|
||||
"""Generate queries based upon user input.
|
||||
|
||||
Args:
|
||||
question: user query
|
||||
run_manager: run manager for callbacks
|
||||
|
||||
Returns:
|
||||
List of LLM generated queries that are similar to the user input
|
||||
"""
|
||||
response = self.llm_chain.invoke(
|
||||
{"question": question},
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
)
|
||||
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||
if self.verbose:
|
||||
logger.info("Generated queries: %s", lines)
|
||||
return lines
|
||||
|
||||
def retrieve_documents(
|
||||
self,
|
||||
queries: list[str],
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Run all LLM generated queries.
|
||||
|
||||
Args:
|
||||
queries: query list
|
||||
run_manager: run manager for callbacks
|
||||
|
||||
Returns:
|
||||
List of retrieved Documents
|
||||
"""
|
||||
documents = []
|
||||
for query in queries:
|
||||
docs = self.retriever.invoke(
|
||||
query,
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
)
|
||||
documents.extend(docs)
|
||||
return documents
|
||||
|
||||
def unique_union(self, documents: list[Document]) -> list[Document]:
|
||||
"""Get unique Documents.
|
||||
|
||||
Args:
|
||||
documents: List of retrieved Documents
|
||||
|
||||
Returns:
|
||||
List of unique retrieved Documents
|
||||
"""
|
||||
return _unique_documents(documents)
|
||||
@@ -0,0 +1,155 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.storage._lc_store import create_kv_docstore
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Enumerator of the types of search to perform."""
|
||||
|
||||
similarity = "similarity"
|
||||
"""Similarity search."""
|
||||
similarity_score_threshold = "similarity_score_threshold"
|
||||
"""Similarity search with a score threshold."""
|
||||
mmr = "mmr"
|
||||
"""Maximal Marginal Relevance reranking of similarity search."""
|
||||
|
||||
|
||||
class MultiVectorRetriever(BaseRetriever):
|
||||
"""Retriever that supports multiple embeddings per parent document.
|
||||
|
||||
This retriever is designed for scenarios where documents are split into
|
||||
smaller chunks for embedding and vector search, but retrieval returns
|
||||
the original parent documents rather than individual chunks.
|
||||
|
||||
It works by:
|
||||
- Performing similarity (or MMR) search over embedded child chunks
|
||||
- Collecting unique parent document IDs from chunk metadata
|
||||
- Fetching and returning the corresponding parent documents from the docstore
|
||||
|
||||
This pattern is commonly used in RAG pipelines to improve answer grounding
|
||||
while preserving full document context.
|
||||
"""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""The underlying `VectorStore` to use to store small chunks
|
||||
and their embedding vectors"""
|
||||
|
||||
byte_store: ByteStore | None = None
|
||||
"""The lower-level backing storage layer for the parent documents"""
|
||||
|
||||
docstore: BaseStore[str, Document]
|
||||
"""The storage interface for the parent documents"""
|
||||
|
||||
id_key: str = "doc_id"
|
||||
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the search function."""
|
||||
|
||||
search_type: SearchType = SearchType.similarity
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _shim_docstore(cls, values: dict) -> Any:
|
||||
byte_store = values.get("byte_store")
|
||||
docstore = values.get("docstore")
|
||||
if byte_store is not None:
|
||||
docstore = create_kv_docstore(byte_store)
|
||||
elif docstore is None:
|
||||
msg = "You must pass a `byte_store` parameter."
|
||||
raise ValueError(msg)
|
||||
values["docstore"] = docstore
|
||||
return values
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents.
|
||||
"""
|
||||
if self.search_type == SearchType.mmr:
|
||||
sub_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
elif self.search_type == SearchType.similarity_score_threshold:
|
||||
sub_docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
)
|
||||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||
else:
|
||||
sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
|
||||
# We do this to maintain the order of the IDs that are returned
|
||||
ids = []
|
||||
for d in sub_docs:
|
||||
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
docs = self.docstore.mget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents.
|
||||
"""
|
||||
if self.search_type == SearchType.mmr:
|
||||
sub_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
query,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
elif self.search_type == SearchType.similarity_score_threshold:
|
||||
sub_docs_and_similarities = (
|
||||
await self.vectorstore.asimilarity_search_with_relevance_scores(
|
||||
query,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
)
|
||||
sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
|
||||
else:
|
||||
sub_docs = await self.vectorstore.asimilarity_search(
|
||||
query,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
|
||||
# We do this to maintain the order of the IDs that are returned
|
||||
ids = []
|
||||
for d in sub_docs:
|
||||
if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
|
||||
ids.append(d.metadata[self.id_key])
|
||||
docs = await self.docstore.amget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import OutlineRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"OutlineRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OutlineRetriever",
|
||||
]
|
||||
@@ -0,0 +1,176 @@
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
from langchain_classic.retrievers import MultiVectorRetriever
|
||||
|
||||
|
||||
class ParentDocumentRetriever(MultiVectorRetriever):
|
||||
"""Retrieve small chunks then retrieve their parent documents.
|
||||
|
||||
When splitting documents for retrieval, there are often conflicting desires:
|
||||
|
||||
1. You may want to have small documents, so that their embeddings can most
|
||||
accurately reflect their meaning. If too long, then the embeddings can
|
||||
lose meaning.
|
||||
2. You want to have long enough documents that the context of each chunk is
|
||||
retained.
|
||||
|
||||
The ParentDocumentRetriever strikes that balance by splitting and storing
|
||||
small chunks of data. During retrieval, it first fetches the small chunks
|
||||
but then looks up the parent IDs for those chunks and returns those larger
|
||||
documents.
|
||||
|
||||
Note that "parent document" refers to the document that a small chunk
|
||||
originated from. This can either be the whole raw document OR a larger
|
||||
chunk.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_classic.storage import InMemoryStore
|
||||
|
||||
# This text splitter is used to create the parent documents
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=2000, add_start_index=True
|
||||
)
|
||||
# This text splitter is used to create the child documents
|
||||
# It should create documents smaller than the parent
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=400, add_start_index=True
|
||||
)
|
||||
# The VectorStore to use to index the child chunks
|
||||
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
|
||||
# The storage layer for the parent documents
|
||||
store = InMemoryStore()
|
||||
|
||||
# Initialize the retriever
|
||||
retriever = ParentDocumentRetriever(
|
||||
vectorstore=vectorstore,
|
||||
docstore=store,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
child_splitter: TextSplitter
|
||||
"""The text splitter to use to create child documents."""
|
||||
|
||||
"""The key to use to track the parent id. This will be stored in the
|
||||
metadata of child documents."""
|
||||
parent_splitter: TextSplitter | None = None
|
||||
"""The text splitter to use to create parent documents.
|
||||
If none, then the parent documents will be the raw documents passed in."""
|
||||
|
||||
child_metadata_fields: Sequence[str] | None = None
|
||||
"""Metadata fields to leave in child documents. If `None`, leave all parent document
|
||||
metadata.
|
||||
"""
|
||||
|
||||
def _split_docs_for_adding(
|
||||
self,
|
||||
documents: list[Document],
|
||||
ids: list[str] | None = None,
|
||||
*,
|
||||
add_to_docstore: bool = True,
|
||||
) -> tuple[list[Document], list[tuple[str, Document]]]:
|
||||
if self.parent_splitter is not None:
|
||||
documents = self.parent_splitter.split_documents(documents)
|
||||
if ids is None:
|
||||
doc_ids = [str(uuid.uuid4()) for _ in documents]
|
||||
if not add_to_docstore:
|
||||
msg = "If IDs are not passed in, `add_to_docstore` MUST be True"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
if len(documents) != len(ids):
|
||||
msg = (
|
||||
"Got uneven list of documents and ids. "
|
||||
"If `ids` is provided, should be same length as `documents`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
doc_ids = ids
|
||||
|
||||
docs = []
|
||||
full_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
_id = doc_ids[i]
|
||||
sub_docs = self.child_splitter.split_documents([doc])
|
||||
if self.child_metadata_fields is not None:
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata = {
|
||||
k: _doc.metadata[k] for k in self.child_metadata_fields
|
||||
}
|
||||
for _doc in sub_docs:
|
||||
_doc.metadata[self.id_key] = _id
|
||||
docs.extend(sub_docs)
|
||||
full_docs.append((_id, doc))
|
||||
|
||||
return docs, full_docs
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: list[Document],
|
||||
ids: list[str] | None = None,
|
||||
add_to_docstore: bool = True, # noqa: FBT001,FBT002
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Adds documents to the docstore and vectorstores.
|
||||
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of IDs for documents. If provided should be the same
|
||||
length as the list of documents. Can be provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
IDs.
|
||||
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||
This can be false if and only if `ids` are provided. You may want
|
||||
to set this to False if the documents are already in the docstore
|
||||
and you don't want to re-add them.
|
||||
**kwargs: additional keyword arguments passed to the `VectorStore`.
|
||||
"""
|
||||
docs, full_docs = self._split_docs_for_adding(
|
||||
documents,
|
||||
ids,
|
||||
add_to_docstore=add_to_docstore,
|
||||
)
|
||||
self.vectorstore.add_documents(docs, **kwargs)
|
||||
if add_to_docstore:
|
||||
self.docstore.mset(full_docs)
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: list[Document],
|
||||
ids: list[str] | None = None,
|
||||
add_to_docstore: bool = True, # noqa: FBT001,FBT002
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Adds documents to the docstore and vectorstores.
|
||||
|
||||
Args:
|
||||
documents: List of documents to add
|
||||
ids: Optional list of IDs for documents. If provided should be the same
|
||||
length as the list of documents. Can be provided if parent documents
|
||||
are already in the document store and you don't want to re-add
|
||||
to the docstore. If not provided, random UUIDs will be used as
|
||||
idIDss.
|
||||
add_to_docstore: Boolean of whether to add documents to docstore.
|
||||
This can be false if and only if `ids` are provided. You may want
|
||||
to set this to False if the documents are already in the docstore
|
||||
and you don't want to re-add them.
|
||||
**kwargs: additional keyword arguments passed to the `VectorStore`.
|
||||
"""
|
||||
docs, full_docs = self._split_docs_for_adding(
|
||||
documents,
|
||||
ids,
|
||||
add_to_docstore=add_to_docstore,
|
||||
)
|
||||
await self.vectorstore.aadd_documents(docs, **kwargs)
|
||||
if add_to_docstore:
|
||||
await self.docstore.amset(full_docs)
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"PineconeHybridSearchRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PineconeHybridSearchRetriever",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import PubMedRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"PubMedRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PubMedRetriever",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers import PubMedRetriever
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"PubMedRetriever": "langchain_community.retrievers"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PubMedRetriever",
|
||||
]
|
||||
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default template
|
||||
DEFAULT_TEMPLATE = """You are an assistant tasked with taking a natural language \
|
||||
query from a user and converting it into a query for a vectorstore. \
|
||||
In this process, you strip out information that is not relevant for \
|
||||
the retrieval task. Here is the user query: {question}"""
|
||||
|
||||
# Default prompt
|
||||
DEFAULT_QUERY_PROMPT = PromptTemplate.from_template(DEFAULT_TEMPLATE)
|
||||
|
||||
|
||||
class RePhraseQueryRetriever(BaseRetriever):
|
||||
"""Given a query, use an LLM to re-phrase it.
|
||||
|
||||
Then, retrieve docs for the re-phrased query.
|
||||
"""
|
||||
|
||||
retriever: BaseRetriever
|
||||
llm_chain: Runnable
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||
) -> "RePhraseQueryRetriever":
|
||||
"""Initialize from llm using default template.
|
||||
|
||||
The prompt used here expects a single input: `question`
|
||||
|
||||
Args:
|
||||
retriever: retriever to query documents from
|
||||
llm: llm for query generation using DEFAULT_QUERY_PROMPT
|
||||
prompt: prompt template for query generation
|
||||
|
||||
Returns:
|
||||
RePhraseQueryRetriever
|
||||
"""
|
||||
llm_chain = prompt | llm | StrOutputParser()
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
llm_chain=llm_chain,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
"""Get relevant documents given a user question.
|
||||
|
||||
Args:
|
||||
query: user question
|
||||
run_manager: callback handler to use
|
||||
|
||||
Returns:
|
||||
Relevant documents for re-phrased question
|
||||
"""
|
||||
re_phrased_question = self.llm_chain.invoke(
|
||||
query,
|
||||
{"callbacks": run_manager.get_child()},
|
||||
)
|
||||
logger.info("Re-phrased question: %s", re_phrased_question)
|
||||
return self.retriever.invoke(
|
||||
re_phrased_question,
|
||||
config={"callbacks": run_manager.get_child()},
|
||||
)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user