initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,253 @@
"""**Retriever** class returns Documents given a text **query**.
It is more general than a vector store. A retriever does not need to be able to
store documents, only to return (or retrieve) it. Vector stores can be used as
the backbone of a retriever, but there are other types of retrievers as well.
**Class hierarchy:**
.. code-block::
BaseRetriever --> <name>Retriever # Examples: ArxivRetriever, MergerRetriever
**Main helpers:**
.. code-block::
Document, Serializable, Callbacks,
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.retrievers.arcee import (
ArceeRetriever,
)
from langchain_community.retrievers.arxiv import (
ArxivRetriever,
)
from langchain_community.retrievers.asknews import (
AskNewsRetriever,
)
from langchain_community.retrievers.azure_ai_search import (
AzureAISearchRetriever,
AzureCognitiveSearchRetriever,
)
from langchain_community.retrievers.bedrock import (
AmazonKnowledgeBasesRetriever,
)
from langchain_community.retrievers.bm25 import (
BM25Retriever,
)
from langchain_community.retrievers.breebs import (
BreebsRetriever,
)
from langchain_community.retrievers.chaindesk import (
ChaindeskRetriever,
)
from langchain_community.retrievers.chatgpt_plugin_retriever import (
ChatGPTPluginRetriever,
)
from langchain_community.retrievers.cohere_rag_retriever import (
CohereRagRetriever,
)
from langchain_community.retrievers.docarray import (
DocArrayRetriever,
)
from langchain_community.retrievers.dria_index import (
DriaRetriever,
)
from langchain_community.retrievers.elastic_search_bm25 import (
ElasticSearchBM25Retriever,
)
from langchain_community.retrievers.embedchain import (
EmbedchainRetriever,
)
from langchain_community.retrievers.google_cloud_documentai_warehouse import (
GoogleDocumentAIWarehouseRetriever,
)
from langchain_community.retrievers.google_vertex_ai_search import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)
from langchain_community.retrievers.kay import (
KayAiRetriever,
)
from langchain_community.retrievers.kendra import (
AmazonKendraRetriever,
)
from langchain_community.retrievers.knn import (
KNNRetriever,
)
from langchain_community.retrievers.llama_index import (
LlamaIndexGraphRetriever,
LlamaIndexRetriever,
)
from langchain_community.retrievers.metal import (
MetalRetriever,
)
from langchain_community.retrievers.milvus import (
MilvusRetriever,
)
from langchain_community.retrievers.nanopq import NanoPQRetriever
from langchain_community.retrievers.needle import NeedleRetriever
from langchain_community.retrievers.outline import (
OutlineRetriever,
)
from langchain_community.retrievers.pinecone_hybrid_search import (
PineconeHybridSearchRetriever,
)
from langchain_community.retrievers.pubmed import (
PubMedRetriever,
)
from langchain_community.retrievers.qdrant_sparse_vector_retriever import (
QdrantSparseVectorRetriever,
)
from langchain_community.retrievers.rememberizer import (
RememberizerRetriever,
)
from langchain_community.retrievers.remote_retriever import (
RemoteLangChainRetriever,
)
from langchain_community.retrievers.svm import (
SVMRetriever,
)
from langchain_community.retrievers.tavily_search_api import (
TavilySearchAPIRetriever,
)
from langchain_community.retrievers.tfidf import (
TFIDFRetriever,
)
from langchain_community.retrievers.thirdai_neuraldb import NeuralDBRetriever
from langchain_community.retrievers.vespa_retriever import (
VespaRetriever,
)
from langchain_community.retrievers.weaviate_hybrid_search import (
WeaviateHybridSearchRetriever,
)
from langchain_community.retrievers.web_research import WebResearchRetriever
from langchain_community.retrievers.wikipedia import (
WikipediaRetriever,
)
from langchain_community.retrievers.you import (
YouRetriever,
)
from langchain_community.retrievers.zep import (
ZepRetriever,
)
from langchain_community.retrievers.zep_cloud import (
ZepCloudRetriever,
)
from langchain_community.retrievers.zilliz import (
ZillizRetriever,
)
_module_lookup = {
"AmazonKendraRetriever": "langchain_community.retrievers.kendra",
"AmazonKnowledgeBasesRetriever": "langchain_community.retrievers.bedrock",
"ArceeRetriever": "langchain_community.retrievers.arcee",
"ArxivRetriever": "langchain_community.retrievers.arxiv",
"AskNewsRetriever": "langchain_community.retrievers.asknews",
"AzureAISearchRetriever": "langchain_community.retrievers.azure_ai_search",
"AzureCognitiveSearchRetriever": "langchain_community.retrievers.azure_ai_search",
"BM25Retriever": "langchain_community.retrievers.bm25",
"BreebsRetriever": "langchain_community.retrievers.breebs",
"ChaindeskRetriever": "langchain_community.retrievers.chaindesk",
"ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever",
"CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever",
"DocArrayRetriever": "langchain_community.retrievers.docarray",
"DriaRetriever": "langchain_community.retrievers.dria_index",
"ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25",
"EmbedchainRetriever": "langchain_community.retrievers.embedchain",
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
"GoogleDocumentAIWarehouseRetriever": "langchain_community.retrievers.google_cloud_documentai_warehouse", # noqa: E501
"GoogleVertexAIMultiTurnSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
"GoogleVertexAISearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
"KNNRetriever": "langchain_community.retrievers.knn",
"KayAiRetriever": "langchain_community.retrievers.kay",
"LlamaIndexGraphRetriever": "langchain_community.retrievers.llama_index",
"LlamaIndexRetriever": "langchain_community.retrievers.llama_index",
"MetalRetriever": "langchain_community.retrievers.metal",
"MilvusRetriever": "langchain_community.retrievers.milvus",
"NanoPQRetriever": "langchain_community.retrievers.nanopq",
"NeedleRetriever": "langchain_community.retrievers.needle",
"OutlineRetriever": "langchain_community.retrievers.outline",
"PineconeHybridSearchRetriever": "langchain_community.retrievers.pinecone_hybrid_search", # noqa: E501
"PubMedRetriever": "langchain_community.retrievers.pubmed",
"QdrantSparseVectorRetriever": "langchain_community.retrievers.qdrant_sparse_vector_retriever", # noqa: E501
"RememberizerRetriever": "langchain_community.retrievers.rememberizer",
"RemoteLangChainRetriever": "langchain_community.retrievers.remote_retriever",
"SVMRetriever": "langchain_community.retrievers.svm",
"TFIDFRetriever": "langchain_community.retrievers.tfidf",
"TavilySearchAPIRetriever": "langchain_community.retrievers.tavily_search_api",
"VespaRetriever": "langchain_community.retrievers.vespa_retriever",
"WeaviateHybridSearchRetriever": "langchain_community.retrievers.weaviate_hybrid_search", # noqa: E501
"WebResearchRetriever": "langchain_community.retrievers.web_research",
"WikipediaRetriever": "langchain_community.retrievers.wikipedia",
"YouRetriever": "langchain_community.retrievers.you",
"ZepRetriever": "langchain_community.retrievers.zep",
"ZepCloudRetriever": "langchain_community.retrievers.zep_cloud",
"ZillizRetriever": "langchain_community.retrievers.zilliz",
"NeuralDBRetriever": "langchain_community.retrievers.thirdai_neuraldb",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AskNewsRetriever",
"AzureAISearchRetriever",
"AzureCognitiveSearchRetriever",
"BM25Retriever",
"BreebsRetriever",
"ChaindeskRetriever",
"ChatGPTPluginRetriever",
"CohereRagRetriever",
"DocArrayRetriever",
"DriaRetriever",
"ElasticSearchBM25Retriever",
"EmbedchainRetriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever",
"KayAiRetriever",
"KNNRetriever",
"LlamaIndexGraphRetriever",
"LlamaIndexRetriever",
"MetalRetriever",
"MilvusRetriever",
"NanoPQRetriever",
"NeedleRetriever",
"NeuralDBRetriever",
"OutlineRetriever",
"PineconeHybridSearchRetriever",
"PubMedRetriever",
"QdrantSparseVectorRetriever",
"RememberizerRetriever",
"RemoteLangChainRetriever",
"SVMRetriever",
"TavilySearchAPIRetriever",
"TFIDFRetriever",
"VespaRetriever",
"WeaviateHybridSearchRetriever",
"WebResearchRetriever",
"WikipediaRetriever",
"YouRetriever",
"ZepRetriever",
"ZepCloudRetriever",
"ZillizRetriever",
]

View File

@@ -0,0 +1,137 @@
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import ConfigDict, SecretStr
from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter
class ArceeRetriever(BaseRetriever):
"""Arcee Domain Adapted Language Models (DALMs) retriever.
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import ArceeRetriever
retriever = ArceeRetriever(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
documents = retriever.invoke("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client."""
arcee_api_key: SecretStr
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
model_config = ConfigDict(
extra="forbid",
)
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key.get_secret_value(),
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@pre_init
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must not be negative.")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list.")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.retrieve(query=query, **kwargs)
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e

View File

@@ -0,0 +1,92 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities.arxiv import ArxivAPIWrapper
class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
"""`Arxiv` retriever.
Setup:
Install ``arxiv``:
.. code-block:: bash
pip install -U arxiv
Key init args:
load_max_docs: int
maximum number of documents to load
get_ful_documents: bool
whether to return full document text or snippets
Instantiate:
.. code-block:: python
from langchain_community.retrievers import ArxivRetriever
retriever = ArxivRetriever(
load_max_docs=2,
get_ful_documents=True,
)
Usage:
.. code-block:: python
docs = retriever.invoke("What is the ImageBind model?")
docs[0].metadata
.. code-block:: none
{'Entry ID': 'http://arxiv.org/abs/2305.05665v2',
'Published': datetime.date(2023, 5, 31),
'Title': 'ImageBind: One Embedding Space To Bind Them All',
'Authors': 'Rohit Girdhar, Alaaeldin El-Nouby, Zhuang Liu, Mannat Singh, Kalyan Vasudev Alwala, Armand Joulin, Ishan Misra'}
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("What is the ImageBind model?")
.. code-block:: none
'The ImageBind model is an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data...'
""" # noqa: E501
get_full_documents: bool = False
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.get_full_documents:
return self.load(query=query)
else:
return self.get_summaries_as_docs(query)

View File

@@ -0,0 +1,146 @@
import os
import re
from typing import Any, Dict, List, Literal, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class AskNewsRetriever(BaseRetriever):
"""AskNews retriever."""
k: int = 10
offset: int = 0
start_timestamp: Optional[int] = None
end_timestamp: Optional[int] = None
method: Literal["nl", "kw"] = "nl"
categories: List[
Literal[
"All",
"Business",
"Crime",
"Politics",
"Science",
"Sports",
"Technology",
"Military",
"Health",
"Entertainment",
"Finance",
"Culture",
"Climate",
"Environment",
"World",
]
] = ["All"]
historical: bool = False
similarity_score_threshold: float = 0.5
kwargs: Optional[Dict[str, Any]] = {}
client_id: Optional[str] = None
client_secret: Optional[str] = None
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
try:
from asknews_sdk import AskNewsSDK
except ImportError:
raise ImportError(
"AskNews python package not found. "
"Please install it with `pip install asknews`."
)
an_client = AskNewsSDK(
client_id=self.client_id or os.environ["ASKNEWS_CLIENT_ID"],
client_secret=self.client_secret or os.environ["ASKNEWS_CLIENT_SECRET"],
scopes=["news"],
)
response = an_client.news.search_news(
query=query,
n_articles=self.k,
start_timestamp=self.start_timestamp,
end_timestamp=self.end_timestamp,
method=self.method,
categories=self.categories,
historical=self.historical,
similarity_score_threshold=self.similarity_score_threshold,
offset=self.offset,
doc_start_delimiter="<doc>",
doc_end_delimiter="</doc>",
return_type="both",
**self.kwargs,
)
return self._extract_documents(response)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
try:
from asknews_sdk import AsyncAskNewsSDK
except ImportError:
raise ImportError(
"AskNews python package not found. "
"Please install it with `pip install asknews`."
)
an_client = AsyncAskNewsSDK(
client_id=self.client_id or os.environ["ASKNEWS_CLIENT_ID"],
client_secret=self.client_secret or os.environ["ASKNEWS_CLIENT_SECRET"],
scopes=["news"],
)
response = await an_client.news.search_news(
query=query,
n_articles=self.k,
start_timestamp=self.start_timestamp,
end_timestamp=self.end_timestamp,
method=self.method,
categories=self.categories,
historical=self.historical,
similarity_score_threshold=self.similarity_score_threshold,
offset=self.offset,
return_type="both",
doc_start_delimiter="<doc>",
doc_end_delimiter="</doc>",
**self.kwargs,
)
return self._extract_documents(response)
def _extract_documents(self, response: Any) -> List[Document]:
"""Extract documents from an api response."""
from asknews_sdk.dto.news import SearchResponse
sr: SearchResponse = response
matches = re.findall(r"<doc>(.*?)</doc>", sr.as_string, re.DOTALL)
docs = [
Document(
page_content=matches[i].strip(),
metadata={
"title": sr.as_dicts[i].title,
"source": str(sr.as_dicts[i].article_url)
if sr.as_dicts[i].article_url
else None,
"images": sr.as_dicts[i].image_url,
},
)
for i in range(len(matches))
]
return docs

View File

@@ -0,0 +1,226 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env, get_from_env
from pydantic import ConfigDict, model_validator
DEFAULT_URL_SUFFIX = "search.windows.net"
"""Default URL Suffix for endpoint connection - commercial cloud"""
class AzureAISearchRetriever(BaseRetriever):
"""`Azure AI Search` service retriever.
Setup:
See here for more detail: https://python.langchain.com/docs/integrations/retrievers/azure_ai_search/
We will need to install the below dependencies and set the required
environment variables:
.. code-block:: bash
pip install -U langchain-community azure-identity azure-search-documents
export AZURE_AI_SEARCH_SERVICE_NAME="<YOUR_SEARCH_SERVICE_NAME>"
export AZURE_AI_SEARCH_INDEX_NAME="<YOUR_SEARCH_INDEX_NAME>"
export AZURE_AI_SEARCH_API_KEY="<YOUR_API_KEY>"
or
export AZURE_AI_SEARCH_BEARER_TOKEN="<YOUR_BEARER_TOKEN>"
Key init args:
content_key: str
top_k: int
index_name: str
Instantiate:
.. code-block:: python
from langchain_community.retrievers import AzureAISearchRetriever
retriever = AzureAISearchRetriever(
content_key="content", top_k=1, index_name="langchain-vector-demo"
)
Usage:
.. code-block:: python
retriever.invoke("here is my unstructured query string")
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import AzureChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = AzureChatOpenAI(azure_deployment="gpt-35-turbo")
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("...")
""" # noqa: E501
service_name: str = ""
"""Name of Azure AI Search service"""
index_name: str = ""
"""Name of Index inside Azure AI Search service"""
api_key: str = ""
"""API Key. Both Admin and Query keys work, but for reading data it's
recommended to use a Query key."""
api_version: str = "2023-11-01"
"""API version"""
aiosession: Optional[aiohttp.ClientSession] = None
"""ClientSession, in case we want to reuse connection for better performance."""
azure_ad_token: str = ""
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_AI_SEARCH_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
"""
content_key: str = "content"
"""Key in a retrieved result to set as the Document page_content."""
top_k: Optional[int] = None
"""Number of results to retrieve. Set to None to retrieve all results."""
filter: Optional[str] = None
"""OData $filter expression to apply to the search query."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that service name, index name and api key exists in environment."""
values["service_name"] = get_from_dict_or_env(
values, "service_name", "AZURE_AI_SEARCH_SERVICE_NAME"
)
values["index_name"] = get_from_dict_or_env(
values, "index_name", "AZURE_AI_SEARCH_INDEX_NAME"
)
values["azure_ad_token"] = get_from_dict_or_env(
values, "azure_ad_token", "AZURE_AI_SEARCH_AD_TOKEN", default=""
)
values["api_key"] = get_from_dict_or_env(
values, "api_key", "AZURE_AI_SEARCH_API_KEY", default=""
)
if values["azure_ad_token"] == "" and values["api_key"] == "":
raise ValueError(
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, "
"or the `AZURE_AI_SEARCH_API_KEY` or `AZURE_AI_SEARCH_AD_TOKEN` "
"environment variables."
)
return values
def _build_search_url(self, query: str) -> str:
url_suffix = get_from_env("", "AZURE_AI_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX)
if url_suffix in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}/"
elif url_suffix in self.service_name and "https://" not in self.service_name:
base_url = f"https://{self.service_name}/"
elif url_suffix not in self.service_name and "https://" in self.service_name:
base_url = f"{self.service_name}.{url_suffix}/"
elif (
url_suffix not in self.service_name and "https://" not in self.service_name
):
base_url = f"https://{self.service_name}.{url_suffix}/"
else:
# pass to Azure to throw a specific error
base_url = self.service_name
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
top_param = f"&$top={self.top_k}" if self.top_k else ""
filter_param = f"&$filter={self.filter}" if self.filter else ""
return base_url + endpoint_path + f"&search={query}" + top_param + filter_param
@property
def _headers(self) -> Dict[str, str]:
headers = {
"Content-Type": "application/json",
}
if self.azure_ad_token:
headers["Authorization"] = f"Bearer {self.azure_ad_token}"
elif self.api_key:
headers["api-key"] = f"{self.api_key}"
return headers
def _search(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
response = requests.get(search_url, headers=self._headers)
if response.status_code != 200:
raise Exception(f"Error in search request: {response}")
return json.loads(response.text)["value"]
async def _asearch(self, query: str) -> List[dict]:
search_url = self._build_search_url(query)
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(search_url, headers=self._headers) as response:
response_json = await response.json()
else:
async with self.aiosession.get(
search_url, headers=self._headers
) as response:
response_json = await response.json()
return response_json["value"]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
search_results = self._search(query)
return [
Document(page_content=result.pop(self.content_key), metadata=result)
for result in search_results
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
search_results = await self._asearch(query)
return [
Document(page_content=result.pop(self.content_key), metadata=result)
for result in search_results
]
# For backwards compatibility
class AzureCognitiveSearchRetriever(AzureAISearchRetriever):
"""`Azure Cognitive Search` service retriever.
This version of the retriever will soon be
depreciated. Please switch to AzureAISearchRetriever
"""

View File

@@ -0,0 +1,186 @@
from typing import Any, Dict, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import BaseModel, model_validator
class VectorSearchConfig(BaseModel, extra="allow"):
"""Configuration for vector search."""
numberOfResults: int = 4
class RetrievalConfig(BaseModel, extra="allow"):
"""Configuration for retrieval."""
vectorSearchConfiguration: VectorSearchConfig
@deprecated(
since="0.3.16",
removal="1.0",
alternative_import="langchain_aws.AmazonKnowledgeBasesRetriever",
)
class AmazonKnowledgeBasesRetriever(BaseRetriever):
"""Amazon Bedrock Knowledge Bases retriever.
See https://aws.amazon.com/bedrock/knowledge-bases for more info.
Setup:
Install ``langchain-aws``:
.. code-block:: bash
pip install -U langchain-aws
Key init args:
knowledge_base_id: Knowledge Base ID.
region_name: The aws region e.g., `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in
~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
client: boto3 client for bedrock agent runtime.
retrieval_config: Configuration for retrieval.
Instantiate:
.. code-block:: python
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id="<knowledge-base-id>",
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 4
}
},
)
Usage:
.. code-block:: python
query = "..."
retriever.invoke(query)
Use within a chain:
.. code-block:: python
from langchain_aws import ChatBedrockConverse
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatBedrockConverse(
model_id="anthropic.claude-3-5-sonnet-20240620-v1:0"
)
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("...")
""" # noqa: E501
knowledge_base_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
endpoint_url: Optional[str] = None
client: Any
retrieval_config: RetrievalConfig
@model_validator(mode="before")
@classmethod
def create_client(cls, values: Dict[str, Any]) -> Any:
if values.get("client") is not None:
return values
try:
import boto3
from botocore.client import Config
from botocore.exceptions import UnknownServiceError
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {
"config": Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)
}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
if values.get("endpoint_url"):
client_params["endpoint_url"] = values["endpoint_url"]
values["client"] = session.client("bedrock-agent-runtime", **client_params)
return values
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except UnknownServiceError as e:
raise ImportError(
"Ensure that you have installed the latest boto3 package "
"that contains the API for `bedrock-runtime-agent`."
) from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = self.client.retrieve(
retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id,
retrievalConfiguration=self.retrieval_config.dict(),
)
results = response["retrievalResults"]
documents = []
for result in results:
content = result["content"]["text"]
result.pop("content")
if "score" not in result:
result["score"] = 0
if "metadata" in result:
result["source_metadata"] = result.pop("metadata")
documents.append(
Document(
page_content=content,
metadata=result,
)
)
return documents

View File

@@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, Field
def default_preprocessing_func(text: str) -> List[str]:
return text.split()
class BM25Retriever(BaseRetriever):
"""`BM25` retriever without Elasticsearch."""
vectorizer: Any = None
""" BM25 vectorizer."""
docs: List[Document] = Field(repr=False)
""" List of documents."""
k: int = 4
""" Number of documents to return."""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
""" Preprocessing function to use on the text before BM25 vectorization."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
ids: Optional[Iterable[str]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts: A list of texts to vectorize.
metadatas: A list of metadata dicts to associate with each text.
ids: A list of ids to associate with each text.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"Could not import rank_bm25, please install with `pip install "
"rank_bm25`."
)
texts_processed = [preprocess_func(t) for t in texts]
bm25_params = bm25_params or {}
vectorizer = BM25Okapi(texts_processed, **bm25_params)
metadatas = metadatas or ({} for _ in texts)
if ids:
docs = [
Document(page_content=t, metadata=m, id=i)
for t, m, i in zip(texts, metadatas, ids)
]
else:
docs = [
Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)
]
return cls(
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, **kwargs
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents: A list of Documents to vectorize.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas, ids = zip(
*((d.page_content, d.metadata, d.id) for d in documents)
)
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
ids=ids,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
processed_query = self.preprocess_func(query)
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
return return_docs

View File

@@ -0,0 +1,49 @@
from typing import List
import requests
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents.base import Document
from langchain_core.retrievers import BaseRetriever
class BreebsRetriever(BaseRetriever):
"""A retriever class for `Breebs`.
See https://www.breebs.com/ for more info.
Args:
breeb_key: The key to trigger the breeb
(specialized knowledge pill on a specific topic).
To retrieve the list of all available Breebs : you can call https://breebs.promptbreeders.com/web/listbreebs
"""
breeb_key: str
url: str = "https://breebs.promptbreeders.com/knowledge"
def __init__(self, breeb_key: str):
super().__init__(breeb_key=breeb_key) # type: ignore[call-arg]
self.breeb_key = breeb_key
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve context for given query.
Note that for time being there is no score."""
r = requests.post(
self.url,
json={
"breeb_key": self.breeb_key,
"query": query,
},
)
if r.status_code != 200:
return []
else:
chunks = r.json()
return [
Document(
page_content=chunk["content"],
metadata={"source": chunk["source_url"], "score": 1},
)
for chunk in chunks
]

View File

@@ -0,0 +1,92 @@
from typing import Any, List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ChaindeskRetriever(BaseRetriever):
"""`Chaindesk API` retriever."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]
def __init__(
self,
datastore_url: str,
top_k: Optional[int] = None,
api_key: Optional[str] = None,
):
self.datastore_url = datastore_url
self.api_key = api_key
self.top_k = top_k
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
)
data = response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
) as response:
data = await response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from typing import List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict
class ChatGPTPluginRetriever(BaseRetriever):
"""`ChatGPT plugin` retriever."""
url: str
"""URL of the ChatGPT plugin."""
bearer_token: str
"""Bearer token for the ChatGPT plugin."""
top_k: int = 3
"""Number of documents to return."""
filter: Optional[dict] = None
"""Filter to apply to the results."""
aiosession: Optional[aiohttp.ClientSession] = None
"""Aiohttp session to use for requests."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
url, json, headers = self._create_request(query)
response = requests.post(url, json=json, headers=headers)
results = response.json()["results"][0]["results"]
docs = []
for d in results:
content = d.pop("text")
metadata = d.pop("metadata", d)
if metadata.get("source_id"):
metadata["source"] = metadata.pop("source_id")
docs.append(Document(page_content=content, metadata=metadata))
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
url, json, headers = self._create_request(query)
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=json) as response:
res = await response.json()
else:
async with self.aiosession.post(
url, headers=headers, json=json
) as response:
res = await response.json()
results = res["results"][0]["results"]
docs = []
for d in results:
content = d.pop("text")
metadata = d.pop("metadata", d)
if metadata.get("source_id"):
metadata["source"] = metadata.pop("source_id")
docs.append(Document(page_content=content, metadata=metadata))
return docs
def _create_request(self, query: str) -> tuple[str, dict, dict]:
url = f"{self.url}/query"
json = {
"queries": [
{
"query": query,
"filter": self.filter,
"top_k": self.top_k,
}
]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
}
return url, json, headers

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, Field
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
)
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs
@deprecated(
since="0.0.30",
removal="1.0",
alternative_import="langchain_cohere.CohereRagRetriever",
)
class CohereRagRetriever(BaseRetriever):
"""Cohere Chat API with RAG."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
When specified, the model's reply will be enriched with information found by
querying each of the connectors (RAG). These will be returned as langchain
documents.
Currently only accepts {"id": "web-search"}.
"""
llm: BaseChatModel
"""Cohere ChatModel to use."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
)
).generations[0][0]
return _get_docs(res)

View File

@@ -0,0 +1,74 @@
from typing import List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class DataberryRetriever(BaseRetriever):
"""`Databerry API` retriever."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
)
data = response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
) as response:
data = await response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]

