initial commit
This commit is contained in:
@@ -0,0 +1,253 @@
|
||||
"""**Retriever** class returns Documents given a text **query**.
|
||||
|
||||
It is more general than a vector store. A retriever does not need to be able to
|
||||
store documents, only to return (or retrieve) it. Vector stores can be used as
|
||||
the backbone of a retriever, but there are other types of retrievers as well.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseRetriever --> <name>Retriever # Examples: ArxivRetriever, MergerRetriever
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
Document, Serializable, Callbacks,
|
||||
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.retrievers.arcee import (
|
||||
ArceeRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.arxiv import (
|
||||
ArxivRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.asknews import (
|
||||
AskNewsRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.azure_ai_search import (
|
||||
AzureAISearchRetriever,
|
||||
AzureCognitiveSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.bedrock import (
|
||||
AmazonKnowledgeBasesRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.bm25 import (
|
||||
BM25Retriever,
|
||||
)
|
||||
from langchain_community.retrievers.breebs import (
|
||||
BreebsRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.chaindesk import (
|
||||
ChaindeskRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.chatgpt_plugin_retriever import (
|
||||
ChatGPTPluginRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.cohere_rag_retriever import (
|
||||
CohereRagRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.docarray import (
|
||||
DocArrayRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.dria_index import (
|
||||
DriaRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.elastic_search_bm25 import (
|
||||
ElasticSearchBM25Retriever,
|
||||
)
|
||||
from langchain_community.retrievers.embedchain import (
|
||||
EmbedchainRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.google_cloud_documentai_warehouse import (
|
||||
GoogleDocumentAIWarehouseRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.kay import (
|
||||
KayAiRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.kendra import (
|
||||
AmazonKendraRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.knn import (
|
||||
KNNRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.llama_index import (
|
||||
LlamaIndexGraphRetriever,
|
||||
LlamaIndexRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.metal import (
|
||||
MetalRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.milvus import (
|
||||
MilvusRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.nanopq import NanoPQRetriever
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
from langchain_community.retrievers.outline import (
|
||||
OutlineRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.pinecone_hybrid_search import (
|
||||
PineconeHybridSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.pubmed import (
|
||||
PubMedRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.qdrant_sparse_vector_retriever import (
|
||||
QdrantSparseVectorRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.rememberizer import (
|
||||
RememberizerRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.remote_retriever import (
|
||||
RemoteLangChainRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.svm import (
|
||||
SVMRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.tavily_search_api import (
|
||||
TavilySearchAPIRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.tfidf import (
|
||||
TFIDFRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.thirdai_neuraldb import NeuralDBRetriever
|
||||
from langchain_community.retrievers.vespa_retriever import (
|
||||
VespaRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.weaviate_hybrid_search import (
|
||||
WeaviateHybridSearchRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.web_research import WebResearchRetriever
|
||||
from langchain_community.retrievers.wikipedia import (
|
||||
WikipediaRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.you import (
|
||||
YouRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.zep import (
|
||||
ZepRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.zep_cloud import (
|
||||
ZepCloudRetriever,
|
||||
)
|
||||
from langchain_community.retrievers.zilliz import (
|
||||
ZillizRetriever,
|
||||
)
|
||||
|
||||
|
||||
_module_lookup = {
|
||||
"AmazonKendraRetriever": "langchain_community.retrievers.kendra",
|
||||
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers.bedrock",
|
||||
"ArceeRetriever": "langchain_community.retrievers.arcee",
|
||||
"ArxivRetriever": "langchain_community.retrievers.arxiv",
|
||||
"AskNewsRetriever": "langchain_community.retrievers.asknews",
|
||||
"AzureAISearchRetriever": "langchain_community.retrievers.azure_ai_search",
|
||||
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_ai_search",
|
||||
"BM25Retriever": "langchain_community.retrievers.bm25",
|
||||
"BreebsRetriever": "langchain_community.retrievers.breebs",
|
||||
"ChaindeskRetriever": "langchain_community.retrievers.chaindesk",
|
||||
"ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever",
|
||||
"CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever",
|
||||
"DocArrayRetriever": "langchain_community.retrievers.docarray",
|
||||
"DriaRetriever": "langchain_community.retrievers.dria_index",
|
||||
"ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25",
|
||||
"EmbedchainRetriever": "langchain_community.retrievers.embedchain",
|
||||
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
|
||||
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers.google_cloud_documentai_warehouse", # noqa: E501
|
||||
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
|
||||
"GoogleVertexAISearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
|
||||
"KNNRetriever": "langchain_community.retrievers.knn",
|
||||
"KayAiRetriever": "langchain_community.retrievers.kay",
|
||||
"LlamaIndexGraphRetriever": "langchain_community.retrievers.llama_index",
|
||||
"LlamaIndexRetriever": "langchain_community.retrievers.llama_index",
|
||||
"MetalRetriever": "langchain_community.retrievers.metal",
|
||||
"MilvusRetriever": "langchain_community.retrievers.milvus",
|
||||
"NanoPQRetriever": "langchain_community.retrievers.nanopq",
|
||||
"NeedleRetriever": "langchain_community.retrievers.needle",
|
||||
"OutlineRetriever": "langchain_community.retrievers.outline",
|
||||
"PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501
|
||||
"PubMedRetriever": "langchain_community.retrievers.pubmed",
|
||||
"QdrantSparseVectorRetriever": "langchain_community.retrievers.qdrant_sparse_vector_retriever", # noqa: E501
|
||||
"RememberizerRetriever": "langchain_community.retrievers.rememberizer",
|
||||
"RemoteLangChainRetriever": "langchain_community.retrievers.remote_retriever",
|
||||
"SVMRetriever": "langchain_community.retrievers.svm",
|
||||
"TFIDFRetriever": "langchain_community.retrievers.tfidf",
|
||||
"TavilySearchAPIRetriever": "langchain_community.retrievers.tavily_search_api",
|
||||
"VespaRetriever": "langchain_community.retrievers.vespa_retriever",
|
||||
"WeaviateHybridSearchRetriever": "langchain_community.retrievers.weaviate_hybrid_search", # noqa: E501
|
||||
"WebResearchRetriever": "langchain_community.retrievers.web_research",
|
||||
"WikipediaRetriever": "langchain_community.retrievers.wikipedia",
|
||||
"YouRetriever": "langchain_community.retrievers.you",
|
||||
"ZepRetriever": "langchain_community.retrievers.zep",
|
||||
"ZepCloudRetriever": "langchain_community.retrievers.zep_cloud",
|
||||
"ZillizRetriever": "langchain_community.retrievers.zilliz",
|
||||
"NeuralDBRetriever": "langchain_community.retrievers.thirdai_neuraldb",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AmazonKendraRetriever",
|
||||
"AmazonKnowledgeBasesRetriever",
|
||||
"ArceeRetriever",
|
||||
"ArxivRetriever",
|
||||
"AskNewsRetriever",
|
||||
"AzureAISearchRetriever",
|
||||
"AzureCognitiveSearchRetriever",
|
||||
"BM25Retriever",
|
||||
"BreebsRetriever",
|
||||
"ChaindeskRetriever",
|
||||
"ChatGPTPluginRetriever",
|
||||
"CohereRagRetriever",
|
||||
"DocArrayRetriever",
|
||||
"DriaRetriever",
|
||||
"ElasticSearchBM25Retriever",
|
||||
"EmbedchainRetriever",
|
||||
"GoogleCloudEnterpriseSearchRetriever",
|
||||
"GoogleDocumentAIWarehouseRetriever",
|
||||
"GoogleVertexAIMultiTurnSearchRetriever",
|
||||
"GoogleVertexAISearchRetriever",
|
||||
"KayAiRetriever",
|
||||
"KNNRetriever",
|
||||
"LlamaIndexGraphRetriever",
|
||||
"LlamaIndexRetriever",
|
||||
"MetalRetriever",
|
||||
"MilvusRetriever",
|
||||
"NanoPQRetriever",
|
||||
"NeedleRetriever",
|
||||
"NeuralDBRetriever",
|
||||
"OutlineRetriever",
|
||||
"PineconeHybridSearchRetriever",
|
||||
"PubMedRetriever",
|
||||
"QdrantSparseVectorRetriever",
|
||||
"RememberizerRetriever",
|
||||
"RemoteLangChainRetriever",
|
||||
"SVMRetriever",
|
||||
"TavilySearchAPIRetriever",
|
||||
"TFIDFRetriever",
|
||||
"VespaRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"WebResearchRetriever",
|
||||
"WikipediaRetriever",
|
||||
"YouRetriever",
|
||||
"ZepRetriever",
|
||||
"ZepCloudRetriever",
|
||||
"ZillizRetriever",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
137
venv/Lib/site-packages/langchain_community/retrievers/arcee.py
Normal file
137
venv/Lib/site-packages/langchain_community/retrievers/arcee.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict, SecretStr
|
||||
|
||||
from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
|
||||
|
||||
class ArceeRetriever(BaseRetriever):
|
||||
"""Arcee Domain Adapted Language Models (DALMs) retriever.
|
||||
|
||||
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
|
||||
or pass ``arcee_api_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import ArceeRetriever
|
||||
|
||||
retriever = ArceeRetriever(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_key="ARCEE-API-KEY"
|
||||
)
|
||||
|
||||
documents = retriever.invoke("AI-driven music therapy")
|
||||
"""
|
||||
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee client."""
|
||||
|
||||
arcee_api_key: SecretStr
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
"""Arcee DALM name"""
|
||||
|
||||
arcee_api_url: str = "https://api.arcee.ai"
|
||||
"""Arcee API URL"""
|
||||
|
||||
arcee_api_version: str = "v2"
|
||||
"""Arcee API Version"""
|
||||
|
||||
arcee_app_url: str = "https://app.arcee.ai"
|
||||
"""Arcee App URL"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes private fields."""
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=self.arcee_api_key.get_secret_value(),
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
model_name=self.model,
|
||||
)
|
||||
|
||||
self._client.validate_model_training_status()
|
||||
|
||||
@pre_init
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
# validate env vars
|
||||
values["arcee_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
)
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_url",
|
||||
"ARCEE_API_URL",
|
||||
)
|
||||
|
||||
values["arcee_app_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_app_url",
|
||||
"ARCEE_APP_URL",
|
||||
)
|
||||
|
||||
values["arcee_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_version",
|
||||
"ARCEE_API_VERSION",
|
||||
)
|
||||
|
||||
# validate model kwargs
|
||||
if values["model_kwargs"]:
|
||||
kw = values["model_kwargs"]
|
||||
|
||||
# validate size
|
||||
if kw.get("size") is not None:
|
||||
if not kw.get("size") >= 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
|
||||
# validate filters
|
||||
if kw.get("filters") is not None:
|
||||
if not isinstance(kw.get("filters"), List):
|
||||
raise ValueError("`filters` must be a list.")
|
||||
for f in kw.get("filters"):
|
||||
DALMFilter(**f)
|
||||
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve {size} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
size: The max number of context results to retrieve.
|
||||
Defaults to 3. (Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not self._client:
|
||||
raise ValueError("Client is not initialized.")
|
||||
return self._client.retrieve(query=query, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
|
||||
class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
"""`Arxiv` retriever.
|
||||
|
||||
Setup:
|
||||
Install ``arxiv``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U arxiv
|
||||
|
||||
Key init args:
|
||||
load_max_docs: int
|
||||
maximum number of documents to load
|
||||
get_ful_documents: bool
|
||||
whether to return full document text or snippets
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import ArxivRetriever
|
||||
|
||||
retriever = ArxivRetriever(
|
||||
load_max_docs=2,
|
||||
get_ful_documents=True,
|
||||
)
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
docs = retriever.invoke("What is the ImageBind model?")
|
||||
docs[0].metadata
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{'Entry ID': 'http://arxiv.org/abs/2305.05665v2',
|
||||
'Published': datetime.date(2023, 5, 31),
|
||||
'Title': 'ImageBind: One Embedding Space To Bind Them All',
|
||||
'Authors': 'Rohit Girdhar, Alaaeldin El-Nouby, Zhuang Liu, Mannat Singh, Kalyan Vasudev Alwala, Armand Joulin, Ishan Misra'}
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke("What is the ImageBind model?")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
'The ImageBind model is an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data...'
|
||||
""" # noqa: E501
|
||||
|
||||
get_full_documents: bool = False
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.get_full_documents:
|
||||
return self.load(query=query)
|
||||
else:
|
||||
return self.get_summaries_as_docs(query)
|
||||
146
venv/Lib/site-packages/langchain_community/retrievers/asknews.py
Normal file
146
venv/Lib/site-packages/langchain_community/retrievers/asknews.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class AskNewsRetriever(BaseRetriever):
|
||||
"""AskNews retriever."""
|
||||
|
||||
k: int = 10
|
||||
offset: int = 0
|
||||
start_timestamp: Optional[int] = None
|
||||
end_timestamp: Optional[int] = None
|
||||
method: Literal["nl", "kw"] = "nl"
|
||||
categories: List[
|
||||
Literal[
|
||||
"All",
|
||||
"Business",
|
||||
"Crime",
|
||||
"Politics",
|
||||
"Science",
|
||||
"Sports",
|
||||
"Technology",
|
||||
"Military",
|
||||
"Health",
|
||||
"Entertainment",
|
||||
"Finance",
|
||||
"Culture",
|
||||
"Climate",
|
||||
"Environment",
|
||||
"World",
|
||||
]
|
||||
] = ["All"]
|
||||
historical: bool = False
|
||||
similarity_score_threshold: float = 0.5
|
||||
kwargs: Optional[Dict[str, Any]] = {}
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
try:
|
||||
from asknews_sdk import AskNewsSDK
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"AskNews python package not found. "
|
||||
"Please install it with `pip install asknews`."
|
||||
)
|
||||
an_client = AskNewsSDK(
|
||||
client_id=self.client_id or os.environ["ASKNEWS_CLIENT_ID"],
|
||||
client_secret=self.client_secret or os.environ["ASKNEWS_CLIENT_SECRET"],
|
||||
scopes=["news"],
|
||||
)
|
||||
response = an_client.news.search_news(
|
||||
query=query,
|
||||
n_articles=self.k,
|
||||
start_timestamp=self.start_timestamp,
|
||||
end_timestamp=self.end_timestamp,
|
||||
method=self.method,
|
||||
categories=self.categories,
|
||||
historical=self.historical,
|
||||
similarity_score_threshold=self.similarity_score_threshold,
|
||||
offset=self.offset,
|
||||
doc_start_delimiter="<doc>",
|
||||
doc_end_delimiter="</doc>",
|
||||
return_type="both",
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
return self._extract_documents(response)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
try:
|
||||
from asknews_sdk import AsyncAskNewsSDK
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"AskNews python package not found. "
|
||||
"Please install it with `pip install asknews`."
|
||||
)
|
||||
an_client = AsyncAskNewsSDK(
|
||||
client_id=self.client_id or os.environ["ASKNEWS_CLIENT_ID"],
|
||||
client_secret=self.client_secret or os.environ["ASKNEWS_CLIENT_SECRET"],
|
||||
scopes=["news"],
|
||||
)
|
||||
response = await an_client.news.search_news(
|
||||
query=query,
|
||||
n_articles=self.k,
|
||||
start_timestamp=self.start_timestamp,
|
||||
end_timestamp=self.end_timestamp,
|
||||
method=self.method,
|
||||
categories=self.categories,
|
||||
historical=self.historical,
|
||||
similarity_score_threshold=self.similarity_score_threshold,
|
||||
offset=self.offset,
|
||||
return_type="both",
|
||||
doc_start_delimiter="<doc>",
|
||||
doc_end_delimiter="</doc>",
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
return self._extract_documents(response)
|
||||
|
||||
def _extract_documents(self, response: Any) -> List[Document]:
|
||||
"""Extract documents from an api response."""
|
||||
|
||||
from asknews_sdk.dto.news import SearchResponse
|
||||
|
||||
sr: SearchResponse = response
|
||||
matches = re.findall(r"<doc>(.*?)</doc>", sr.as_string, re.DOTALL)
|
||||
docs = [
|
||||
Document(
|
||||
page_content=matches[i].strip(),
|
||||
metadata={
|
||||
"title": sr.as_dicts[i].title,
|
||||
"source": str(sr.as_dicts[i].article_url)
|
||||
if sr.as_dicts[i].article_url
|
||||
else None,
|
||||
"images": sr.as_dicts[i].image_url,
|
||||
},
|
||||
)
|
||||
for i in range(len(matches))
|
||||
]
|
||||
return docs
|
||||
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import get_from_dict_or_env, get_from_env
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
DEFAULT_URL_SUFFIX = "search.windows.net"
|
||||
"""Default URL Suffix for endpoint connection - commercial cloud"""
|
||||
|
||||
|
||||
class AzureAISearchRetriever(BaseRetriever):
|
||||
"""`Azure AI Search` service retriever.
|
||||
|
||||
Setup:
|
||||
See here for more detail: https://python.langchain.com/docs/integrations/retrievers/azure_ai_search/
|
||||
|
||||
We will need to install the below dependencies and set the required
|
||||
environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-community azure-identity azure-search-documents
|
||||
export AZURE_AI_SEARCH_SERVICE_NAME="<YOUR_SEARCH_SERVICE_NAME>"
|
||||
export AZURE_AI_SEARCH_INDEX_NAME="<YOUR_SEARCH_INDEX_NAME>"
|
||||
|
||||
export AZURE_AI_SEARCH_API_KEY="<YOUR_API_KEY>"
|
||||
or
|
||||
export AZURE_AI_SEARCH_BEARER_TOKEN="<YOUR_BEARER_TOKEN>"
|
||||
|
||||
Key init args:
|
||||
content_key: str
|
||||
top_k: int
|
||||
index_name: str
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import AzureAISearchRetriever
|
||||
|
||||
retriever = AzureAISearchRetriever(
|
||||
content_key="content", top_k=1, index_name="langchain-vector-demo"
|
||||
)
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
retriever.invoke("here is my unstructured query string")
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke("...")
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
service_name: str = ""
|
||||
"""Name of Azure AI Search service"""
|
||||
index_name: str = ""
|
||||
"""Name of Index inside Azure AI Search service"""
|
||||
api_key: str = ""
|
||||
"""API Key. Both Admin and Query keys work, but for reading data it's
|
||||
recommended to use a Query key."""
|
||||
api_version: str = "2023-11-01"
|
||||
"""API version"""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||
azure_ad_token: str = ""
|
||||
"""Your Azure Active Directory token.
|
||||
|
||||
Automatically inferred from env var `AZURE_AI_SEARCH_AD_TOKEN` if not provided.
|
||||
|
||||
For more:
|
||||
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
|
||||
"""
|
||||
content_key: str = "content"
|
||||
"""Key in a retrieved result to set as the Document page_content."""
|
||||
top_k: Optional[int] = None
|
||||
"""Number of results to retrieve. Set to None to retrieve all results."""
|
||||
filter: Optional[str] = None
|
||||
"""OData $filter expression to apply to the search query."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that service name, index name and api key exists in environment."""
|
||||
values["service_name"] = get_from_dict_or_env(
|
||||
values, "service_name", "AZURE_AI_SEARCH_SERVICE_NAME"
|
||||
)
|
||||
values["index_name"] = get_from_dict_or_env(
|
||||
values, "index_name", "AZURE_AI_SEARCH_INDEX_NAME"
|
||||
)
|
||||
values["azure_ad_token"] = get_from_dict_or_env(
|
||||
values, "azure_ad_token", "AZURE_AI_SEARCH_AD_TOKEN", default=""
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "AZURE_AI_SEARCH_API_KEY", default=""
|
||||
)
|
||||
if values["azure_ad_token"] == "" and values["api_key"] == "":
|
||||
raise ValueError(
|
||||
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, "
|
||||
"or the `AZURE_AI_SEARCH_API_KEY` or `AZURE_AI_SEARCH_AD_TOKEN` "
|
||||
"environment variables."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _build_search_url(self, query: str) -> str:
|
||||
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
|
||||
if url_suffix in self.service_name and "https://" in self.service_name:
|
||||
base_url = f"{self.service_name}/"
|
||||
elif url_suffix in self.service_name and "https://" not in self.service_name:
|
||||
base_url = f"https://{self.service_name}/"
|
||||
elif url_suffix not in self.service_name and "https://" in self.service_name:
|
||||
base_url = f"{self.service_name}.{url_suffix}/"
|
||||
elif (
|
||||
url_suffix not in self.service_name and "https://" not in self.service_name
|
||||
):
|
||||
base_url = f"https://{self.service_name}.{url_suffix}/"
|
||||
else:
|
||||
# pass to Azure to throw a specific error
|
||||
base_url = self.service_name
|
||||
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
|
||||
top_param = f"&$top={self.top_k}" if self.top_k else ""
|
||||
filter_param = f"&$filter={self.filter}" if self.filter else ""
|
||||
return base_url + endpoint_path + f"&search={query}" + top_param + filter_param
|
||||
|
||||
@property
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.azure_ad_token:
|
||||
headers["Authorization"] = f"Bearer {self.azure_ad_token}"
|
||||
elif self.api_key:
|
||||
headers["api-key"] = f"{self.api_key}"
|
||||
return headers
|
||||
|
||||
def _search(self, query: str) -> List[dict]:
|
||||
search_url = self._build_search_url(query)
|
||||
response = requests.get(search_url, headers=self._headers)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error in search request: {response}")
|
||||
|
||||
return json.loads(response.text)["value"]
|
||||
|
||||
async def _asearch(self, query: str) -> List[dict]:
|
||||
search_url = self._build_search_url(query)
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(search_url, headers=self._headers) as response:
|
||||
response_json = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(
|
||||
search_url, headers=self._headers
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
return response_json["value"]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
search_results = self._search(query)
|
||||
|
||||
return [
|
||||
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
search_results = await self._asearch(query)
|
||||
|
||||
return [
|
||||
Document(page_content=result.pop(self.content_key), metadata=result)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
class AzureCognitiveSearchRetriever(AzureAISearchRetriever):
|
||||
"""`Azure Cognitive Search` service retriever.
|
||||
This version of the retriever will soon be
|
||||
depreciated. Please switch to AzureAISearchRetriever
|
||||
"""
|
||||
186
venv/Lib/site-packages/langchain_community/retrievers/bedrock.py
Normal file
186
venv/Lib/site-packages/langchain_community/retrievers/bedrock.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class VectorSearchConfig(BaseModel, extra="allow"):
|
||||
"""Configuration for vector search."""
|
||||
|
||||
numberOfResults: int = 4
|
||||
|
||||
|
||||
class RetrievalConfig(BaseModel, extra="allow"):
|
||||
"""Configuration for retrieval."""
|
||||
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.16",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.AmazonKnowledgeBasesRetriever",
|
||||
)
|
||||
class AmazonKnowledgeBasesRetriever(BaseRetriever):
|
||||
"""Amazon Bedrock Knowledge Bases retriever.
|
||||
|
||||
See https://aws.amazon.com/bedrock/knowledge-bases for more info.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-aws``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-aws
|
||||
|
||||
Key init args:
|
||||
knowledge_base_id: Knowledge Base ID.
|
||||
region_name: The aws region e.g., `us-west-2`.
|
||||
Fallback to AWS_DEFAULT_REGION env variable or region specified in
|
||||
~/.aws/config.
|
||||
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
||||
or ~/.aws/config files, which has either access keys or role information
|
||||
specified. If not specified, the default credential profile or, if on an
|
||||
EC2 instance, credentials from IMDS will be used.
|
||||
client: boto3 client for bedrock agent runtime.
|
||||
retrieval_config: Configuration for retrieval.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
|
||||
|
||||
retriever = AmazonKnowledgeBasesRetriever(
|
||||
knowledge_base_id="<knowledge-base-id>",
|
||||
retrieval_config={
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": 4
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
query = "..."
|
||||
|
||||
retriever.invoke(query)
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatBedrockConverse(
|
||||
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
)
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke("...")
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
knowledge_base_id: str
|
||||
region_name: Optional[str] = None
|
||||
credentials_profile_name: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
client: Any
|
||||
retrieval_config: RetrievalConfig
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: Dict[str, Any]) -> Any:
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
try:
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import UnknownServiceError
|
||||
|
||||
if values.get("credentials_profile_name"):
|
||||
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {
|
||||
"config": Config(
|
||||
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
|
||||
)
|
||||
}
|
||||
if values.get("region_name"):
|
||||
client_params["region_name"] = values["region_name"]
|
||||
|
||||
if values.get("endpoint_url"):
|
||||
client_params["endpoint_url"] = values["endpoint_url"]
|
||||
|
||||
values["client"] = session.client("bedrock-agent-runtime", **client_params)
|
||||
|
||||
return values
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except UnknownServiceError as e:
|
||||
raise ImportError(
|
||||
"Ensure that you have installed the latest boto3 package "
|
||||
"that contains the API for `bedrock-runtime-agent`."
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = self.client.retrieve(
|
||||
retrievalQuery={"text": query.strip()},
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
retrievalConfiguration=self.retrieval_config.dict(),
|
||||
)
|
||||
results = response["retrievalResults"]
|
||||
documents = []
|
||||
for result in results:
|
||||
content = result["content"]["text"]
|
||||
result.pop("content")
|
||||
if "score" not in result:
|
||||
result["score"] = 0
|
||||
if "metadata" in result:
|
||||
result["source_metadata"] = result.pop("metadata")
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=content,
|
||||
metadata=result,
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
116
venv/Lib/site-packages/langchain_community/retrievers/bm25.py
Normal file
116
venv/Lib/site-packages/langchain_community/retrievers/bm25.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
|
||||
def default_preprocessing_func(text: str) -> List[str]:
|
||||
return text.split()
|
||||
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
"""`BM25` retriever without Elasticsearch."""
|
||||
|
||||
vectorizer: Any = None
|
||||
""" BM25 vectorizer."""
|
||||
docs: List[Document] = Field(repr=False)
|
||||
""" List of documents."""
|
||||
k: int = 4
|
||||
""" Number of documents to return."""
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
|
||||
""" Preprocessing function to use on the text before BM25 vectorization."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
bm25_params: Optional[Dict[str, Any]] = None,
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
|
||||
**kwargs: Any,
|
||||
) -> BM25Retriever:
|
||||
"""
|
||||
Create a BM25Retriever from a list of texts.
|
||||
Args:
|
||||
texts: A list of texts to vectorize.
|
||||
metadatas: A list of metadata dicts to associate with each text.
|
||||
ids: A list of ids to associate with each text.
|
||||
bm25_params: Parameters to pass to the BM25 vectorizer.
|
||||
preprocess_func: A function to preprocess each text before vectorization.
|
||||
**kwargs: Any other arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
A BM25Retriever instance.
|
||||
"""
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rank_bm25, please install with `pip install "
|
||||
"rank_bm25`."
|
||||
)
|
||||
|
||||
texts_processed = [preprocess_func(t) for t in texts]
|
||||
bm25_params = bm25_params or {}
|
||||
vectorizer = BM25Okapi(texts_processed, **bm25_params)
|
||||
metadatas = metadatas or ({} for _ in texts)
|
||||
if ids:
|
||||
docs = [
|
||||
Document(page_content=t, metadata=m, id=i)
|
||||
for t, m, i in zip(texts, metadatas, ids)
|
||||
]
|
||||
else:
|
||||
docs = [
|
||||
Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
|
||||
]
|
||||
return cls(
|
||||
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
*,
|
||||
bm25_params: Optional[Dict[str, Any]] = None,
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
|
||||
**kwargs: Any,
|
||||
) -> BM25Retriever:
|
||||
"""
|
||||
Create a BM25Retriever from a list of Documents.
|
||||
Args:
|
||||
documents: A list of Documents to vectorize.
|
||||
bm25_params: Parameters to pass to the BM25 vectorizer.
|
||||
preprocess_func: A function to preprocess each text before vectorization.
|
||||
**kwargs: Any other arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
A BM25Retriever instance.
|
||||
"""
|
||||
texts, metadatas, ids = zip(
|
||||
*((d.page_content, d.metadata, d.id) for d in documents)
|
||||
)
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
bm25_params=bm25_params,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
preprocess_func=preprocess_func,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
processed_query = self.preprocess_func(query)
|
||||
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
|
||||
return return_docs
|
||||
@@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents.base import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class BreebsRetriever(BaseRetriever):
|
||||
"""A retriever class for `Breebs`.
|
||||
|
||||
See https://www.breebs.com/ for more info.
|
||||
Args:
|
||||
breeb_key: The key to trigger the breeb
|
||||
(specialized knowledge pill on a specific topic).
|
||||
|
||||
To retrieve the list of all available Breebs : you can call https://breebs.promptbreeders.com/web/listbreebs
|
||||
"""
|
||||
|
||||
breeb_key: str
|
||||
url: str = "https://breebs.promptbreeders.com/knowledge"
|
||||
|
||||
def __init__(self, breeb_key: str):
|
||||
super().__init__(breeb_key=breeb_key) # type: ignore[call-arg]
|
||||
self.breeb_key = breeb_key
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Retrieve context for given query.
|
||||
Note that for time being there is no score."""
|
||||
r = requests.post(
|
||||
self.url,
|
||||
json={
|
||||
"breeb_key": self.breeb_key,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
if r.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
chunks = r.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=chunk["content"],
|
||||
metadata={"source": chunk["source_url"], "score": 1},
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class ChaindeskRetriever(BaseRetriever):
|
||||
"""`Chaindesk API` retriever."""
|
||||
|
||||
datastore_url: str
|
||||
top_k: Optional[int]
|
||||
api_key: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datastore_url: str,
|
||||
top_k: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.datastore_url = datastore_url
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.datastore_url,
|
||||
json={
|
||||
"query": query,
|
||||
**({"topK": self.top_k} if self.top_k is not None else {}),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.api_key is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata={"source": r["source"], "score": r["score"]},
|
||||
)
|
||||
for r in data["results"]
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST",
|
||||
self.datastore_url,
|
||||
json={
|
||||
"query": query,
|
||||
**({"topK": self.top_k} if self.top_k is not None else {}),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.api_key is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata={"source": r["source"], "score": r["score"]},
|
||||
)
|
||||
for r in data["results"]
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class ChatGPTPluginRetriever(BaseRetriever):
|
||||
"""`ChatGPT plugin` retriever."""
|
||||
|
||||
url: str
|
||||
"""URL of the ChatGPT plugin."""
|
||||
bearer_token: str
|
||||
"""Bearer token for the ChatGPT plugin."""
|
||||
top_k: int = 3
|
||||
"""Number of documents to return."""
|
||||
filter: Optional[dict] = None
|
||||
"""Filter to apply to the results."""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""Aiohttp session to use for requests."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
results = response.json()["results"][0]["results"]
|
||||
docs = []
|
||||
for d in results:
|
||||
content = d.pop("text")
|
||||
metadata = d.pop("metadata", d)
|
||||
if metadata.get("source_id"):
|
||||
metadata["source"] = metadata.pop("source_id")
|
||||
docs.append(Document(page_content=content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
url, json, headers = self._create_request(query)
|
||||
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=json) as response:
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.post(
|
||||
url, headers=headers, json=json
|
||||
) as response:
|
||||
res = await response.json()
|
||||
|
||||
results = res["results"][0]["results"]
|
||||
docs = []
|
||||
for d in results:
|
||||
content = d.pop("text")
|
||||
metadata = d.pop("metadata", d)
|
||||
if metadata.get("source_id"):
|
||||
metadata["source"] = metadata.pop("source_id")
|
||||
docs.append(Document(page_content=content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def _create_request(self, query: str) -> tuple[str, dict, dict]:
|
||||
url = f"{self.url}/query"
|
||||
json = {
|
||||
"queries": [
|
||||
{
|
||||
"query": query,
|
||||
"filter": self.filter,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
]
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.bearer_token}",
|
||||
}
|
||||
return url, json, headers
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
def _get_docs(response: Any) -> List[Document]:
|
||||
docs = (
|
||||
[]
|
||||
if "documents" not in response.generation_info
|
||||
else [
|
||||
Document(page_content=doc["snippet"], metadata=doc)
|
||||
for doc in response.generation_info["documents"]
|
||||
]
|
||||
)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=response.message.content,
|
||||
metadata={
|
||||
"type": "model_response",
|
||||
"citations": response.generation_info["citations"],
|
||||
"search_results": response.generation_info["search_results"],
|
||||
"search_queries": response.generation_info["search_queries"],
|
||||
"token_count": response.generation_info["token_count"],
|
||||
},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.30",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_cohere.CohereRagRetriever",
|
||||
)
|
||||
class CohereRagRetriever(BaseRetriever):
|
||||
"""Cohere Chat API with RAG."""
|
||||
|
||||
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
|
||||
"""
|
||||
When specified, the model's reply will be enriched with information found by
|
||||
querying each of the connectors (RAG). These will be returned as langchain
|
||||
documents.
|
||||
|
||||
Currently only accepts {"id": "web-search"}.
|
||||
"""
|
||||
|
||||
llm: BaseChatModel
|
||||
"""Cohere ChatModel to use."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = self.llm.generate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
).generations[0][0]
|
||||
return _get_docs(res)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
||||
res = (
|
||||
await self.llm.agenerate(
|
||||
messages,
|
||||
connectors=self.connectors,
|
||||
callbacks=run_manager.get_child(),
|
||||
**kwargs,
|
||||
)
|
||||
).generations[0][0]
|
||||
return _get_docs(res)
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class DataberryRetriever(BaseRetriever):
|
||||
"""`Databerry API` retriever."""
|
||||
|
||||
datastore_url: str
|
||||
top_k: Optional[int]
|
||||
api_key: Optional[str]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.datastore_url,
|
||||
json={
|
||||
"query": query,
|
||||
**({"topK": self.top_k} if self.top_k is not None else {}),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.api_key is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
data = response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata={"source": r["source"], "score": r["score"]},
|
||||
)
|
||||
for r in data["results"]
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST",
|
||||
self.datastore_url,
|
||||
json={
|
||||
"query": query,
|
||||
**({"topK": self.top_k} if self.top_k is not None else {}),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{"Authorization": f"Bearer {self.api_key}"}
|
||||
if self.api_key is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata={"source": r["source"], "score": r["score"]},
|
||||
)
|
||||
for r in data["results"]
|
||||
]
|
||||
@@ -0,0 +1,208 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils.pydantic import get_fields
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Enumerator of the types of search to perform."""
|
||||
|
||||
similarity = "similarity"
|
||||
mmr = "mmr"
|
||||
|
||||
|
||||
class DocArrayRetriever(BaseRetriever):
|
||||
"""`DocArray Document Indices` retriever.
|
||||
|
||||
Currently, it supports 5 backends:
|
||||
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
|
||||
ElasticDocIndex, and WeaviateDocumentIndex.
|
||||
|
||||
Args:
|
||||
index: One of the above-mentioned index instances
|
||||
embeddings: Embedding model to represent text as vectors
|
||||
search_field: Field to consider for searching in the documents.
|
||||
Should be an embedding/vector/tensor.
|
||||
content_field: Field that represents the main content in your document schema.
|
||||
Will be used as a `page_content`. Everything else will go into `metadata`.
|
||||
search_type: Type of search to perform (similarity / mmr)
|
||||
filters: Filters applied for document retrieval.
|
||||
top_k: Number of documents to return
|
||||
"""
|
||||
|
||||
index: Any = None
|
||||
embeddings: Embeddings
|
||||
search_field: str
|
||||
content_field: str
|
||||
search_type: SearchType = SearchType.similarity
|
||||
top_k: int = 1
|
||||
filters: Optional[Any] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
query_emb = np.array(self.embeddings.embed_query(query))
|
||||
|
||||
if self.search_type == SearchType.similarity:
|
||||
results = self._similarity_search(query_emb)
|
||||
elif self.search_type == SearchType.mmr:
|
||||
results = self._mmr_search(query_emb)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Search type {self.search_type} does not exist. "
|
||||
f"Choose either 'similarity' or 'mmr'."
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _search(
|
||||
self, query_emb: np.ndarray, top_k: int
|
||||
) -> List[Union[Dict[str, Any], Any]]:
|
||||
"""
|
||||
Perform a search using the query embedding and return top_k documents.
|
||||
|
||||
Args:
|
||||
query_emb: Query represented as an embedding
|
||||
top_k: Number of documents to return
|
||||
|
||||
Returns:
|
||||
A list of top_k documents matching the query
|
||||
"""
|
||||
|
||||
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
|
||||
|
||||
filter_args = {}
|
||||
search_field = self.search_field
|
||||
if isinstance(self.index, WeaviateDocumentIndex):
|
||||
filter_args["where_filter"] = self.filters
|
||||
search_field = ""
|
||||
elif isinstance(self.index, ElasticDocIndex):
|
||||
filter_args["query"] = self.filters
|
||||
else:
|
||||
filter_args["filter_query"] = self.filters
|
||||
|
||||
if self.filters:
|
||||
query = (
|
||||
self.index.build_query() # get empty query object
|
||||
.find(
|
||||
query=query_emb, search_field=search_field
|
||||
) # add vector similarity search
|
||||
.filter(**filter_args) # add filter search
|
||||
.build(limit=top_k) # build the query
|
||||
)
|
||||
# execute the combined query and return the results
|
||||
docs = self.index.execute_query(query)
|
||||
if hasattr(docs, "documents"):
|
||||
docs = docs.documents
|
||||
docs = docs[:top_k]
|
||||
else:
|
||||
docs = self.index.find(
|
||||
query=query_emb, search_field=search_field, limit=top_k
|
||||
).documents
|
||||
return docs
|
||||
|
||||
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
|
||||
"""
|
||||
Perform a similarity search.
|
||||
|
||||
Args:
|
||||
query_emb: Query represented as an embedding
|
||||
|
||||
Returns:
|
||||
A list of documents most similar to the query
|
||||
"""
|
||||
docs = self._search(query_emb=query_emb, top_k=self.top_k)
|
||||
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
|
||||
return results
|
||||
|
||||
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
|
||||
"""
|
||||
Perform a maximal marginal relevance (mmr) search.
|
||||
|
||||
Args:
|
||||
query_emb: Query represented as an embedding
|
||||
|
||||
Returns:
|
||||
A list of diverse documents related to the query
|
||||
"""
|
||||
docs = self._search(query_emb=query_emb, top_k=20)
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
query_emb,
|
||||
[
|
||||
doc[self.search_field]
|
||||
if isinstance(doc, dict)
|
||||
else getattr(doc, self.search_field)
|
||||
for doc in docs
|
||||
],
|
||||
k=self.top_k,
|
||||
)
|
||||
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
|
||||
return results
|
||||
|
||||
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
|
||||
"""
|
||||
Convert a DocArray document (which also might be a dict)
|
||||
to a langchain document format.
|
||||
|
||||
DocArray document can contain arbitrary fields, so the mapping is done
|
||||
in the following way:
|
||||
|
||||
page_content <-> content_field
|
||||
metadata <-> all other fields excluding
|
||||
tensors and embeddings (so float, int, string)
|
||||
|
||||
Args:
|
||||
doc: DocArray document
|
||||
|
||||
Returns:
|
||||
Document in langchain format
|
||||
|
||||
Raises:
|
||||
ValueError: If the document doesn't contain the content field
|
||||
"""
|
||||
|
||||
fields = doc.keys() if isinstance(doc, dict) else get_fields(doc)
|
||||
|
||||
if self.content_field not in fields:
|
||||
raise ValueError(
|
||||
f"Document does not contain the content field - {self.content_field}."
|
||||
)
|
||||
lc_doc = Document(
|
||||
page_content=doc[self.content_field]
|
||||
if isinstance(doc, dict)
|
||||
else getattr(doc, self.content_field)
|
||||
)
|
||||
|
||||
for name in fields:
|
||||
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
|
||||
if (
|
||||
isinstance(value, (str, int, float, bool))
|
||||
and name != self.content_field
|
||||
):
|
||||
lc_doc.metadata[name] = value
|
||||
|
||||
return lc_doc
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Wrapper around Dria Retriever."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities import DriaAPIWrapper
|
||||
|
||||
|
||||
class DriaRetriever(BaseRetriever):
|
||||
"""`Dria` retriever using the DriaAPIWrapper."""
|
||||
|
||||
api_wrapper: DriaAPIWrapper
|
||||
|
||||
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
|
||||
"""
|
||||
Initialize the DriaRetriever with a DriaAPIWrapper instance.
|
||||
|
||||
Args:
|
||||
api_key: The API key for Dria.
|
||||
contract_id: The contract ID of the knowledge base to interact with.
|
||||
"""
|
||||
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
|
||||
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
def create_knowledge_base(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
category: str = "Unspecified",
|
||||
embedding: str = "jina",
|
||||
) -> str:
|
||||
"""Create a new knowledge base in Dria.
|
||||
|
||||
Args:
|
||||
name: The name of the knowledge base.
|
||||
description: The description of the knowledge base.
|
||||
category: The category of the knowledge base.
|
||||
embedding: The embedding model to use for the knowledge base.
|
||||
|
||||
|
||||
Returns:
|
||||
The ID of the created knowledge base.
|
||||
"""
|
||||
response = self.api_wrapper.create_knowledge_base(
|
||||
name, description, category, embedding
|
||||
)
|
||||
return response
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: List,
|
||||
) -> None:
|
||||
"""Add texts to the Dria knowledge base.
|
||||
|
||||
Args:
|
||||
texts: An iterable of texts and metadatas to add to the knowledge base.
|
||||
|
||||
Returns:
|
||||
List of IDs representing the added texts.
|
||||
"""
|
||||
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
|
||||
self.api_wrapper.insert_data(data)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Retrieve relevant documents from Dria based on a query.
|
||||
|
||||
Args:
|
||||
query: The query string to search for in the knowledge base.
|
||||
run_manager: Callback manager for the retriever run.
|
||||
|
||||
Returns:
|
||||
A list of Documents containing the search results.
|
||||
"""
|
||||
results = self.api_wrapper.search(query)
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result["metadata"],
|
||||
metadata={"id": result["id"], "score": result["score"]},
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return docs
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Iterable, List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
"""`Elasticsearch` retriever that uses `BM25`.
|
||||
|
||||
To connect to an Elasticsearch instance that requires login credentials,
|
||||
including Elastic Cloud, use the Elasticsearch URL format
|
||||
https://username:password@es_host:9243. For example, to connect to Elastic
|
||||
Cloud, create the Elasticsearch URL with the required authentication details and
|
||||
pass it to the ElasticVectorSearch constructor as the named parameter
|
||||
elasticsearch_url.
|
||||
|
||||
You can obtain your Elastic Cloud URL and login credentials by logging in to the
|
||||
Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
|
||||
navigating to the "Deployments" page.
|
||||
|
||||
To obtain your Elastic Cloud password for the default "elastic" user:
|
||||
|
||||
1. Log in to the Elastic Cloud console at https://cloud.elastic.co
|
||||
2. Go to "Security" > "Users"
|
||||
3. Locate the "elastic" user and click "Edit"
|
||||
4. Click "Reset password"
|
||||
5. Follow the prompts to reset the password
|
||||
|
||||
The format for Elastic Cloud URLs is
|
||||
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
|
||||
"""
|
||||
|
||||
client: Any
|
||||
"""Elasticsearch client."""
|
||||
index_name: str
|
||||
"""Name of the index to use in Elasticsearch."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
|
||||
) -> ElasticSearchBM25Retriever:
|
||||
"""
|
||||
Create a ElasticSearchBM25Retriever from a list of texts.
|
||||
|
||||
Args:
|
||||
elasticsearch_url: URL of the Elasticsearch instance to connect to.
|
||||
index_name: Name of the index to use in Elasticsearch.
|
||||
k1: BM25 parameter k1.
|
||||
b: BM25 parameter b.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
# Create an Elasticsearch client instance
|
||||
es = Elasticsearch(elasticsearch_url)
|
||||
|
||||
# Define the index settings and mappings
|
||||
settings = {
|
||||
"analysis": {"analyzer": {"default": {"type": "standard"}}},
|
||||
"similarity": {
|
||||
"custom_bm25": {
|
||||
"type": "BM25",
|
||||
"k1": k1,
|
||||
"b": b,
|
||||
}
|
||||
},
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"similarity": "custom_bm25", # Use the custom BM25 similarity
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Create the index with the specified settings and mappings
|
||||
es.indices.create(index=index_name, mappings=mappings, settings=settings)
|
||||
return cls(client=es, index_name=index_name)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
refresh_indices: bool = True,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the retriever.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the retriever.
|
||||
refresh_indices: bool to refresh ElasticSearch indices
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the retriever.
|
||||
"""
|
||||
try:
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import elasticsearch python package. "
|
||||
"Please install it with `pip install elasticsearch`."
|
||||
)
|
||||
requests = []
|
||||
ids = []
|
||||
for i, text in enumerate(texts):
|
||||
_id = str(uuid.uuid4())
|
||||
request = {
|
||||
"_op_type": "index",
|
||||
"_index": self.index_name,
|
||||
"content": text,
|
||||
"_id": _id,
|
||||
}
|
||||
ids.append(_id)
|
||||
requests.append(request)
|
||||
bulk(self.client, requests)
|
||||
|
||||
if refresh_indices:
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
return ids
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
query_dict = {"query": {"match": {"content": query}}}
|
||||
res = self.client.search(index=self.index_name, body=query_dict)
|
||||
|
||||
docs = []
|
||||
for r in res["hits"]["hits"]:
|
||||
docs.append(Document(page_content=r["_source"]["content"]))
|
||||
return docs
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Wrapper around Embedchain Retriever."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class EmbedchainRetriever(BaseRetriever):
|
||||
"""`Embedchain` retriever."""
|
||||
|
||||
client: Any
|
||||
"""Embedchain Pipeline."""
|
||||
|
||||
@classmethod
|
||||
def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever:
|
||||
"""
|
||||
Create a EmbedchainRetriever from a YAML configuration file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the YAML configuration file. If not provided,
|
||||
a default configuration is used.
|
||||
|
||||
Returns:
|
||||
An instance of EmbedchainRetriever.
|
||||
|
||||
"""
|
||||
from embedchain import Pipeline
|
||||
|
||||
# Create an Embedchain Pipeline instance
|
||||
if yaml_path:
|
||||
client = Pipeline.from_config(yaml_path=yaml_path)
|
||||
else:
|
||||
client = Pipeline()
|
||||
return cls(client=client)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the retriever.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings/URLs to add to the retriever.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the retriever.
|
||||
"""
|
||||
ids = []
|
||||
for text in texts:
|
||||
_id = self.client.add(text)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
res = self.client.search(query)
|
||||
|
||||
docs = []
|
||||
for r in res:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=r["context"],
|
||||
metadata={
|
||||
"source": r["metadata"]["url"],
|
||||
"document_id": r["metadata"]["doc_id"],
|
||||
},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
@@ -0,0 +1,126 @@
|
||||
"""Retriever wrapper for Google Cloud Document AI Warehouse."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
|
||||
from langchain_community.utilities.vertexai import get_client_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud.contentwarehouse_v1 import (
|
||||
DocumentServiceClient,
|
||||
RequestMetadata,
|
||||
SearchDocumentsRequest,
|
||||
)
|
||||
from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
|
||||
SearchDocumentsPager,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.32",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_google_community.DocumentAIWarehouseRetriever",
|
||||
)
|
||||
class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
|
||||
"""A retriever based on Document AI Warehouse.
|
||||
|
||||
Documents should be created and documents should be uploaded
|
||||
in a separate flow, and this retriever uses only Document AI
|
||||
schema_id provided to search for relevant documents.
|
||||
|
||||
More info: https://cloud.google.com/document-ai-warehouse.
|
||||
"""
|
||||
|
||||
location: str = "us"
|
||||
"""Google Cloud location where Document AI Warehouse is placed."""
|
||||
project_number: str
|
||||
"""Google Cloud project number, should contain digits only."""
|
||||
schema_id: Optional[str] = None
|
||||
"""Document AI Warehouse schema to query against.
|
||||
If nothing is provided, all documents in the project will be searched."""
|
||||
qa_size_limit: int = 5
|
||||
"""The limit on the number of documents returned."""
|
||||
client: "DocumentServiceClient" = None #: :meta private:
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validates the environment."""
|
||||
try:
|
||||
from google.cloud.contentwarehouse_v1 import DocumentServiceClient
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"google.cloud.contentwarehouse is not installed."
|
||||
"Please install it with pip install google-cloud-contentwarehouse"
|
||||
) from exc
|
||||
|
||||
values["project_number"] = get_from_dict_or_env(
|
||||
values, "project_number", "PROJECT_NUMBER"
|
||||
)
|
||||
values["client"] = DocumentServiceClient(
|
||||
client_info=get_client_info(module="document-ai-warehouse")
|
||||
)
|
||||
return values
|
||||
|
||||
def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
|
||||
from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo
|
||||
|
||||
user_info = UserInfo(id=f"user:{user_ldap}")
|
||||
return RequestMetadata(user_info=user_info)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
request = self._prepare_search_request(query, **kwargs)
|
||||
response = self.client.search_documents(request=request)
|
||||
return self._parse_search_response(response=response)
|
||||
|
||||
def _prepare_search_request(
|
||||
self, query: str, **kwargs: Any
|
||||
) -> "SearchDocumentsRequest":
|
||||
from google.cloud.contentwarehouse_v1 import (
|
||||
DocumentQuery,
|
||||
SearchDocumentsRequest,
|
||||
)
|
||||
|
||||
try:
|
||||
user_ldap = kwargs["user_ldap"]
|
||||
except KeyError:
|
||||
raise ValueError("Argument user_ldap should be provided!")
|
||||
|
||||
request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
|
||||
schemas = []
|
||||
if self.schema_id:
|
||||
schemas.append(
|
||||
self.client.document_schema_path(
|
||||
project=self.project_number,
|
||||
location=self.location,
|
||||
document_schema=self.schema_id,
|
||||
)
|
||||
)
|
||||
return SearchDocumentsRequest(
|
||||
parent=self.client.common_location_path(self.project_number, self.location),
|
||||
request_metadata=request_metadata,
|
||||
document_query=DocumentQuery(
|
||||
query=query, is_nl_query=True, document_schema_names=schemas
|
||||
),
|
||||
qa_size_limit=self.qa_size_limit,
|
||||
)
|
||||
|
||||
def _parse_search_response(
|
||||
self, response: "SearchDocumentsPager"
|
||||
) -> List[Document]:
|
||||
documents = []
|
||||
for doc in response.matching_documents:
|
||||
metadata = {
|
||||
"title": doc.document.title,
|
||||
"source": doc.document.raw_document_path,
|
||||
}
|
||||
documents.append(
|
||||
Document(page_content=doc.search_text_snippet, metadata=metadata)
|
||||
)
|
||||
return documents
|
||||
@@ -0,0 +1,491 @@
|
||||
"""Retriever wrapper for Google Vertex AI Search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_community.utilities.vertexai import get_client_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.api_core.client_options import ClientOptions
|
||||
from google.cloud.discoveryengine_v1beta import SearchRequest, SearchResult
|
||||
|
||||
|
||||
class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||
project_id: str
|
||||
"""Google Cloud Project ID."""
|
||||
data_store_id: Optional[str] = None
|
||||
"""Vertex AI Search data store ID."""
|
||||
search_engine_id: Optional[str] = None
|
||||
"""Vertex AI Search app ID."""
|
||||
location_id: str = "global"
|
||||
"""Vertex AI Search data store location."""
|
||||
serving_config_id: str = "default_config"
|
||||
"""Vertex AI Search serving config ID."""
|
||||
credentials: Any = None
|
||||
"""The default custom credentials (google.auth.credentials.Credentials) to use
|
||||
when making API calls. If not provided, credentials will be ascertained from
|
||||
the environment."""
|
||||
engine_data_type: int = Field(default=0, ge=0, le=3)
|
||||
""" Defines the Vertex AI Search app data type
|
||||
0 - Unstructured data
|
||||
1 - Structured data
|
||||
2 - Website data
|
||||
3 - Blended search
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validates the environment."""
|
||||
try:
|
||||
from google.cloud import discoveryengine_v1beta # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"google.cloud.discoveryengine is not installed."
|
||||
"Please install it with pip install "
|
||||
"google-cloud-discoveryengine>=0.11.10"
|
||||
) from exc
|
||||
try:
|
||||
from google.api_core.exceptions import InvalidArgument # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"google.api_core.exceptions is not installed. "
|
||||
"Please install it with pip install google-api-core"
|
||||
) from exc
|
||||
|
||||
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
|
||||
|
||||
try:
|
||||
values["data_store_id"] = get_from_dict_or_env(
|
||||
values, "data_store_id", "DATA_STORE_ID"
|
||||
)
|
||||
values["search_engine_id"] = get_from_dict_or_env(
|
||||
values, "search_engine_id", "SEARCH_ENGINE_ID"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def client_options(self) -> "ClientOptions":
|
||||
from google.api_core.client_options import ClientOptions
|
||||
|
||||
return ClientOptions(
|
||||
api_endpoint=(
|
||||
f"{self.location_id}-discoveryengine.googleapis.com"
|
||||
if self.location_id != "global"
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_structured_search_response(
|
||||
self, results: Sequence[SearchResult]
|
||||
) -> List[Document]:
|
||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||
import json
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
for result in results:
|
||||
document_dict = MessageToDict(
|
||||
result.document._pb, preserving_proto_field_name=True
|
||||
)
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=json.dumps(document_dict.get("struct_data", {})),
|
||||
metadata={"id": document_dict["id"], "name": document_dict["name"]},
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
def _convert_unstructured_search_response(
|
||||
self, results: Sequence[SearchResult], chunk_type: str
|
||||
) -> List[Document]:
|
||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
for result in results:
|
||||
document_dict = MessageToDict(
|
||||
result.document._pb, preserving_proto_field_name=True
|
||||
)
|
||||
derived_struct_data = document_dict.get("derived_struct_data")
|
||||
if not derived_struct_data:
|
||||
continue
|
||||
|
||||
doc_metadata = document_dict.get("struct_data", {})
|
||||
doc_metadata["id"] = document_dict["id"]
|
||||
|
||||
if chunk_type not in derived_struct_data:
|
||||
continue
|
||||
|
||||
for chunk in derived_struct_data[chunk_type]:
|
||||
chunk_metadata = doc_metadata.copy()
|
||||
chunk_metadata["source"] = derived_struct_data.get("link", "")
|
||||
|
||||
if chunk_type == "extractive_answers":
|
||||
chunk_metadata["source"] += f":{chunk.get('pageNumber', '')}"
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=chunk.get("content", ""), metadata=chunk_metadata
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
def _convert_website_search_response(
|
||||
self, results: Sequence[SearchResult], chunk_type: str
|
||||
) -> List[Document]:
|
||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
for result in results:
|
||||
document_dict = MessageToDict(
|
||||
result.document._pb, preserving_proto_field_name=True
|
||||
)
|
||||
derived_struct_data = document_dict.get("derived_struct_data")
|
||||
if not derived_struct_data:
|
||||
continue
|
||||
|
||||
doc_metadata = document_dict.get("struct_data", {})
|
||||
doc_metadata["id"] = document_dict["id"]
|
||||
doc_metadata["source"] = derived_struct_data.get("link", "")
|
||||
if derived_struct_data.get("title") is not None:
|
||||
doc_metadata["title"] = derived_struct_data.get("title")
|
||||
|
||||
if chunk_type not in derived_struct_data:
|
||||
continue
|
||||
|
||||
text_field = "snippet" if chunk_type == "snippets" else "content"
|
||||
|
||||
for chunk in derived_struct_data[chunk_type]:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=chunk.get(text_field, ""), metadata=doc_metadata
|
||||
)
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No {chunk_type} could be found.") # noqa: T201
|
||||
if chunk_type == "extractive_answers":
|
||||
print( # noqa: T201
|
||||
"Make sure that your data store is using Advanced Website "
|
||||
"Indexing.\n"
|
||||
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.33",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_google_community.VertexAISearchRetriever",
|
||||
)
|
||||
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex AI Search` retriever.
|
||||
|
||||
For a detailed explanation of the Vertex AI Search concepts
|
||||
and configuration parameters, refer to the product documentation.
|
||||
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
|
||||
"""
|
||||
|
||||
filter: Optional[str] = None
|
||||
"""Filter expression."""
|
||||
get_extractive_answers: bool = False
|
||||
"""If True return Extractive Answers, otherwise return Extractive Segments or Snippets.""" # noqa: E501
|
||||
max_documents: int = Field(default=5, ge=1, le=100)
|
||||
"""The maximum number of documents to return."""
|
||||
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
|
||||
"""The maximum number of extractive answers returned in each search result.
|
||||
At most 5 answers will be returned for each SearchResult.
|
||||
"""
|
||||
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
|
||||
"""The maximum number of extractive segments returned in each search result.
|
||||
Currently one segment will be returned for each SearchResult.
|
||||
"""
|
||||
query_expansion_condition: int = Field(default=1, ge=0, le=2)
|
||||
"""Specification to determine under which conditions query expansion should occur.
|
||||
0 - Unspecified query expansion condition. In this case, server behavior defaults
|
||||
to disabled
|
||||
1 - Disabled query expansion. Only the exact search query is used, even if
|
||||
SearchResponse.total_size is zero.
|
||||
2 - Automatic query expansion built by the Search API.
|
||||
"""
|
||||
spell_correction_mode: int = Field(default=2, ge=0, le=2)
|
||||
"""Specification to determine under which conditions query expansion should occur.
|
||||
0 - Unspecified spell correction mode. In this case, server behavior defaults
|
||||
to auto.
|
||||
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
|
||||
and put in the `SearchResponse.corrected_query`.
|
||||
The spell suggestion will not be used as the search query.
|
||||
2 - Automatic spell correction built by the Search API.
|
||||
Search will be based on the corrected query if found.
|
||||
"""
|
||||
|
||||
# type is SearchServiceClient but can't be set due to optional imports
|
||||
_client: Any = None
|
||||
_serving_config: str
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initializes private fields."""
|
||||
try:
|
||||
from google.cloud.discoveryengine_v1beta import SearchServiceClient
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"google.cloud.discoveryengine is not installed."
|
||||
"Please install it with pip install google-cloud-discoveryengine"
|
||||
) from exc
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# For more information, refer to:
|
||||
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
|
||||
self._client = SearchServiceClient(
|
||||
credentials=self.credentials,
|
||||
client_options=self.client_options,
|
||||
client_info=get_client_info(module="vertex-ai-search"),
|
||||
)
|
||||
|
||||
if self.engine_data_type == 3 and not self.search_engine_id:
|
||||
raise ValueError(
|
||||
"search_engine_id must be specified for blended search apps."
|
||||
)
|
||||
|
||||
if self.search_engine_id:
|
||||
self._serving_config = f"projects/{self.project_id}/locations/{self.location_id}/collections/default_collection/engines/{self.search_engine_id}/servingConfigs/default_config" # noqa: E501
|
||||
elif self.data_store_id:
|
||||
self._serving_config = self._client.serving_config_path(
|
||||
project=self.project_id,
|
||||
location=self.location_id,
|
||||
data_store=self.data_store_id,
|
||||
serving_config=self.serving_config_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either data_store_id or search_engine_id must be specified."
|
||||
)
|
||||
|
||||
def _create_search_request(self, query: str) -> SearchRequest:
|
||||
"""Prepares a SearchRequest object."""
|
||||
from google.cloud.discoveryengine_v1beta import SearchRequest
|
||||
|
||||
query_expansion_spec = SearchRequest.QueryExpansionSpec(
|
||||
condition=self.query_expansion_condition,
|
||||
)
|
||||
|
||||
spell_correction_spec = SearchRequest.SpellCorrectionSpec(
|
||||
mode=self.spell_correction_mode
|
||||
)
|
||||
|
||||
if self.engine_data_type == 0:
|
||||
if self.get_extractive_answers:
|
||||
extractive_content_spec = (
|
||||
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
|
||||
max_extractive_answer_count=self.max_extractive_answer_count,
|
||||
)
|
||||
)
|
||||
else:
|
||||
extractive_content_spec = (
|
||||
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
|
||||
max_extractive_segment_count=self.max_extractive_segment_count,
|
||||
)
|
||||
)
|
||||
content_search_spec = SearchRequest.ContentSearchSpec(
|
||||
extractive_content_spec=extractive_content_spec
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
content_search_spec = None
|
||||
elif self.engine_data_type in (2, 3):
|
||||
content_search_spec = SearchRequest.ContentSearchSpec(
|
||||
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
|
||||
max_extractive_answer_count=self.max_extractive_answer_count,
|
||||
),
|
||||
snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec(
|
||||
return_snippet=True
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"2 (Website), or 3 (Blended) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
return SearchRequest(
|
||||
query=query,
|
||||
filter=self.filter,
|
||||
serving_config=self._serving_config,
|
||||
page_size=self.max_documents,
|
||||
content_search_spec=content_search_spec,
|
||||
query_expansion_spec=query_expansion_spec,
|
||||
spell_correction_spec=spell_correction_spec,
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
return self.get_relevant_documents_with_response(query)[0]
|
||||
|
||||
def get_relevant_documents_with_response(
|
||||
self, query: str
|
||||
) -> Tuple[List[Document], Any]:
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
search_request = self._create_search_request(query)
|
||||
|
||||
try:
|
||||
response = self._client.search(search_request)
|
||||
except InvalidArgument as exc:
|
||||
raise type(exc)(
|
||||
exc.message
|
||||
+ " This might be due to engine_data_type not set correctly."
|
||||
)
|
||||
|
||||
if self.engine_data_type == 0:
|
||||
chunk_type = (
|
||||
"extractive_answers"
|
||||
if self.get_extractive_answers
|
||||
else "extractive_segments"
|
||||
)
|
||||
documents = self._convert_unstructured_search_response(
|
||||
response.results, chunk_type
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
documents = self._convert_structured_search_response(response.results)
|
||||
elif self.engine_data_type in (2, 3):
|
||||
chunk_type = (
|
||||
"extractive_answers" if self.get_extractive_answers else "snippets"
|
||||
)
|
||||
documents = self._convert_website_search_response(
|
||||
response.results, chunk_type
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"2 (Website), or 3 (Blended) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
return documents, response
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.33",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_google_community.VertexAIMultiTurnSearchRetriever",
|
||||
)
|
||||
class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
BaseRetriever, _BaseGoogleVertexAISearchRetriever
|
||||
):
|
||||
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
|
||||
|
||||
conversation_id: str = "-"
|
||||
"""Vertex AI Search Conversation ID."""
|
||||
|
||||
# type is ConversationalSearchServiceClient but can't be set due to optional imports
|
||||
_client: Any = None
|
||||
_serving_config: str
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
from google.cloud.discoveryengine_v1beta import (
|
||||
ConversationalSearchServiceClient,
|
||||
)
|
||||
|
||||
self._client = ConversationalSearchServiceClient(
|
||||
credentials=self.credentials,
|
||||
client_options=self.client_options,
|
||||
client_info=get_client_info(module="vertex-ai-search"),
|
||||
)
|
||||
|
||||
if not self.data_store_id:
|
||||
raise ValueError("data_store_id is required for MultiTurnSearchRetriever.")
|
||||
|
||||
self._serving_config = self._client.serving_config_path(
|
||||
project=self.project_id,
|
||||
location=self.location_id,
|
||||
data_store=self.data_store_id,
|
||||
serving_config=self.serving_config_id,
|
||||
)
|
||||
|
||||
if self.engine_data_type == 1 or self.engine_data_type == 3:
|
||||
raise NotImplementedError(
|
||||
"Data store type 1 (Structured) and 3 (Blended)"
|
||||
"is not currently supported for multi-turn search."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
from google.cloud.discoveryengine_v1beta import (
|
||||
ConverseConversationRequest,
|
||||
TextInput,
|
||||
)
|
||||
|
||||
request = ConverseConversationRequest(
|
||||
name=self._client.conversation_path(
|
||||
self.project_id,
|
||||
self.location_id,
|
||||
self.data_store_id,
|
||||
self.conversation_id,
|
||||
),
|
||||
serving_config=self._serving_config,
|
||||
query=TextInput(input=query),
|
||||
)
|
||||
response = self._client.converse_conversation(request)
|
||||
|
||||
if self.engine_data_type == 2:
|
||||
return self._convert_website_search_response(
|
||||
response.search_results, "extractive_answers"
|
||||
)
|
||||
|
||||
return self._convert_unstructured_search_response(
|
||||
response.search_results, "extractive_answers"
|
||||
)
|
||||
|
||||
|
||||
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex Search API` retriever alias for backwards compatibility.
|
||||
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
60
venv/Lib/site-packages/langchain_community/retrievers/kay.py
Normal file
60
venv/Lib/site-packages/langchain_community/retrievers/kay.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class KayAiRetriever(BaseRetriever):
|
||||
"""
|
||||
Retriever for Kay.ai datasets.
|
||||
|
||||
To work properly, expects you to have KAY_API_KEY env variable set.
|
||||
You can get one for free at https://kay.ai/.
|
||||
"""
|
||||
|
||||
client: Any
|
||||
num_contexts: int
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
data_types: List[str],
|
||||
num_contexts: int = 6,
|
||||
) -> KayAiRetriever:
|
||||
"""
|
||||
Create a KayRetriever given a Kay dataset id and a list of datasources.
|
||||
|
||||
Args:
|
||||
dataset_id: A dataset id category in Kay, like "company"
|
||||
data_types: A list of datasources present within a dataset. For
|
||||
"company" the corresponding datasources could be
|
||||
["10-K", "10-Q", "8-K", "PressRelease"].
|
||||
num_contexts: The number of documents to retrieve on each query.
|
||||
Defaults to 6.
|
||||
"""
|
||||
try:
|
||||
from kay.rag.retrievers import KayRetriever
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import kay python package. Please install it with "
|
||||
"`pip install kay`.",
|
||||
)
|
||||
|
||||
client = KayRetriever(dataset_id, data_types)
|
||||
return cls(client=client, num_contexts=num_contexts)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
ctxs = self.client.query(query=query, num_context=self.num_contexts)
|
||||
docs = []
|
||||
for ctx in ctxs:
|
||||
page_content = ctx.pop("chunk_embed_text", None)
|
||||
if page_content is None:
|
||||
continue
|
||||
docs.append(Document(page_content=page_content, metadata={**ctx}))
|
||||
return docs
|
||||
496
venv/Lib/site-packages/langchain_community/retrievers/kendra.py
Normal file
496
venv/Lib/site-packages/langchain_community/retrievers/kendra.py
Normal file
@@ -0,0 +1,496 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
model_validator,
|
||||
validator,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
def clean_excerpt(excerpt: str) -> str:
|
||||
"""Clean an excerpt from Kendra.
|
||||
|
||||
Args:
|
||||
excerpt: The excerpt to clean.
|
||||
|
||||
Returns:
|
||||
The cleaned excerpt.
|
||||
|
||||
"""
|
||||
if not excerpt:
|
||||
return excerpt
|
||||
res = re.sub(r"\s+", " ", excerpt).replace("...", "")
|
||||
return res
|
||||
|
||||
|
||||
def combined_text(item: "ResultItem") -> str:
|
||||
"""Combine a ResultItem title and excerpt into a single string.
|
||||
|
||||
Args:
|
||||
item: the ResultItem of a Kendra search.
|
||||
|
||||
Returns:
|
||||
A combined text of the title and excerpt of the given item.
|
||||
|
||||
"""
|
||||
text = ""
|
||||
title = item.get_title()
|
||||
if title:
|
||||
text += f"Document Title: {title}\n"
|
||||
excerpt = clean_excerpt(item.get_excerpt())
|
||||
if excerpt:
|
||||
text += f"Document Excerpt: \n{excerpt}\n"
|
||||
return text
|
||||
|
||||
|
||||
DocumentAttributeValueType = Union[str, int, List[str], None]
|
||||
"""Possible types of a DocumentAttributeValue.
|
||||
|
||||
Dates are also represented as str.
|
||||
"""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class Highlight(BaseModel, extra="allow"):
|
||||
"""Information that highlights the keywords in the excerpt."""
|
||||
|
||||
BeginOffset: int
|
||||
"""The zero-based location in the excerpt where the highlight starts."""
|
||||
EndOffset: int
|
||||
"""The zero-based location in the excerpt where the highlight ends."""
|
||||
TopAnswer: Optional[bool]
|
||||
"""Indicates whether the result is the best one."""
|
||||
Type: Optional[str]
|
||||
"""The highlight type: STANDARD or THESAURUS_SYNONYM."""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class TextWithHighLights(BaseModel, extra="allow"):
|
||||
"""Text with highlights."""
|
||||
|
||||
Text: str
|
||||
"""The text."""
|
||||
Highlights: Optional[Any]
|
||||
"""The highlights."""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class AdditionalResultAttributeValue(BaseModel, extra="allow"):
|
||||
"""Value of an additional result attribute."""
|
||||
|
||||
TextWithHighlightsValue: TextWithHighLights
|
||||
"""The text with highlights value."""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class AdditionalResultAttribute(BaseModel, extra="allow"):
|
||||
"""Additional result attribute."""
|
||||
|
||||
Key: str
|
||||
"""The key of the attribute."""
|
||||
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
|
||||
"""The type of the value."""
|
||||
Value: AdditionalResultAttributeValue
|
||||
"""The value of the attribute."""
|
||||
|
||||
def get_value_text(self) -> str:
|
||||
return self.Value.TextWithHighlightsValue.Text
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class DocumentAttributeValue(BaseModel, extra="allow"):
|
||||
"""Value of a document attribute."""
|
||||
|
||||
DateValue: Optional[str] = None
|
||||
"""The date expressed as an ISO 8601 string."""
|
||||
LongValue: Optional[int] = None
|
||||
"""The long value."""
|
||||
StringListValue: Optional[List[str]] = None
|
||||
"""The string list value."""
|
||||
StringValue: Optional[str] = None
|
||||
"""The string value."""
|
||||
|
||||
@property
|
||||
def value(self) -> DocumentAttributeValueType:
|
||||
"""The only defined document attribute value or None.
|
||||
According to Amazon Kendra, you can only provide one
|
||||
value for a document attribute.
|
||||
"""
|
||||
if self.DateValue:
|
||||
return self.DateValue
|
||||
if self.LongValue:
|
||||
return self.LongValue
|
||||
if self.StringListValue:
|
||||
return self.StringListValue
|
||||
if self.StringValue:
|
||||
return self.StringValue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class DocumentAttribute(BaseModel, extra="allow"):
|
||||
"""Document attribute."""
|
||||
|
||||
Key: str
|
||||
"""The key of the attribute."""
|
||||
Value: DocumentAttributeValue
|
||||
"""The value of the attribute."""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class ResultItem(BaseModel, ABC, extra="allow"):
|
||||
"""Base class of a result item."""
|
||||
|
||||
Id: Optional[str]
|
||||
"""The ID of the relevant result item."""
|
||||
DocumentId: Optional[str]
|
||||
"""The document ID."""
|
||||
DocumentURI: Optional[str]
|
||||
"""The document URI."""
|
||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||
"""The document attributes."""
|
||||
ScoreAttributes: Optional[dict]
|
||||
"""The kendra score confidence"""
|
||||
|
||||
@abstractmethod
|
||||
def get_title(self) -> str:
|
||||
"""Document title."""
|
||||
|
||||
@abstractmethod
|
||||
def get_excerpt(self) -> str:
|
||||
"""Document excerpt or passage original content as retrieved by Kendra."""
|
||||
|
||||
def get_additional_metadata(self) -> dict:
|
||||
"""Document additional metadata dict.
|
||||
This returns any extra metadata except these:
|
||||
* result_id
|
||||
* document_id
|
||||
* source
|
||||
* title
|
||||
* excerpt
|
||||
* document_attributes
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
|
||||
"""Document attributes dict."""
|
||||
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
|
||||
|
||||
def get_score_attribute(self) -> str:
|
||||
"""Document Score Confidence"""
|
||||
if self.ScoreAttributes is not None:
|
||||
return self.ScoreAttributes["ScoreConfidence"]
|
||||
else:
|
||||
return "NOT_AVAILABLE"
|
||||
|
||||
def to_doc(
|
||||
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
|
||||
) -> Document:
|
||||
"""Converts this item to a Document."""
|
||||
page_content = page_content_formatter(self)
|
||||
metadata = self.get_additional_metadata()
|
||||
metadata.update(
|
||||
{
|
||||
"result_id": self.Id,
|
||||
"document_id": self.DocumentId,
|
||||
"source": self.DocumentURI,
|
||||
"title": self.get_title(),
|
||||
"excerpt": self.get_excerpt(),
|
||||
"document_attributes": self.get_document_attributes_dict(),
|
||||
"score": self.get_score_attribute(),
|
||||
}
|
||||
)
|
||||
return Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
|
||||
class QueryResultItem(ResultItem):
|
||||
"""Query API result item."""
|
||||
|
||||
DocumentTitle: TextWithHighLights
|
||||
"""The document title."""
|
||||
FeedbackToken: Optional[str]
|
||||
"""Identifies a particular result from a particular query."""
|
||||
Format: Optional[str]
|
||||
"""
|
||||
If the Type is ANSWER, then format is either:
|
||||
* TABLE: a table excerpt is returned in TableExcerpt;
|
||||
* TEXT: a text excerpt is returned in DocumentExcerpt.
|
||||
"""
|
||||
Type: Optional[str]
|
||||
"""Type of result: DOCUMENT or QUESTION_ANSWER or ANSWER"""
|
||||
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
|
||||
"""One or more additional attributes associated with the result."""
|
||||
DocumentExcerpt: Optional[TextWithHighLights]
|
||||
"""Excerpt of the document text."""
|
||||
|
||||
def get_title(self) -> str:
|
||||
return self.DocumentTitle.Text
|
||||
|
||||
def get_attribute_value(self) -> str:
|
||||
if not self.AdditionalAttributes:
|
||||
return ""
|
||||
if not self.AdditionalAttributes[0]:
|
||||
return ""
|
||||
else:
|
||||
return self.AdditionalAttributes[0].get_value_text()
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if (
|
||||
self.AdditionalAttributes
|
||||
and self.AdditionalAttributes[0].Key == "AnswerText"
|
||||
):
|
||||
excerpt = self.get_attribute_value()
|
||||
elif self.DocumentExcerpt:
|
||||
excerpt = self.DocumentExcerpt.Text
|
||||
else:
|
||||
excerpt = ""
|
||||
|
||||
return excerpt
|
||||
|
||||
def get_additional_metadata(self) -> dict:
|
||||
additional_metadata = {"type": self.Type}
|
||||
return additional_metadata
|
||||
|
||||
|
||||
class RetrieveResultItem(ResultItem):
|
||||
"""Retrieve API result item."""
|
||||
|
||||
DocumentTitle: Optional[str]
|
||||
"""The document title."""
|
||||
Content: Optional[str]
|
||||
"""The content of the item."""
|
||||
|
||||
def get_title(self) -> str:
|
||||
return self.DocumentTitle or ""
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
return self.Content or ""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class QueryResult(BaseModel, extra="allow"):
|
||||
"""`Amazon Kendra Query API` search result.
|
||||
|
||||
It is composed of:
|
||||
* Relevant suggested answers: either a text excerpt or table excerpt.
|
||||
* Matching FAQs or questions-answer from your FAQ file.
|
||||
* Documents including an excerpt of each document with its title.
|
||||
"""
|
||||
|
||||
ResultItems: List[QueryResultItem]
|
||||
"""The result items."""
|
||||
|
||||
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class RetrieveResult(BaseModel, extra="allow"):
|
||||
"""`Amazon Kendra Retrieve API` search result.
|
||||
|
||||
It is composed of:
|
||||
* relevant passages or text excerpts given an input query.
|
||||
"""
|
||||
|
||||
QueryId: str
|
||||
"""The ID of the query."""
|
||||
ResultItems: List[RetrieveResultItem]
|
||||
"""The result items."""
|
||||
|
||||
|
||||
KENDRA_CONFIDENCE_MAPPING = {
|
||||
"NOT_AVAILABLE": 0.0,
|
||||
"LOW": 0.25,
|
||||
"MEDIUM": 0.50,
|
||||
"HIGH": 0.75,
|
||||
"VERY_HIGH": 1.0,
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.16",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.AmazonKendraRetriever",
|
||||
)
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""`Amazon Kendra Index` retriever.
|
||||
|
||||
Args:
|
||||
index_id: Kendra index id
|
||||
|
||||
region_name: The aws region e.g., `us-west-2`.
|
||||
Fallsback to AWS_DEFAULT_REGION env variable
|
||||
or region specified in ~/.aws/config.
|
||||
|
||||
credentials_profile_name: The name of the profile in the ~/.aws/credentials
|
||||
or ~/.aws/config files, which has either access keys or role information
|
||||
specified. If not specified, the default credential profile or, if on an
|
||||
EC2 instance, credentials from IMDS will be used.
|
||||
|
||||
top_k: No of results to return
|
||||
|
||||
attribute_filter: Additional filtering of results based on metadata
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
document_relevance_override_configurations: Overrides relevance tuning
|
||||
configurations of fields/attributes set at the index level
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
page_content_formatter: generates the Document page_content
|
||||
allowing access to all result item attributes. By default, it uses
|
||||
the item's title and excerpt.
|
||||
|
||||
client: boto3 client for Kendra
|
||||
|
||||
user_context: Provides information about the user context
|
||||
See: https://docs.aws.amazon.com/kendra/latest/APIReference
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
retriever = AmazonKendraRetriever(
|
||||
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
index_id: str
|
||||
region_name: Optional[str] = None
|
||||
credentials_profile_name: Optional[str] = None
|
||||
top_k: int = 3
|
||||
attribute_filter: Optional[Dict] = None
|
||||
document_relevance_override_configurations: Optional[List[Dict]] = None
|
||||
page_content_formatter: Callable[[ResultItem], str] = combined_text
|
||||
client: Any
|
||||
user_context: Optional[Dict] = None
|
||||
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
|
||||
|
||||
@validator("top_k")
|
||||
def validate_top_k(cls, value: int) -> int:
|
||||
if value < 0:
|
||||
raise ValueError(f"top_k ({value}) cannot be negative.")
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: Dict[str, Any]) -> Any:
|
||||
top_k = values.get("top_k")
|
||||
if top_k is not None and top_k < 0:
|
||||
raise ValueError(f"top_k ({top_k}) cannot be negative.")
|
||||
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
if values.get("credentials_profile_name"):
|
||||
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
client_params = {}
|
||||
if values.get("region_name"):
|
||||
client_params["region_name"] = values["region_name"]
|
||||
|
||||
values["client"] = session.client("kendra", **client_params)
|
||||
|
||||
return values
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
def _kendra_query(self, query: str) -> Sequence[ResultItem]:
|
||||
kendra_kwargs = {
|
||||
"IndexId": self.index_id,
|
||||
# truncate the query to ensure that
|
||||
# there is no validation exception from Kendra.
|
||||
"QueryText": query.strip()[0:999],
|
||||
"PageSize": self.top_k,
|
||||
}
|
||||
if self.attribute_filter is not None:
|
||||
kendra_kwargs["AttributeFilter"] = self.attribute_filter
|
||||
if self.document_relevance_override_configurations is not None:
|
||||
kendra_kwargs["DocumentRelevanceOverrideConfigurations"] = (
|
||||
self.document_relevance_override_configurations
|
||||
)
|
||||
if self.user_context is not None:
|
||||
kendra_kwargs["UserContext"] = self.user_context
|
||||
|
||||
response = self.client.retrieve(**kendra_kwargs)
|
||||
r_result = RetrieveResult.parse_obj(response)
|
||||
if r_result.ResultItems:
|
||||
return r_result.ResultItems
|
||||
|
||||
# Retrieve API returned 0 results, fall back to Query API
|
||||
response = self.client.query(**kendra_kwargs)
|
||||
q_result = QueryResult.parse_obj(response)
|
||||
return q_result.ResultItems
|
||||
|
||||
def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
|
||||
top_docs = [
|
||||
item.to_doc(self.page_content_formatter)
|
||||
for item in result_items[: self.top_k]
|
||||
]
|
||||
return top_docs
|
||||
|
||||
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
|
||||
"""
|
||||
Filter out the records that have a score confidence
|
||||
greater than the required threshold.
|
||||
"""
|
||||
if not self.min_score_confidence:
|
||||
return docs
|
||||
filtered_docs = [
|
||||
item
|
||||
for item in docs
|
||||
if (
|
||||
item.metadata.get("score") is not None
|
||||
and isinstance(item.metadata["score"], str)
|
||||
and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0)
|
||||
>= self.min_score_confidence
|
||||
)
|
||||
]
|
||||
return filtered_docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Run search on Kendra index and get top k documents
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
docs = retriever.invoke('This is my query')
|
||||
|
||||
"""
|
||||
result_items = self._kendra_query(query)
|
||||
top_k_docs = self._get_top_k_docs(result_items)
|
||||
return self._filter_by_score_confidence(top_k_docs)
|
||||
107
venv/Lib/site-packages/langchain_community/retrievers/knn.py
Normal file
107
venv/Lib/site-packages/langchain_community/retrievers/knn.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""KNN Retriever.
|
||||
Largely based on
|
||||
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
"""
|
||||
Create an index of embeddings for a list of contexts.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
embeddings: Embeddings model to use.
|
||||
|
||||
Returns:
|
||||
Index of embeddings.
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class KNNRetriever(BaseRetriever):
|
||||
"""`KNN` retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
index: Any = None
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
metadatas: Optional[List[dict]] = None
|
||||
"""List of metadatas corresponding with each text."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
"""Threshold for relevancy."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> KNNRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(
|
||||
embeddings=embeddings,
|
||||
index=index,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> KNNRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
# calc L2 norm
|
||||
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
|
||||
query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
|
||||
|
||||
similarities = index_embeds.dot(query_embeds)
|
||||
sorted_ix = np.argsort(-similarities)
|
||||
|
||||
denominator = np.max(similarities) - np.min(similarities) + 1e-6
|
||||
normalized_similarities = (similarities - np.min(similarities)) / denominator
|
||||
|
||||
top_k_results = [
|
||||
Document(
|
||||
page_content=self.texts[row],
|
||||
metadata=self.metadatas[row] if self.metadatas else {},
|
||||
)
|
||||
for row in sorted_ix[0 : self.k]
|
||||
if (
|
||||
self.relevancy_threshold is None
|
||||
or normalized_similarities[row] >= self.relevancy_threshold
|
||||
)
|
||||
]
|
||||
return top_k_results
|
||||
@@ -0,0 +1,86 @@
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class LlamaIndexRetriever(BaseRetriever):
|
||||
"""`LlamaIndex` retriever.
|
||||
|
||||
It is used for the question-answering with sources over
|
||||
an LlamaIndex data structure."""
|
||||
|
||||
index: Any = None
|
||||
"""LlamaIndex index to query."""
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the query method."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.core.base.response.schema import Response
|
||||
from llama_index.core.indices.base import BaseGPTIndex
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install `pip install llama-index` to use this retriever."
|
||||
)
|
||||
index = cast(BaseGPTIndex, self.index)
|
||||
|
||||
response = index.query(query, **self.query_kwargs)
|
||||
response = cast(Response, response)
|
||||
# parse source nodes
|
||||
docs = []
|
||||
for source_node in response.source_nodes:
|
||||
metadata = source_node.metadata or {}
|
||||
docs.append(
|
||||
Document(page_content=source_node.get_content(), metadata=metadata)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
class LlamaIndexGraphRetriever(BaseRetriever):
|
||||
"""`LlamaIndex` graph data structure retriever.
|
||||
|
||||
It is used for question-answering with sources over an LlamaIndex
|
||||
graph data structure."""
|
||||
|
||||
graph: Any = None
|
||||
"""LlamaIndex graph to query."""
|
||||
query_configs: List[Dict] = Field(default_factory=list)
|
||||
"""List of query configs to pass to the query method."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query."""
|
||||
try:
|
||||
from llama_index.core.base.response.schema import Response
|
||||
from llama_index.core.composability.base import (
|
||||
QUERY_CONFIG_TYPE,
|
||||
ComposableGraph,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install `pip install llama-index` to use this retriever."
|
||||
)
|
||||
graph = cast(ComposableGraph, self.graph)
|
||||
|
||||
# for now, inject response_mode="no_text" into query configs
|
||||
for query_config in self.query_configs:
|
||||
query_config["response_mode"] = "no_text"
|
||||
query_configs = cast(List[QUERY_CONFIG_TYPE], self.query_configs)
|
||||
response = graph.query(query, query_configs=query_configs)
|
||||
response = cast(Response, response)
|
||||
|
||||
# parse source nodes
|
||||
docs = []
|
||||
for source_node in response.source_nodes:
|
||||
metadata = source_node.metadata or {}
|
||||
docs.append(
|
||||
Document(page_content=source_node.get_content(), metadata=metadata)
|
||||
)
|
||||
return docs
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class MetalRetriever(BaseRetriever):
|
||||
"""`Metal API` retriever."""
|
||||
|
||||
client: Any
|
||||
"""The Metal client to use."""
|
||||
params: Optional[dict] = None
|
||||
"""The parameters to pass to the Metal client."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_client(cls, values: dict) -> Any:
|
||||
"""Validate that the client is of the correct type."""
|
||||
from metal_sdk.metal import Metal
|
||||
|
||||
if "client" in values:
|
||||
client = values["client"]
|
||||
if not isinstance(client, Metal):
|
||||
raise ValueError(
|
||||
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
|
||||
f"Instead, got {type(client)}"
|
||||
)
|
||||
|
||||
values["params"] = values.get("params", {})
|
||||
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
results = self.client.search({"text": query}, **self.params)
|
||||
final_results = []
|
||||
for r in results["data"]:
|
||||
metadata = {k: v for k, v in r.items() if k != "text"}
|
||||
final_results.append(Document(page_content=r["text"], metadata=metadata))
|
||||
return final_results
|
||||
150
venv/Lib/site-packages/langchain_community/retrievers/milvus.py
Normal file
150
venv/Lib/site-packages/langchain_community/retrievers/milvus.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Milvus Retriever"""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_community.vectorstores.milvus import Milvus
|
||||
|
||||
# TODO: Update to MilvusClient + Hybrid Search when available
|
||||
|
||||
|
||||
class MilvusRetriever(BaseRetriever):
|
||||
"""Milvus API retriever.
|
||||
|
||||
See detailed instructions here: https://python.langchain.com/docs/integrations/retrievers/milvus_hybrid_search/
|
||||
|
||||
Setup:
|
||||
Install ``langchain-milvus`` and other dependencies:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U pymilvus[model] langchain-milvus
|
||||
|
||||
Key init args:
|
||||
collection: Milvus Collection
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
retriever = MilvusCollectionHybridSearchRetriever(collection=collection)
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
query = "What are the story about ventures?"
|
||||
|
||||
retriever.invoke(query)
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
[Document(page_content="In 'The Lost Expedition' by Caspian Grey...", metadata={'doc_id': '449281835035545843'}),
|
||||
Document(page_content="In 'The Phantom Pilgrim' by Rowan Welles...", metadata={'doc_id': '449281835035545845'}),
|
||||
Document(page_content="In 'The Dreamwalker's Journey' by Lyra Snow..", metadata={'doc_id': '449281835035545846'})]
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke("What novels has Lila written and what are their contents?")
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
"Lila Rose has written 'The Memory Thief,' which follows a charismatic thief..."
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
embedding_function: Embeddings
|
||||
collection_name: str = "LangChainCollection"
|
||||
collection_properties: Optional[Dict[str, Any]] = None
|
||||
connection_args: Optional[Dict[str, Any]] = None
|
||||
consistency_level: str = "Session"
|
||||
search_params: Optional[dict] = None
|
||||
|
||||
store: Milvus
|
||||
retriever: BaseRetriever
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_retriever(cls, values: Dict) -> Any:
|
||||
"""Create the Milvus store and retriever."""
|
||||
values["store"] = Milvus(
|
||||
values["embedding_function"],
|
||||
values["collection_name"],
|
||||
values["collection_properties"],
|
||||
values["connection_args"],
|
||||
values["consistency_level"],
|
||||
)
|
||||
values["retriever"] = values["store"].as_retriever(
|
||||
search_kwargs={"param": values["search_params"]}
|
||||
)
|
||||
return values
|
||||
|
||||
def add_texts(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
"""Add text to the Milvus store
|
||||
|
||||
Args:
|
||||
texts (List[str]): The text
|
||||
metadatas (List[dict]): Metadata dicts, must line up with existing store
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.retriever.invoke(
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
|
||||
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
MilvusRetriever
|
||||
"""
|
||||
warnings.warn(
|
||||
"MilvusRetreiver will be deprecated in the future. "
|
||||
"Please use MilvusRetriever ('i' before 'e') instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return MilvusRetriever(*args, **kwargs)
|
||||
125
venv/Lib/site-packages/langchain_community/retrievers/nanopq.py
Normal file
125
venv/Lib/site-packages/langchain_community/retrievers/nanopq.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
"""
|
||||
Create an index of embeddings for a list of contexts.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
embeddings: Embeddings model to use.
|
||||
|
||||
Returns:
|
||||
Index of embeddings.
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class NanoPQRetriever(BaseRetriever):
|
||||
"""`NanoPQ retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
index: Any = None
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
metadatas: Optional[List[dict]] = None
|
||||
"""List of metadatas corresponding with each text."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
"""Threshold for relevancy."""
|
||||
subspace: int = 4
|
||||
"""No of subspaces to be created, should be a multiple of embedding shape"""
|
||||
clusters: int = 128
|
||||
"""No of clusters to be created"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> NanoPQRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(
|
||||
embeddings=embeddings,
|
||||
index=index,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> NanoPQRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
try:
|
||||
from nanopq import PQ
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import nanopq, please install with `pip install nanopq`."
|
||||
)
|
||||
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
try:
|
||||
pq = PQ(M=self.subspace, Ks=self.clusters, verbose=True).fit(
|
||||
self.index.astype("float32")
|
||||
)
|
||||
except AssertionError:
|
||||
error_message = (
|
||||
"Received params: training_sample={training_sample}, "
|
||||
"n_cluster={n_clusters}, subspace={subspace}, "
|
||||
"embedding_shape={embedding_shape}. Issue with the combination. "
|
||||
"Please retrace back to find the exact error"
|
||||
).format(
|
||||
training_sample=self.index.shape[0],
|
||||
n_clusters=self.clusters,
|
||||
subspace=self.subspace,
|
||||
embedding_shape=self.index.shape[1],
|
||||
)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
index_code = pq.encode(vecs=self.index.astype("float32"))
|
||||
dt = pq.dtable(query=query_embeds.astype("float32"))
|
||||
dists = dt.adist(codes=index_code)
|
||||
|
||||
sorted_ix = np.argsort(dists)
|
||||
|
||||
top_k_results = [
|
||||
Document(
|
||||
page_content=self.texts[row],
|
||||
metadata=self.metadatas[row] if self.metadatas else {},
|
||||
)
|
||||
for row in sorted_ix[0 : self.k]
|
||||
]
|
||||
|
||||
return top_k_results
|
||||
101
venv/Lib/site-packages/langchain_community/retrievers/needle.py
Normal file
101
venv/Lib/site-packages/langchain_community/retrievers/needle.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import Any, List, Optional # noqa: I001
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class NeedleRetriever(BaseRetriever, BaseModel):
|
||||
"""
|
||||
NeedleRetriever retrieves relevant documents or context from a Needle collection
|
||||
based on a search query.
|
||||
|
||||
Setup:
|
||||
Install the `needle-python` library and set your Needle API key.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install needle-python
|
||||
export NEEDLE_API_KEY="your-api-key"
|
||||
|
||||
Key init args:
|
||||
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
|
||||
- `collection_id` (str): The ID of the Needle collection to search in.
|
||||
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
|
||||
- `top_k` (Optional[int]): Maximum number of results to return.
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key="your-api-key",
|
||||
collection_id="your-collection-id",
|
||||
top_k=10 # optional
|
||||
)
|
||||
|
||||
results = retriever.retrieve("example query")
|
||||
for doc in results:
|
||||
print(doc.page_content)
|
||||
"""
|
||||
|
||||
client: Optional[Any] = None
|
||||
"""Optional instance of NeedleClient."""
|
||||
needle_api_key: Optional[str] = Field(None, description="Needle API Key")
|
||||
collection_id: Optional[str] = Field(
|
||||
..., description="The ID of the Needle collection to search in"
|
||||
)
|
||||
top_k: Optional[int] = Field(
|
||||
default=None, description="Maximum number of search results to return"
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
"""
|
||||
Initialize the NeedleClient with the provided API key.
|
||||
|
||||
If a client instance is already provided, this method does nothing.
|
||||
"""
|
||||
try:
|
||||
from needle.v1 import NeedleClient
|
||||
except ImportError:
|
||||
raise ImportError("Please install with `pip install needle-python`.")
|
||||
|
||||
if not self.client:
|
||||
self.client = NeedleClient(api_key=self.needle_api_key)
|
||||
|
||||
def _search_collection(self, query: str) -> List[Document]:
|
||||
"""
|
||||
Search the Needle collection for relevant documents.
|
||||
|
||||
Args:
|
||||
query (str): The search query used to find relevant documents.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents matching the search query.
|
||||
"""
|
||||
self._initialize_client()
|
||||
if self.client is None:
|
||||
raise ValueError("NeedleClient is not initialized. Provide an API key.")
|
||||
|
||||
results = self.client.collections.search(
|
||||
collection_id=self.collection_id, text=query, top_k=self.top_k
|
||||
)
|
||||
docs = [Document(page_content=result.content) for result in results]
|
||||
return docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve relevant documents based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The query string used to search the collection.
|
||||
Returns:
|
||||
List[Document]: A list of documents relevant to the query.
|
||||
"""
|
||||
# The `run_manager` parameter is included to match the superclass signature,
|
||||
# but it is not used in this implementation.
|
||||
return self._search_collection(query)
|
||||
@@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.outline import OutlineAPIWrapper
|
||||
|
||||
|
||||
class OutlineRetriever(BaseRetriever, OutlineAPIWrapper):
|
||||
"""Retriever for Outline API.
|
||||
|
||||
It wraps run() to get_relevant_documents().
|
||||
It uses all OutlineAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.run(query=query)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
def hash_text(text: str) -> str:
|
||||
"""Hash a text using SHA256.
|
||||
|
||||
Args:
|
||||
text: Text to hash.
|
||||
|
||||
Returns:
|
||||
Hashed text.
|
||||
"""
|
||||
return str(hashlib.sha256(text.encode("utf-8")).hexdigest())
|
||||
|
||||
|
||||
def create_index(
|
||||
contexts: List[str],
|
||||
index: Any,
|
||||
embeddings: Embeddings,
|
||||
sparse_encoder: Any,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
namespace: Optional[str] = None,
|
||||
text_key: str = "context",
|
||||
) -> None:
|
||||
"""Create an index from a list of contexts.
|
||||
|
||||
It modifies the index argument in-place!
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
index: Index to use.
|
||||
embeddings: Embeddings model to use.
|
||||
sparse_encoder: Sparse encoder to use.
|
||||
ids: List of ids to use for the documents.
|
||||
metadatas: List of metadata to use for the documents.
|
||||
namespace: Namespace value for index partition.
|
||||
"""
|
||||
batch_size = 32
|
||||
_iterator = range(0, len(contexts), batch_size)
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
_iterator = tqdm(_iterator)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if ids is None:
|
||||
# create unique ids using hash of the text
|
||||
ids = [hash_text(context) for context in contexts]
|
||||
|
||||
for i in _iterator:
|
||||
# find end of batch
|
||||
i_end = min(i + batch_size, len(contexts))
|
||||
# extract batch
|
||||
context_batch = contexts[i:i_end]
|
||||
batch_ids = ids[i:i_end]
|
||||
metadata_batch = (
|
||||
metadatas[i:i_end] if metadatas else [{} for _ in context_batch]
|
||||
)
|
||||
# add context passages as metadata
|
||||
meta = [
|
||||
{text_key: context, **metadata}
|
||||
for context, metadata in zip(context_batch, metadata_batch)
|
||||
]
|
||||
|
||||
# create dense vectors
|
||||
dense_embeds = embeddings.embed_documents(context_batch)
|
||||
# create sparse vectors
|
||||
sparse_embeds = sparse_encoder.encode_documents(context_batch)
|
||||
for s in sparse_embeds:
|
||||
s["values"] = [float(s1) for s1 in s["values"]]
|
||||
|
||||
vectors = []
|
||||
# loop through the data and create dictionaries for upserts
|
||||
for doc_id, sparse, dense, metadata in zip(
|
||||
batch_ids, sparse_embeds, dense_embeds, meta
|
||||
):
|
||||
vectors.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"sparse_values": sparse,
|
||||
"values": dense,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# upload the documents to the new hybrid index
|
||||
index.upsert(vectors, namespace=namespace)
|
||||
|
||||
|
||||
class PineconeHybridSearchRetriever(BaseRetriever):
|
||||
"""`Pinecone Hybrid Search` retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
"""description"""
|
||||
sparse_encoder: Any = None
|
||||
"""Sparse encoder to use."""
|
||||
index: Any = None
|
||||
"""Pinecone index to use."""
|
||||
top_k: int = 4
|
||||
"""Number of documents to return."""
|
||||
alpha: float = 0.5
|
||||
"""Alpha value for hybrid search."""
|
||||
namespace: Optional[str] = None
|
||||
"""Namespace value for index partition."""
|
||||
text_key: str = "context"
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
create_index(
|
||||
texts,
|
||||
self.index,
|
||||
self.embeddings,
|
||||
self.sparse_encoder,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
namespace=namespace,
|
||||
text_key=self.text_key,
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from pinecone_text.hybrid import hybrid_convex_scale # noqa:F401
|
||||
from pinecone_text.sparse.base_sparse_encoder import (
|
||||
BaseSparseEncoder, # noqa:F401
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import pinecone_text python package. "
|
||||
"Please install it with `pip install pinecone_text`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
from pinecone_text.hybrid import hybrid_convex_scale
|
||||
|
||||
sparse_vec = self.sparse_encoder.encode_queries(query)
|
||||
# convert the question into a dense vector
|
||||
dense_vec = self.embeddings.embed_query(query)
|
||||
# scale alpha with hybrid_scale
|
||||
dense_vec, sparse_vec = hybrid_convex_scale(dense_vec, sparse_vec, self.alpha)
|
||||
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
|
||||
# query pinecone with the query parameters
|
||||
result = self.index.query(
|
||||
vector=dense_vec,
|
||||
sparse_vector=sparse_vec,
|
||||
top_k=self.top_k,
|
||||
include_metadata=True,
|
||||
namespace=self.namespace,
|
||||
**kwargs,
|
||||
)
|
||||
final_result = []
|
||||
for res in result["matches"]:
|
||||
context = res["metadata"].pop(self.text_key)
|
||||
metadata = res["metadata"]
|
||||
if "score" not in metadata and "score" in res:
|
||||
metadata["score"] = res["score"]
|
||||
final_result.append(Document(page_content=context, metadata=metadata))
|
||||
# return search results as json
|
||||
return final_result
|
||||
@@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.pubmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""`PubMed API` retriever.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_community.retrievers.pubmed import PubMedRetriever
|
||||
|
||||
__all__ = [
|
||||
"PubMedRetriever",
|
||||
]
|
||||
@@ -0,0 +1,220 @@
|
||||
import uuid
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.vectorstores.qdrant import Qdrant, QdrantException
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.16",
|
||||
alternative=(
|
||||
"Qdrant vector store now supports sparse retrievals natively. "
|
||||
"Use langchain_qdrant.QdrantVectorStore#as_retriever() instead. "
|
||||
"Reference: "
|
||||
"https://python.langchain.com/docs/integrations/vectorstores/qdrant/#sparse-vector-search"
|
||||
),
|
||||
removal="0.5.0",
|
||||
)
|
||||
class QdrantSparseVectorRetriever(BaseRetriever):
|
||||
"""Qdrant sparse vector retriever."""
|
||||
|
||||
client: Any = None
|
||||
"""'qdrant_client' instance to use."""
|
||||
collection_name: str
|
||||
"""Qdrant collection name."""
|
||||
sparse_vector_name: str
|
||||
"""Name of the sparse vector to use."""
|
||||
sparse_encoder: Callable[[str], Tuple[List[int], List[float]]]
|
||||
"""Sparse encoder function to use."""
|
||||
k: int = 4
|
||||
"""Number of documents to return per query. Defaults to 4."""
|
||||
filter: Optional[Any] = None
|
||||
"""Qdrant qdrant_client.models.Filter to use for queries. Defaults to None."""
|
||||
content_payload_key: str = "content"
|
||||
"""Payload field containing the document content. Defaults to 'content'"""
|
||||
metadata_payload_key: str = "metadata"
|
||||
"""Payload field containing the document metadata. Defaults to 'metadata'."""
|
||||
search_options: Dict[str, Any] = {}
|
||||
"""Additional search options to pass to qdrant_client.QdrantClient.search()."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that 'qdrant_client' python package exists in environment."""
|
||||
try:
|
||||
from grpc import RpcError
|
||||
from qdrant_client import QdrantClient, models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import qdrant-client python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
)
|
||||
|
||||
client = values["client"]
|
||||
if not isinstance(client, QdrantClient):
|
||||
raise ValueError(
|
||||
f"client should be an instance of qdrant_client.QdrantClient, "
|
||||
f"got {type(client)}"
|
||||
)
|
||||
|
||||
filter = values["filter"]
|
||||
if filter is not None and not isinstance(filter, models.Filter):
|
||||
raise ValueError(
|
||||
f"filter should be an instance of qdrant_client.models.Filter, "
|
||||
f"got {type(filter)}"
|
||||
)
|
||||
|
||||
client = cast(QdrantClient, client)
|
||||
|
||||
collection_name = values["collection_name"]
|
||||
sparse_vector_name = values["sparse_vector_name"]
|
||||
|
||||
try:
|
||||
collection_info = client.get_collection(collection_name)
|
||||
sparse_vectors_config = collection_info.config.params.sparse_vectors
|
||||
|
||||
if sparse_vector_name not in sparse_vectors_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain sparse vector named {sparse_vector_name}."
|
||||
f"Did you mean one of {', '.join(sparse_vectors_config.keys())}?"
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
raise QdrantException(
|
||||
f"Qdrant collection {collection_name} does not exist."
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
from qdrant_client import QdrantClient, models
|
||||
|
||||
client = cast(QdrantClient, self.client)
|
||||
query_indices, query_values = self.sparse_encoder(query)
|
||||
results = client.search(
|
||||
self.collection_name,
|
||||
query_filter=self.filter,
|
||||
query_vector=models.NamedSparseVector(
|
||||
name=self.sparse_vector_name,
|
||||
vector=models.SparseVector(
|
||||
indices=query_indices,
|
||||
values=query_values,
|
||||
),
|
||||
),
|
||||
limit=self.k,
|
||||
with_vectors=False,
|
||||
**self.search_options,
|
||||
)
|
||||
return [
|
||||
Qdrant._document_from_scored_point(
|
||||
point,
|
||||
self.collection_name,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
)
|
||||
for point in results
|
||||
]
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
documents (List[Document]: Documents to add to the vectorstore.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
return self.add_texts(texts, metadatas, **kwargs)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
added_ids = []
|
||||
client = cast(QdrantClient, self.client)
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, metadatas, ids, batch_size
|
||||
):
|
||||
client.upsert(self.collection_name, points=points, **kwargs)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _generate_rest_batches(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
batch_size: int = 64,
|
||||
) -> Generator[Tuple[List[str], List[Any]], None, None]:
|
||||
from qdrant_client import models as rest
|
||||
|
||||
texts_iterator = iter(texts)
|
||||
metadatas_iterator = iter(metadatas or [])
|
||||
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
|
||||
while batch_texts := list(islice(texts_iterator, batch_size)):
|
||||
# Take the corresponding metadata and id for each text in a batch
|
||||
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
|
||||
batch_ids = list(islice(ids_iterator, batch_size))
|
||||
|
||||
# Generate the sparse embeddings for all the texts in a batch
|
||||
batch_embeddings: List[Tuple[List[int], List[float]]] = [
|
||||
self.sparse_encoder(text) for text in batch_texts
|
||||
]
|
||||
|
||||
points = [
|
||||
rest.PointStruct(
|
||||
id=point_id,
|
||||
vector={
|
||||
self.sparse_vector_name: rest.SparseVector(
|
||||
indices=sparse_vector[0],
|
||||
values=sparse_vector[1],
|
||||
)
|
||||
},
|
||||
payload=payload,
|
||||
)
|
||||
for point_id, sparse_vector, payload in zip(
|
||||
batch_ids,
|
||||
batch_embeddings,
|
||||
Qdrant._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
self.content_payload_key,
|
||||
self.metadata_payload_key,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
yield batch_ids, points
|
||||
@@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
|
||||
|
||||
|
||||
class RememberizerRetriever(BaseRetriever, RememberizerAPIWrapper):
|
||||
"""`Rememberizer` retriever.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all RememberizerAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
@@ -0,0 +1,56 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class RemoteLangChainRetriever(BaseRetriever):
|
||||
"""`LangChain API` retriever."""
|
||||
|
||||
url: str
|
||||
"""URL of the remote LangChain API."""
|
||||
headers: Optional[dict] = None
|
||||
"""Headers to use for the request."""
|
||||
input_key: str = "message"
|
||||
"""Key to use for the input in the request."""
|
||||
response_key: str = "response"
|
||||
"""Key to use for the response in the request."""
|
||||
page_content_key: str = "page_content"
|
||||
"""Key to use for the page content in the response."""
|
||||
metadata_key: str = "metadata"
|
||||
"""Key to use for the metadata in the response."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
response = requests.post(
|
||||
self.url, json={self.input_key: query}, headers=self.headers
|
||||
)
|
||||
result = response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r[self.page_content_key], metadata=r[self.metadata_key]
|
||||
)
|
||||
for r in result[self.response_key]
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
"POST", self.url, headers=self.headers, json={self.input_key: query}
|
||||
) as response:
|
||||
result = await response.json()
|
||||
return [
|
||||
Document(
|
||||
page_content=r[self.page_content_key], metadata=r[self.metadata_key]
|
||||
)
|
||||
for r in result[self.response_key]
|
||||
]
|
||||
127
venv/Lib/site-packages/langchain_community/retrievers/svm.py
Normal file
127
venv/Lib/site-packages/langchain_community/retrievers/svm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
"""
|
||||
Create an index of embeddings for a list of contexts.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
embeddings: Embeddings model to use.
|
||||
|
||||
Returns:
|
||||
Index of embeddings.
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class SVMRetriever(BaseRetriever):
|
||||
"""`SVM` retriever.
|
||||
|
||||
Largely based on
|
||||
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
|
||||
"""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
index: Any = None
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
metadatas: Optional[List[dict]] = None
|
||||
"""List of metadatas corresponding with each text."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
"""Threshold for relevancy."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> SVMRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(
|
||||
embeddings=embeddings,
|
||||
index=index,
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> SVMRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
try:
|
||||
from sklearn import svm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import scikit-learn, please install with `pip install "
|
||||
"scikit-learn`."
|
||||
)
|
||||
|
||||
query_embeds = np.array(self.embeddings.embed_query(query))
|
||||
x = np.concatenate([query_embeds[None, ...], self.index])
|
||||
y = np.zeros(x.shape[0])
|
||||
y[0] = 1
|
||||
|
||||
clf = svm.LinearSVC(
|
||||
class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
|
||||
)
|
||||
clf.fit(x, y)
|
||||
|
||||
similarities = clf.decision_function(x)
|
||||
sorted_ix = np.argsort(-similarities)
|
||||
|
||||
# svm.LinearSVC in scikit-learn is non-deterministic.
|
||||
# if a text is the same as a query, there is no guarantee
|
||||
# the query will be in the first index.
|
||||
# this performs a simple swap, this works because anything
|
||||
# left of the 0 should be equivalent.
|
||||
zero_index = np.where(sorted_ix == 0)[0][0]
|
||||
if zero_index != 0:
|
||||
sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
|
||||
|
||||
denominator = np.max(similarities) - np.min(similarities) + 1e-6
|
||||
normalized_similarities = (similarities - np.min(similarities)) / denominator
|
||||
|
||||
top_k_results = []
|
||||
for row in sorted_ix[1 : self.k + 1]:
|
||||
if (
|
||||
self.relevancy_threshold is None
|
||||
or normalized_similarities[row] >= self.relevancy_threshold
|
||||
):
|
||||
metadata = self.metadatas[row - 1] if self.metadatas else {}
|
||||
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
|
||||
top_k_results.append(doc)
|
||||
return top_k_results
|
||||
@@ -0,0 +1,152 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class SearchDepth(Enum):
|
||||
"""Search depth as enumerator."""
|
||||
|
||||
BASIC = "basic"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
|
||||
class TavilySearchAPIRetriever(BaseRetriever):
|
||||
"""Tavily Search API retriever.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-community`` and set environment variable ``TAVILY_API_KEY``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-community
|
||||
export TAVILY_API_KEY="your-api-key"
|
||||
|
||||
Key init args:
|
||||
k: int
|
||||
Number of results to include.
|
||||
include_generated_answer: bool
|
||||
Include a generated answer with results
|
||||
include_raw_content: bool
|
||||
Include raw content with results.
|
||||
include_images: bool
|
||||
Return images in addition to text.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import TavilySearchAPIRetriever
|
||||
|
||||
retriever = TavilySearchAPIRetriever(k=3)
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
query = "what year was breath of the wild released?"
|
||||
|
||||
retriever.invoke(query)
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\n\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke("how many units did bretch of the wild sell in 2020")
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
k: int = 10
|
||||
include_generated_answer: bool = False
|
||||
include_raw_content: bool = False
|
||||
include_images: bool = False
|
||||
search_depth: SearchDepth = SearchDepth.BASIC
|
||||
include_domains: Optional[List[str]] = None
|
||||
exclude_domains: Optional[List[str]] = None
|
||||
kwargs: Optional[Dict[str, Any]] = {}
|
||||
api_key: Optional[str] = None
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
try:
|
||||
try:
|
||||
from tavily import TavilyClient
|
||||
except ImportError:
|
||||
# Older of tavily used Client
|
||||
from tavily import Client as TavilyClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Tavily python package not found. "
|
||||
"Please install it with `pip install tavily-python`."
|
||||
)
|
||||
|
||||
tavily = TavilyClient(api_key=self.api_key or os.environ["TAVILY_API_KEY"])
|
||||
max_results = self.k if not self.include_generated_answer else self.k - 1
|
||||
response = tavily.search(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth=self.search_depth.value,
|
||||
include_answer=self.include_generated_answer,
|
||||
include_domains=self.include_domains,
|
||||
exclude_domains=self.exclude_domains,
|
||||
include_raw_content=self.include_raw_content,
|
||||
include_images=self.include_images,
|
||||
**self.kwargs,
|
||||
)
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.get("content", "")
|
||||
if not self.include_raw_content
|
||||
else (result.get("raw_content") or ""),
|
||||
metadata={
|
||||
"title": result.get("title", ""),
|
||||
"source": result.get("url", ""),
|
||||
**{
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in ("content", "title", "url", "raw_content")
|
||||
},
|
||||
"images": response.get("images"),
|
||||
},
|
||||
)
|
||||
for result in response.get("results")
|
||||
]
|
||||
if self.include_generated_answer:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=response.get("answer", ""),
|
||||
metadata={
|
||||
"title": "Suggested Answer",
|
||||
"source": "https://tavily.com/",
|
||||
},
|
||||
),
|
||||
*docs,
|
||||
]
|
||||
|
||||
return docs
|
||||
159
venv/Lib/site-packages/langchain_community/retrievers/tfidf.py
Normal file
159
venv/Lib/site-packages/langchain_community/retrievers/tfidf.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class TFIDFRetriever(BaseRetriever):
|
||||
"""`TF-IDF` retriever.
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb
|
||||
"""
|
||||
|
||||
vectorizer: Any = None
|
||||
"""TF-IDF vectorizer."""
|
||||
docs: List[Document]
|
||||
"""Documents."""
|
||||
tfidf_array: Any = None
|
||||
"""TF-IDF array."""
|
||||
k: int = 4
|
||||
"""Number of documents to return."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> TFIDFRetriever:
|
||||
try:
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import scikit-learn, please install with `pip install "
|
||||
"scikit-learn`."
|
||||
)
|
||||
|
||||
tfidf_params = tfidf_params or {}
|
||||
vectorizer = TfidfVectorizer(**tfidf_params)
|
||||
tfidf_array = vectorizer.fit_transform(texts)
|
||||
metadatas = metadatas or ({} for _ in texts)
|
||||
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
|
||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: Iterable[Document],
|
||||
*,
|
||||
tfidf_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> TFIDFRetriever:
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
query_vec = self.vectorizer.transform(
|
||||
[query]
|
||||
) # Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape(
|
||||
(-1,)
|
||||
) # Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
return return_docs
|
||||
|
||||
def save_local(
|
||||
self,
|
||||
folder_path: str,
|
||||
file_name: str = "tfidf_vectorizer",
|
||||
) -> None:
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import joblib, please install with `pip install joblib`."
|
||||
)
|
||||
|
||||
path = Path(folder_path)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Save vectorizer with joblib dump.
|
||||
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
|
||||
|
||||
# Save docs and tfidf array as pickle.
|
||||
with open(path / f"{file_name}.pkl", "wb") as f:
|
||||
pickle.dump((self.docs, self.tfidf_array), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
*,
|
||||
allow_dangerous_deserialization: bool = False,
|
||||
file_name: str = "tfidf_vectorizer",
|
||||
) -> TFIDFRetriever:
|
||||
"""Load the retriever from local storage.
|
||||
|
||||
Args:
|
||||
folder_path: Folder path to load from.
|
||||
allow_dangerous_deserialization: Whether to allow dangerous deserialization.
|
||||
Defaults to False.
|
||||
The deserialization relies on .joblib and .pkl files, which can be
|
||||
modified to deliver a malicious payload that results in execution of
|
||||
arbitrary code on your machine. You will need to set this to `True` to
|
||||
use deserialization. If you do this, make sure you trust the source of
|
||||
the file.
|
||||
file_name: File name to load from. Defaults to "tfidf_vectorizer".
|
||||
|
||||
Returns:
|
||||
TFIDFRetriever: Loaded retriever.
|
||||
"""
|
||||
try:
|
||||
import joblib
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import joblib, please install with `pip install joblib`."
|
||||
)
|
||||
|
||||
if not allow_dangerous_deserialization:
|
||||
raise ValueError(
|
||||
"The de-serialization of this retriever is based on .joblib and "
|
||||
".pkl files."
|
||||
"Such files can be modified to deliver a malicious payload that "
|
||||
"results in execution of arbitrary code on your machine."
|
||||
"You will need to set `allow_dangerous_deserialization` to `True` to "
|
||||
"load this retriever. If you do this, make sure you trust the source "
|
||||
"of the file, and you are responsible for validating the file "
|
||||
"came from a trusted source."
|
||||
)
|
||||
|
||||
path = Path(folder_path)
|
||||
|
||||
# Load vectorizer with joblib load.
|
||||
vectorizer = joblib.load(path / f"{file_name}.joblib")
|
||||
|
||||
# Load docs and tfidf array as pickle.
|
||||
with open(path / f"{file_name}.pkl", "rb") as f:
|
||||
# This code path can only be triggered if the user
|
||||
# passed allow_dangerous_deserialization=True
|
||||
docs, tfidf_array = pickle.load(f) # ignore[pickle]: explicit-opt-in
|
||||
|
||||
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)
|
||||
@@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import ConfigDict, SecretStr
|
||||
|
||||
|
||||
class NeuralDBRetriever(BaseRetriever):
|
||||
"""Document retriever that uses ThirdAI's NeuralDB."""
|
||||
|
||||
thirdai_key: SecretStr
|
||||
"""ThirdAI API Key"""
|
||||
|
||||
db: Any = None #: :meta private:
|
||||
"""NeuralDB instance"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
|
||||
try:
|
||||
from thirdai import licensing
|
||||
|
||||
importlib.util.find_spec("thirdai.neural_db")
|
||||
|
||||
licensing.activate(thirdai_key or os.getenv("THIRDAI_KEY"))
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import thirdai python package and neuraldb dependencies. "
|
||||
"Please install it with `pip install thirdai[neural_db]`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scratch(
|
||||
cls,
|
||||
thirdai_key: Optional[str] = None,
|
||||
**model_kwargs: dict,
|
||||
) -> NeuralDBRetriever:
|
||||
"""
|
||||
Create a NeuralDBRetriever from scratch.
|
||||
|
||||
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
|
||||
API key, or pass ``thirdai_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import NeuralDBRetriever
|
||||
|
||||
retriever = NeuralDBRetriever.from_scratch(
|
||||
thirdai_key="your-thirdai-key",
|
||||
)
|
||||
|
||||
retriever.insert([
|
||||
"/path/to/doc.pdf",
|
||||
"/path/to/doc.docx",
|
||||
"/path/to/doc.csv",
|
||||
])
|
||||
|
||||
documents = retriever.invoke("AI-driven music therapy")
|
||||
"""
|
||||
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB(**model_kwargs)) # type: ignore[arg-type]
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
checkpoint: Union[str, Path],
|
||||
thirdai_key: Optional[str] = None,
|
||||
) -> NeuralDBRetriever:
|
||||
"""
|
||||
Create a NeuralDBRetriever with a base model from a saved checkpoint
|
||||
|
||||
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
|
||||
API key, or pass ``thirdai_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import NeuralDBRetriever
|
||||
|
||||
retriever = NeuralDBRetriever.from_checkpoint(
|
||||
checkpoint="/path/to/checkpoint.ndb",
|
||||
thirdai_key="your-thirdai-key",
|
||||
)
|
||||
|
||||
retriever.insert([
|
||||
"/path/to/doc.pdf",
|
||||
"/path/to/doc.docx",
|
||||
"/path/to/doc.csv",
|
||||
])
|
||||
|
||||
documents = retriever.invoke("AI-driven music therapy")
|
||||
"""
|
||||
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[arg-type]
|
||||
|
||||
@pre_init
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate ThirdAI environment variables."""
|
||||
values["thirdai_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"thirdai_key",
|
||||
"THIRDAI_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
def insert(
|
||||
self,
|
||||
sources: List[Any],
|
||||
train: bool = True,
|
||||
fast_mode: bool = True,
|
||||
**kwargs: dict,
|
||||
) -> None:
|
||||
"""Inserts files / document sources into the retriever.
|
||||
|
||||
Args:
|
||||
train: When True this means that the underlying model in the
|
||||
NeuralDB will undergo unsupervised pretraining on the inserted files.
|
||||
Defaults to True.
|
||||
fast_mode: Much faster insertion with a slight drop in performance.
|
||||
Defaults to True.
|
||||
"""
|
||||
sources = self._preprocess_sources(sources)
|
||||
self.db.insert(
|
||||
sources=sources,
|
||||
train=train,
|
||||
fast_approximation=fast_mode,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess_sources(self, sources: list) -> list:
|
||||
"""Checks if the provided sources are string paths. If they are, convert
|
||||
to NeuralDB document objects.
|
||||
|
||||
Args:
|
||||
sources: list of either string paths to PDF, DOCX or CSV files, or
|
||||
NeuralDB document objects.
|
||||
"""
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
if not sources:
|
||||
return sources
|
||||
preprocessed_sources = []
|
||||
for doc in sources:
|
||||
if not isinstance(doc, str):
|
||||
preprocessed_sources.append(doc)
|
||||
else:
|
||||
if doc.lower().endswith(".pdf"):
|
||||
preprocessed_sources.append(ndb.PDF(doc))
|
||||
elif doc.lower().endswith(".docx"):
|
||||
preprocessed_sources.append(ndb.DOCX(doc))
|
||||
elif doc.lower().endswith(".csv"):
|
||||
preprocessed_sources.append(ndb.CSV(doc))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Could not automatically load {doc}. Only files "
|
||||
"with .pdf, .docx, or .csv extensions can be loaded "
|
||||
"automatically. For other formats, please use the "
|
||||
"appropriate document object from the ThirdAI library."
|
||||
)
|
||||
return preprocessed_sources
|
||||
|
||||
def upvote(self, query: str, document_id: int) -> None:
|
||||
"""The retriever upweights the score of a document for a specific query.
|
||||
This is useful for fine-tuning the retriever to user behavior.
|
||||
|
||||
Args:
|
||||
query: text to associate with `document_id`
|
||||
document_id: id of the document to associate query with.
|
||||
"""
|
||||
self.db.text_to_result(query, document_id)
|
||||
|
||||
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
|
||||
"""Given a batch of (query, document id) pairs, the retriever upweights
|
||||
the scores of the document for the corresponding queries.
|
||||
This is useful for fine-tuning the retriever to user behavior.
|
||||
|
||||
Args:
|
||||
query_id_pairs: list of (query, document id) pairs. For each pair in
|
||||
this list, the model will upweight the document id for the query.
|
||||
"""
|
||||
self.db.text_to_result_batch(query_id_pairs)
|
||||
|
||||
def associate(self, source: str, target: str) -> None:
|
||||
"""The retriever associates a source phrase with a target phrase.
|
||||
When the retriever sees the source phrase, it will also consider results
|
||||
that are relevant to the target phrase.
|
||||
|
||||
Args:
|
||||
source: text to associate to `target`.
|
||||
target: text to associate `source` to.
|
||||
"""
|
||||
self.db.associate(source, target)
|
||||
|
||||
def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
|
||||
"""Given a batch of (source, target) pairs, the retriever associates
|
||||
each source phrase with the corresponding target phrase.
|
||||
|
||||
Args:
|
||||
text_pairs: list of (source, target) text pairs. For each pair in
|
||||
this list, the source will be associated with the target.
|
||||
"""
|
||||
self.db.associate_batch(text_pairs)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve {top_k} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
top_k: The max number of context results to retrieve. Defaults to 10.
|
||||
"""
|
||||
try:
|
||||
if "top_k" not in kwargs:
|
||||
kwargs["top_k"] = 10
|
||||
references = self.db.search(query=query, **kwargs)
|
||||
return [
|
||||
Document(
|
||||
page_content=ref.text,
|
||||
metadata={
|
||||
"id": ref.id,
|
||||
"upvote_ids": ref.upvote_ids,
|
||||
"source": ref.source,
|
||||
"metadata": ref.metadata,
|
||||
"score": ref.score,
|
||||
"context": ref.context(1),
|
||||
},
|
||||
)
|
||||
for ref in references
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
|
||||
calling NeuralDB.from_checkpoint(path)
|
||||
|
||||
Args:
|
||||
path: path on disk to save the NeuralDB instance to.
|
||||
"""
|
||||
self.db.save(path)
|
||||
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VespaRetriever(BaseRetriever):
|
||||
"""`Vespa` retriever."""
|
||||
|
||||
app: Any
|
||||
"""Vespa application to query."""
|
||||
body: Dict
|
||||
"""Body of the query."""
|
||||
content_field: str
|
||||
"""Name of the content field."""
|
||||
metadata_fields: Sequence[str]
|
||||
"""Names of the metadata fields."""
|
||||
|
||||
def _query(self, body: Dict) -> List[Document]:
|
||||
response = self.app.query(body)
|
||||
|
||||
if not str(response.status_code).startswith("2"):
|
||||
raise RuntimeError(
|
||||
"Could not retrieve data from Vespa. Error code: {}".format(
|
||||
response.status_code
|
||||
)
|
||||
)
|
||||
|
||||
root = response.json["root"]
|
||||
if "errors" in root:
|
||||
raise RuntimeError(json.dumps(root["errors"]))
|
||||
|
||||
docs = []
|
||||
for child in response.hits:
|
||||
page_content = child["fields"].pop(self.content_field, "")
|
||||
if self.metadata_fields == "*":
|
||||
metadata = child["fields"]
|
||||
else:
|
||||
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
|
||||
metadata["id"] = child["id"]
|
||||
docs.append(Document(page_content=page_content, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
body = self.body.copy()
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
def get_relevant_documents_with_filter(
|
||||
self, query: str, *, _filter: Optional[str] = None
|
||||
) -> List[Document]:
|
||||
body = self.body.copy()
|
||||
_filter = f" and {_filter}" if _filter else ""
|
||||
body["yql"] = body["yql"] + _filter
|
||||
body["query"] = query
|
||||
return self._query(body)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
url: str,
|
||||
content_field: str,
|
||||
*,
|
||||
k: Optional[int] = None,
|
||||
metadata_fields: Union[Sequence[str], Literal["*"]] = (),
|
||||
sources: Union[Sequence[str], Literal["*"], None] = None,
|
||||
_filter: Optional[str] = None,
|
||||
yql: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> VespaRetriever:
|
||||
"""Instantiate retriever from params.
|
||||
|
||||
Args:
|
||||
url (str): Vespa app URL.
|
||||
content_field (str): Field in results to return as Document page_content.
|
||||
k (Optional[int]): Number of Documents to return. Defaults to None.
|
||||
metadata_fields(Sequence[str] or "*"): Fields in results to include in
|
||||
document metadata. Defaults to empty tuple ().
|
||||
sources (Sequence[str] or "*" or None): Sources to retrieve
|
||||
from. Defaults to None.
|
||||
_filter (Optional[str]): Document filter condition expressed in YQL.
|
||||
Defaults to None.
|
||||
yql (Optional[str]): Full YQL query to be used. Should not be specified
|
||||
if _filter or sources are specified. Defaults to None.
|
||||
kwargs (Any): Keyword arguments added to query body.
|
||||
|
||||
Returns:
|
||||
VespaRetriever: Instantiated VespaRetriever.
|
||||
"""
|
||||
try:
|
||||
from vespa.application import Vespa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pyvespa is not installed, please install with `pip install pyvespa`"
|
||||
)
|
||||
app = Vespa(url)
|
||||
body = kwargs.copy()
|
||||
if yql and (sources or _filter):
|
||||
raise ValueError(
|
||||
"yql should only be specified if both sources and _filter are not "
|
||||
"specified."
|
||||
)
|
||||
else:
|
||||
if metadata_fields == "*":
|
||||
_fields = "*"
|
||||
body["summary"] = "short"
|
||||
else:
|
||||
_fields = ", ".join([content_field] + list(metadata_fields or []))
|
||||
_sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
|
||||
_filter = f" and {_filter}" if _filter else ""
|
||||
yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
|
||||
body["yql"] = yql
|
||||
if k:
|
||||
body["hits"] = k
|
||||
return cls(
|
||||
app=app,
|
||||
body=body,
|
||||
content_field=content_field,
|
||||
metadata_fields=metadata_fields,
|
||||
)
|
||||
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.18",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_weaviate.WeaviateVectorStore",
|
||||
)
|
||||
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
"""`Weaviate hybrid search` retriever.
|
||||
|
||||
See the documentation:
|
||||
https://weaviate.io/blog/hybrid-search-explained
|
||||
"""
|
||||
|
||||
client: Any = None
|
||||
"""keyword arguments to pass to the Weaviate client."""
|
||||
index_name: str
|
||||
"""The name of the index to use."""
|
||||
text_key: str
|
||||
"""The name of the text key to use."""
|
||||
alpha: float = 0.5
|
||||
"""The weight of the text key in the hybrid search."""
|
||||
k: int = 4
|
||||
"""The number of results to return."""
|
||||
attributes: List[str]
|
||||
"""The attributes to return in the results."""
|
||||
create_schema_if_missing: bool = True
|
||||
"""Whether to create the schema if it doesn't exist."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_client(
|
||||
cls,
|
||||
values: Dict[str, Any],
|
||||
) -> Any:
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
if not isinstance(values["client"], weaviate.Client):
|
||||
client = values["client"]
|
||||
raise ValueError(
|
||||
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||
)
|
||||
if values.get("attributes") is None:
|
||||
values["attributes"] = []
|
||||
|
||||
cast(List, values["attributes"]).append(values["text_key"])
|
||||
|
||||
if values.get("create_schema_if_missing", True):
|
||||
class_obj = {
|
||||
"class": values["index_name"],
|
||||
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
|
||||
"vectorizer": "text2vec-openai",
|
||||
}
|
||||
|
||||
if not values["client"].schema.exists(values["index_name"]):
|
||||
values["client"].schema.create_class(class_obj)
|
||||
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# added text_key
|
||||
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Upload documents to Weaviate."""
|
||||
from weaviate.util import get_valid_uuid
|
||||
|
||||
with self.client.batch as batch:
|
||||
ids = []
|
||||
for i, doc in enumerate(docs):
|
||||
metadata = doc.metadata or {}
|
||||
data_properties = {self.text_key: doc.page_content, **metadata}
|
||||
|
||||
# If the UUID of one of the objects already exists
|
||||
# then the existing objectwill be replaced by the new object.
|
||||
if "uuids" in kwargs:
|
||||
_id = kwargs["uuids"][i]
|
||||
else:
|
||||
_id = get_valid_uuid(uuid4())
|
||||
|
||||
batch.add_data_object(data_properties, self.index_name, _id)
|
||||
ids.append(_id)
|
||||
return ids
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
where_filter: Optional[Dict[str, object]] = None,
|
||||
score: bool = False,
|
||||
hybrid_search_kwargs: Optional[Dict[str, object]] = None,
|
||||
) -> List[Document]:
|
||||
"""Look up similar documents in Weaviate.
|
||||
|
||||
query: The query to search for relevant documents
|
||||
of using weviate hybrid search.
|
||||
|
||||
where_filter: A filter to apply to the query.
|
||||
https://weaviate.io/developers/weaviate/guides/querying/#filtering
|
||||
|
||||
score: Whether to include the score, and score explanation
|
||||
in the returned Documents meta_data.
|
||||
|
||||
hybrid_search_kwargs: Used to pass additional arguments
|
||||
to the .with_hybrid() method.
|
||||
The primary uses cases for this are:
|
||||
1) Search specific properties only -
|
||||
specify which properties to be used during hybrid search portion.
|
||||
Note: this is not the same as the (self.attributes) to be returned.
|
||||
Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
|
||||
https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
|
||||
|
||||
2) Weight boosted searched properties -
|
||||
Boost the weight of certain properties during the hybrid search portion.
|
||||
Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
|
||||
https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
|
||||
|
||||
3) Search with a custom vector - Define a different vector
|
||||
to be used during the hybrid search portion.
|
||||
Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
|
||||
https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
|
||||
|
||||
4) Use Fusion ranking method
|
||||
Example - from weaviate.gql.get import HybridFusion
|
||||
hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
|
||||
https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
|
||||
"""
|
||||
query_obj = self.client.query.get(self.index_name, self.attributes)
|
||||
if where_filter:
|
||||
query_obj = query_obj.with_where(where_filter)
|
||||
|
||||
if score:
|
||||
query_obj = query_obj.with_additional(["score", "explainScore"])
|
||||
|
||||
if hybrid_search_kwargs is None:
|
||||
hybrid_search_kwargs = {}
|
||||
|
||||
result = (
|
||||
query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
|
||||
.with_limit(self.k)
|
||||
.do()
|
||||
)
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
|
||||
docs = []
|
||||
|
||||
for res in result["data"]["Get"][self.index_name]:
|
||||
text = res.pop(self.text_key)
|
||||
docs.append(Document(page_content=text, metadata=res))
|
||||
return docs
|
||||
@@ -0,0 +1,267 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_classic.chains import LLMChain
|
||||
from langchain_classic.chains.prompt_selector import ConditionalPromptSelector
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_community.document_loaders import AsyncHtmlLoader
|
||||
from langchain_community.document_transformers import Html2TextTransformer
|
||||
from langchain_community.llms import LlamaCpp
|
||||
from langchain_community.utilities import GoogleSearchAPIWrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchQueries(BaseModel):
|
||||
"""Search queries to research for the user's goal."""
|
||||
|
||||
queries: List[str] = Field(
|
||||
..., description="List of search queries to look up on Google"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template="""<<SYS>> \n You are an assistant tasked with improving Google search \
|
||||
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \
|
||||
are similar to this question. The output should be a numbered list of questions \
|
||||
and each should have a question mark at the end: \n\n {question} [/INST]""",
|
||||
)
|
||||
|
||||
DEFAULT_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template="""You are an assistant tasked with improving Google search \
|
||||
results. Generate THREE Google search queries that are similar to \
|
||||
this question. The output should be a numbered list of questions and each \
|
||||
should have a question mark at the end: {question}""",
|
||||
)
|
||||
|
||||
|
||||
class QuestionListOutputParser(BaseOutputParser[List[str]]):
|
||||
"""Output parser for a list of numbered questions."""
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
|
||||
return lines
|
||||
|
||||
|
||||
class WebResearchRetriever(BaseRetriever):
|
||||
"""`Google Search API` retriever."""
|
||||
|
||||
# Inputs
|
||||
vectorstore: VectorStore = Field(
|
||||
..., description="Vector store for storing web pages"
|
||||
)
|
||||
llm_chain: LLMChain
|
||||
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
|
||||
num_search_results: int = Field(1, description="Number of pages per Google search")
|
||||
text_splitter: TextSplitter = Field(
|
||||
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
|
||||
description="Text splitter for splitting web pages into chunks",
|
||||
)
|
||||
url_database: List[str] = Field(
|
||||
default_factory=list, description="List of processed URLs"
|
||||
)
|
||||
trust_env: bool = Field(
|
||||
False,
|
||||
description="Whether to use the http_proxy/https_proxy env variables or "
|
||||
"check .netrc for proxy configuration",
|
||||
)
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""A flag to force users to acknowledge the risks of SSRF attacks when using
|
||||
this retriever.
|
||||
|
||||
Users should set this flag to `True` if they have taken the necessary precautions
|
||||
to prevent SSRF attacks when using this retriever.
|
||||
|
||||
For example, users can run the requests through a properly configured
|
||||
proxy and prevent the crawler from accidentally crawling internal resources.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the retriever."""
|
||||
allow_dangerous_requests = kwargs.get("allow_dangerous_requests", False)
|
||||
if not allow_dangerous_requests:
|
||||
raise ValueError(
|
||||
"WebResearchRetriever crawls URLs surfaced through "
|
||||
"the provided search engine. It is possible that some of those URLs "
|
||||
"will end up pointing to machines residing on an internal network, "
|
||||
"leading"
|
||||
"to an SSRF (Server-Side Request Forgery) attack. "
|
||||
"To protect yourself against that risk, you can run the requests "
|
||||
"through a proxy and prevent the crawler from accidentally crawling "
|
||||
"internal resources."
|
||||
"If've taken the necessary precautions, you can set "
|
||||
"`allow_dangerous_requests` to `True`."
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
vectorstore: VectorStore,
|
||||
llm: BaseLLM,
|
||||
search: GoogleSearchAPIWrapper,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
num_search_results: int = 1,
|
||||
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1500, chunk_overlap=150
|
||||
),
|
||||
trust_env: bool = False,
|
||||
allow_dangerous_requests: bool = False,
|
||||
) -> "WebResearchRetriever":
|
||||
"""Initialize from llm using default template.
|
||||
|
||||
Args:
|
||||
vectorstore: Vector store for storing web pages
|
||||
llm: llm for search question generation
|
||||
search: GoogleSearchAPIWrapper
|
||||
prompt: prompt to generating search questions
|
||||
num_search_results: Number of pages per Google search
|
||||
text_splitter: Text splitter for splitting web pages into chunks
|
||||
trust_env: Whether to use the http_proxy/https_proxy env variables
|
||||
or check .netrc for proxy configuration
|
||||
allow_dangerous_requests: A flag to force users to acknowledge
|
||||
the risks of SSRF attacks when using this retriever
|
||||
|
||||
Returns:
|
||||
WebResearchRetriever
|
||||
"""
|
||||
|
||||
if not prompt:
|
||||
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=DEFAULT_SEARCH_PROMPT,
|
||||
conditionals=[
|
||||
(lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT)
|
||||
],
|
||||
)
|
||||
prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm)
|
||||
|
||||
# Use chat model prompt
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
output_parser=QuestionListOutputParser(),
|
||||
)
|
||||
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
llm_chain=llm_chain,
|
||||
search=search,
|
||||
num_search_results=num_search_results,
|
||||
text_splitter=text_splitter,
|
||||
trust_env=trust_env,
|
||||
allow_dangerous_requests=allow_dangerous_requests,
|
||||
)
|
||||
|
||||
def clean_search_query(self, query: str) -> str:
|
||||
# Some search tools (e.g., Google) will
|
||||
# fail to return results if query has a
|
||||
# leading digit: 1. "LangCh..."
|
||||
# Check if the first character is a digit
|
||||
if query[0].isdigit():
|
||||
# Find the position of the first quote
|
||||
first_quote_pos = query.find('"')
|
||||
if first_quote_pos != -1:
|
||||
# Extract the part of the string after the quote
|
||||
query = query[first_quote_pos + 1 :]
|
||||
# Remove the trailing quote if present
|
||||
if query.endswith('"'):
|
||||
query = query[:-1]
|
||||
return query.strip()
|
||||
|
||||
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
|
||||
"""Returns num_search_results pages per Google search."""
|
||||
query_clean = self.clean_search_query(query)
|
||||
result = self.search.results(query_clean, num_search_results)
|
||||
return result
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""Search Google for documents related to the query input.
|
||||
|
||||
Args:
|
||||
query: user query
|
||||
|
||||
Returns:
|
||||
Relevant documents from all various urls.
|
||||
"""
|
||||
|
||||
# Get search questions
|
||||
logger.info("Generating questions for Google Search ...")
|
||||
result = self.llm_chain({"question": query})
|
||||
logger.info(f"Questions for Google Search (raw): {result}")
|
||||
questions = result["text"]
|
||||
logger.info(f"Questions for Google Search: {questions}")
|
||||
|
||||
# Get urls
|
||||
logger.info("Searching for relevant urls...")
|
||||
urls_to_look = []
|
||||
for query in questions:
|
||||
# Google search
|
||||
search_results = self.search_tool(query, self.num_search_results)
|
||||
logger.info("Searching for relevant urls...")
|
||||
logger.info(f"Search results: {search_results}")
|
||||
for res in search_results:
|
||||
if res.get("link", None):
|
||||
urls_to_look.append(res["link"])
|
||||
|
||||
# Relevant urls
|
||||
urls = set(urls_to_look)
|
||||
|
||||
# Check for any new urls that we have not processed
|
||||
new_urls = list(urls.difference(self.url_database))
|
||||
|
||||
logger.info(f"New URLs to load: {new_urls}")
|
||||
# Load, split, and add new urls to vectorstore
|
||||
if new_urls:
|
||||
loader = AsyncHtmlLoader(
|
||||
new_urls, ignore_load_errors=True, trust_env=self.trust_env
|
||||
)
|
||||
html2text = Html2TextTransformer()
|
||||
logger.info("Indexing new urls...")
|
||||
docs = loader.load()
|
||||
docs = list(html2text.transform_documents(docs))
|
||||
docs = self.text_splitter.split_documents(docs)
|
||||
self.vectorstore.add_documents(docs)
|
||||
self.url_database.extend(new_urls)
|
||||
|
||||
# Search for relevant splits
|
||||
# TODO: make this async
|
||||
logger.info("Grabbing most relevant splits from urls...")
|
||||
docs = []
|
||||
for query in questions:
|
||||
docs.extend(self.vectorstore.similarity_search(query))
|
||||
|
||||
# Get unique docs
|
||||
unique_documents_dict = {
|
||||
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs
|
||||
}
|
||||
unique_documents = list(unique_documents_dict.values())
|
||||
return unique_documents
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,77 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
|
||||
|
||||
|
||||
class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
"""`Wikipedia API` retriever.
|
||||
|
||||
Setup:
|
||||
Install the ``wikipedia`` dependency:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U wikipedia
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.retrievers import WikipediaRetriever
|
||||
|
||||
retriever = WikipediaRetriever()
|
||||
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
docs = retriever.invoke("TOKYO GHOUL")
|
||||
print(docs[0].page_content[:100])
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
Tokyo Ghoul (Japanese: 東京喰種(トーキョーグール), Hepburn: Tōkyō Gūru) is a Japanese dark fantasy
|
||||
|
||||
Use within a chain:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
\"\"\"Answer the question based only on the context provided.
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}\"\"\"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
|
||||
|
||||
def format_docs(docs):
|
||||
return "\\n\\n".join(doc.page_content for doc in docs)
|
||||
|
||||
chain = (
|
||||
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
chain.invoke(
|
||||
"Who is the main character in `Tokyo Ghoul` and does he transform into a ghoul?"
|
||||
)
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
'The main character in Tokyo Ghoul is Ken Kaneki, who transforms into a ghoul after receiving an organ transplant from a ghoul named Rize.'
|
||||
""" # noqa: E501
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
return self.load(query=query)
|
||||
39
venv/Lib/site-packages/langchain_community/retrievers/you.py
Normal file
39
venv/Lib/site-packages/langchain_community/retrievers/you.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain_community.utilities import YouSearchAPIWrapper
|
||||
|
||||
|
||||
class YouRetriever(BaseRetriever, YouSearchAPIWrapper):
|
||||
"""You.com Search API retriever.
|
||||
|
||||
It wraps results() to get_relevant_documents
|
||||
It uses all YouSearchAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.results(query, run_manager=run_manager.get_child(), **kwargs)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
results = await self.results_async(
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
return results
|
||||
183
venv/Lib/site-packages/langchain_community/retrievers/zep.py
Normal file
183
venv/Lib/site-packages/langchain_community/retrievers/zep.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python.memory import MemorySearchResult
|
||||
|
||||
|
||||
class SearchScope(str, Enum):
|
||||
"""Which documents to search. Messages or Summaries?"""
|
||||
|
||||
messages = "messages"
|
||||
"""Search chat history messages."""
|
||||
summary = "summary"
|
||||
"""Search chat history summaries."""
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
"""Enumerator of the types of search to perform."""
|
||||
|
||||
similarity = "similarity"
|
||||
"""Similarity search."""
|
||||
mmr = "mmr"
|
||||
"""Maximal Marginal Relevance reranking of similarity search."""
|
||||
|
||||
|
||||
class ZepRetriever(BaseRetriever):
|
||||
"""`Zep` MemoryStore Retriever.
|
||||
|
||||
Search your user's long-term chat history with Zep.
|
||||
|
||||
Zep offers both simple semantic search and Maximal Marginal Relevance (MMR)
|
||||
reranking of search results.
|
||||
|
||||
Note: You will need to provide the user's `session_id` to use this retriever.
|
||||
|
||||
Args:
|
||||
url: URL of your Zep server (required)
|
||||
api_key: Your Zep API key (optional)
|
||||
session_id: Identifies your user or a user's session (required)
|
||||
top_k: Number of documents to return (default: 3, optional)
|
||||
search_type: Type of search to perform (similarity / mmr) (default: similarity,
|
||||
optional)
|
||||
mmr_lambda: Lambda value for MMR search. Defaults to 0.5 (optional)
|
||||
|
||||
Zep - Fast, scalable building blocks for LLM Apps
|
||||
=========
|
||||
Zep is an open source platform for productionizing LLM apps. Go from a prototype
|
||||
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
|
||||
rewriting code.
|
||||
|
||||
For server installation instructions, see:
|
||||
https://docs.getzep.com/deployment/quickstart/
|
||||
"""
|
||||
|
||||
zep_client: Optional[Any] = None
|
||||
"""Zep client."""
|
||||
url: str
|
||||
"""URL of your Zep server."""
|
||||
api_key: Optional[str] = None
|
||||
"""Your Zep API key."""
|
||||
session_id: str
|
||||
"""Zep session ID."""
|
||||
top_k: Optional[int]
|
||||
"""Number of items to return."""
|
||||
search_scope: SearchScope = SearchScope.messages
|
||||
"""Which documents to search. Messages or Summaries?"""
|
||||
search_type: SearchType = SearchType.similarity
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
mmr_lambda: Optional[float] = None
|
||||
"""Lambda value for MMR search."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: dict) -> Any:
|
||||
try:
|
||||
from zep_python import ZepClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. "
|
||||
"Please install it with `pip install zep-python`."
|
||||
)
|
||||
values["zep_client"] = values.get(
|
||||
"zep_client",
|
||||
ZepClient(base_url=values["url"], api_key=values.get("api_key")),
|
||||
)
|
||||
return values
|
||||
|
||||
def _messages_search_result_to_doc(
|
||||
self, results: List[MemorySearchResult]
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=r.message.pop("content"),
|
||||
metadata={"score": r.dist, **r.message},
|
||||
)
|
||||
for r in results
|
||||
if r.message
|
||||
]
|
||||
|
||||
def _summary_search_result_to_doc(
|
||||
self, results: List[MemorySearchResult]
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=r.summary.content,
|
||||
metadata={
|
||||
"score": r.dist,
|
||||
"uuid": r.summary.uuid,
|
||||
"created_at": r.summary.created_at,
|
||||
"token_count": r.summary.token_count,
|
||||
},
|
||||
)
|
||||
for r in results
|
||||
if r.summary
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
from zep_python.memory import MemorySearchPayload
|
||||
|
||||
if not self.zep_client:
|
||||
raise RuntimeError("Zep client not initialized.")
|
||||
|
||||
payload = MemorySearchPayload(
|
||||
text=query,
|
||||
metadata=metadata,
|
||||
search_scope=self.search_scope,
|
||||
search_type=self.search_type,
|
||||
mmr_lambda=self.mmr_lambda,
|
||||
)
|
||||
|
||||
results: List[MemorySearchResult] = self.zep_client.memory.search_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
if self.search_scope == SearchScope.summary:
|
||||
return self._summary_search_result_to_doc(results)
|
||||
|
||||
return self._messages_search_result_to_doc(results)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
from zep_python.memory import MemorySearchPayload
|
||||
|
||||
if not self.zep_client:
|
||||
raise RuntimeError("Zep client not initialized.")
|
||||
|
||||
payload = MemorySearchPayload(
|
||||
text=query,
|
||||
metadata=metadata,
|
||||
search_scope=self.search_scope,
|
||||
search_type=self.search_type,
|
||||
mmr_lambda=self.mmr_lambda,
|
||||
)
|
||||
|
||||
results: List[MemorySearchResult] = await self.zep_client.memory.asearch_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
if self.search_scope == SearchScope.summary:
|
||||
return self._summary_search_result_to_doc(results)
|
||||
|
||||
return self._messages_search_result_to_doc(results)
|
||||
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_cloud import MemorySearchResult, SearchScope, SearchType
|
||||
from zep_cloud.client import AsyncZep, Zep
|
||||
|
||||
|
||||
class ZepCloudRetriever(BaseRetriever):
|
||||
"""`Zep Cloud` MemoryStore Retriever.
|
||||
|
||||
Search your user's long-term chat history with Zep.
|
||||
|
||||
Zep offers both simple semantic search and Maximal Marginal Relevance (MMR)
|
||||
reranking of search results.
|
||||
|
||||
Note: You will need to provide the user's `session_id` to use this retriever.
|
||||
|
||||
Args:
|
||||
api_key: Your Zep API key
|
||||
session_id: Identifies your user or a user's session (required)
|
||||
top_k: Number of documents to return (default: 3, optional)
|
||||
search_type: Type of search to perform (similarity / mmr)
|
||||
(default: similarity, optional)
|
||||
mmr_lambda: Lambda value for MMR search. Defaults to 0.5 (optional)
|
||||
|
||||
Zep - Recall, understand, and extract data from chat histories.
|
||||
Power personalized AI experiences.
|
||||
=========
|
||||
Zep is a long-term memory service for AI Assistant apps.
|
||||
With Zep, you can provide AI assistants with the ability
|
||||
to recall past conversations,
|
||||
no matter how distant, while also reducing hallucinations, latency, and cost.
|
||||
|
||||
see Zep Cloud Docs: https://help.getzep.com
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
"""Your Zep API key."""
|
||||
zep_client: Zep
|
||||
"""Zep client used for making API requests."""
|
||||
zep_client_async: AsyncZep
|
||||
"""Async Zep client used for making API requests."""
|
||||
session_id: str
|
||||
"""Zep session ID."""
|
||||
top_k: Optional[int]
|
||||
"""Number of items to return."""
|
||||
search_scope: SearchScope = "messages"
|
||||
"""Which documents to search. Messages or Summaries?"""
|
||||
search_type: SearchType = "similarity"
|
||||
"""Type of search to perform (similarity / mmr)"""
|
||||
mmr_lambda: Optional[float] = None
|
||||
"""Lambda value for MMR search."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: dict) -> Any:
|
||||
try:
|
||||
from zep_cloud.client import AsyncZep, Zep
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-cloud package. "
|
||||
"Please install it with `pip install zep-cloud`."
|
||||
)
|
||||
if values.get("api_key") is None:
|
||||
raise ValueError("Zep API key is required.")
|
||||
values["zep_client"] = Zep(api_key=values.get("api_key"))
|
||||
values["zep_client_async"] = AsyncZep(api_key=values.get("api_key"))
|
||||
return values
|
||||
|
||||
def _messages_search_result_to_doc(
|
||||
self, results: List[MemorySearchResult]
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=str(r.message.content),
|
||||
metadata={
|
||||
"score": r.score,
|
||||
"uuid": r.message.uuid_,
|
||||
"created_at": r.message.created_at,
|
||||
"token_count": r.message.token_count,
|
||||
"role": r.message.role or r.message.role_type,
|
||||
},
|
||||
)
|
||||
for r in results or []
|
||||
if r.message
|
||||
]
|
||||
|
||||
def _summary_search_result_to_doc(
|
||||
self, results: List[MemorySearchResult]
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=str(r.summary.content),
|
||||
metadata={
|
||||
"score": r.score,
|
||||
"uuid": r.summary.uuid_,
|
||||
"created_at": r.summary.created_at,
|
||||
"token_count": r.summary.token_count,
|
||||
},
|
||||
)
|
||||
for r in results
|
||||
if r.summary
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
if not self.zep_client:
|
||||
raise RuntimeError("Zep client not initialized.")
|
||||
|
||||
results = self.zep_client.memory.search(
|
||||
self.session_id,
|
||||
text=query,
|
||||
metadata=metadata,
|
||||
search_scope=self.search_scope,
|
||||
search_type=self.search_type,
|
||||
mmr_lambda=self.mmr_lambda,
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
if self.search_scope == "summary":
|
||||
return self._summary_search_result_to_doc(results)
|
||||
|
||||
return self._messages_search_result_to_doc(results)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Document]:
|
||||
if not self.zep_client_async:
|
||||
raise RuntimeError("Zep client not initialized.")
|
||||
|
||||
results = await self.zep_client_async.memory.search(
|
||||
self.session_id,
|
||||
text=query,
|
||||
metadata=metadata,
|
||||
search_scope=self.search_scope,
|
||||
search_type=self.search_type,
|
||||
mmr_lambda=self.mmr_lambda,
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
if self.search_scope == "summary":
|
||||
return self._summary_search_result_to_doc(results)
|
||||
|
||||
return self._messages_search_result_to_doc(results)
|
||||
@@ -0,0 +1,87 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import model_validator
|
||||
|
||||
from langchain_community.vectorstores.zilliz import Zilliz
|
||||
|
||||
# TODO: Update to ZillizClient + Hybrid Search when available
|
||||
|
||||
|
||||
class ZillizRetriever(BaseRetriever):
|
||||
"""`Zilliz API` retriever."""
|
||||
|
||||
embedding_function: Embeddings
|
||||
"""The underlying embedding function from which documents will be retrieved."""
|
||||
collection_name: str = "LangChainCollection"
|
||||
"""The name of the collection in Zilliz."""
|
||||
connection_args: Optional[Dict[str, Any]] = None
|
||||
"""The connection arguments for the Zilliz client."""
|
||||
consistency_level: str = "Session"
|
||||
"""The consistency level for the Zilliz client."""
|
||||
search_params: Optional[dict] = None
|
||||
"""The search parameters for the Zilliz client."""
|
||||
store: Zilliz
|
||||
"""The underlying Zilliz store."""
|
||||
retriever: BaseRetriever
|
||||
"""The underlying retriever."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: dict) -> Any:
|
||||
values["store"] = Zilliz(
|
||||
values["embedding_function"],
|
||||
values["collection_name"],
|
||||
values["connection_args"],
|
||||
values["consistency_level"],
|
||||
)
|
||||
values["retriever"] = values["store"].as_retriever(
|
||||
search_kwargs={"param": values["search_params"]}
|
||||
)
|
||||
return values
|
||||
|
||||
def add_texts(
|
||||
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
||||
) -> None:
|
||||
"""Add text to the Zilliz store
|
||||
|
||||
Args:
|
||||
texts (List[str]): The text
|
||||
metadatas (List[dict]): Metadata dicts, must line up with existing store
|
||||
"""
|
||||
self.store.add_texts(texts, metadatas)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.retriever.invoke(
|
||||
query, run_manager=run_manager.get_child(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
|
||||
"""Deprecated ZillizRetreiver.
|
||||
|
||||
Please use ZillizRetriever ('i' before 'e') instead.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
ZillizRetriever
|
||||
"""
|
||||
warnings.warn(
|
||||
"ZillizRetreiver will be deprecated in the future. "
|
||||
"Please use ZillizRetriever ('i' before 'e') instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return ZillizRetriever(*args, **kwargs)
|
||||
Reference in New Issue
Block a user