initial commit
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
|
||||
services.
|
||||
|
||||
**Cross encoder models** can be LLMs or not.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.cross_encoders.base import (
|
||||
BaseCrossEncoder,
|
||||
)
|
||||
from langchain_community.cross_encoders.fake import (
|
||||
FakeCrossEncoder,
|
||||
)
|
||||
from langchain_community.cross_encoders.huggingface import (
|
||||
HuggingFaceCrossEncoder,
|
||||
)
|
||||
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
||||
SagemakerEndpointCrossEncoder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseCrossEncoder",
|
||||
"FakeCrossEncoder",
|
||||
"HuggingFaceCrossEncoder",
|
||||
"SagemakerEndpointCrossEncoder",
|
||||
]
|
||||
|
||||
_module_lookup = {
|
||||
"BaseCrossEncoder": "langchain_community.cross_encoders.base",
|
||||
"FakeCrossEncoder": "langchain_community.cross_encoders.fake",
|
||||
"HuggingFaceCrossEncoder": "langchain_community.cross_encoders.huggingface",
|
||||
"SagemakerEndpointCrossEncoder": "langchain_community.cross_encoders.sagemaker_endpoint", # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
from langchain_classic.retrievers.document_compressors.cross_encoder import (
|
||||
BaseCrossEncoder,
|
||||
)
|
||||
|
||||
__all__ = ["BaseCrossEncoder"]
|
||||
@@ -0,0 +1,18 @@
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
|
||||
class FakeCrossEncoder(BaseCrossEncoder, BaseModel):
|
||||
"""Fake cross encoder model."""
|
||||
|
||||
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||
scores = list(
|
||||
map(
|
||||
lambda pair: SequenceMatcher(None, pair[0], pair[1]).ratio(), text_pairs
|
||||
)
|
||||
)
|
||||
return scores
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
|
||||
|
||||
class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||
"""HuggingFace cross encoder models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||
|
||||
model_name = "BAAI/bge-reranker-base"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
hf = HuggingFaceCrossEncoder(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs
|
||||
)
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model_name: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.CrossEncoder(
|
||||
self.model_name, **self.model_kwargs
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="forbid", protected_namespaces=())
|
||||
|
||||
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||
"""Compute similarity scores using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text_pairs: The list of text text_pairs to score the similarity.
|
||||
|
||||
Returns:
|
||||
List of scores, one for each pair.
|
||||
"""
|
||||
scores = self.client.predict(text_pairs)
|
||||
# Some models e.g bert-multilingual-passage-reranking-msmarco
|
||||
# gives two score not_relevant and relevant as compare with the query.
|
||||
if len(scores.shape) > 1: # we are going to get the relevant scores
|
||||
scores = map(lambda x: x[1], scores)
|
||||
return scores
|
||||
@@ -0,0 +1,150 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||
|
||||
|
||||
class CrossEncoderContentHandler:
|
||||
"""Content handler for CrossEncoder class."""
|
||||
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
|
||||
def transform_input(self, text_pairs: List[Tuple[str, str]]) -> bytes:
|
||||
input_str = json.dumps({"text_pairs": text_pairs})
|
||||
return input_str.encode("utf-8")
|
||||
|
||||
def transform_output(self, output: Any) -> List[float]:
|
||||
response_json = json.loads(output.read().decode("utf-8"))
|
||||
scores = response_json["scores"]
|
||||
return scores
|
||||
|
||||
|
||||
class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||
"""SageMaker Inference CrossEncoder endpoint.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
Sagemaker model & the region where it is deployed.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Sagemaker endpoint.
|
||||
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from langchain_classic.embeddings import SagemakerEndpointCrossEncoder
|
||||
endpoint_name = (
|
||||
"my-endpoint-name"
|
||||
)
|
||||
region_name = (
|
||||
"us-west-2"
|
||||
)
|
||||
credentials_profile_name = (
|
||||
"default"
|
||||
)
|
||||
se = SagemakerEndpointCrossEncoder(
|
||||
endpoint_name=endpoint_name,
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name
|
||||
)
|
||||
"""
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
endpoint_name: str = ""
|
||||
"""The name of the endpoint from the deployed Sagemaker model.
|
||||
Must be unique within an AWS Region."""
|
||||
|
||||
region_name: str = ""
|
||||
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||
|
||||
credentials_profile_name: Optional[str] = None
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
content_handler: CrossEncoderContentHandler = CrossEncoderContentHandler()
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
endpoint_kwargs: Optional[Dict] = None
|
||||
"""Optional attributes passed to the invoke_endpoint
|
||||
function. See `boto3`_. docs for more info.
|
||||
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, extra="forbid", protected_namespaces=()
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
try:
|
||||
import boto3
|
||||
|
||||
try:
|
||||
if values.get("credentials_profile_name"):
|
||||
session = boto3.Session(
|
||||
profile_name=values["credentials_profile_name"]
|
||||
)
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
values["client"] = session.client(
|
||||
"sagemaker-runtime", region_name=values["region_name"]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
"profile name are valid."
|
||||
) from e
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
return values
|
||||
|
||||
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||
"""Call out to SageMaker Inference CrossEncoder endpoint."""
|
||||
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||
|
||||
body = self.content_handler.transform_input(text_pairs)
|
||||
content_type = self.content_handler.content_type
|
||||
accepts = self.content_handler.accepts
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
return self.content_handler.transform_output(response["Body"])
|
||||
Reference in New Issue
Block a user