View File

@@ -0,0 +1,208 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils.pydantic import get_fields
from pydantic import ConfigDict
from langchain_community.vectorstores.utils import maximal_marginal_relevance
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
mmr = "mmr"
class DocArrayRetriever(BaseRetriever):
"""`DocArray Document Indices` retriever.
Currently, it supports 5 backends:
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
ElasticDocIndex, and WeaviateDocumentIndex.
Args:
index: One of the above-mentioned index instances
embeddings: Embedding model to represent text as vectors
search_field: Field to consider for searching in the documents.
Should be an embedding/vector/tensor.
content_field: Field that represents the main content in your document schema.
Will be used as a `page_content`. Everything else will go into `metadata`.
search_type: Type of search to perform (similarity / mmr)
filters: Filters applied for document retrieval.
top_k: Number of documents to return
"""
index: Any = None
embeddings: Embeddings
search_field: str
content_field: str
search_type: SearchType = SearchType.similarity
top_k: int = 1
filters: Optional[Any] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
query_emb = np.array(self.embeddings.embed_query(query))
if self.search_type == SearchType.similarity:
results = self._similarity_search(query_emb)
elif self.search_type == SearchType.mmr:
results = self._mmr_search(query_emb)
else:
raise ValueError(
f"Search type {self.search_type} does not exist. "
f"Choose either 'similarity' or 'mmr'."
)
return results
def _search(
self, query_emb: np.ndarray, top_k: int
) -> List[Union[Dict[str, Any], Any]]:
"""
Perform a search using the query embedding and return top_k documents.
Args:
query_emb: Query represented as an embedding
top_k: Number of documents to return
Returns:
A list of top_k documents matching the query
"""
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
filter_args = {}
search_field = self.search_field
if isinstance(self.index, WeaviateDocumentIndex):
filter_args["where_filter"] = self.filters
search_field = ""
elif isinstance(self.index, ElasticDocIndex):
filter_args["query"] = self.filters
else:
filter_args["filter_query"] = self.filters
if self.filters:
query = (
self.index.build_query() # get empty query object
.find(
query=query_emb, search_field=search_field
) # add vector similarity search
.filter(**filter_args) # add filter search
.build(limit=top_k) # build the query
)
# execute the combined query and return the results
docs = self.index.execute_query(query)
if hasattr(docs, "documents"):
docs = docs.documents
docs = docs[:top_k]
else:
docs = self.index.find(
query=query_emb, search_field=search_field, limit=top_k
).documents
return docs
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a similarity search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of documents most similar to the query
"""
docs = self._search(query_emb=query_emb, top_k=self.top_k)
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
return results
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a maximal marginal relevance (mmr) search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of diverse documents related to the query
"""
docs = self._search(query_emb=query_emb, top_k=20)
mmr_selected = maximal_marginal_relevance(
query_emb,
[
doc[self.search_field]
if isinstance(doc, dict)
else getattr(doc, self.search_field)
for doc in docs
],
k=self.top_k,
)
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
return results
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
"""
Convert a DocArray document (which also might be a dict)
to a langchain document format.
DocArray document can contain arbitrary fields, so the mapping is done
in the following way:
page_content <-> content_field
metadata <-> all other fields excluding
tensors and embeddings (so float, int, string)
Args:
doc: DocArray document
Returns:
Document in langchain format
Raises:
ValueError: If the document doesn't contain the content field
"""
fields = doc.keys() if isinstance(doc, dict) else get_fields(doc)
if self.content_field not in fields:
raise ValueError(
f"Document does not contain the content field - {self.content_field}."
)
lc_doc = Document(
page_content=doc[self.content_field]
if isinstance(doc, dict)
else getattr(doc, self.content_field)
)
for name in fields:
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
if (
isinstance(value, (str, int, float, bool))
and name != self.content_field
):
lc_doc.metadata[name] = value
return lc_doc

