initial commit
This commit is contained in:
917
venv/Lib/site-packages/langchain_community/llms/bedrock.py
Normal file
917
venv/Lib/site-packages/langchain_community/llms/bedrock.py
Normal file
@@ -0,0 +1,917 @@
|
||||
import asyncio
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
from langchain_community.utilities.anthropic import (
|
||||
get_num_tokens_anthropic,
|
||||
get_token_ids_anthropic,
|
||||
)
|
||||
|
||||
AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
|
||||
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
ALTERNATION_ERROR = (
|
||||
"Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'."
|
||||
)
|
||||
|
||||
|
||||
def _add_newlines_before_ha(input_text: str) -> str:
|
||||
new_text = input_text
|
||||
for word in ["Human:", "Assistant:"]:
|
||||
new_text = new_text.replace(word, "\n\n" + word)
|
||||
for i in range(2):
|
||||
new_text = new_text.replace("\n\n\n" + word, "\n\n" + word)
|
||||
return new_text
|
||||
|
||||
|
||||
def _human_assistant_format(input_text: str) -> str:
|
||||
if input_text.count("Human:") == 0 or (
|
||||
input_text.find("Human:") > input_text.find("Assistant:")
|
||||
and "Assistant:" in input_text
|
||||
):
|
||||
input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION
|
||||
if input_text.count("Assistant:") == 0:
|
||||
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
|
||||
if input_text[: len("Human:")] == "Human:":
|
||||
input_text = "\n\n" + input_text
|
||||
input_text = _add_newlines_before_ha(input_text)
|
||||
count = 0
|
||||
# track alternation
|
||||
for i in range(len(input_text)):
|
||||
if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT:
|
||||
if count % 2 == 0:
|
||||
count += 1
|
||||
else:
|
||||
warnings.warn(ALTERNATION_ERROR + f" Received {input_text}")
|
||||
if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT:
|
||||
if count % 2 == 1:
|
||||
count += 1
|
||||
else:
|
||||
warnings.warn(ALTERNATION_ERROR + f" Received {input_text}")
|
||||
|
||||
if count % 2 == 1: # Only saw Human, no Assistant
|
||||
input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION
|
||||
|
||||
return input_text
|
||||
|
||||
|
||||
def _stream_response_to_generation_chunk(
|
||||
stream_response: Dict[str, Any],
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
if not stream_response["delta"]:
|
||||
return GenerationChunk(text="")
|
||||
return GenerationChunk(
|
||||
text=stream_response["delta"]["text"],
|
||||
generation_info=dict(
|
||||
finish_reason=stream_response.get("stop_reason", None),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LLMInputOutputAdapter:
|
||||
"""Adapter class to prepare the inputs from Langchain to a format
|
||||
that LLM model expects.
|
||||
|
||||
It also provides helper function to extract
|
||||
the generated text from the model response."""
|
||||
|
||||
provider_to_output_key_map = {
|
||||
"anthropic": "completion",
|
||||
"amazon": "outputText",
|
||||
"cohere": "text",
|
||||
"meta": "generation",
|
||||
"mistral": "outputs",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def prepare_input(
|
||||
cls,
|
||||
provider: str,
|
||||
model_kwargs: Dict[str, Any],
|
||||
prompt: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
input_body = {**model_kwargs}
|
||||
if provider == "anthropic":
|
||||
if messages:
|
||||
input_body["anthropic_version"] = "bedrock-2023-05-31"
|
||||
input_body["messages"] = messages
|
||||
if system:
|
||||
input_body["system"] = system
|
||||
if "max_tokens" not in input_body:
|
||||
input_body["max_tokens"] = 1024
|
||||
if prompt:
|
||||
input_body["prompt"] = _human_assistant_format(prompt)
|
||||
if "max_tokens_to_sample" not in input_body:
|
||||
input_body["max_tokens_to_sample"] = 1024
|
||||
elif provider in ("ai21", "cohere", "meta", "mistral"):
|
||||
input_body["prompt"] = prompt
|
||||
elif provider == "amazon":
|
||||
input_body = dict()
|
||||
input_body["inputText"] = prompt
|
||||
input_body["textGenerationConfig"] = {**model_kwargs}
|
||||
else:
|
||||
input_body["inputText"] = prompt
|
||||
|
||||
return input_body
|
||||
|
||||
@classmethod
|
||||
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||
text = ""
|
||||
if provider == "anthropic":
|
||||
response_body = json.loads(response.get("body").read().decode())
|
||||
if "completion" in response_body:
|
||||
text = response_body.get("completion")
|
||||
elif "content" in response_body:
|
||||
content = response_body.get("content")
|
||||
text = content[0].get("text")
|
||||
else:
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
if provider == "ai21":
|
||||
text = response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "cohere":
|
||||
text = response_body.get("generations")[0].get("text")
|
||||
elif provider == "meta":
|
||||
text = response_body.get("generation")
|
||||
elif provider == "mistral":
|
||||
text = response_body.get("outputs")[0].get("text")
|
||||
else:
|
||||
text = response_body.get("results")[0].get("outputText")
|
||||
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
return {
|
||||
"text": text,
|
||||
"body": response_body,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def prepare_output_stream(
|
||||
cls,
|
||||
provider: str,
|
||||
response: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
messages_api: bool = False,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
stream = response.get("body")
|
||||
|
||||
if not stream:
|
||||
return
|
||||
|
||||
if messages_api:
|
||||
output_key = "message"
|
||||
else:
|
||||
output_key = cls.provider_to_output_key_map.get(provider, "")
|
||||
|
||||
if not output_key:
|
||||
raise ValueError(
|
||||
f"Unknown streaming response output key for provider: {provider}"
|
||||
)
|
||||
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
if provider == "cohere" and (
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
elif (
|
||||
provider == "mistral"
|
||||
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
|
||||
):
|
||||
return
|
||||
|
||||
elif messages_api and (chunk_obj.get("type") == "content_block_stop"):
|
||||
return
|
||||
|
||||
if messages_api and chunk_obj.get("type") in (
|
||||
"message_start",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
):
|
||||
if chunk_obj.get("type") == "content_block_delta":
|
||||
chk = _stream_response_to_generation_chunk(chunk_obj)
|
||||
yield chk
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
# chunk obj format varies with provider
|
||||
yield GenerationChunk(
|
||||
text=(
|
||||
chunk_obj[output_key]
|
||||
if provider != "mistral"
|
||||
else chunk_obj[output_key][0]["text"]
|
||||
),
|
||||
generation_info={
|
||||
GUARDRAILS_BODY_KEY: (
|
||||
chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def aprepare_output_stream(
|
||||
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
stream = response.get("body")
|
||||
|
||||
if not stream:
|
||||
return
|
||||
|
||||
output_key = cls.provider_to_output_key_map.get(provider, None)
|
||||
|
||||
if not output_key:
|
||||
raise ValueError(
|
||||
f"Unknown streaming response output key for provider: {provider}"
|
||||
)
|
||||
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
if provider == "cohere" and (
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
if (
|
||||
provider == "mistral"
|
||||
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
|
||||
):
|
||||
return
|
||||
|
||||
yield GenerationChunk(
|
||||
text=(
|
||||
chunk_obj[output_key]
|
||||
if provider != "mistral"
|
||||
else chunk_obj[output_key][0]["text"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BedrockBase(BaseModel, ABC):
|
||||
"""Base class for Bedrock models."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
client: Any = Field(exclude=True) #: :meta private:
|
||||
|
||||
region_name: Optional[str] = None
|
||||
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
|
||||
or region specified in ~/.aws/config in case it is not provided here.
|
||||
"""
|
||||
|
||||
credentials_profile_name: Optional[str] = Field(default=None, exclude=True)
|
||||
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||
has either access keys or role information specified.
|
||||
If not specified, the default credential profile or, if on an EC2 instance,
|
||||
credentials from IMDS will be used.
|
||||
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
"""
|
||||
|
||||
config: Any = None
|
||||
"""An optional botocore.config.Config instance to pass to the client."""
|
||||
|
||||
provider: Optional[str] = None
|
||||
"""The model provider, e.g., amazon, cohere, ai21, etc. When not supplied, provider
|
||||
is extracted from the first part of the model_id e.g. 'amazon' in
|
||||
'amazon.titan-text-express-v1'. This value should be provided for model ids that do
|
||||
not have the provider in them, e.g., custom and provisioned models that have an ARN
|
||||
associated with them."""
|
||||
|
||||
model_id: str
|
||||
"""Id of the model to call, e.g., amazon.titan-text-express-v1, this is
|
||||
equivalent to the modelId property in the list-foundation-models api. For custom and
|
||||
provisioned models, an ARN value is expected."""
|
||||
|
||||
model_kwargs: Optional[Dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
provider_stop_sequence_key_name_map: Mapping[str, str] = {
|
||||
"anthropic": "stop_sequences",
|
||||
"amazon": "stopSequences",
|
||||
"ai21": "stop_sequences",
|
||||
"cohere": "stop_sequences",
|
||||
"mistral": "stop",
|
||||
}
|
||||
|
||||
guardrails: Optional[Mapping[str, Any]] = {
|
||||
"id": None,
|
||||
"version": None,
|
||||
"trace": False,
|
||||
}
|
||||
"""
|
||||
An optional dictionary to configure guardrails for Bedrock.
|
||||
|
||||
This field 'guardrails' consists of two keys: 'id' and 'version',
|
||||
which should be strings, but are initialized to None. It's used to
|
||||
determine if specific guardrails are enabled and properly set.
|
||||
|
||||
Type:
|
||||
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.
|
||||
|
||||
Example:
|
||||
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": "<guardrail_id>",
|
||||
"version": "<guardrail_version>"})
|
||||
|
||||
To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
|
||||
'run_manager' parameter of the 'generate', '_call' methods.
|
||||
|
||||
Example:
|
||||
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": "<guardrail_id>",
|
||||
"version": "<guardrail_version>",
|
||||
"trace": True},
|
||||
callbacks=[BedrockAsyncCallbackHandler()])
|
||||
|
||||
[https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers.
|
||||
|
||||
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
reason = kwargs.get("reason")
|
||||
if reason == "GUARDRAIL_INTERVENED":
|
||||
...Logic to handle guardrail intervention...
|
||||
""" # noqa: E501
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
|
||||
# Skip creating new client if passed in constructor
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
if values["credentials_profile_name"] is not None:
|
||||
session = boto3.Session(profile_name=values["credentials_profile_name"])
|
||||
else:
|
||||
# use default credentials
|
||||
session = boto3.Session()
|
||||
|
||||
values["region_name"] = get_from_dict_or_env(
|
||||
values,
|
||||
"region_name",
|
||||
"AWS_DEFAULT_REGION",
|
||||
default=session.region_name,
|
||||
)
|
||||
|
||||
client_params = {}
|
||||
if values["region_name"]:
|
||||
client_params["region_name"] = values["region_name"]
|
||||
if values["endpoint_url"]:
|
||||
client_params["endpoint_url"] = values["endpoint_url"]
|
||||
if values["config"]:
|
||||
client_params["config"] = values["config"]
|
||||
|
||||
values["client"] = session.client("bedrock-runtime", **client_params)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not load credentials to authenticate with AWS client. "
|
||||
"Please check that credentials in the specified "
|
||||
f"profile name are valid. Bedrock error: {e}"
|
||||
) from e
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
def _get_provider(self) -> str:
|
||||
if self.provider:
|
||||
return self.provider
|
||||
if self.model_id.startswith("arn"):
|
||||
raise ValueError(
|
||||
"Model provider should be supplied when passing a model ARN as model_id"
|
||||
)
|
||||
|
||||
return self.model_id.split(".")[0]
|
||||
|
||||
@property
|
||||
def _model_is_anthropic(self) -> bool:
|
||||
return self._get_provider() == "anthropic"
|
||||
|
||||
@property
|
||||
def _guardrails_enabled(self) -> bool:
|
||||
"""
|
||||
Determines if guardrails are enabled and correctly configured.
|
||||
Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys.
|
||||
Checks if 'guardrails.trace' is true.
|
||||
|
||||
Returns:
|
||||
bool: True if guardrails are correctly configured, False otherwise.
|
||||
Raises:
|
||||
TypeError: If 'guardrails' lacks 'id' or 'version' keys.
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
isinstance(self.guardrails, dict)
|
||||
and bool(self.guardrails["id"])
|
||||
and bool(self.guardrails["version"])
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
raise TypeError(
|
||||
"Guardrails must be a dictionary with 'id' and 'version' keys."
|
||||
) from e
|
||||
|
||||
def _get_guardrails_canonical(self) -> Dict[str, Any]:
|
||||
"""
|
||||
The canonical way to pass in guardrails to the bedrock service
|
||||
adheres to the following format:
|
||||
|
||||
"amazon-bedrock-guardrailDetails": {
|
||||
"guardrailId": "string",
|
||||
"guardrailVersion": "string"
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"amazon-bedrock-guardrailDetails": {
|
||||
"guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
|
||||
"guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_input_and_invoke(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self._get_provider()
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
if self._guardrails_enabled:
|
||||
params.update(self._get_guardrails_canonical())
|
||||
input_body = LLMInputOutputAdapter.prepare_input(
|
||||
provider=provider,
|
||||
model_kwargs=params,
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
messages=messages,
|
||||
)
|
||||
body = json.dumps(input_body)
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
|
||||
request_options = {
|
||||
"body": body,
|
||||
"modelId": self.model_id,
|
||||
"accept": accept,
|
||||
"contentType": contentType,
|
||||
}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
request_options["guardrail"] = "ENABLED"
|
||||
if self.guardrails.get("trace"): # type: ignore[union-attr]
|
||||
request_options["trace"] = "ENABLED"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(**request_options)
|
||||
|
||||
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
|
||||
provider, response
|
||||
).values()
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
# Verify and raise a callback error if any intervention occurs or a signal is
|
||||
# sent from a Bedrock service,
|
||||
# such as when guardrails are triggered.
|
||||
services_trace = self._get_bedrock_services_signal(body)
|
||||
|
||||
if services_trace.get("signal") and run_manager is not None:
|
||||
run_manager.on_llm_error(
|
||||
Exception(
|
||||
f"Error raised by bedrock service: {services_trace.get('reason')}"
|
||||
),
|
||||
**services_trace,
|
||||
)
|
||||
|
||||
return text, usage_info
|
||||
|
||||
def _get_bedrock_services_signal(self, body: dict) -> dict:
|
||||
"""
|
||||
This function checks the response body for an interrupt flag or message that indicates
|
||||
whether any of the Bedrock services have intervened in the processing flow. It is
|
||||
primarily used to identify modifications or interruptions imposed by these services
|
||||
during the request-response cycle with a Large Language Model (LLM).
|
||||
""" # noqa: E501
|
||||
|
||||
if (
|
||||
self._guardrails_enabled
|
||||
and self.guardrails.get("trace") # type: ignore[union-attr]
|
||||
and self._is_guardrails_intervention(body)
|
||||
):
|
||||
return {
|
||||
"signal": True,
|
||||
"reason": "GUARDRAIL_INTERVENED",
|
||||
"trace": body.get(AMAZON_BEDROCK_TRACE_KEY),
|
||||
}
|
||||
|
||||
return {
|
||||
"signal": False,
|
||||
"reason": None,
|
||||
"trace": None,
|
||||
}
|
||||
|
||||
def _is_guardrails_intervention(self, body: dict) -> bool:
|
||||
return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED"
|
||||
|
||||
def _prepare_input_and_invoke_stream(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
provider = self._get_provider()
|
||||
|
||||
if stop:
|
||||
if provider not in self.provider_stop_sequence_key_name_map:
|
||||
raise ValueError(
|
||||
f"Stop sequence key name for {provider} is not supported."
|
||||
)
|
||||
|
||||
# stop sequence from _generate() overrides
|
||||
# stop sequences in the class attribute
|
||||
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
|
||||
|
||||
if provider == "cohere":
|
||||
_model_kwargs["stream"] = True
|
||||
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
params.update(self._get_guardrails_canonical())
|
||||
|
||||
input_body = LLMInputOutputAdapter.prepare_input(
|
||||
provider=provider,
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
messages=messages,
|
||||
model_kwargs=params,
|
||||
)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
request_options = {
|
||||
"body": body,
|
||||
"modelId": self.model_id,
|
||||
"accept": "application/json",
|
||||
"contentType": "application/json",
|
||||
}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
request_options["guardrail"] = "ENABLED"
|
||||
if self.guardrails.get("trace"): # type: ignore[union-attr]
|
||||
request_options["trace"] = "ENABLED"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model_with_response_stream(**request_options)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
for chunk in LLMInputOutputAdapter.prepare_output_stream(
|
||||
provider, response, stop, True if messages else False
|
||||
):
|
||||
# verify and raise callback error if any middleware intervened
|
||||
self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]
|
||||
|
||||
if run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _aprepare_input_and_invoke_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
provider = self._get_provider()
|
||||
|
||||
if stop:
|
||||
if provider not in self.provider_stop_sequence_key_name_map:
|
||||
raise ValueError(
|
||||
f"Stop sequence key name for {provider} is not supported."
|
||||
)
|
||||
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
|
||||
|
||||
if provider == "cohere":
|
||||
_model_kwargs["stream"] = True
|
||||
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(
|
||||
provider=provider, prompt=prompt, model_kwargs=params
|
||||
)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: self.client.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in LLMInputOutputAdapter.aprepare_output_stream(
|
||||
provider, response, stop
|
||||
):
|
||||
if run_manager is not None and asyncio.iscoroutinefunction(
|
||||
run_manager.on_llm_new_token
|
||||
):
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
elif run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine]
|
||||
yield chunk
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.34", removal="1.0", alternative_import="langchain_aws.BedrockLLM"
|
||||
)
|
||||
class Bedrock(LLM, BedrockBase):
|
||||
"""Bedrock models.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Bedrock service.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from bedrock_langchain.bedrock_llm import BedrockLLM
|
||||
|
||||
llm = BedrockLLM(
|
||||
credentials_profile_name="default",
|
||||
model_id="amazon.titan-text-express-v1",
|
||||
streaming=True
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
model_id = values["model_id"]
|
||||
if model_id.startswith("anthropic.claude-3"):
|
||||
raise ValueError(
|
||||
"Claude v3 models are not supported by this LLM."
|
||||
"Please use `from langchain_community.chat_models import BedrockChat` "
|
||||
"instead."
|
||||
)
|
||||
return super().validate_environment(values)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "amazon_bedrock"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "llms", "bedrock"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.region_name:
|
||||
attributes["region_name"] = self.region_name
|
||||
|
||||
return attributes
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Call out to Bedrock service with streaming.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to pass into the model
|
||||
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||
override any stop sequences in the `model_kwargs` attribute.
|
||||
Defaults to None.
|
||||
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||
run managers used to process the output. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterator[GenerationChunk]: Generator that yields the streamed responses.
|
||||
|
||||
Yields:
|
||||
Iterator[GenerationChunk]: Responses from the model.
|
||||
"""
|
||||
return self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Bedrock service model.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = llm.invoke("Tell me a joke.")
|
||||
"""
|
||||
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
text, _ = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return text
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[GenerationChunk, None]:
|
||||
"""Call out to Bedrock service with streaming.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to pass into the model
|
||||
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||
override any stop sequences in the `model_kwargs` attribute.
|
||||
Defaults to None.
|
||||
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||
run managers used to process the output. Defaults to None.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[GenerationChunk, None]: Generator that asynchronously yields
|
||||
the streamed responses.
|
||||
"""
|
||||
async for chunk in self._aprepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Bedrock service model.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = await llm._acall("Tell me a joke.")
|
||||
"""
|
||||
|
||||
if not self.streaming:
|
||||
raise ValueError("Streaming must be set to True for async operations. ")
|
||||
|
||||
chunks = [
|
||||
chunk.text
|
||||
async for chunk in self._astream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
]
|
||||
return "".join(chunks)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
else:
|
||||
return super().get_num_tokens(text)
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
if self._model_is_anthropic:
|
||||
return get_token_ids_anthropic(text)
|
||||
else:
|
||||
return super().get_token_ids(text)
|
||||
Reference in New Issue
Block a user