View File

@@ -0,0 +1,87 @@
"""Wrapper around Dria Retriever."""
from typing import Any, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities import DriaAPIWrapper
class DriaRetriever(BaseRetriever):
"""`Dria` retriever using the DriaAPIWrapper."""
api_wrapper: DriaAPIWrapper
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""
Initialize the DriaRetriever with a DriaAPIWrapper instance.
Args:
api_key: The API key for Dria.
contract_id: The contract ID of the knowledge base to interact with.
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""Create a new knowledge base in Dria.
Args:
name: The name of the knowledge base.
description: The description of the knowledge base.
category: The category of the knowledge base.
embedding: The embedding model to use for the knowledge base.
Returns:
The ID of the created knowledge base.
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response
def add_texts(
self,
texts: List,
) -> None:
"""Add texts to the Dria knowledge base.
Args:
texts: An iterable of texts and metadatas to add to the knowledge base.
Returns:
List of IDs representing the added texts.
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve relevant documents from Dria based on a query.
Args:
query: The query string to search for in the knowledge base.
run_manager: Callback manager for the retriever run.
Returns:
A list of Documents containing the search results.
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs

View File

@@ -0,0 +1,137 @@
"""Wrapper around Elasticsearch vector database."""
from __future__ import annotations
import uuid
from typing import Any, Iterable, List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ElasticSearchBM25Retriever(BaseRetriever):
"""`Elasticsearch` retriever that uses `BM25`.
To connect to an Elasticsearch instance that requires login credentials,
including Elastic Cloud, use the Elasticsearch URL format
https://username:password@es_host:9243. For example, to connect to Elastic
Cloud, create the Elasticsearch URL with the required authentication details and
pass it to the ElasticVectorSearch constructor as the named parameter
elasticsearch_url.
You can obtain your Elastic Cloud URL and login credentials by logging in to the
Elastic Cloud console at https://cloud.elastic.co, selecting your deployment, and
navigating to the "Deployments" page.
To obtain your Elastic Cloud password for the default "elastic" user:
1. Log in to the Elastic Cloud console at https://cloud.elastic.co
2. Go to "Security" > "Users"
3. Locate the "elastic" user and click "Edit"
4. Click "Reset password"
5. Follow the prompts to reset the password
The format for Elastic Cloud URLs is
https://username:password@cluster_id.region_id.gcp.cloud.es.io:9243.
"""
client: Any
"""Elasticsearch client."""
index_name: str
"""Name of the index to use in Elasticsearch."""
@classmethod
def create(
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
) -> ElasticSearchBM25Retriever:
"""
Create a ElasticSearchBM25Retriever from a list of texts.
Args:
elasticsearch_url: URL of the Elasticsearch instance to connect to.
index_name: Name of the index to use in Elasticsearch.
k1: BM25 parameter k1.
b: BM25 parameter b.
Returns:
"""
from elasticsearch import Elasticsearch
# Create an Elasticsearch client instance
es = Elasticsearch(elasticsearch_url)
# Define the index settings and mappings
settings = {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": k1,
"b": b,
}
},
}
mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25", # Use the custom BM25 similarity
}
}
}
# Create the index with the specified settings and mappings
es.indices.create(index=index_name, mappings=mappings, settings=settings)
return cls(client=es, index_name=index_name)
def add_texts(
self,
texts: Iterable[str],
refresh_indices: bool = True,
) -> List[str]:
"""Run more texts through the embeddings and add to the retriever.
Args:
texts: Iterable of strings to add to the retriever.
refresh_indices: bool to refresh ElasticSearch indices
Returns:
List of ids from adding the texts into the retriever.
"""
try:
from elasticsearch.helpers import bulk
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
requests = []
ids = []
for i, text in enumerate(texts):
_id = str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": self.index_name,
"content": text,
"_id": _id,
}
ids.append(_id)
requests.append(request)
bulk(self.client, requests)
if refresh_indices:
self.client.indices.refresh(index=self.index_name)
return ids
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict)
docs = []
for r in res["hits"]["hits"]:
docs.append(Document(page_content=r["_source"]["content"]))
return docs

View File

@@ -0,0 +1,74 @@
"""Wrapper around Embedchain Retriever."""
from __future__ import annotations
from typing import Any, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class EmbedchainRetriever(BaseRetriever):
"""`Embedchain` retriever."""
client: Any
"""Embedchain Pipeline."""
@classmethod
def create(cls, yaml_path: Optional[str] = None) -> EmbedchainRetriever:
"""
Create a EmbedchainRetriever from a YAML configuration file.
Args:
yaml_path: Path to the YAML configuration file. If not provided,
a default configuration is used.
Returns:
An instance of EmbedchainRetriever.
"""
from embedchain import Pipeline
# Create an Embedchain Pipeline instance
if yaml_path:
client = Pipeline.from_config(yaml_path=yaml_path)
else:
client = Pipeline()
return cls(client=client)
def add_texts(
self,
texts: Iterable[str],
) -> List[str]:
"""Run more texts through the embeddings and add to the retriever.
Args:
texts: Iterable of strings/URLs to add to the retriever.
Returns:
List of ids from adding the texts into the retriever.
"""
ids = []
for text in texts:
_id = self.client.add(text)
ids.append(_id)
return ids
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
res = self.client.search(query)
docs = []
for r in res:
docs.append(
Document(
page_content=r["context"],
metadata={
"source": r["metadata"]["url"],
"document_id": r["metadata"]["doc_id"],
},
)
)
return docs

View File

@@ -0,0 +1,126 @@
"""Retriever wrapper for Google Cloud Document AI Warehouse."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_community.utilities.vertexai import get_client_info
if TYPE_CHECKING:
from google.cloud.contentwarehouse_v1 import (
DocumentServiceClient,
RequestMetadata,
SearchDocumentsRequest,
)
from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
SearchDocumentsPager,
)
@deprecated(
since="0.0.32",
removal="1.0",
alternative_import="langchain_google_community.DocumentAIWarehouseRetriever",
)
class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
"""A retriever based on Document AI Warehouse.
Documents should be created and documents should be uploaded
in a separate flow, and this retriever uses only Document AI
schema_id provided to search for relevant documents.
More info: https://cloud.google.com/document-ai-warehouse.
"""
location: str = "us"
"""Google Cloud location where Document AI Warehouse is placed."""
project_number: str
"""Google Cloud project number, should contain digits only."""
schema_id: Optional[str] = None
"""Document AI Warehouse schema to query against.
If nothing is provided, all documents in the project will be searched."""
qa_size_limit: int = 5
"""The limit on the number of documents returned."""
client: "DocumentServiceClient" = None #: :meta private:
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
try:
from google.cloud.contentwarehouse_v1 import DocumentServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.contentwarehouse is not installed."
"Please install it with pip install google-cloud-contentwarehouse"
) from exc
values["project_number"] = get_from_dict_or_env(
values, "project_number", "PROJECT_NUMBER"
)
values["client"] = DocumentServiceClient(
client_info=get_client_info(module="document-ai-warehouse")
)
return values
def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo
user_info = UserInfo(id=f"user:{user_ldap}")
return RequestMetadata(user_info=user_info)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
request = self._prepare_search_request(query, **kwargs)
response = self.client.search_documents(request=request)
return self._parse_search_response(response=response)
def _prepare_search_request(
self, query: str, **kwargs: Any
) -> "SearchDocumentsRequest":
from google.cloud.contentwarehouse_v1 import (
DocumentQuery,
SearchDocumentsRequest,
)
try:
user_ldap = kwargs["user_ldap"]
except KeyError:
raise ValueError("Argument user_ldap should be provided!")
request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
schemas = []
if self.schema_id:
schemas.append(
self.client.document_schema_path(
project=self.project_number,
location=self.location,
document_schema=self.schema_id,
)
)
return SearchDocumentsRequest(
parent=self.client.common_location_path(self.project_number, self.location),
request_metadata=request_metadata,
document_query=DocumentQuery(
query=query, is_nl_query=True, document_schema_names=schemas
),
qa_size_limit=self.qa_size_limit,
)
def _parse_search_response(
self, response: "SearchDocumentsPager"
) -> List[Document]:
documents = []
for doc in response.matching_documents:
metadata = {
"title": doc.document.title,
"source": doc.document.raw_document_path,
}
documents.append(
Document(page_content=doc.search_text_snippet, metadata=metadata)
)
return documents

View File

@@ -0,0 +1,491 @@
"""Retriever wrapper for Google Vertex AI Search."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, Field, model_validator
from langchain_community.utilities.vertexai import get_client_info
if TYPE_CHECKING:
from google.api_core.client_options import ClientOptions
from google.cloud.discoveryengine_v1beta import SearchRequest, SearchResult
class _BaseGoogleVertexAISearchRetriever(BaseModel):
project_id: str
"""Google Cloud Project ID."""
data_store_id: Optional[str] = None
"""Vertex AI Search data store ID."""
search_engine_id: Optional[str] = None
"""Vertex AI Search app ID."""
location_id: str = "global"
"""Vertex AI Search data store location."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
engine_data_type: int = Field(default=0, ge=0, le=3)
""" Defines the Vertex AI Search app data type
0 - Unstructured data
1 - Structured data
2 - Website data
3 - Blended search
"""
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validates the environment."""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install "
"google-cloud-discoveryengine>=0.11.10"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
raise ImportError(
"google.api_core.exceptions is not installed. "
"Please install it with pip install google-api-core"
) from exc
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
try:
values["data_store_id"] = get_from_dict_or_env(
values, "data_store_id", "DATA_STORE_ID"
)
values["search_engine_id"] = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID"
)
except Exception:
pass
return values
@property
def client_options(self) -> "ClientOptions":
from google.api_core.client_options import ClientOptions
return ClientOptions(
api_endpoint=(
f"{self.location_id}-discoveryengine.googleapis.com"
if self.location_id != "global"
else None
)
)
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
chunk_metadata = doc_metadata.copy()
chunk_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
chunk_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=chunk_metadata
)
)
return documents
def _convert_website_search_response(
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
doc_metadata["source"] = derived_struct_data.get("link", "")
if derived_struct_data.get("title") is not None:
doc_metadata["title"] = derived_struct_data.get("title")
if chunk_type not in derived_struct_data:
continue
text_field = "snippet" if chunk_type == "snippets" else "content"
for chunk in derived_struct_data[chunk_type]:
documents.append(
Document(
page_content=chunk.get(text_field, ""), metadata=doc_metadata
)
)
if not documents:
print(f"No {chunk_type} could be found.") # noqa: T201
if chunk_type == "extractive_answers":
print( # noqa: T201
"Make sure that your data store is using Advanced Website "
"Indexing.\n"
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing"
)
return documents
@deprecated(
since="0.0.33",
removal="1.0",
alternative_import="langchain_google_community.VertexAISearchRetriever",
)
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments or Snippets.""" # noqa: E501
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
# type is SearchServiceClient but can't be set due to optional imports
_client: Any = None
_serving_config: str
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="ignore",
)
def __init__(self, **kwargs: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
super().__init__(**kwargs)
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
self._client = SearchServiceClient(
credentials=self.credentials,
client_options=self.client_options,
client_info=get_client_info(module="vertex-ai-search"),
)
if self.engine_data_type == 3 and not self.search_engine_id:
raise ValueError(
"search_engine_id must be specified for blended search apps."
)
if self.search_engine_id:
self._serving_config = f"projects/{self.project_id}/locations/{self.location_id}/collections/default_collection/engines/{self.search_engine_id}/servingConfigs/default_config" # noqa: E501
elif self.data_store_id:
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
else:
raise ValueError(
"Either data_store_id or search_engine_id must be specified."
)
def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
)
spell_correction_spec = SearchRequest.SpellCorrectionSpec(
mode=self.spell_correction_mode
)
if self.engine_data_type == 0:
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec
)
elif self.engine_data_type == 1:
content_search_spec = None
elif self.engine_data_type in (2, 3):
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
),
snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec(
return_snippet=True
),
)
else:
raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured),"
"2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}"
)
return SearchRequest(
query=query,
filter=self.filter,
serving_config=self._serving_config,
page_size=self.max_documents,
content_search_spec=content_search_spec,
query_expansion_spec=query_expansion_spec,
spell_correction_spec=spell_correction_spec,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
return self.get_relevant_documents_with_response(query)[0]
def get_relevant_documents_with_response(
self, query: str
) -> Tuple[List[Document], Any]:
from google.api_core.exceptions import InvalidArgument
search_request = self._create_search_request(query)
try:
response = self._client.search(search_request)
except InvalidArgument as exc:
raise type(exc)(
exc.message
+ " This might be due to engine_data_type not set correctly."
)
if self.engine_data_type == 0:
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
documents = self._convert_unstructured_search_response(
response.results, chunk_type
)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
elif self.engine_data_type in (2, 3):
chunk_type = (
"extractive_answers" if self.get_extractive_answers else "snippets"
)
documents = self._convert_website_search_response(
response.results, chunk_type
)
else:
raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured),"
"2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}"
)
return documents, response
@deprecated(
since="0.0.33",
removal="1.0",
alternative_import="langchain_google_community.VertexAIMultiTurnSearchRetriever",
)
class GoogleVertexAIMultiTurnSearchRetriever(
BaseRetriever, _BaseGoogleVertexAISearchRetriever
):
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
conversation_id: str = "-"
"""Vertex AI Search Conversation ID."""
# type is ConversationalSearchServiceClient but can't be set due to optional imports
_client: Any = None
_serving_config: str
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="ignore",
)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
)
self._client = ConversationalSearchServiceClient(
credentials=self.credentials,
client_options=self.client_options,
client_info=get_client_info(module="vertex-ai-search"),
)
if not self.data_store_id:
raise ValueError("data_store_id is required for MultiTurnSearchRetriever.")
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
if self.engine_data_type == 1 or self.engine_data_type == 3:
raise NotImplementedError(
"Data store type 1 (Structured) and 3 (Blended)"
"is not currently supported for multi-turn search."
+ f" Got {self.engine_data_type}"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.cloud.discoveryengine_v1beta import (
ConverseConversationRequest,
TextInput,
)
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id,
self.location_id,
self.data_store_id,
self.conversation_id,
),
serving_config=self._serving_config,
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
if self.engine_data_type == 2:
return self._convert_website_search_response(
response.search_results, "extractive_answers"
)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import Any, List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class KayAiRetriever(BaseRetriever):
"""
Retriever for Kay.ai datasets.
To work properly, expects you to have KAY_API_KEY env variable set.
You can get one for free at https://kay.ai/.
"""
client: Any
num_contexts: int
@classmethod
def create(
cls,
dataset_id: str,
data_types: List[str],
num_contexts: int = 6,
) -> KayAiRetriever:
"""
Create a KayRetriever given a Kay dataset id and a list of datasources.
Args:
dataset_id: A dataset id category in Kay, like "company"
data_types: A list of datasources present within a dataset. For
"company" the corresponding datasources could be
["10-K", "10-Q", "8-K", "PressRelease"].
num_contexts: The number of documents to retrieve on each query.
Defaults to 6.
"""
try:
from kay.rag.retrievers import KayRetriever
except ImportError:
raise ImportError(
"Could not import kay python package. Please install it with "
"`pip install kay`.",
)
client = KayRetriever(dataset_id, data_types)
return cls(client=client, num_contexts=num_contexts)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
ctxs = self.client.query(query=query, num_context=self.num_contexts)
docs = []
for ctx in ctxs:
page_content = ctx.pop("chunk_embed_text", None)
if page_content is None:
continue
docs.append(Document(page_content=page_content, metadata={**ctx}))
return docs

View File

@@ -0,0 +1,496 @@
import re
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import (
BaseModel,
Field,
model_validator,
validator,
)
from typing_extensions import Annotated
def clean_excerpt(excerpt: str) -> str:
"""Clean an excerpt from Kendra.
Args:
excerpt: The excerpt to clean.
Returns:
The cleaned excerpt.
"""
if not excerpt:
return excerpt
res = re.sub(r"\s+", " ", excerpt).replace("...", "")
return res
def combined_text(item: "ResultItem") -> str:
"""Combine a ResultItem title and excerpt into a single string.
Args:
item: the ResultItem of a Kendra search.
Returns:
A combined text of the title and excerpt of the given item.
"""
text = ""
title = item.get_title()
if title:
text += f"Document Title: {title}\n"
excerpt = clean_excerpt(item.get_excerpt())
if excerpt:
text += f"Document Excerpt: \n{excerpt}\n"
return text
DocumentAttributeValueType = Union[str, int, List[str], None]
"""Possible types of a DocumentAttributeValue.
Dates are also represented as str.
"""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class Highlight(BaseModel, extra="allow"):
"""Information that highlights the keywords in the excerpt."""
BeginOffset: int
"""The zero-based location in the excerpt where the highlight starts."""
EndOffset: int
"""The zero-based location in the excerpt where the highlight ends."""
TopAnswer: Optional[bool]
"""Indicates whether the result is the best one."""
Type: Optional[str]
"""The highlight type: STANDARD or THESAURUS_SYNONYM."""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class TextWithHighLights(BaseModel, extra="allow"):
"""Text with highlights."""
Text: str
"""The text."""
Highlights: Optional[Any]
"""The highlights."""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class AdditionalResultAttributeValue(BaseModel, extra="allow"):
"""Value of an additional result attribute."""
TextWithHighlightsValue: TextWithHighLights
"""The text with highlights value."""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class AdditionalResultAttribute(BaseModel, extra="allow"):
"""Additional result attribute."""
Key: str
"""The key of the attribute."""
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
"""The type of the value."""
Value: AdditionalResultAttributeValue
"""The value of the attribute."""
def get_value_text(self) -> str:
return self.Value.TextWithHighlightsValue.Text
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class DocumentAttributeValue(BaseModel, extra="allow"):
"""Value of a document attribute."""
DateValue: Optional[str] = None
"""The date expressed as an ISO 8601 string."""
LongValue: Optional[int] = None
"""The long value."""
StringListValue: Optional[List[str]] = None
"""The string list value."""
StringValue: Optional[str] = None
"""The string value."""
@property
def value(self) -> DocumentAttributeValueType:
"""The only defined document attribute value or None.
According to Amazon Kendra, you can only provide one
value for a document attribute.
"""
if self.DateValue:
return self.DateValue
if self.LongValue:
return self.LongValue
if self.StringListValue:
return self.StringListValue
if self.StringValue:
return self.StringValue
return None
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class DocumentAttribute(BaseModel, extra="allow"):
"""Document attribute."""
Key: str
"""The key of the attribute."""
Value: DocumentAttributeValue
"""The value of the attribute."""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class ResultItem(BaseModel, ABC, extra="allow"):
"""Base class of a result item."""
Id: Optional[str]
"""The ID of the relevant result item."""
DocumentId: Optional[str]
"""The document ID."""
DocumentURI: Optional[str]
"""The document URI."""
DocumentAttributes: Optional[List[DocumentAttribute]] = []
"""The document attributes."""
ScoreAttributes: Optional[dict]
"""The kendra score confidence"""
@abstractmethod
def get_title(self) -> str:
"""Document title."""
@abstractmethod
def get_excerpt(self) -> str:
"""Document excerpt or passage original content as retrieved by Kendra."""
def get_additional_metadata(self) -> dict:
"""Document additional metadata dict.
This returns any extra metadata except these:
* result_id
* document_id
* source
* title
* excerpt
* document_attributes
"""
return {}
def get_document_attributes_dict(self) -> Dict[str, DocumentAttributeValueType]:
"""Document attributes dict."""
return {attr.Key: attr.Value.value for attr in (self.DocumentAttributes or [])}
def get_score_attribute(self) -> str:
"""Document Score Confidence"""
if self.ScoreAttributes is not None:
return self.ScoreAttributes["ScoreConfidence"]
else:
return "NOT_AVAILABLE"
def to_doc(
self, page_content_formatter: Callable[["ResultItem"], str] = combined_text
) -> Document:
"""Converts this item to a Document."""
page_content = page_content_formatter(self)
metadata = self.get_additional_metadata()
metadata.update(
{
"result_id": self.Id,
"document_id": self.DocumentId,
"source": self.DocumentURI,
"title": self.get_title(),
"excerpt": self.get_excerpt(),
"document_attributes": self.get_document_attributes_dict(),
"score": self.get_score_attribute(),
}
)
return Document(page_content=page_content, metadata=metadata)
class QueryResultItem(ResultItem):
"""Query API result item."""
DocumentTitle: TextWithHighLights
"""The document title."""
FeedbackToken: Optional[str]
"""Identifies a particular result from a particular query."""
Format: Optional[str]
"""
If the Type is ANSWER, then format is either:
* TABLE: a table excerpt is returned in TableExcerpt;
* TEXT: a text excerpt is returned in DocumentExcerpt.
"""
Type: Optional[str]
"""Type of result: DOCUMENT or QUESTION_ANSWER or ANSWER"""
AdditionalAttributes: Optional[List[AdditionalResultAttribute]] = []
"""One or more additional attributes associated with the result."""
DocumentExcerpt: Optional[TextWithHighLights]
"""Excerpt of the document text."""
def get_title(self) -> str:
return self.DocumentTitle.Text
def get_attribute_value(self) -> str:
if not self.AdditionalAttributes:
return ""
if not self.AdditionalAttributes[0]:
return ""
else:
return self.AdditionalAttributes[0].get_value_text()
def get_excerpt(self) -> str:
if (
self.AdditionalAttributes
and self.AdditionalAttributes[0].Key == "AnswerText"
):
excerpt = self.get_attribute_value()
elif self.DocumentExcerpt:
excerpt = self.DocumentExcerpt.Text
else:
excerpt = ""
return excerpt
def get_additional_metadata(self) -> dict:
additional_metadata = {"type": self.Type}
return additional_metadata
class RetrieveResultItem(ResultItem):
"""Retrieve API result item."""
DocumentTitle: Optional[str]
"""The document title."""
Content: Optional[str]
"""The content of the item."""
def get_title(self) -> str:
return self.DocumentTitle or ""
def get_excerpt(self) -> str:
return self.Content or ""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class QueryResult(BaseModel, extra="allow"):
"""`Amazon Kendra Query API` search result.
It is composed of:
* Relevant suggested answers: either a text excerpt or table excerpt.
* Matching FAQs or questions-answer from your FAQ file.
* Documents including an excerpt of each document with its title.
"""
ResultItems: List[QueryResultItem]
"""The result items."""
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class RetrieveResult(BaseModel, extra="allow"):
"""`Amazon Kendra Retrieve API` search result.
It is composed of:
* relevant passages or text excerpts given an input query.
"""
QueryId: str
"""The ID of the query."""
ResultItems: List[RetrieveResultItem]
"""The result items."""
KENDRA_CONFIDENCE_MAPPING = {
"NOT_AVAILABLE": 0.0,
"LOW": 0.25,
"MEDIUM": 0.50,
"HIGH": 0.75,
"VERY_HIGH": 1.0,
}
@deprecated(
since="0.3.16",
removal="1.0",
alternative_import="langchain_aws.AmazonKendraRetriever",
)
class AmazonKendraRetriever(BaseRetriever):
"""`Amazon Kendra Index` retriever.
Args:
index_id: Kendra index id
region_name: The aws region e.g., `us-west-2`.
Fallsback to AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
top_k: No of results to return
attribute_filter: Additional filtering of results based on metadata
See: https://docs.aws.amazon.com/kendra/latest/APIReference
document_relevance_override_configurations: Overrides relevance tuning
configurations of fields/attributes set at the index level
See: https://docs.aws.amazon.com/kendra/latest/APIReference
page_content_formatter: generates the Document page_content
allowing access to all result item attributes. By default, it uses
the item's title and excerpt.
client: boto3 client for Kendra
user_context: Provides information about the user context
See: https://docs.aws.amazon.com/kendra/latest/APIReference
Example:
.. code-block:: python
retriever = AmazonKendraRetriever(
index_id="c0806df7-e76b-4bce-9b5c-d5582f6b1a03"
)
"""
index_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
top_k: int = 3
attribute_filter: Optional[Dict] = None
document_relevance_override_configurations: Optional[List[Dict]] = None
page_content_formatter: Callable[[ResultItem], str] = combined_text
client: Any
user_context: Optional[Dict] = None
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
@validator("top_k")
def validate_top_k(cls, value: int) -> int:
if value < 0:
raise ValueError(f"top_k ({value}) cannot be negative.")
return value
@model_validator(mode="before")
@classmethod
def create_client(cls, values: Dict[str, Any]) -> Any:
top_k = values.get("top_k")
if top_k is not None and top_k < 0:
raise ValueError(f"top_k ({top_k}) cannot be negative.")
if values.get("client") is not None:
return values
try:
import boto3
if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()
client_params = {}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]
values["client"] = session.client("kendra", **client_params)
return values
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
def _kendra_query(self, query: str) -> Sequence[ResultItem]:
kendra_kwargs = {
"IndexId": self.index_id,
# truncate the query to ensure that
# there is no validation exception from Kendra.
"QueryText": query.strip()[0:999],
"PageSize": self.top_k,
}
if self.attribute_filter is not None:
kendra_kwargs["AttributeFilter"] = self.attribute_filter
if self.document_relevance_override_configurations is not None:
kendra_kwargs["DocumentRelevanceOverrideConfigurations"] = (
self.document_relevance_override_configurations
)
if self.user_context is not None:
kendra_kwargs["UserContext"] = self.user_context
response = self.client.retrieve(**kendra_kwargs)
r_result = RetrieveResult.parse_obj(response)
if r_result.ResultItems:
return r_result.ResultItems
# Retrieve API returned 0 results, fall back to Query API
response = self.client.query(**kendra_kwargs)
q_result = QueryResult.parse_obj(response)
return q_result.ResultItems
def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
top_docs = [
item.to_doc(self.page_content_formatter)
for item in result_items[: self.top_k]
]
return top_docs
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
"""
Filter out the records that have a score confidence
greater than the required threshold.
"""
if not self.min_score_confidence:
return docs
filtered_docs = [
item
for item in docs
if (
item.metadata.get("score") is not None
and isinstance(item.metadata["score"], str)
and KENDRA_CONFIDENCE_MAPPING.get(item.metadata["score"], 0.0)
>= self.min_score_confidence
)
]
return filtered_docs
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Run search on Kendra index and get top k documents
Example:
.. code-block:: python
docs = retriever.invoke('This is my query')
"""
result_items = self._kendra_query(query)
top_k_docs = self._get_top_k_docs(result_items)
return self._filter_by_score_confidence(top_k_docs)

View File

@@ -0,0 +1,107 @@
"""KNN Retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations
import concurrent.futures
from typing import Any, Iterable, List, Optional
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
"""
Create an index of embeddings for a list of contexts.
Args:
contexts: List of contexts to embed.
embeddings: Embeddings model to use.
Returns:
Index of embeddings.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class KNNRetriever(BaseRetriever):
"""`KNN` retriever."""
embeddings: Embeddings
"""Embeddings model to use."""
index: Any = None
"""Index of embeddings."""
texts: List[str]
"""List of texts to index."""
metadatas: Optional[List[dict]] = None
"""List of metadatas corresponding with each text."""
k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embeddings: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> KNNRetriever:
index = create_index(texts, embeddings)
return cls(
embeddings=embeddings,
index=index,
texts=texts,
metadatas=metadatas,
**kwargs,
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
embeddings: Embeddings,
**kwargs: Any,
) -> KNNRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
query_embeds = query_embeds / np.sqrt((query_embeds**2).sum())
similarities = index_embeds.dot(query_embeds)
sorted_ix = np.argsort(-similarities)
denominator = np.max(similarities) - np.min(similarities) + 1e-6
normalized_similarities = (similarities - np.min(similarities)) / denominator
top_k_results = [
Document(
page_content=self.texts[row],
metadata=self.metadatas[row] if self.metadatas else {},
)
for row in sorted_ix[0 : self.k]
if (
self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold
)
]
return top_k_results

View File

@@ -0,0 +1,86 @@
from typing import Any, Dict, List, cast
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
class LlamaIndexRetriever(BaseRetriever):
"""`LlamaIndex` retriever.
It is used for the question-answering with sources over
an LlamaIndex data structure."""
index: Any = None
"""LlamaIndex index to query."""
query_kwargs: Dict = Field(default_factory=dict)
"""Keyword arguments to pass to the query method."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.core.base.response.schema import Response
from llama_index.core.indices.base import BaseGPTIndex
except ImportError:
raise ImportError(
"You need to install `pip install llama-index` to use this retriever."
)
index = cast(BaseGPTIndex, self.index)
response = index.query(query, **self.query_kwargs)
response = cast(Response, response)
# parse source nodes
docs = []
for source_node in response.source_nodes:
metadata = source_node.metadata or {}
docs.append(
Document(page_content=source_node.get_content(), metadata=metadata)
)
return docs
class LlamaIndexGraphRetriever(BaseRetriever):
"""`LlamaIndex` graph data structure retriever.
It is used for question-answering with sources over an LlamaIndex
graph data structure."""
graph: Any = None
"""LlamaIndex graph to query."""
query_configs: List[Dict] = Field(default_factory=list)
"""List of query configs to pass to the query method."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.core.base.response.schema import Response
from llama_index.core.composability.base import (
QUERY_CONFIG_TYPE,
ComposableGraph,
)
except ImportError:
raise ImportError(
"You need to install `pip install llama-index` to use this retriever."
)
graph = cast(ComposableGraph, self.graph)
# for now, inject response_mode="no_text" into query configs
for query_config in self.query_configs:
query_config["response_mode"] = "no_text"
query_configs = cast(List[QUERY_CONFIG_TYPE], self.query_configs)
response = graph.query(query, query_configs=query_configs)
response = cast(Response, response)
# parse source nodes
docs = []
for source_node in response.source_nodes:
metadata = source_node.metadata or {}
docs.append(
Document(page_content=source_node.get_content(), metadata=metadata)
)
return docs

View File

@@ -0,0 +1,43 @@
from typing import Any, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import model_validator
class MetalRetriever(BaseRetriever):
"""`Metal API` retriever."""
client: Any
"""The Metal client to use."""
params: Optional[dict] = None
"""The parameters to pass to the Metal client."""
@model_validator(mode="before")
@classmethod
def validate_client(cls, values: dict) -> Any:
"""Validate that the client is of the correct type."""
from metal_sdk.metal import Metal
if "client" in values:
client = values["client"]
if not isinstance(client, Metal):
raise ValueError(
"Got unexpected client, should be of type metal_sdk.metal.Metal. "
f"Instead, got {type(client)}"
)
values["params"] = values.get("params", {})
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
results = self.client.search({"text": query}, **self.params)
final_results = []
for r in results["data"]:
metadata = {k: v for k, v in r.items() if k != "text"}
final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results

View File

@@ -0,0 +1,150 @@
"""Milvus Retriever"""
import warnings
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import model_validator
from langchain_community.vectorstores.milvus import Milvus
# TODO: Update to MilvusClient + Hybrid Search when available
class MilvusRetriever(BaseRetriever):
"""Milvus API retriever.
See detailed instructions here: https://python.langchain.com/docs/integrations/retrievers/milvus_hybrid_search/
Setup:
Install ``langchain-milvus`` and other dependencies:
.. code-block:: bash
pip install -U pymilvus[model] langchain-milvus
Key init args:
collection: Milvus Collection
Instantiate:
.. code-block:: python
retriever = MilvusCollectionHybridSearchRetriever(collection=collection)
Usage:
.. code-block:: python
query = "What are the story about ventures?"
retriever.invoke(query)
.. code-block:: none
[Document(page_content="In 'The Lost Expedition' by Caspian Grey...", metadata={'doc_id': '449281835035545843'}),
Document(page_content="In 'The Phantom Pilgrim' by Rowan Welles...", metadata={'doc_id': '449281835035545845'}),
Document(page_content="In 'The Dreamwalker's Journey' by Lyra Snow..", metadata={'doc_id': '449281835035545846'})]
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("What novels has Lila written and what are their contents?")
.. code-block:: none
"Lila Rose has written 'The Memory Thief,' which follows a charismatic thief..."
""" # noqa: E501
embedding_function: Embeddings
collection_name: str = "LangChainCollection"
collection_properties: Optional[Dict[str, Any]] = None
connection_args: Optional[Dict[str, Any]] = None
consistency_level: str = "Session"
search_params: Optional[dict] = None
store: Milvus
retriever: BaseRetriever
@model_validator(mode="before")
@classmethod
def create_retriever(cls, values: Dict) -> Any:
"""Create the Milvus store and retriever."""
values["store"] = Milvus(
values["embedding_function"],
values["collection_name"],
values["collection_properties"],
values["connection_args"],
values["consistency_level"],
)
values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
return values
def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> None:
"""Add text to the Milvus store
Args:
texts (List[str]): The text
metadatas (List[dict]): Metadata dicts, must line up with existing store
"""
self.store.add_texts(texts, metadatas)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.retriever.invoke(
query, run_manager=run_manager.get_child(), **kwargs
)
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.
Args:
*args:
**kwargs:
Returns:
MilvusRetriever
"""
warnings.warn(
"MilvusRetreiver will be deprecated in the future. "
"Please use MilvusRetriever ('i' before 'e') instead.",
DeprecationWarning,
)
return MilvusRetriever(*args, **kwargs)

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
import concurrent.futures
from typing import Any, Iterable, List, Optional
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
"""
Create an index of embeddings for a list of contexts.
Args:
contexts: List of contexts to embed.
embeddings: Embeddings model to use.
Returns:
Index of embeddings.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class NanoPQRetriever(BaseRetriever):
"""`NanoPQ retriever."""
embeddings: Embeddings
"""Embeddings model to use."""
index: Any = None
"""Index of embeddings."""
texts: List[str]
"""List of texts to index."""
metadatas: Optional[List[dict]] = None
"""List of metadatas corresponding with each text."""
k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
subspace: int = 4
"""No of subspaces to be created, should be a multiple of embedding shape"""
clusters: int = 128
"""No of clusters to be created"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embeddings: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> NanoPQRetriever:
index = create_index(texts, embeddings)
return cls(
embeddings=embeddings,
index=index,
texts=texts,
metadatas=metadatas,
**kwargs,
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
embeddings: Embeddings,
**kwargs: Any,
) -> NanoPQRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
try:
from nanopq import PQ
except ImportError:
raise ImportError(
"Could not import nanopq, please install with `pip install nanopq`."
)
query_embeds = np.array(self.embeddings.embed_query(query))
try:
pq = PQ(M=self.subspace, Ks=self.clusters, verbose=True).fit(
self.index.astype("float32")
)
except AssertionError:
error_message = (
"Received params: training_sample={training_sample}, "
"n_cluster={n_clusters}, subspace={subspace}, "
"embedding_shape={embedding_shape}. Issue with the combination. "
"Please retrace back to find the exact error"
).format(
training_sample=self.index.shape[0],
n_clusters=self.clusters,
subspace=self.subspace,
embedding_shape=self.index.shape[1],
)
raise RuntimeError(error_message)
index_code = pq.encode(vecs=self.index.astype("float32"))
dt = pq.dtable(query=query_embeds.astype("float32"))
dists = dt.adist(codes=index_code)
sorted_ix = np.argsort(dists)
top_k_results = [
Document(
page_content=self.texts[row],
metadata=self.metadatas[row] if self.metadatas else {},
)
for row in sorted_ix[0 : self.k]
]
return top_k_results

View File

@@ -0,0 +1,101 @@
from typing import Any, List, Optional # noqa: I001
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import BaseModel, Field
class NeedleRetriever(BaseRetriever, BaseModel):
"""
NeedleRetriever retrieves relevant documents or context from a Needle collection
based on a search query.
Setup:
Install the `needle-python` library and set your Needle API key.
.. code-block:: bash
pip install needle-python
export NEEDLE_API_KEY="your-api-key"
Key init args:
- `needle_api_key` (Optional[str]): The API key for authenticating with Needle.
- `collection_id` (str): The ID of the Needle collection to search in.
- `client` (Optional[NeedleClient]): An optional instance of the NeedleClient.
- `top_k` (Optional[int]): Maximum number of results to return.
Usage:
.. code-block:: python
from langchain_community.retrievers.needle import NeedleRetriever
retriever = NeedleRetriever(
needle_api_key="your-api-key",
collection_id="your-collection-id",
top_k=10 # optional
)
results = retriever.retrieve("example query")
for doc in results:
print(doc.page_content)
"""
client: Optional[Any] = None
"""Optional instance of NeedleClient."""
needle_api_key: Optional[str] = Field(None, description="Needle API Key")
collection_id: Optional[str] = Field(
..., description="The ID of the Needle collection to search in"
)
top_k: Optional[int] = Field(
default=None, description="Maximum number of search results to return"
)
def _initialize_client(self) -> None:
"""
Initialize the NeedleClient with the provided API key.
If a client instance is already provided, this method does nothing.
"""
try:
from needle.v1 import NeedleClient
except ImportError:
raise ImportError("Please install with `pip install needle-python`.")
if not self.client:
self.client = NeedleClient(api_key=self.needle_api_key)
def _search_collection(self, query: str) -> List[Document]:
"""
Search the Needle collection for relevant documents.
Args:
query (str): The search query used to find relevant documents.
Returns:
List[Document]: A list of documents matching the search query.
"""
self._initialize_client()
if self.client is None:
raise ValueError("NeedleClient is not initialized. Provide an API key.")
results = self.client.collections.search(
collection_id=self.collection_id, text=query, top_k=self.top_k
)
docs = [Document(page_content=result.content) for result in results]
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Retrieve relevant documents based on the query.
Args:
query (str): The query string used to search the collection.
Returns:
List[Document]: A list of documents relevant to the query.
"""
# The `run_manager` parameter is included to match the superclass signature,
# but it is not used in this implementation.
return self._search_collection(query)

View File

@@ -0,0 +1,20 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities.outline import OutlineAPIWrapper
class OutlineRetriever(BaseRetriever, OutlineAPIWrapper):
"""Retriever for Outline API.
It wraps run() to get_relevant_documents().
It uses all OutlineAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.run(query=query)

View File

@@ -0,0 +1,185 @@
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
import hashlib
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import pre_init
from pydantic import ConfigDict
def hash_text(text: str) -> str:
"""Hash a text using SHA256.
Args:
text: Text to hash.
Returns:
Hashed text.
"""
return str(hashlib.sha256(text.encode("utf-8")).hexdigest())
def create_index(
contexts: List[str],
index: Any,
embeddings: Embeddings,
sparse_encoder: Any,
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
namespace: Optional[str] = None,
text_key: str = "context",
) -> None:
"""Create an index from a list of contexts.
It modifies the index argument in-place!
Args:
contexts: List of contexts to embed.
index: Index to use.
embeddings: Embeddings model to use.
sparse_encoder: Sparse encoder to use.
ids: List of ids to use for the documents.
metadatas: List of metadata to use for the documents.
namespace: Namespace value for index partition.
"""
batch_size = 32
_iterator = range(0, len(contexts), batch_size)
try:
from tqdm.auto import tqdm
_iterator = tqdm(_iterator)
except ImportError:
pass
if ids is None:
# create unique ids using hash of the text
ids = [hash_text(context) for context in contexts]
for i in _iterator:
# find end of batch
i_end = min(i + batch_size, len(contexts))
# extract batch
context_batch = contexts[i:i_end]
batch_ids = ids[i:i_end]
metadata_batch = (
metadatas[i:i_end] if metadatas else [{} for _ in context_batch]
)
# add context passages as metadata
meta = [
{text_key: context, **metadata}
for context, metadata in zip(context_batch, metadata_batch)
]
# create dense vectors
dense_embeds = embeddings.embed_documents(context_batch)
# create sparse vectors
sparse_embeds = sparse_encoder.encode_documents(context_batch)
for s in sparse_embeds:
s["values"] = [float(s1) for s1 in s["values"]]
vectors = []
# loop through the data and create dictionaries for upserts
for doc_id, sparse, dense, metadata in zip(
batch_ids, sparse_embeds, dense_embeds, meta
):
vectors.append(
{
"id": doc_id,
"sparse_values": sparse,
"values": dense,
"metadata": metadata,
}
)
# upload the documents to the new hybrid index
index.upsert(vectors, namespace=namespace)
class PineconeHybridSearchRetriever(BaseRetriever):
"""`Pinecone Hybrid Search` retriever."""
embeddings: Embeddings
"""Embeddings model to use."""
"""description"""
sparse_encoder: Any = None
"""Sparse encoder to use."""
index: Any = None
"""Pinecone index to use."""
top_k: int = 4
"""Number of documents to return."""
alpha: float = 0.5
"""Alpha value for hybrid search."""
namespace: Optional[str] = None
"""Namespace value for index partition."""
text_key: str = "context"
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def add_texts(
self,
texts: List[str],
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
namespace: Optional[str] = None,
) -> None:
create_index(
texts,
self.index,
self.embeddings,
self.sparse_encoder,
ids=ids,
metadatas=metadatas,
namespace=namespace,
text_key=self.text_key,
)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from pinecone_text.hybrid import hybrid_convex_scale # noqa:F401
from pinecone_text.sparse.base_sparse_encoder import (
BaseSparseEncoder, # noqa:F401
)
except ImportError:
raise ImportError(
"Could not import pinecone_text python package. "
"Please install it with `pip install pinecone_text`."
)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale
sparse_vec = self.sparse_encoder.encode_queries(query)
# convert the question into a dense vector
dense_vec = self.embeddings.embed_query(query)
# scale alpha with hybrid_scale
dense_vec, sparse_vec = hybrid_convex_scale(dense_vec, sparse_vec, self.alpha)
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
# query pinecone with the query parameters
result = self.index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=self.top_k,
include_metadata=True,
namespace=self.namespace,
**kwargs,
)
final_result = []
for res in result["matches"]:
context = res["metadata"].pop(self.text_key)
metadata = res["metadata"]
if "score" not in metadata and "score" in res:
metadata["score"] = res["score"]
final_result.append(Document(page_content=context, metadata=metadata))
# return search results as json
return final_result

View File

@@ -0,0 +1,20 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities.pubmed import PubMedAPIWrapper
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
"""`PubMed API` retriever.
It wraps load() to get_relevant_documents().
It uses all PubMedAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load_docs(query=query)

View File

@@ -0,0 +1,5 @@
from langchain_community.retrievers.pubmed import PubMedRetriever
__all__ = [
"PubMedRetriever",
]

View File

@@ -0,0 +1,220 @@
import uuid
from itertools import islice
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
cast,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import pre_init
from pydantic import ConfigDict
from langchain_community.vectorstores.qdrant import Qdrant, QdrantException
@deprecated(
since="0.2.16",
alternative=(
"Qdrant vector store now supports sparse retrievals natively. "
"Use langchain_qdrant.QdrantVectorStore#as_retriever() instead. "
"Reference: "
"https://python.langchain.com/docs/integrations/vectorstores/qdrant/#sparse-vector-search"
),
removal="0.5.0",
)
class QdrantSparseVectorRetriever(BaseRetriever):
"""Qdrant sparse vector retriever."""
client: Any = None
"""'qdrant_client' instance to use."""
collection_name: str
"""Qdrant collection name."""
sparse_vector_name: str
"""Name of the sparse vector to use."""
sparse_encoder: Callable[[str], Tuple[List[int], List[float]]]
"""Sparse encoder function to use."""
k: int = 4
"""Number of documents to return per query. Defaults to 4."""
filter: Optional[Any] = None
"""Qdrant qdrant_client.models.Filter to use for queries. Defaults to None."""
content_payload_key: str = "content"
"""Payload field containing the document content. Defaults to 'content'"""
metadata_payload_key: str = "metadata"
"""Payload field containing the document metadata. Defaults to 'metadata'."""
search_options: Dict[str, Any] = {}
"""Additional search options to pass to qdrant_client.QdrantClient.search()."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that 'qdrant_client' python package exists in environment."""
try:
from grpc import RpcError
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
except ImportError:
raise ImportError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
client = values["client"]
if not isinstance(client, QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
filter = values["filter"]
if filter is not None and not isinstance(filter, models.Filter):
raise ValueError(
f"filter should be an instance of qdrant_client.models.Filter, "
f"got {type(filter)}"
)
client = cast(QdrantClient, client)
collection_name = values["collection_name"]
sparse_vector_name = values["sparse_vector_name"]
try:
collection_info = client.get_collection(collection_name)
sparse_vectors_config = collection_info.config.params.sparse_vectors
if sparse_vector_name not in sparse_vectors_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain sparse vector named {sparse_vector_name}."
f"Did you mean one of {', '.join(sparse_vectors_config.keys())}?"
)
except (UnexpectedResponse, RpcError, ValueError):
raise QdrantException(
f"Qdrant collection {collection_name} does not exist."
)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from qdrant_client import QdrantClient, models
client = cast(QdrantClient, self.client)
query_indices, query_values = self.sparse_encoder(query)
results = client.search(
self.collection_name,
query_filter=self.filter,
query_vector=models.NamedSparseVector(
name=self.sparse_vector_name,
vector=models.SparseVector(
indices=query_indices,
values=query_values,
),
),
limit=self.k,
with_vectors=False,
**self.search_options,
)
return [
Qdrant._document_from_scored_point(
point,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
)
for point in results
]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
documents (List[Document]: Documents to add to the vectorstore.
Returns:
List[str]: List of IDs of the added texts.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
from qdrant_client import QdrantClient
added_ids = []
client = cast(QdrantClient, self.client)
for batch_ids, points in self._generate_rest_batches(
texts, metadatas, ids, batch_size
):
client.upsert(self.collection_name, points=points, **kwargs)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
) -> Generator[Tuple[List[str], List[Any]], None, None]:
from qdrant_client import models as rest
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the sparse embeddings for all the texts in a batch
batch_embeddings: List[Tuple[List[int], List[float]]] = [
self.sparse_encoder(text) for text in batch_texts
]
points = [
rest.PointStruct(
id=point_id,
vector={
self.sparse_vector_name: rest.SparseVector(
indices=sparse_vector[0],
values=sparse_vector[1],
)
},
payload=payload,
)
for point_id, sparse_vector, payload in zip(
batch_ids,
batch_embeddings,
Qdrant._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
)
]
yield batch_ids, points

View File

@@ -0,0 +1,20 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
class RememberizerRetriever(BaseRetriever, RememberizerAPIWrapper):
"""`Rememberizer` retriever.
It wraps load() to get_relevant_documents().
It uses all RememberizerAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)

View File

@@ -0,0 +1,56 @@
from typing import List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class RemoteLangChainRetriever(BaseRetriever):
"""`LangChain API` retriever."""
url: str
"""URL of the remote LangChain API."""
headers: Optional[dict] = None
"""Headers to use for the request."""
input_key: str = "message"
"""Key to use for the input in the request."""
response_key: str = "response"
"""Key to use for the response in the request."""
page_content_key: str = "page_content"
"""Key to use for the page content in the response."""
metadata_key: str = "metadata"
"""Key to use for the metadata in the response."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = requests.post(
self.url, json={self.input_key: query}, headers=self.headers
)
result = response.json()
return [
Document(
page_content=r[self.page_content_key], metadata=r[self.metadata_key]
)
for r in result[self.response_key]
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST", self.url, headers=self.headers, json={self.input_key: query}
) as response:
result = await response.json()
return [
Document(
page_content=r[self.page_content_key], metadata=r[self.metadata_key]
)
for r in result[self.response_key]
]

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
import concurrent.futures
from typing import Any, Iterable, List, Optional
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
"""
Create an index of embeddings for a list of contexts.
Args:
contexts: List of contexts to embed.
embeddings: Embeddings model to use.
Returns:
Index of embeddings.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
return np.array(list(executor.map(embeddings.embed_query, contexts)))
class SVMRetriever(BaseRetriever):
"""`SVM` retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
"""
embeddings: Embeddings
"""Embeddings model to use."""
index: Any = None
"""Index of embeddings."""
texts: List[str]
"""List of texts to index."""
metadatas: Optional[List[dict]] = None
"""List of metadatas corresponding with each text."""
k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: List[str],
embeddings: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> SVMRetriever:
index = create_index(texts, embeddings)
return cls(
embeddings=embeddings,
index=index,
texts=texts,
metadatas=metadatas,
**kwargs,
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
embeddings: Embeddings,
**kwargs: Any,
) -> SVMRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, embeddings=embeddings, metadatas=metadatas, **kwargs
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
try:
from sklearn import svm
except ImportError:
raise ImportError(
"Could not import scikit-learn, please install with `pip install "
"scikit-learn`."
)
query_embeds = np.array(self.embeddings.embed_query(query))
x = np.concatenate([query_embeds[None, ...], self.index])
y = np.zeros(x.shape[0])
y[0] = 1
clf = svm.LinearSVC(
class_weight="balanced", verbose=False, max_iter=10000, tol=1e-6, C=0.1
)
clf.fit(x, y)
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
# svm.LinearSVC in scikit-learn is non-deterministic.
# if a text is the same as a query, there is no guarantee
# the query will be in the first index.
# this performs a simple swap, this works because anything
# left of the 0 should be equivalent.
zero_index = np.where(sorted_ix == 0)[0][0]
if zero_index != 0:
sorted_ix[0], sorted_ix[zero_index] = sorted_ix[zero_index], sorted_ix[0]
denominator = np.max(similarities) - np.min(similarities) + 1e-6
normalized_similarities = (similarities - np.min(similarities)) / denominator
top_k_results = []
for row in sorted_ix[1 : self.k + 1]:
if (
self.relevancy_threshold is None
or normalized_similarities[row] >= self.relevancy_threshold
):
metadata = self.metadatas[row - 1] if self.metadatas else {}
doc = Document(page_content=self.texts[row - 1], metadata=metadata)
top_k_results.append(doc)
return top_k_results

View File

@@ -0,0 +1,152 @@
import os
from enum import Enum
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class SearchDepth(Enum):
"""Search depth as enumerator."""
BASIC = "basic"
ADVANCED = "advanced"
class TavilySearchAPIRetriever(BaseRetriever):
"""Tavily Search API retriever.
Setup:
Install ``langchain-community`` and set environment variable ``TAVILY_API_KEY``.
.. code-block:: bash
pip install -U langchain-community
export TAVILY_API_KEY="your-api-key"
Key init args:
k: int
Number of results to include.
include_generated_answer: bool
Include a generated answer with results
include_raw_content: bool
Include raw content with results.
include_images: bool
Return images in addition to text.
Instantiate:
.. code-block:: python
from langchain_community.retrievers import TavilySearchAPIRetriever
retriever = TavilySearchAPIRetriever(k=3)
Usage:
.. code-block:: python
query = "what year was breath of the wild released?"
retriever.invoke(query)
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke("how many units did bretch of the wild sell in 2020")
""" # noqa: E501
k: int = 10
include_generated_answer: bool = False
include_raw_content: bool = False
include_images: bool = False
search_depth: SearchDepth = SearchDepth.BASIC
include_domains: Optional[List[str]] = None
exclude_domains: Optional[List[str]] = None
kwargs: Optional[Dict[str, Any]] = {}
api_key: Optional[str] = None
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
try:
try:
from tavily import TavilyClient
except ImportError:
# Older of tavily used Client
from tavily import Client as TavilyClient
except ImportError:
raise ImportError(
"Tavily python package not found. "
"Please install it with `pip install tavily-python`."
)
tavily = TavilyClient(api_key=self.api_key or os.environ["TAVILY_API_KEY"])
max_results = self.k if not self.include_generated_answer else self.k - 1
response = tavily.search(
query=query,
max_results=max_results,
search_depth=self.search_depth.value,
include_answer=self.include_generated_answer,
include_domains=self.include_domains,
exclude_domains=self.exclude_domains,
include_raw_content=self.include_raw_content,
include_images=self.include_images,
**self.kwargs,
)
docs = [
Document(
page_content=result.get("content", "")
if not self.include_raw_content
else (result.get("raw_content") or ""),
metadata={
"title": result.get("title", ""),
"source": result.get("url", ""),
**{
k: v
for k, v in result.items()
if k not in ("content", "title", "url", "raw_content")
},
"images": response.get("images"),
},
)
for result in response.get("results")
]
if self.include_generated_answer:
docs = [
Document(
page_content=response.get("answer", ""),
metadata={
"title": "Suggested Answer",
"source": "https://tavily.com/",
},
),
*docs,
]
return docs

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict
class TFIDFRetriever(BaseRetriever):
"""`TF-IDF` retriever.
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb
"""
vectorizer: Any = None
"""TF-IDF vectorizer."""
docs: List[Document]
"""Documents."""
tfidf_array: Any = None
"""TF-IDF array."""
k: int = 4
"""Number of documents to return."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TFIDFRetriever:
try:
from sklearn.feature_extraction.text import TfidfVectorizer
except ImportError:
raise ImportError(
"Could not import scikit-learn, please install with `pip install "
"scikit-learn`."
)
tfidf_params = tfidf_params or {}
vectorizer = TfidfVectorizer(**tfidf_params)
tfidf_array = vectorizer.fit_transform(texts)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array, **kwargs)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TFIDFRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform(
[query]
) # Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
results = cosine_similarity(self.tfidf_array, query_vec).reshape(
(-1,)
) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
def save_local(
self,
folder_path: str,
file_name: str = "tfidf_vectorizer",
) -> None:
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
# Save vectorizer with joblib dump.
joblib.dump(self.vectorizer, path / f"{file_name}.joblib")
# Save docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "wb") as f:
pickle.dump((self.docs, self.tfidf_array), f)
@classmethod
def load_local(
cls,
folder_path: str,
*,
allow_dangerous_deserialization: bool = False,
file_name: str = "tfidf_vectorizer",
) -> TFIDFRetriever:
"""Load the retriever from local storage.
Args:
folder_path: Folder path to load from.
allow_dangerous_deserialization: Whether to allow dangerous deserialization.
Defaults to False.
The deserialization relies on .joblib and .pkl files, which can be
modified to deliver a malicious payload that results in execution of
arbitrary code on your machine. You will need to set this to `True` to
use deserialization. If you do this, make sure you trust the source of
the file.
file_name: File name to load from. Defaults to "tfidf_vectorizer".
Returns:
TFIDFRetriever: Loaded retriever.
"""
try:
import joblib
except ImportError:
raise ImportError(
"Could not import joblib, please install with `pip install joblib`."
)
if not allow_dangerous_deserialization:
raise ValueError(
"The de-serialization of this retriever is based on .joblib and "
".pkl files."
"Such files can be modified to deliver a malicious payload that "
"results in execution of arbitrary code on your machine."
"You will need to set `allow_dangerous_deserialization` to `True` to "
"load this retriever. If you do this, make sure you trust the source "
"of the file, and you are responsible for validating the file "
"came from a trusted source."
)
path = Path(folder_path)
# Load vectorizer with joblib load.
vectorizer = joblib.load(path / f"{file_name}.joblib")
# Load docs and tfidf array as pickle.
with open(path / f"{file_name}.pkl", "rb") as f:
# This code path can only be triggered if the user
# passed allow_dangerous_deserialization=True
docs, tfidf_array = pickle.load(f) # ignore[pickle]: explicit-opt-in
return cls(vectorizer=vectorizer, docs=docs, tfidf_array=tfidf_array)

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
import importlib
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import ConfigDict, SecretStr
class NeuralDBRetriever(BaseRetriever):
"""Document retriever that uses ThirdAI's NeuralDB."""
thirdai_key: SecretStr
"""ThirdAI API Key"""
db: Any = None #: :meta private:
"""NeuralDB instance"""
model_config = ConfigDict(
extra="forbid",
)
@staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None) -> None:
try:
from thirdai import licensing
importlib.util.find_spec("thirdai.neural_db")
licensing.activate(thirdai_key or os.getenv("THIRDAI_KEY"))
except ImportError:
raise ImportError(
"Could not import thirdai python package and neuraldb dependencies. "
"Please install it with `pip install thirdai[neural_db]`."
)
@classmethod
def from_scratch(
cls,
thirdai_key: Optional[str] = None,
**model_kwargs: dict,
) -> NeuralDBRetriever:
"""
Create a NeuralDBRetriever from scratch.
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
API key, or pass ``thirdai_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import NeuralDBRetriever
retriever = NeuralDBRetriever.from_scratch(
thirdai_key="your-thirdai-key",
)
retriever.insert([
"/path/to/doc.pdf",
"/path/to/doc.docx",
"/path/to/doc.csv",
])
documents = retriever.invoke("AI-driven music therapy")
"""
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB(**model_kwargs)) # type: ignore[arg-type]
@classmethod
def from_checkpoint(
cls,
checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None,
) -> NeuralDBRetriever:
"""
Create a NeuralDBRetriever with a base model from a saved checkpoint
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
API key, or pass ``thirdai_key`` as a named parameter.
Example:
.. code-block:: python
from langchain_community.retrievers import NeuralDBRetriever
retriever = NeuralDBRetriever.from_checkpoint(
checkpoint="/path/to/checkpoint.ndb",
thirdai_key="your-thirdai-key",
)
retriever.insert([
"/path/to/doc.pdf",
"/path/to/doc.docx",
"/path/to/doc.csv",
])
documents = retriever.invoke("AI-driven music therapy")
"""
NeuralDBRetriever._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(thirdai_key=thirdai_key, db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[arg-type]
@pre_init
def validate_environments(cls, values: Dict) -> Dict:
"""Validate ThirdAI environment variables."""
values["thirdai_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"thirdai_key",
"THIRDAI_KEY",
)
)
return values
def insert(
self,
sources: List[Any],
train: bool = True,
fast_mode: bool = True,
**kwargs: dict,
) -> None:
"""Inserts files / document sources into the retriever.
Args:
train: When True this means that the underlying model in the
NeuralDB will undergo unsupervised pretraining on the inserted files.
Defaults to True.
fast_mode: Much faster insertion with a slight drop in performance.
Defaults to True.
"""
sources = self._preprocess_sources(sources)
self.db.insert(
sources=sources,
train=train,
fast_approximation=fast_mode,
**kwargs,
)
def _preprocess_sources(self, sources: list) -> list:
"""Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects.
Args:
sources: list of either string paths to PDF, DOCX or CSV files, or
NeuralDB document objects.
"""
from thirdai import neural_db as ndb
if not sources:
return sources
preprocessed_sources = []
for doc in sources:
if not isinstance(doc, str):
preprocessed_sources.append(doc)
else:
if doc.lower().endswith(".pdf"):
preprocessed_sources.append(ndb.PDF(doc))
elif doc.lower().endswith(".docx"):
preprocessed_sources.append(ndb.DOCX(doc))
elif doc.lower().endswith(".csv"):
preprocessed_sources.append(ndb.CSV(doc))
else:
raise RuntimeError(
f"Could not automatically load {doc}. Only files "
"with .pdf, .docx, or .csv extensions can be loaded "
"automatically. For other formats, please use the "
"appropriate document object from the ThirdAI library."
)
return preprocessed_sources
def upvote(self, query: str, document_id: int) -> None:
"""The retriever upweights the score of a document for a specific query.
This is useful for fine-tuning the retriever to user behavior.
Args:
query: text to associate with `document_id`
document_id: id of the document to associate query with.
"""
self.db.text_to_result(query, document_id)
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]) -> None:
"""Given a batch of (query, document id) pairs, the retriever upweights
the scores of the document for the corresponding queries.
This is useful for fine-tuning the retriever to user behavior.
Args:
query_id_pairs: list of (query, document id) pairs. For each pair in
this list, the model will upweight the document id for the query.
"""
self.db.text_to_result_batch(query_id_pairs)
def associate(self, source: str, target: str) -> None:
"""The retriever associates a source phrase with a target phrase.
When the retriever sees the source phrase, it will also consider results
that are relevant to the target phrase.
Args:
source: text to associate to `target`.
target: text to associate `source` to.
"""
self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]) -> None:
"""Given a batch of (source, target) pairs, the retriever associates
each source phrase with the corresponding target phrase.
Args:
text_pairs: list of (source, target) text pairs. For each pair in
this list, the source will be associated with the target.
"""
self.db.associate_batch(text_pairs)
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {top_k} contexts with your retriever for a given query
Args:
query: Query to submit to the model
top_k: The max number of context results to retrieve. Defaults to 10.
"""
try:
if "top_k" not in kwargs:
kwargs["top_k"] = 10
references = self.db.search(query=query, **kwargs)
return [
Document(
page_content=ref.text,
metadata={
"id": ref.id,
"upvote_ids": ref.upvote_ids,
"source": ref.source,
"metadata": ref.metadata,
"score": ref.score,
"context": ref.context(1),
},
)
for ref in references
]
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str) -> None:
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path)
Args:
path: path on disk to save the NeuralDB instance to.
"""
self.db.save(path)

View File

@@ -0,0 +1,126 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class VespaRetriever(BaseRetriever):
"""`Vespa` retriever."""
app: Any
"""Vespa application to query."""
body: Dict
"""Body of the query."""
content_field: str
"""Name of the content field."""
metadata_fields: Sequence[str]
"""Names of the metadata fields."""
def _query(self, body: Dict) -> List[Document]:
response = self.app.query(body)
if not str(response.status_code).startswith("2"):
raise RuntimeError(
"Could not retrieve data from Vespa. Error code: {}".format(
response.status_code
)
)
root = response.json["root"]
if "errors" in root:
raise RuntimeError(json.dumps(root["errors"]))
docs = []
for child in response.hits:
page_content = child["fields"].pop(self.content_field, "")
if self.metadata_fields == "*":
metadata = child["fields"]
else:
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
metadata["id"] = child["id"]
docs.append(Document(page_content=page_content, metadata=metadata))
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
body = self.body.copy()
body["query"] = query
return self._query(body)
def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None
) -> List[Document]:
body = self.body.copy()
_filter = f" and {_filter}" if _filter else ""
body["yql"] = body["yql"] + _filter
body["query"] = query
return self._query(body)
@classmethod
def from_params(
cls,
url: str,
content_field: str,
*,
k: Optional[int] = None,
metadata_fields: Union[Sequence[str], Literal["*"]] = (),
sources: Union[Sequence[str], Literal["*"], None] = None,
_filter: Optional[str] = None,
yql: Optional[str] = None,
**kwargs: Any,
) -> VespaRetriever:
"""Instantiate retriever from params.
Args:
url (str): Vespa app URL.
content_field (str): Field in results to return as Document page_content.
k (Optional[int]): Number of Documents to return. Defaults to None.
metadata_fields(Sequence[str] or "*"): Fields in results to include in
document metadata. Defaults to empty tuple ().
sources (Sequence[str] or "*" or None): Sources to retrieve
from. Defaults to None.
_filter (Optional[str]): Document filter condition expressed in YQL.
Defaults to None.
yql (Optional[str]): Full YQL query to be used. Should not be specified
if _filter or sources are specified. Defaults to None.
kwargs (Any): Keyword arguments added to query body.
Returns:
VespaRetriever: Instantiated VespaRetriever.
"""
try:
from vespa.application import Vespa
except ImportError:
raise ImportError(
"pyvespa is not installed, please install with `pip install pyvespa`"
)
app = Vespa(url)
body = kwargs.copy()
if yql and (sources or _filter):
raise ValueError(
"yql should only be specified if both sources and _filter are not "
"specified."
)
else:
if metadata_fields == "*":
_fields = "*"
body["summary"] = "short"
else:
_fields = ", ".join([content_field] + list(metadata_fields or []))
_sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
_filter = f" and {_filter}" if _filter else ""
yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
body["yql"] = yql
if k:
body["hits"] = k
return cls(
app=app,
body=body,
content_field=content_field,
metadata_fields=metadata_fields,
)

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, cast
from uuid import uuid4
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, model_validator
@deprecated(
since="0.3.18",
removal="1.0",
alternative_import="langchain_weaviate.WeaviateVectorStore",
)
class WeaviateHybridSearchRetriever(BaseRetriever):
"""`Weaviate hybrid search` retriever.
See the documentation:
https://weaviate.io/blog/hybrid-search-explained
"""
client: Any = None
"""keyword arguments to pass to the Weaviate client."""
index_name: str
"""The name of the index to use."""
text_key: str
"""The name of the text key to use."""
alpha: float = 0.5
"""The weight of the text key in the hybrid search."""
k: int = 4
"""The number of results to return."""
attributes: List[str]
"""The attributes to return in the results."""
create_schema_if_missing: bool = True
"""Whether to create the schema if it doesn't exist."""
@model_validator(mode="before")
@classmethod
def validate_client(
cls,
values: Dict[str, Any],
) -> Any:
try:
import weaviate
except ImportError:
raise ImportError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(values["client"], weaviate.Client):
client = values["client"]
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
if values.get("attributes") is None:
values["attributes"] = []
cast(List, values["attributes"]).append(values["text_key"])
if values.get("create_schema_if_missing", True):
class_obj = {
"class": values["index_name"],
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
"vectorizer": "text2vec-openai",
}
if not values["client"].schema.exists(values["index_name"]):
values["client"].schema.create_class(class_obj)
return values
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# added text_key
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
with self.client.batch as batch:
ids = []
for i, doc in enumerate(docs):
metadata = doc.metadata or {}
data_properties = {self.text_key: doc.page_content, **metadata}
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self.index_name, _id)
ids.append(_id)
return ids
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
score: bool = False,
hybrid_search_kwargs: Optional[Dict[str, object]] = None,
) -> List[Document]:
"""Look up similar documents in Weaviate.
query: The query to search for relevant documents
of using weviate hybrid search.
where_filter: A filter to apply to the query.
https://weaviate.io/developers/weaviate/guides/querying/#filtering
score: Whether to include the score, and score explanation
in the returned Documents meta_data.
hybrid_search_kwargs: Used to pass additional arguments
to the .with_hybrid() method.
The primary uses cases for this are:
1) Search specific properties only -
specify which properties to be used during hybrid search portion.
Note: this is not the same as the (self.attributes) to be returned.
Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
2) Weight boosted searched properties -
Boost the weight of certain properties during the hybrid search portion.
Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
3) Search with a custom vector - Define a different vector
to be used during the hybrid search portion.
Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
4) Use Fusion ranking method
Example - from weaviate.gql.get import HybridFusion
hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
"""
query_obj = self.client.query.get(self.index_name, self.attributes)
if where_filter:
query_obj = query_obj.with_where(where_filter)
if score:
query_obj = query_obj.with_additional(["score", "explainScore"])
if hybrid_search_kwargs is None:
hybrid_search_kwargs = {}
result = (
query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
.with_limit(self.k)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self.index_name]:
text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res))
return docs

View File

@@ -0,0 +1,267 @@
import logging
import re
from typing import Any, List, Optional
from langchain_classic.chains import LLMChain
from langchain_classic.chains.prompt_selector import ConditionalPromptSelector
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLLM
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from pydantic import BaseModel, Field
from langchain_community.document_loaders import AsyncHtmlLoader
from langchain_community.document_transformers import Html2TextTransformer
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import GoogleSearchAPIWrapper
logger = logging.getLogger(__name__)
class SearchQueries(BaseModel):
"""Search queries to research for the user's goal."""
queries: List[str] = Field(
..., description="List of search queries to look up on Google"
)
DEFAULT_LLAMA_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""<<SYS>> \n You are an assistant tasked with improving Google search \
results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \
are similar to this question. The output should be a numbered list of questions \
and each should have a question mark at the end: \n\n {question} [/INST]""",
)
DEFAULT_SEARCH_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an assistant tasked with improving Google search \
results. Generate THREE Google search queries that are similar to \
this question. The output should be a numbered list of questions and each \
should have a question mark at the end: {question}""",
)
class QuestionListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of numbered questions."""
def parse(self, text: str) -> List[str]:
lines = re.findall(r"\d+\..*?(?:\n|$)", text)
return lines
class WebResearchRetriever(BaseRetriever):
"""`Google Search API` retriever."""
# Inputs
vectorstore: VectorStore = Field(
..., description="Vector store for storing web pages"
)
llm_chain: LLMChain
search: GoogleSearchAPIWrapper = Field(..., description="Google Search API Wrapper")
num_search_results: int = Field(1, description="Number of pages per Google search")
text_splitter: TextSplitter = Field(
RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50),
description="Text splitter for splitting web pages into chunks",
)
url_database: List[str] = Field(
default_factory=list, description="List of processed URLs"
)
trust_env: bool = Field(
False,
description="Whether to use the http_proxy/https_proxy env variables or "
"check .netrc for proxy configuration",
)
allow_dangerous_requests: bool = False
"""A flag to force users to acknowledge the risks of SSRF attacks when using
this retriever.
Users should set this flag to `True` if they have taken the necessary precautions
to prevent SSRF attacks when using this retriever.
For example, users can run the requests through a properly configured
proxy and prevent the crawler from accidentally crawling internal resources.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the retriever."""
allow_dangerous_requests = kwargs.get("allow_dangerous_requests", False)
if not allow_dangerous_requests:
raise ValueError(
"WebResearchRetriever crawls URLs surfaced through "
"the provided search engine. It is possible that some of those URLs "
"will end up pointing to machines residing on an internal network, "
"leading"
"to an SSRF (Server-Side Request Forgery) attack. "
"To protect yourself against that risk, you can run the requests "
"through a proxy and prevent the crawler from accidentally crawling "
"internal resources."
"If've taken the necessary precautions, you can set "
"`allow_dangerous_requests` to `True`."
)
super().__init__(**kwargs)
@classmethod
def from_llm(
cls,
vectorstore: VectorStore,
llm: BaseLLM,
search: GoogleSearchAPIWrapper,
prompt: Optional[BasePromptTemplate] = None,
num_search_results: int = 1,
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=150
),
trust_env: bool = False,
allow_dangerous_requests: bool = False,
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
Args:
vectorstore: Vector store for storing web pages
llm: llm for search question generation
search: GoogleSearchAPIWrapper
prompt: prompt to generating search questions
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
trust_env: Whether to use the http_proxy/https_proxy env variables
or check .netrc for proxy configuration
allow_dangerous_requests: A flag to force users to acknowledge
the risks of SSRF attacks when using this retriever
Returns:
WebResearchRetriever
"""
if not prompt:
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_SEARCH_PROMPT,
conditionals=[
(lambda llm: isinstance(llm, LlamaCpp), DEFAULT_LLAMA_SEARCH_PROMPT)
],
)
prompt = QUESTION_PROMPT_SELECTOR.get_prompt(llm)
# Use chat model prompt
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=QuestionListOutputParser(),
)
return cls(
vectorstore=vectorstore,
llm_chain=llm_chain,
search=search,
num_search_results=num_search_results,
text_splitter=text_splitter,
trust_env=trust_env,
allow_dangerous_requests=allow_dangerous_requests,
)
def clean_search_query(self, query: str) -> str:
# Some search tools (e.g., Google) will
# fail to return results if query has a
# leading digit: 1. "LangCh..."
# Check if the first character is a digit
if query[0].isdigit():
# Find the position of the first quote
first_quote_pos = query.find('"')
if first_quote_pos != -1:
# Extract the part of the string after the quote
query = query[first_quote_pos + 1 :]
# Remove the trailing quote if present
if query.endswith('"'):
query = query[:-1]
return query.strip()
def search_tool(self, query: str, num_search_results: int = 1) -> List[dict]:
"""Returns num_search_results pages per Google search."""
query_clean = self.clean_search_query(query)
result = self.search.results(query_clean, num_search_results)
return result
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Search Google for documents related to the query input.
Args:
query: user query
Returns:
Relevant documents from all various urls.
"""
# Get search questions
logger.info("Generating questions for Google Search ...")
result = self.llm_chain({"question": query})
logger.info(f"Questions for Google Search (raw): {result}")
questions = result["text"]
logger.info(f"Questions for Google Search: {questions}")
# Get urls
logger.info("Searching for relevant urls...")
urls_to_look = []
for query in questions:
# Google search
search_results = self.search_tool(query, self.num_search_results)
logger.info("Searching for relevant urls...")
logger.info(f"Search results: {search_results}")
for res in search_results:
if res.get("link", None):
urls_to_look.append(res["link"])
# Relevant urls
urls = set(urls_to_look)
# Check for any new urls that we have not processed
new_urls = list(urls.difference(self.url_database))
logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore
if new_urls:
loader = AsyncHtmlLoader(
new_urls, ignore_load_errors=True, trust_env=self.trust_env
)
html2text = Html2TextTransformer()
logger.info("Indexing new urls...")
docs = loader.load()
docs = list(html2text.transform_documents(docs))
docs = self.text_splitter.split_documents(docs)
self.vectorstore.add_documents(docs)
self.url_database.extend(new_urls)
# Search for relevant splits
# TODO: make this async
logger.info("Grabbing most relevant splits from urls...")
docs = []
for query in questions:
docs.extend(self.vectorstore.similarity_search(query))
# Get unique docs
unique_documents_dict = {
(doc.page_content, tuple(sorted(doc.metadata.items()))): doc for doc in docs
}
unique_documents = list(unique_documents_dict.values())
return unique_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError

View File

@@ -0,0 +1,77 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
"""`Wikipedia API` retriever.
Setup:
Install the ``wikipedia`` dependency:
.. code-block:: bash
pip install -U wikipedia
Instantiate:
.. code-block:: python
from langchain_community.retrievers import WikipediaRetriever
retriever = WikipediaRetriever()
Usage:
.. code-block:: python
docs = retriever.invoke("TOKYO GHOUL")
print(docs[0].page_content[:100])
.. code-block:: none
Tokyo Ghoul (Japanese: 東京喰種(トーキョーグール), Hepburn: Tōkyō Gūru) is a Japanese dark fantasy
Use within a chain:
.. code-block:: python
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
prompt = ChatPromptTemplate.from_template(
\"\"\"Answer the question based only on the context provided.
Context: {context}
Question: {question}\"\"\"
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
def format_docs(docs):
return "\\n\\n".join(doc.page_content for doc in docs)
chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
chain.invoke(
"Who is the main character in `Tokyo Ghoul` and does he transform into a ghoul?"
)
.. code-block:: none
'The main character in Tokyo Ghoul is Ken Kaneki, who transforms into a ghoul after receiving an organ transplant from a ghoul named Rize.'
""" # noqa: E501
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)

View File

@@ -0,0 +1,39 @@
from typing import Any, List
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities import YouSearchAPIWrapper
class YouRetriever(BaseRetriever, YouSearchAPIWrapper):
"""You.com Search API retriever.
It wraps results() to get_relevant_documents
It uses all YouSearchAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.results(query, run_manager=run_manager.get_child(), **kwargs)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
results = await self.results_async(
query, run_manager=run_manager.get_child(), **kwargs
)
return results

View File

@@ -0,0 +1,183 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import model_validator
if TYPE_CHECKING:
from zep_python.memory import MemorySearchResult
class SearchScope(str, Enum):
"""Which documents to search. Messages or Summaries?"""
messages = "messages"
"""Search chat history messages."""
summary = "summary"
"""Search chat history summaries."""
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
"""Similarity search."""
mmr = "mmr"
"""Maximal Marginal Relevance reranking of similarity search."""
class ZepRetriever(BaseRetriever):
"""`Zep` MemoryStore Retriever.
Search your user's long-term chat history with Zep.
Zep offers both simple semantic search and Maximal Marginal Relevance (MMR)
reranking of search results.
Note: You will need to provide the user's `session_id` to use this retriever.
Args:
url: URL of your Zep server (required)
api_key: Your Zep API key (optional)
session_id: Identifies your user or a user's session (required)
top_k: Number of documents to return (default: 3, optional)
search_type: Type of search to perform (similarity / mmr) (default: similarity,
optional)
mmr_lambda: Lambda value for MMR search. Defaults to 0.5 (optional)
Zep - Fast, scalable building blocks for LLM Apps
=========
Zep is an open source platform for productionizing LLM apps. Go from a prototype
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
rewriting code.
For server installation instructions, see:
https://docs.getzep.com/deployment/quickstart/
"""
zep_client: Optional[Any] = None
"""Zep client."""
url: str
"""URL of your Zep server."""
api_key: Optional[str] = None
"""Your Zep API key."""
session_id: str
"""Zep session ID."""
top_k: Optional[int]
"""Number of items to return."""
search_scope: SearchScope = SearchScope.messages
"""Which documents to search. Messages or Summaries?"""
search_type: SearchType = SearchType.similarity
"""Type of search to perform (similarity / mmr)"""
mmr_lambda: Optional[float] = None
"""Lambda value for MMR search."""
@model_validator(mode="before")
@classmethod
def create_client(cls, values: dict) -> Any:
try:
from zep_python import ZepClient
except ImportError:
raise ImportError(
"Could not import zep-python package. "
"Please install it with `pip install zep-python`."
)
values["zep_client"] = values.get(
"zep_client",
ZepClient(base_url=values["url"], api_key=values.get("api_key")),
)
return values
def _messages_search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [
Document(
page_content=r.message.pop("content"),
metadata={"score": r.dist, **r.message},
)
for r in results
if r.message
]
def _summary_search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [
Document(
page_content=r.summary.content,
metadata={
"score": r.dist,
"uuid": r.summary.uuid,
"created_at": r.summary.created_at,
"token_count": r.summary.token_count,
},
)
for r in results
if r.summary
]
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
metadata: Optional[Dict[str, Any]] = None,
) -> List[Document]:
from zep_python.memory import MemorySearchPayload
if not self.zep_client:
raise RuntimeError("Zep client not initialized.")
payload = MemorySearchPayload(
text=query,
metadata=metadata,
search_scope=self.search_scope,
search_type=self.search_type,
mmr_lambda=self.mmr_lambda,
)
results: List[MemorySearchResult] = self.zep_client.memory.search_memory(
self.session_id, payload, limit=self.top_k
)
if self.search_scope == SearchScope.summary:
return self._summary_search_result_to_doc(results)
return self._messages_search_result_to_doc(results)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
metadata: Optional[Dict[str, Any]] = None,
) -> List[Document]:
from zep_python.memory import MemorySearchPayload
if not self.zep_client:
raise RuntimeError("Zep client not initialized.")
payload = MemorySearchPayload(
text=query,
metadata=metadata,
search_scope=self.search_scope,
search_type=self.search_type,
mmr_lambda=self.mmr_lambda,
)
results: List[MemorySearchResult] = await self.zep_client.memory.asearch_memory(
self.session_id, payload, limit=self.top_k
)
if self.search_scope == SearchScope.summary:
return self._summary_search_result_to_doc(results)
return self._messages_search_result_to_doc(results)

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import model_validator
if TYPE_CHECKING:
from zep_cloud import MemorySearchResult, SearchScope, SearchType
from zep_cloud.client import AsyncZep, Zep
class ZepCloudRetriever(BaseRetriever):
"""`Zep Cloud` MemoryStore Retriever.
Search your user's long-term chat history with Zep.
Zep offers both simple semantic search and Maximal Marginal Relevance (MMR)
reranking of search results.
Note: You will need to provide the user's `session_id` to use this retriever.
Args:
api_key: Your Zep API key
session_id: Identifies your user or a user's session (required)
top_k: Number of documents to return (default: 3, optional)
search_type: Type of search to perform (similarity / mmr)
(default: similarity, optional)
mmr_lambda: Lambda value for MMR search. Defaults to 0.5 (optional)
Zep - Recall, understand, and extract data from chat histories.
Power personalized AI experiences.
=========
Zep is a long-term memory service for AI Assistant apps.
With Zep, you can provide AI assistants with the ability
to recall past conversations,
no matter how distant, while also reducing hallucinations, latency, and cost.
see Zep Cloud Docs: https://help.getzep.com
"""
api_key: str
"""Your Zep API key."""
zep_client: Zep
"""Zep client used for making API requests."""
zep_client_async: AsyncZep
"""Async Zep client used for making API requests."""
session_id: str
"""Zep session ID."""
top_k: Optional[int]
"""Number of items to return."""
search_scope: SearchScope = "messages"
"""Which documents to search. Messages or Summaries?"""
search_type: SearchType = "similarity"
"""Type of search to perform (similarity / mmr)"""
mmr_lambda: Optional[float] = None
"""Lambda value for MMR search."""
@model_validator(mode="before")
@classmethod
def create_client(cls, values: dict) -> Any:
try:
from zep_cloud.client import AsyncZep, Zep
except ImportError:
raise ImportError(
"Could not import zep-cloud package. "
"Please install it with `pip install zep-cloud`."
)
if values.get("api_key") is None:
raise ValueError("Zep API key is required.")
values["zep_client"] = Zep(api_key=values.get("api_key"))
values["zep_client_async"] = AsyncZep(api_key=values.get("api_key"))
return values
def _messages_search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [
Document(
page_content=str(r.message.content),
metadata={
"score": r.score,
"uuid": r.message.uuid_,
"created_at": r.message.created_at,
"token_count": r.message.token_count,
"role": r.message.role or r.message.role_type,
},
)
for r in results or []
if r.message
]
def _summary_search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [
Document(
page_content=str(r.summary.content),
metadata={
"score": r.score,
"uuid": r.summary.uuid_,
"created_at": r.summary.created_at,
"token_count": r.summary.token_count,
},
)
for r in results
if r.summary
]
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
metadata: Optional[Dict[str, Any]] = None,
) -> List[Document]:
if not self.zep_client:
raise RuntimeError("Zep client not initialized.")
results = self.zep_client.memory.search(
self.session_id,
text=query,
metadata=metadata,
search_scope=self.search_scope,
search_type=self.search_type,
mmr_lambda=self.mmr_lambda,
limit=self.top_k,
)
if self.search_scope == "summary":
return self._summary_search_result_to_doc(results)
return self._messages_search_result_to_doc(results)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
metadata: Optional[Dict[str, Any]] = None,
) -> List[Document]:
if not self.zep_client_async:
raise RuntimeError("Zep client not initialized.")
results = await self.zep_client_async.memory.search(
self.session_id,
text=query,
metadata=metadata,
search_scope=self.search_scope,
search_type=self.search_type,
mmr_lambda=self.mmr_lambda,
limit=self.top_k,
)
if self.search_scope == "summary":
return self._summary_search_result_to_doc(results)
return self._messages_search_result_to_doc(results)

View File

@@ -0,0 +1,87 @@
import warnings
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import model_validator
from langchain_community.vectorstores.zilliz import Zilliz
# TODO: Update to ZillizClient + Hybrid Search when available
class ZillizRetriever(BaseRetriever):
"""`Zilliz API` retriever."""
embedding_function: Embeddings
"""The underlying embedding function from which documents will be retrieved."""
collection_name: str = "LangChainCollection"
"""The name of the collection in Zilliz."""
connection_args: Optional[Dict[str, Any]] = None
"""The connection arguments for the Zilliz client."""
consistency_level: str = "Session"
"""The consistency level for the Zilliz client."""
search_params: Optional[dict] = None
"""The search parameters for the Zilliz client."""
store: Zilliz
"""The underlying Zilliz store."""
retriever: BaseRetriever
"""The underlying retriever."""
@model_validator(mode="before")
@classmethod
def create_client(cls, values: dict) -> Any:
values["store"] = Zilliz(
values["embedding_function"],
values["collection_name"],
values["connection_args"],
values["consistency_level"],
)
values["retriever"] = values["store"].as_retriever(
search_kwargs={"param": values["search_params"]}
)
return values
def add_texts(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> None:
"""Add text to the Zilliz store
Args:
texts (List[str]): The text
metadatas (List[dict]): Metadata dicts, must line up with existing store
"""
self.store.add_texts(texts, metadatas)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.retriever.invoke(
query, run_manager=run_manager.get_child(), **kwargs
)
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
"""Deprecated ZillizRetreiver.
Please use ZillizRetriever ('i' before 'e') instead.
Args:
*args:
**kwargs:
Returns:
ZillizRetriever
"""
warnings.warn(
"ZillizRetreiver will be deprecated in the future. "
"Please use ZillizRetriever ('i' before 'e') instead.",
DeprecationWarning,
)
return ZillizRetriever(*args, **kwargs)