initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,335 @@
"""**Chat Models** are a variation on language models.
While Chat Models use language models under the hood, the interface they expose
is a bit different. Rather than expose a "text in, text out" API, they expose
an interface where "chat messages" are the inputs and outputs.
**Class hierarchy:**
.. code-block::
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
**Main helpers:**
.. code-block::
AIMessage, BaseMessage, HumanMessage
""" # noqa: E501
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.chat_models.anthropic import (
ChatAnthropic,
)
from langchain_community.chat_models.anyscale import (
ChatAnyscale,
)
from langchain_community.chat_models.azure_openai import (
AzureChatOpenAI,
)
from langchain_community.chat_models.baichuan import (
ChatBaichuan,
)
from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)
from langchain_community.chat_models.bedrock import (
BedrockChat,
)
from langchain_community.chat_models.cohere import (
ChatCohere,
)
from langchain_community.chat_models.coze import (
ChatCoze,
)
from langchain_community.chat_models.databricks import (
ChatDatabricks,
)
from langchain_community.chat_models.deepinfra import (
ChatDeepInfra,
)
from langchain_community.chat_models.edenai import ChatEdenAI
from langchain_community.chat_models.ernie import (
ErnieBotChat,
)
from langchain_community.chat_models.everlyai import (
ChatEverlyAI,
)
from langchain_community.chat_models.fake import (
FakeListChatModel,
)
from langchain_community.chat_models.fireworks import (
ChatFireworks,
)
from langchain_community.chat_models.friendli import (
ChatFriendli,
)
from langchain_community.chat_models.gigachat import (
GigaChat,
)
from langchain_community.chat_models.google_palm import (
ChatGooglePalm,
)
from langchain_community.chat_models.gpt_router import (
GPTRouter,
)
from langchain_community.chat_models.huggingface import (
ChatHuggingFace,
)
from langchain_community.chat_models.human import (
HumanInputChatModel,
)
from langchain_community.chat_models.hunyuan import (
ChatHunyuan,
)
from langchain_community.chat_models.javelin_ai_gateway import (
ChatJavelinAIGateway,
)
from langchain_community.chat_models.jinachat import (
JinaChat,
)
from langchain_community.chat_models.kinetica import (
ChatKinetica,
)
from langchain_community.chat_models.konko import (
ChatKonko,
)
from langchain_community.chat_models.litellm import (
ChatLiteLLM,
)
from langchain_community.chat_models.litellm_router import (
ChatLiteLLMRouter,
)
from langchain_community.chat_models.llama_edge import (
LlamaEdgeChatService,
)
from langchain_community.chat_models.llamacpp import ChatLlamaCpp
from langchain_community.chat_models.maritalk import (
ChatMaritalk,
)
from langchain_community.chat_models.minimax import (
MiniMaxChat,
)
from langchain_community.chat_models.mlflow import (
ChatMlflow,
)
from langchain_community.chat_models.mlflow_ai_gateway import (
ChatMLflowAIGateway,
)
from langchain_community.chat_models.mlx import (
ChatMLX,
)
from langchain_community.chat_models.moonshot import (
MoonshotChat,
)
from langchain_community.chat_models.naver import (
ChatClovaX,
)
from langchain_community.chat_models.oci_data_science import (
ChatOCIModelDeployment,
ChatOCIModelDeploymentTGI,
ChatOCIModelDeploymentVLLM,
)
from langchain_community.chat_models.oci_generative_ai import (
ChatOCIGenAI, # noqa: F401
)
from langchain_community.chat_models.octoai import ChatOctoAI
from langchain_community.chat_models.ollama import (
ChatOllama,
)
from langchain_community.chat_models.openai import (
ChatOpenAI,
)
from langchain_community.chat_models.outlines import ChatOutlines
from langchain_community.chat_models.pai_eas_endpoint import (
PaiEasChatEndpoint,
)
from langchain_community.chat_models.perplexity import (
ChatPerplexity,
)
from langchain_community.chat_models.premai import (
ChatPremAI,
)
from langchain_community.chat_models.promptlayer_openai import (
PromptLayerChatOpenAI,
)
from langchain_community.chat_models.reka import (
ChatReka,
)
from langchain_community.chat_models.sambanova import (
ChatSambaNovaCloud,
ChatSambaStudio,
)
from langchain_community.chat_models.snowflake import (
ChatSnowflakeCortex,
)
from langchain_community.chat_models.solar import (
SolarChat,
)
from langchain_community.chat_models.sparkllm import (
ChatSparkLLM,
)
from langchain_community.chat_models.symblai_nebula import ChatNebula
from langchain_community.chat_models.tongyi import (
ChatTongyi,
)
from langchain_community.chat_models.vertexai import (
ChatVertexAI,
)
from langchain_community.chat_models.volcengine_maas import (
VolcEngineMaasChat,
)
from langchain_community.chat_models.yandex import (
ChatYandexGPT,
)
from langchain_community.chat_models.yi import (
ChatYi,
)
from langchain_community.chat_models.yuan2 import (
ChatYuan2,
)
from langchain_community.chat_models.zhipuai import (
ChatZhipuAI,
)
__all__ = [
"AzureChatOpenAI",
"BedrockChat",
"ChatAnthropic",
"ChatAnyscale",
"ChatBaichuan",
"ChatClovaX",
"ChatCohere",
"ChatCoze",
"ChatOctoAI",
"ChatDatabricks",
"ChatDeepInfra",
"ChatEdenAI",
"ChatEverlyAI",
"ChatFireworks",
"ChatFriendli",
"ChatGooglePalm",
"ChatHuggingFace",
"ChatHunyuan",
"ChatJavelinAIGateway",
"ChatKinetica",
"ChatKonko",
"ChatLiteLLM",
"ChatLiteLLMRouter",
"ChatMLX",
"ChatMLflowAIGateway",
"ChatMaritalk",
"ChatMlflow",
"ChatNebula",
"ChatOCIGenAI",
"ChatOCIModelDeployment",
"ChatOCIModelDeploymentVLLM",
"ChatOCIModelDeploymentTGI",
"ChatOllama",
"ChatOpenAI",
"ChatOutlines",
"ChatPerplexity",
"ChatReka",
"ChatPremAI",
"ChatSambaNovaCloud",
"ChatSambaStudio",
"ChatSparkLLM",
"ChatSnowflakeCortex",
"ChatTongyi",
"ChatVertexAI",
"ChatYandexGPT",
"ChatYuan2",
"ChatZhipuAI",
"ChatLlamaCpp",
"ErnieBotChat",
"FakeListChatModel",
"GPTRouter",
"GigaChat",
"HumanInputChatModel",
"JinaChat",
"LlamaEdgeChatService",
"MiniMaxChat",
"MoonshotChat",
"PaiEasChatEndpoint",
"PromptLayerChatOpenAI",
"QianfanChatEndpoint",
"SolarChat",
"VolcEngineMaasChat",
"ChatYi",
]
_module_lookup = {
"AzureChatOpenAI": "langchain_community.chat_models.azure_openai",
"BedrockChat": "langchain_community.chat_models.bedrock",
"ChatAnthropic": "langchain_community.chat_models.anthropic",
"ChatAnyscale": "langchain_community.chat_models.anyscale",
"ChatBaichuan": "langchain_community.chat_models.baichuan",
"ChatClovaX": "langchain_community.chat_models.naver",
"ChatCohere": "langchain_community.chat_models.cohere",
"ChatCoze": "langchain_community.chat_models.coze",
"ChatDatabricks": "langchain_community.chat_models.databricks",
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
"ChatEdenAI": "langchain_community.chat_models.edenai",
"ChatFireworks": "langchain_community.chat_models.fireworks",
"ChatFriendli": "langchain_community.chat_models.friendli",
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
"ChatHuggingFace": "langchain_community.chat_models.huggingface",
"ChatHunyuan": "langchain_community.chat_models.hunyuan",
"ChatJavelinAIGateway": "langchain_community.chat_models.javelin_ai_gateway",
"ChatKinetica": "langchain_community.chat_models.kinetica",
"ChatKonko": "langchain_community.chat_models.konko",
"ChatLiteLLM": "langchain_community.chat_models.litellm",
"ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router",
"ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway",
"ChatMLX": "langchain_community.chat_models.mlx",
"ChatMaritalk": "langchain_community.chat_models.maritalk",
"ChatMlflow": "langchain_community.chat_models.mlflow",
"ChatNebula": "langchain_community.chat_models.symblai_nebula",
"ChatOctoAI": "langchain_community.chat_models.octoai",
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
"ChatOCIModelDeployment": "langchain_community.chat_models.oci_data_science",
"ChatOCIModelDeploymentVLLM": "langchain_community.chat_models.oci_data_science",
"ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science",
"ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai",
"ChatOutlines": "langchain_community.chat_models.outlines",
"ChatReka": "langchain_community.chat_models.reka",
"ChatPerplexity": "langchain_community.chat_models.perplexity",
"ChatSambaNovaCloud": "langchain_community.chat_models.sambanova",
"ChatSambaStudio": "langchain_community.chat_models.sambanova",
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
"ChatTongyi": "langchain_community.chat_models.tongyi",
"ChatVertexAI": "langchain_community.chat_models.vertexai",
"ChatYandexGPT": "langchain_community.chat_models.yandex",
"ChatYuan2": "langchain_community.chat_models.yuan2",
"ChatZhipuAI": "langchain_community.chat_models.zhipuai",
"ErnieBotChat": "langchain_community.chat_models.ernie",
"FakeListChatModel": "langchain_community.chat_models.fake",
"GPTRouter": "langchain_community.chat_models.gpt_router",
"GigaChat": "langchain_community.chat_models.gigachat",
"HumanInputChatModel": "langchain_community.chat_models.human",
"JinaChat": "langchain_community.chat_models.jinachat",
"LlamaEdgeChatService": "langchain_community.chat_models.llama_edge",
"MiniMaxChat": "langchain_community.chat_models.minimax",
"MoonshotChat": "langchain_community.chat_models.moonshot",
"PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint",
"PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai",
"SolarChat": "langchain_community.chat_models.solar",
"QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint",
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
"ChatPremAI": "langchain_community.chat_models.premai",
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
"ChatYi": "langchain_community.chat_models.yi",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,234 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.prompt_values import PromptValue
from pydantic import ConfigDict
from langchain_community.llms.anthropic import _AnthropicCommon
def _convert_one_message_to_text(
message: BaseMessage,
human_prompt: str,
ai_prompt: str,
) -> str:
content = cast(str, message.content)
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_anthropic(
messages: List[BaseMessage],
*,
human_prompt: str = "\n\nHuman:",
ai_prompt: str = "\n\nAssistant:",
) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
Returns:
str: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = "".join(
_convert_one_message_to_text(message, human_prompt, ai_prompt)
for message in messages
)
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
@deprecated(
since="0.0.28",
removal="1.0",
alternative_import="langchain_anthropic.ChatAnthropic",
)
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
"""`Anthropic` chat large language models.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain_community.chat_models import ChatAnthropic
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
"""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
prompt_params = {}
if self.HUMAN_PROMPT:
prompt_params["human_prompt"] = self.HUMAN_PROMPT
if self.AI_PROMPT:
prompt_params["ai_prompt"] = self.AI_PROMPT
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
def convert_prompt(self, prompt: PromptValue) -> str:
return self._convert_messages_to_prompt(prompt.to_messages())
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = await self.async_client.completions.create(**params, stream=True)
async for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = await self.async_client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)

View File

@@ -0,0 +1,243 @@
"""Anyscale Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from __future__ import annotations
import logging
import os
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Set,
Type,
Union,
)
import requests
from langchain_core.messages import BaseMessage
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, SecretStr, model_validator
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_community.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
)
from langchain_community.utils.openai import is_openai_v1
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
class ChatAnyscale(ChatOpenAI):
"""`Anyscale` Chat large language models.
See https://www.anyscale.com/ for information about Anyscale.
To use, you should have the ``openai`` python package installed, and the
environment variable ``ANYSCALE_API_KEY`` set with your API key.
Alternatively, you can use the anyscale_api_key keyword argument.
Any parameters that are valid to be passed to the `openai.create` call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatAnyscale
chat = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anyscale-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return False
anyscale_api_key: SecretStr = Field(default=SecretStr(""))
"""AnyScale Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
anyscale_api_base: str = Field(default=DEFAULT_API_BASE)
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
anyscale_proxy: Optional[str] = None
"""To support explicit proxy for Anyscale."""
available_models: Optional[Set[str]] = None
"""Available models from Anyscale API."""
@staticmethod
def get_available_models(
anyscale_api_key: Optional[str] = None,
anyscale_api_base: str = DEFAULT_API_BASE,
) -> Set[str]:
"""Get available models from Anyscale API."""
try:
anyscale_api_key = anyscale_api_key or os.environ["ANYSCALE_API_KEY"]
except KeyError as e:
raise ValueError(
"Anyscale API key must be passed as keyword argument or "
"set in environment variable ANYSCALE_API_KEY.",
) from e
models_url = f"{anyscale_api_base}/models"
models_response = requests.get(
models_url,
headers={
"Authorization": f"Bearer {anyscale_api_key}",
},
)
if models_response.status_code != 200:
raise ValueError(
f"Error getting models from {models_url}: "
f"{models_response.status_code}",
)
return {model["id"] for model in models_response.json()["data"]}
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: dict) -> Any:
"""Validate that api key and python package exists in environment."""
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
)
values["anyscale_api_base"] = get_from_dict_or_env(
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
default=DEFAULT_API_BASE,
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"anyscale_proxy",
"ANYSCALE_PROXY",
default="",
)
try:
import openai
except ImportError as e:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`.",
) from e
try:
if is_openai_v1():
client_params = {
"api_key": values["anyscale_api_key"].get_secret_value(),
"base_url": values["anyscale_api_base"],
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
# "max_retries": values["max_retries"],
# "default_headers": values["default_headers"],
# "default_query": values["default_query"],
# "http_client": values["http_client"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).chat.completions
else:
values["openai_api_base"] = values["anyscale_api_base"]
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`.",
) from exc
if "model_name" not in values.keys():
values["model_name"] = DEFAULT_MODEL
model_name = values["model_name"]
available_models = cls.get_available_models(
values["anyscale_api_key"].get_secret_value(),
values["anyscale_api_base"],
)
if model_name not in available_models:
raise ValueError(
f"Model name {model_name} not found in available models: "
f"{available_models}.",
)
values["available_models"] = available_models
return values
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
)
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens

View File

@@ -0,0 +1,293 @@
"""Azure OpenAI chat wrapper."""
from __future__ import annotations
import logging
import os
import warnings
from typing import Any, Awaitable, Callable, Dict, List, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.outputs import ChatResult
from langchain_core.utils import get_from_dict_or_env, pre_init
from pydantic import BaseModel, Field
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
@deprecated(
since="0.0.10",
removal="1.0",
alternative_import="langchain_openai.AzureChatOpenAI",
)
class AzureChatOpenAI(ChatOpenAI):
"""`Azure OpenAI` Chat Completion API.
To use this class you
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
constructor to refer to the "Model deployment name" in the Azure portal.
In addition, you should have the ``openai`` python package installed, and the
following environment variables set or passed in constructor in lower case:
- ``AZURE_OPENAI_API_KEY``
- ``AZURE_OPENAI_ENDPOINT``
- ``AZURE_OPENAI_AD_TOKEN``
- ``OPENAI_API_VERSION``
- ``OPENAI_PROXY``
For example, if you have `gpt-35-turbo` deployed, with the deployment name
`35-turbo-dev`, the constructor should look like:
.. code-block:: python
AzureChatOpenAI(
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
)
Be aware the API version may change.
You can also specify the version of the model using ``model_version`` constructor
parameter, as Azure OpenAI doesn't return model version with the response.
Default is empty. When you specify the version, it will be appended to the
model name in the response. Setting correct version will help you to calculate the
cost properly. Model version is not validated, so make sure you set it correctly
to get the correct cost.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
"""
azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource.
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
Example: `https://example-resource.azure.openai.com/`
"""
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
"""A model deployment.
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
For more:
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
"""
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
model_version: str = ""
"""Legacy, for openai<1.0.0 support."""
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
validate_base_url: bool = True
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
"OPENAI_API_VERSION"
)
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
values["openai_organization"]
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
)
values["openai_api_type"] = get_from_dict_or_env(
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
)
values["openai_proxy"] = get_from_dict_or_env(
values, "openai_proxy", "OPENAI_PROXY", default=""
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
if openai_api_base and values["validate_base_url"]:
if "/openai" not in openai_api_base:
values["openai_api_base"] = (
values["openai_api_base"].rstrip("/") + "/openai"
)
warnings.warn(
"As of openai>=1.0.0, Azure endpoints should be specified via "
f"the `azure_endpoint` param not `openai_api_base` "
f"(or alias `base_url`). Updating `openai_api_base` from "
f"{openai_api_base} to {values['openai_api_base']}."
)
if values["deployment_name"]:
warnings.warn(
"As of openai>=1.0.0, if `deployment_name` (or alias "
"`azure_deployment`) is specified then "
"`openai_api_base` (or alias `base_url`) should not be. "
"Instead use `deployment_name` (or alias `azure_deployment`) "
"and `azure_endpoint`."
)
if values["deployment_name"] not in values["openai_api_base"]:
warnings.warn(
"As of openai>=1.0.0, if `openai_api_base` "
"(or alias `base_url`) is specified it is expected to be "
"of the form "
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
f"Updating {openai_api_base} to "
f"{values['openai_api_base']}."
)
values["openai_api_base"] += (
"/deployments/" + values["deployment_name"]
)
values["deployment_name"] = None
client_params = {
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": {
**(values["default_headers"] or {}),
"User-Agent": "langchain-comm-python-azure-openai",
},
"default_query": values["default_query"],
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
if azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
values["async_client"] = openai.AsyncAzureOpenAI(
**client_params
).chat.completions
else:
values["client"] = openai.ChatCompletion
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
if is_openai_v1():
return super()._default_params
else:
return {
**super()._default_params,
"engine": self.deployment_name,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**self._default_params}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the config params used for the openai client."""
if is_openai_v1():
return super()._client_params
else:
return {
**super()._client_params,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
@property
def _llm_type(self) -> str:
return "azure-openai-chat"
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"openai_api_type": self.openai_api_type,
"openai_api_version": self.openai_api_version,
}
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
raise ValueError(
"Azure has not provided the response due to a content filter "
"being triggered"
)
chat_result = super()._create_chat_result(response)
if "model" in response:
model = response["model"]
if self.model_version:
model = f"{model}-{self.model_version}"
if chat_result.llm_output is not None and isinstance(
chat_result.llm_output, dict
):
chat_result.llm_output["model_name"] = model
return chat_result

View File

@@ -0,0 +1,426 @@
import json
import warnings
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.azureml_endpoint import (
AzureMLBaseEndpoint,
AzureMLEndpointApiType,
ContentFormatterBase,
)
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
def __init__(self) -> None:
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`CustomOpenAIContentFormatter` instead."
)
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
"""Chat Content formatter for models with OpenAI like API scheme."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts a message to a dict according to a role"""
content = cast(str, message.content)
if isinstance(message, HumanMessage):
return {
"role": "user",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, AIMessage):
return {
"role": "assistant",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, SystemMessage):
return {
"role": "system",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif (
isinstance(message, ChatMessage)
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
"content": ContentFormatterBase.escape_special_characters(content),
}
else:
supported = ",".join(
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
for message in messages
]
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> ChatGeneration:
"""Formats response"""
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=AIMessage(
content=choice.strip(),
),
generation_info=None,
)
if api_type == AzureMLEndpointApiType.serverless:
try:
choice = json.loads(output)["choices"][0]
if not isinstance(choice, dict):
raise TypeError(
"Endpoint response is not well formed for a chat "
"model. Expected `dict` but `{type(choice)}` was received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=AIMessage(content=choice["message"]["content"].strip())
if choice["message"]["role"] == "assistant"
else BaseMessage(
content=choice["message"]["content"].strip(),
type=choice["message"]["role"],
),
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
)
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
"""Deprecated: Kept for backwards compatibility
Chat Content formatter for Llama."""
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaChatContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIChatContentFormatter` instead.
"""
)
class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`."""
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [self._convert_message_to_dict(message) for message in messages]
if chat_messages and chat_messages[0]["role"] == "system":
# Mistral OSS models do not explicitly support system prompts, so we have to
# stash in the first user prompt
chat_messages[1]["content"] = (
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
)
del chat_messages[0]
if api_type == AzureMLEndpointApiType.realtime:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint chat models.
Example:
.. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
endpoint_api_type=AzureMLApiType.serverless,
endpoint_api_key="my-api-key",
content_formatter=chat_content_formatter,
)
"""
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_chat_endpoint"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model.invoke("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs.update(kwargs)
if stop:
_model_kwargs["stop"] = stop
request_payload = self.content_formatter.format_messages_request_payload(
messages, _model_kwargs, self.endpoint_api_type
)
response_payload = self.http_client.call(
body=request_payload, run_manager=run_manager
)
generations = self.content_formatter.format_response_payload(
response_payload, self.endpoint_api_type
)
return ChatResult(generations=[generations])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
client = openai.OpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in client.chat.completions.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"],
default_chunk_class,
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info or None,
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
async_client = openai.AsyncOpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
async for chunk in await async_client.chat.completions.create(
messages=message_dicts,
**params,
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore[call-arg]

View File

@@ -0,0 +1,650 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from langchain_community.chat_models.llamacpp import (
_lc_invalid_tool_call_to_openai_tool_call,
_lc_tool_call_to_openai_tool_call,
)
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
content = message.content
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"tool_call_id": message.tool_call_id,
"content": content,
"name": message.name or message.additional_kwargs.get("name"),
}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict.get("content", "")
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
tool_calls = []
invalid_tool_calls = []
additional_kwargs = {}
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=content,
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
@asynccontextmanager
async def aconnect_httpx_sse(
client: Any, method: str, url: str, **kwargs: Any
) -> AsyncIterator:
"""Async context manager for connecting to an SSE stream.
Args:
client: The httpx client.
method: The HTTP method.
url: The URL to connect to.
kwargs: Additional keyword arguments to pass to the client.
Yields:
An EventSource object.
"""
from httpx_sse import EventSource
async with client.stream(method, url, **kwargs) as response:
yield EventSource(response)
class ChatBaichuan(BaseChatModel):
"""Baichuan chat model integration.
Setup:
To use, you should have the environment variable``BAICHUAN_API_KEY`` set with
your API KEY.
.. code-block:: bash
export BAICHUAN_API_KEY="your-api-key"
Key init args — completion params:
model: Optional[str]
Name of Baichuan model to use.
max_tokens: Optional[int]
Max number of tokens to generate.
streaming: Optional[bool]
Whether to stream the results or not.
temperature: Optional[float]
Sampling temperature.
top_p: Optional[float]
What probability mass to use.
top_k: Optional[int]
What search sampling control to use.
Key init args — client params:
api_key: Optional[str]
Baichuan API key. If not passed in will be read from env var BAICHUAN_API_KEY.
base_url: Optional[str]
Base URL for API requests.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_community.chat_models import ChatBaichuan
chat = ChatBaichuan(
api_key=api_key,
model='Baichuan4',
# temperature=...,
# other params...
)
Invoke:
.. code-block:: python
messages = [
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
("human", "我喜欢编程。"),
]
chat.invoke(messages)
.. code-block:: python
AIMessage(
content='I enjoy programming.',
response_metadata={
'token_usage': {
'prompt_tokens': 93,
'completion_tokens': 5,
'total_tokens': 98
},
'model': 'Baichuan4'
},
id='run-944ff552-6a93-44cf-a861-4e4d849746f9-0'
)
Stream:
.. code-block:: python
for chunk in chat.stream(messages):
print(chunk)
.. code-block:: python
content='I' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8'
content=' enjoy programming.' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8
.. code-block:: python
stream = chat.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full
.. code-block:: python
AIMessageChunk(
content='I like programming.',
id='run-74689970-dc31-461d-b729-3b6aa93508d2'
)
Async:
.. code-block:: python
await chat.ainvoke(messages)
# stream
# async for chunk in chat.astream(messages):
# print(chunk)
# batch
# await chat.abatch([messages])
.. code-block:: python
AIMessage(
content='I enjoy programming.',
response_metadata={
'token_usage': {
'prompt_tokens': 93,
'completion_tokens': 5,
'total_tokens': 98
},
'model': 'Baichuan4'
},
id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0'
)
Tool calling:
.. code-block:: python
class get_current_weather(BaseModel):
'''Get current weather.'''
location: str = Field('City or province, such as Shanghai')
llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather])
llm_with_tools.invoke('How is the weather today?')
.. code-block:: python
[{'name': 'get_current_weather',
'args': {'location': 'New York'},
'id': '3951017OF8doB0A',
'type': 'tool_call'}]
Response metadata
.. code-block:: python
ai_msg = chat.invoke(messages)
ai_msg.response_metadata
.. code-block:: python
{
'token_usage': {
'prompt_tokens': 93,
'completion_tokens': 5,
'total_tokens': 98
},
'model': 'Baichuan4'
}
""" # noqa: E501
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"baichuan_api_key": "BAICHUAN_API_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
baichuan_api_base: str = Field(default=DEFAULT_API_BASE, alias="base_url")
"""Baichuan custom endpoints"""
baichuan_api_key: SecretStr = Field(alias="api_key")
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
request_timeout: int = Field(default=60, alias="timeout")
"""request timeout for chat http requests"""
model: str = "Baichuan2-Turbo-192K"
"""model name of Baichuan, default is `Baichuan2-Turbo-192K`,
other options include `Baichuan2-Turbo`"""
temperature: Optional[float] = Field(default=0.3)
"""What sampling temperature to use."""
top_k: int = 5
"""What search sampling control to use."""
top_p: float = 0.85
"""What probability mass to use."""
with_search_enhance: bool = False
"""[DEPRECATED, keeping it for for backward compatibility],
Whether to use search enhance, default is False."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["baichuan_api_base"] = get_from_dict_or_env(
values,
"baichuan_api_base",
"BAICHUAN_API_BASE",
DEFAULT_API_BASE,
)
values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
["baichuan_api_key", "api_key"],
"BAICHUAN_API_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Baichuan API."""
normal_params = {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"stream": self.streaming,
"max_tokens": self.max_tokens,
}
return {**normal_params, **self.model_kwargs}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Baichuan api response: {res}")
response = res.json()
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, stream=True, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error from Baichuan api response: {res}")
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in res.iter_lines():
chunk = chunk.decode("utf-8").strip("\r\n")
parts = chunk.split("data: ", 1)
chunk = parts[1] if len(parts) > 1 else None
if chunk is None:
continue
if chunk == "[DONE]":
break
response = json.loads(chunk)
for m in response.get("choices"):
chunk = _convert_delta_to_message_chunk(
m.get("delta"), default_chunk_class
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
response = await client.post(self.baichuan_api_base, json=payload)
response.raise_for_status()
return self._create_chat_result(response.json())
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
headers = self._create_headers_parameters(**kwargs)
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
import httpx
async with httpx.AsyncClient(
headers=headers, timeout=self.request_timeout
) as client:
async with aconnect_httpx_sse(
client, "POST", self.baichuan_api_base, json=payload
) as event_source:
async for sse in event_source.aiter_sse():
chunk = json.loads(sse.data)
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], AIMessageChunk
)
finish_reason = choice.get("finish_reason", None)
generation_info = (
{"finish_reason": finish_reason}
if finish_reason is not None
else None
)
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
if finish_reason is not None:
break
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
payload = self._create_payload_parameters(messages, **kwargs)
url = self.baichuan_api_base
headers = self._create_headers_parameters(**kwargs)
res = requests.post(
url=url,
timeout=self.request_timeout,
headers=headers,
json=payload,
stream=self.streaming,
)
return res
def _create_payload_parameters(
self, messages: List[BaseMessage], **kwargs: Any
) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
temperature = parameters.pop("temperature", 0.3)
top_k = parameters.pop("top_k", 5)
top_p = parameters.pop("top_p", 0.85)
model = parameters.pop("model")
with_search_enhance = parameters.pop("with_search_enhance", False)
stream = parameters.pop("stream", False)
tools = parameters.pop("tools", [])
payload = {
"model": model,
"messages": [_convert_message_to_dict(m) for m in messages],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"with_search_enhance": with_search_enhance,
"stream": stream,
"tools": tools,
}
return payload
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
parameters = {**self._default_params, **kwargs}
default_headers = parameters.pop("headers", {})
api_key = ""
if self.baichuan_api_key:
api_key = self.baichuan_api_key.get_secret_value()
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
**default_headers,
}
return headers
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for c in response["choices"]:
message = _convert_dict_to_message(c["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
token_usage = response["usage"]
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "baichuan-chat"
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool.
Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@@ -0,0 +1,841 @@
import json
import logging
import uuid
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
logger = logging.getLogger(__name__)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if len(message.tool_calls) != 0:
tool_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
}
# If function call only, content is None not empty string
message_dict["content"] = None
elif isinstance(message, (FunctionMessage, ToolMessage)):
message_dict = {
"role": "function",
"content": _create_tool_content(message.content),
"name": message.name or message.additional_kwargs.get("name"),
}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _create_tool_content(content: Union[str, List[Union[str, Dict[Any, Any]]]]) -> str:
"""Convert tool content to dict scheme."""
if isinstance(content, str):
try:
if isinstance(json.loads(content), dict):
return content
else:
return json.dumps({"tool_result": content})
except json.JSONDecodeError:
return json.dumps({"tool_result": content})
else:
return json.dumps({"tool_result": content})
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("result", "") or ""
additional_kwargs: Mapping[str, Any] = {}
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
if "thoughts" in additional_kwargs["function_call"]:
# align to api sample, which affects the llm function_call output
additional_kwargs["function_call"].pop("thoughts")
# DO NOT ADD ANY NUMERIC OBJECT TO `msg_additional_kwargs` AND `additional_kwargs`
# ALONG WITH THEIRS SUB-CONTAINERS !!!
# OR IT WILL RAISE A DEADLY EXCEPTION FROM `merge_dict`
# 不要往 `msg_additional_kwargs` 和 `additional_kwargs` 里面加任何数值类对象!
# 子容器也不行!
# 不然 `merge_dict` 会报错导致代码无法运行
additional_kwargs = {**_dict.get("body", {}), **additional_kwargs}
msg_additional_kwargs = dict(
finish_reason=additional_kwargs.get("finish_reason", ""),
request_id=additional_kwargs["id"],
object=additional_kwargs.get("object", ""),
search_info=additional_kwargs.get("search_info", []),
)
if additional_kwargs.get("function_call", {}):
msg_additional_kwargs["function_call"] = additional_kwargs.get(
"function_call", {}
)
msg_additional_kwargs["tool_calls"] = [
{
"type": "function",
"function": additional_kwargs.get("function_call", {}),
"id": str(uuid.uuid4()),
}
]
ret = AIMessage(
content=content,
additional_kwargs=msg_additional_kwargs,
)
if usage := additional_kwargs.get("usage", None):
ret.usage_metadata = UsageMetadata(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
)
return ret
class QianfanChatEndpoint(BaseChatModel):
"""Baidu Qianfan chat model integration.
Setup:
Install ``qianfan`` and set environment variables ``QIANFAN_AK``, ``QIANFAN_SK``.
.. code-block:: bash
pip install qianfan
export QIANFAN_AK="your-api-key"
export QIANFAN_SK="your-secret_key"
Key init args — completion params:
model: str
Name of Qianfan model to use.
temperature: Optional[float]
Sampling temperature.
endpoint: Optional[str]
Endpoint of the Qianfan LLM
top_p: Optional[float]
What probability mass to use.
Key init args — client params:
timeout: Optional[int]
Timeout for requests.
api_key: Optional[str]
Qianfan API KEY. If not passed in will be read from env var QIANFAN_AK.
secret_key: Optional[str]
Qianfan SECRET KEY. If not passed in will be read from env var QIANFAN_SK.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_community.chat_models import QianfanChatEndpoint
qianfan_chat = QianfanChatEndpoint(
model="ERNIE-3.5-8K",
temperature=0.2,
timeout=30,
# api_key="...",
# secret_key="...",
# top_p="...",
# other params...
)
Invoke:
.. code-block:: python
messages = [
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
("human", "我喜欢编程。"),
]
qianfan_chat.invoke(messages)
.. code-block:: python
AIMessage(content='I enjoy programming.', additional_kwargs={'finish_reason': 'normal', 'request_id': 'as-7848zeqn1c', 'object': 'chat.completion', 'search_info': []}, response_metadata={'token_usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}, 'model_name': 'ERNIE-3.5-8K', 'finish_reason': 'normal', 'id': 'as-7848zeqn1c', 'object': 'chat.completion', 'created': 1719153606, 'result': 'I enjoy programming.', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}}, id='run-4bca0c10-5043-456b-a5be-2f62a980f3f0-0')
Stream:
.. code-block:: python
for chunk in qianfan_chat.stream(messages):
print(chunk)
.. code-block:: python
content='I enjoy' response_metadata={'finish_reason': 'normal', 'request_id': 'as-yz0yz1w1rq', 'object': 'chat.completion', 'search_info': []} id='run-0fa9da50-003e-4a26-ba16-dbfe96249b8b' role='assistant'
content=' programming.' response_metadata={'finish_reason': 'normal', 'request_id': 'as-yz0yz1w1rq', 'object': 'chat.completion', 'search_info': []} id='run-0fa9da50-003e-4a26-ba16-dbfe96249b8b' role='assistant'
.. code-block:: python
stream = chat.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full
.. code-block::
AIMessageChunk(content='I enjoy programming.', response_metadata={'finish_reason': 'normalnormal', 'request_id': 'as-p63cnn3ppnas-p63cnn3ppn', 'object': 'chat.completionchat.completion', 'search_info': []}, id='run-09a8cbbd-5ded-4529-981d-5bc9d1206404')
Async:
.. code-block:: python
await qianfan_chat.ainvoke(messages)
# stream:
# async for chunk in qianfan_chat.astream(messages):
# print(chunk)
# batch:
# await qianfan_chat.abatch([messages])
.. code-block:: python
[AIMessage(content='I enjoy programming.', additional_kwargs={'finish_reason': 'normal', 'request_id': 'as-mpqa8qa1qb', 'object': 'chat.completion', 'search_info': []}, response_metadata={'token_usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}, 'model_name': 'ERNIE-3.5-8K', 'finish_reason': 'normal', 'id': 'as-mpqa8qa1qb', 'object': 'chat.completion', 'created': 1719155120, 'result': 'I enjoy programming.', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}}, id='run-443b2231-08f9-4725-b807-b77d0507ad44-0')]
Tool calling:
.. code-block:: python
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
chat_with_tools = qianfan_chat.bind_tools([GetWeather, GetPopulation])
ai_msg = chat_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': '533e5f63-a3dc-40f2-9d9c-22b1feee62e0'
}
]
Structured output:
.. code-block:: python
from typing import Optional
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_chat = qianfan_chat.with_structured_output(Joke)
structured_chat.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(
setup='A cat is sitting in front of a mirror and sees another cat. What does the cat think?',
punchline="The cat doesn't think it's another cat, it thinks it's another mirror.",
rating=None
)
Response metadata
.. code-block:: python
ai_msg = qianfan_chat.invoke(messages)
ai_msg.response_metadata
.. code-block:: python
{
'token_usage': {
'prompt_tokens': 16,
'completion_tokens': 4,
'total_tokens': 20},
'model_name': 'ERNIE-3.5-8K',
'finish_reason': 'normal',
'id': 'as-qbzwtydqmi',
'object': 'chat.completion',
'created': 1719158153,
'result': 'I enjoy programming.',
'is_truncated': False,
'need_clear_history': False,
'usage': {
'prompt_tokens': 16,
'completion_tokens': 4,
'total_tokens': 20
}
}
""" # noqa: E501
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
client: Any = None #: :meta private:
# It could be empty due to the use of Console API
# And they're not list here
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Qianfan API KEY"""
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
"""Qianfan SECRET KEY"""
streaming: Optional[bool] = False
"""Whether to stream the results or not."""
request_timeout: Optional[int] = Field(60, alias="timeout")
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
"""What probability mass to use."""
temperature: Optional[float] = 0.95
"""What sampling temperature to use."""
penalty_score: Optional[float] = 1
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
In the case of other model, passing these params will not affect the result.
"""
model: Optional[str] = Field(default=None)
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set.
Default is set by `qianfan` SDK, not here
"""
endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values, ["qianfan_ak", "api_key"], "QIANFAN_AK", default=""
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values, ["qianfan_sk", "secret_key"], "QIANFAN_SK", default=""
)
)
default_values = {
name: field.default
for name, field in get_fields(cls).items()
if field.default is not None
}
default_values.update(values)
params = {
**values.get("init_kwargs", {}),
"model": default_values.get("model"),
"stream": default_values.get("streaming"),
}
if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value()
if (
default_values.get("endpoint") is not None
and default_values["endpoint"] != ""
):
params["endpoint"] = default_values["endpoint"]
try:
import qianfan
values["client"] = qianfan.ChatCompletion(**params)
except ImportError:
raise ImportError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of chat_model."""
return "baidu-qianfan-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Qianfan API."""
normal_params = {
"model": self.model,
"endpoint": self.endpoint,
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> Dict[str, Any]:
"""
Converts a list of messages into a dictionary containing the message content
and default parameters.
Args:
messages (List[BaseMessage]): The list of messages.
**kwargs (Any): Optional arguments to add additional parameters to the
resulting dictionary.
Returns:
`dict` containing the message content and default parameters.
"""
messages_dict: Dict[str, Any] = {
"messages": [
convert_message_to_dict(m)
for m in messages
if not isinstance(m, SystemMessage)
]
}
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
if "system" not in messages_dict:
messages_dict["system"] = ""
messages_dict["system"] += cast(str, messages[i].content) + "\n"
return {
**messages_dict,
**self._default_params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
messages: The messages to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = qianfan_model.invoke("Tell me a joke.")
"""
if self.streaming:
completion = ""
chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
for chunk in self._stream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text
if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
llm_output={
"token_usage": usage_metadata or {},
"model_name": self.model,
},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
response_payload = self.client.do(**params)
lc_msg = _convert_dict_to_message(response_payload)
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=[gen], llm_output=llm_output)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
completion = ""
chat_generation_info: Dict = {}
usage_metadata: Optional[UsageMetadata] = None
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text
if isinstance(chunk.message, AIMessageChunk):
usage_metadata = chunk.message.usage_metadata
lc_msg = AIMessage(
content=completion,
additional_kwargs={},
usage_metadata=usage_metadata,
)
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
llm_output={
"token_usage": usage_metadata or {},
"model_name": self.model,
},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
response_payload = await self.client.ado(**params)
lc_msg = _convert_dict_to_message(response_payload)
generations = []
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
generations.append(gen)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk( # type: ignore[call-arg]
content=msg.content,
role="assistant",
additional_kwargs=additional_kwargs,
usage_metadata=msg.usage_metadata,
tool_call_chunks=[
tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in msg.tool_calls
],
),
generation_info=msg.additional_kwargs,
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
async for res in await self.client.ado(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk( # type: ignore[call-arg]
content=msg.content,
role="assistant",
additional_kwargs=additional_kwargs,
usage_metadata=msg.usage_metadata,
tool_call_chunks=[
tool_call_chunk(
name=tc["name"],
args=json.dumps(tc["args"]),
id=tc["id"],
index=None,
)
for tc in msg.tool_calls
],
),
generation_info=msg.additional_kwargs,
)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
return super().bind(functions=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_mistralai import QianfanChatEndpoint
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
.. code-block:: python
from langchain_mistralai import QianfanChatEndpoint
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
.. code-block:: python
from langchain_mistralai import QianfanChatEndpoint
from pydantic import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
llm = self.bind_tools([schema])
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True,
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

View File

@@ -0,0 +1,337 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import ConfigDict
from langchain_community.chat_models.anthropic import (
convert_messages_to_prompt_anthropic,
)
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
from langchain_community.llms.bedrock import BedrockBase
from langchain_community.utilities.anthropic import (
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<<SYS>> {message.content} <</SYS>>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a prompt for mistral."""
return "\n".join(
[_convert_one_message_to_text_mistral(message) for message in messages]
)
def _format_image(image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
to a dict for anthropic api
{
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg...",
}
And throws an error if it's not a b64 image
"""
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Anthropic only supports base64-encoded images currently."
" Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"type": "base64",
"media_type": match.group("media_type"),
"data": match.group("data"),
}
def _format_anthropic_messages(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
"""
[
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
}
for m in messages
]
"""
system: Optional[str] = None
formatted_messages: List[Dict] = []
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system = message.content
continue
role = _message_type_lookups[message.type]
content: Union[str, List[Dict]]
if not isinstance(message.content, str):
# parse as dict
assert isinstance(message.content, list), (
"Anthropic message content must be str or list of dicts"
)
# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
else:
content = message.content
formatted_messages.append(
{
"role": role,
"content": content,
}
)
return system, formatted_messages
class ChatPromptAdapter:
"""Adapter class to prepare the inputs from Langchain to prompt format
that Chat model expects.
"""
@classmethod
def convert_messages_to_prompt(
cls, provider: str, messages: List[BaseMessage]
) -> str:
if provider == "anthropic":
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
elif provider == "mistral":
prompt = convert_messages_to_prompt_mistral(messages=messages)
elif provider == "amazon":
prompt = convert_messages_to_prompt_anthropic(
messages=messages,
human_prompt="\n\nUser:",
ai_prompt="\n\nBot:",
)
else:
raise NotImplementedError(
f"Provider {provider} model does not support chat."
)
return prompt
@classmethod
def format_messages(
cls, provider: str, messages: List[BaseMessage]
) -> Tuple[Optional[str], List[Dict]]:
if provider == "anthropic":
return _format_anthropic_messages(messages)
raise NotImplementedError(
f"Provider {provider} not supported for format_messages"
)
_message_type_lookups = {
"human": "user",
"ai": "assistant",
"AIMessageChunk": "assistant",
"HumanMessageChunk": "user",
"function": "user",
}
@deprecated(
since="0.0.34", removal="1.0", alternative_import="langchain_aws.ChatBedrock"
)
class BedrockChat(BaseChatModel, BedrockBase):
"""Chat model that uses the Bedrock API."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "amazon_bedrock_chat"
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "bedrock"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.region_name:
attributes["region_name"] = self.region_name
return attributes
model_config = ConfigDict(
extra="forbid",
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt,
system=system,
messages=formatted_messages,
stop=stop,
run_manager=run_manager,
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
params: Dict[str, Any] = {**kwargs}
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
if stop:
params["stop_sequences"] = stop
completion, usage_info = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
system=system,
messages=formatted_messages,
**params,
)
llm_output["usage"] = usage_info
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=completion))],
llm_output=llm_output,
)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_usage: Dict[str, int] = defaultdict(int)
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.get("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
return final_output
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)

View File

@@ -0,0 +1,256 @@
import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from uuid import uuid4
import requests
from langchain_classic.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _convert_messages_to_cloudflare_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Cloudflare Workers AI format."""
cloudflare_messages = []
msg: Dict[str, Any]
for message in messages:
# Base structure for each message
msg = {
"role": "",
"content": message.content if isinstance(message.content, str) else "",
}
# Determine role and additional fields based on message type
if isinstance(message, HumanMessage):
msg["role"] = "user"
elif isinstance(message, AIMessage):
msg["role"] = "assistant"
# If the AIMessage includes tool calls, format them as needed
if message.tool_calls:
tool_calls = [
{"name": tool_call["name"], "arguments": tool_call["args"]}
for tool_call in message.tool_calls
]
msg["tool_calls"] = tool_calls
elif isinstance(message, SystemMessage):
msg["role"] = "system"
elif isinstance(message, ToolMessage):
msg["role"] = "tool"
msg["tool_call_id"] = (
message.tool_call_id
) # Use tool_call_id if it's a ToolMessage
# Add the formatted message to the list
cloudflare_messages.append(msg)
return cloudflare_messages
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "tool_calls" in response.json()["result"]:
for tc in response.json()["result"]["tool_calls"]:
tool_calls.append(
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
)
return tool_calls
@deprecated(
since="0.3.23",
removal="1.0",
alternative_import="langchain_cloudflare.ChatCloudflareWorkersAI",
)
class ChatCloudflareWorkersAI(BaseChatModel):
"""Custom chat model for Cloudflare Workers AI"""
account_id: str = Field(...)
api_token: str = Field(...)
model: str = Field(...)
ai_gateway: str = ""
url: str = ""
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
def __init__(self, **kwargs: Any) -> None:
"""Initialize with necessary credentials."""
super().__init__(**kwargs)
if self.ai_gateway:
self.url = (
f"{self.gateway_url}/{self.account_id}/"
f"{self.ai_gateway}/workers-ai/run/{self.model}"
)
else:
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response based on the messages provided."""
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
headers = {"Authorization": f"Bearer {self.api_token}"}
prompt = "\n".join(
f"role: {msg['role']}, content: {msg['content']}"
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
+ (
f", tool_call_id: {msg['tool_call_id']}"
if "tool_call_id" in msg
else ""
)
for msg in formatted_messages
)
# Initialize `data` with `prompt`
data = {
"prompt": prompt,
"tools": kwargs["tools"] if "tools" in kwargs else None,
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
}
# Ensure `tools` is a list if it's included in `kwargs`
if data["tools"] is not None and not isinstance(data["tools"], list):
data["tools"] = [data["tools"]]
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
response = requests.post(self.url, headers=headers, json=data)
tool_calls = _get_tool_calls_from_response(response)
ai_message = AIMessage(
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
)
chat_generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[chat_generation])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tools for use in model generation."""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "json_schema":
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
method = "function_calling"
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _llm_type(self) -> str:
"""Return the type of the LLM (for Langchain compatibility)."""
return "cloudflare-workers-ai"

View File

@@ -0,0 +1,251 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import ConfigDict
from langchain_community.llms.cohere import BaseCohere
def get_role(message: BaseMessage) -> str:
"""Get the role of the message.
Args:
message: The message.
Returns:
The role of the message.
Raises:
ValueError: If the message is of an unknown type.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "User"
elif isinstance(message, AIMessage):
return "Chatbot"
elif isinstance(message, SystemMessage):
return "System"
else:
raise ValueError(f"Got unknown type {message}")
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Args:
messages: The messages.
connectors: The connectors.
**kwargs: The keyword arguments.
Returns:
The request for the Cohere chat API.
"""
documents = (
None
if "source_documents" not in kwargs
else [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(kwargs["source_documents"])
]
)
kwargs.pop("source_documents", None)
maybe_connectors = connectors if documents is None else None
# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
prompt_truncation = (
"AUTO" if documents is not None or connectors is not None else None
)
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
],
"documents": documents,
"connectors": maybe_connectors,
"prompt_truncation": prompt_truncation,
**kwargs,
}
return {k: v for k, v in req.items() if v is not None}
@deprecated(
since="0.0.30", removal="1.0", alternative_import="langchain_cohere.ChatCohere"
)
class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatCohere
from langchain_core.messages import HumanMessage
chat = ChatCohere(max_tokens=256, temperature=0.75)
messages = [HumanMessage(content="knock knock")]
chat.invoke(messages)
"""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "cohere-chat"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"temperature": self.temperature,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)
for data in stream:
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = await self.async_client.chat_stream(**request)
else:
stream = await self.async_client.chat(**request, stream=True)
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
return len(self.client.tokenize(text=text).tokens)

View File

@@ -0,0 +1,255 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
)
from pydantic import ConfigDict, Field, SecretStr, model_validator
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.coze.com"
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, HumanMessage):
message_dict = {
"role": "user",
"content": message.content,
"content_type": "text",
}
else:
message_dict = {
"role": "assistant",
"content": message.content,
"content_type": "text",
}
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
msg_type = _dict["type"]
if msg_type != "answer":
return None
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
return AIMessageChunk(content=content)
else:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
class ChatCoze(BaseChatModel):
"""ChatCoze chat models API by coze.com
For more information, see https://www.coze.com/open/docs/chat
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"coze_api_key": "COZE_API_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
coze_api_base: str = Field(default=DEFAULT_API_BASE)
"""Coze custom endpoints"""
coze_api_key: Optional[SecretStr] = None
"""Coze API Key"""
request_timeout: int = Field(default=60, alias="timeout")
"""request timeout for chat http requests"""
bot_id: str = Field(default="")
"""The ID of the bot that the API interacts with."""
conversation_id: str = Field(default="")
"""Indicate which conversation the dialog is taking place in. If there is no need to
distinguish the context of the conversation(just a question and answer), skip this
parameter. It will be generated by the system."""
user: str = Field(default="")
"""The user who calls the API to chat with the bot."""
streaming: bool = False
"""Whether to stream the response to the client.
false: if no value is specified or set to false, a non-streaming response is
returned. "Non-streaming response" means that all responses will be returned at once
after they are all ready, and the client does not need to concatenate the content.
true: set to true, partial message deltas will be sent .
"Streaming response" will provide real-time response of the model to the client, and
the client needs to assemble the final reply based on the type of message. """
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["coze_api_base"] = get_from_dict_or_env(
values,
"coze_api_base",
"COZE_API_BASE",
DEFAULT_API_BASE,
)
values["coze_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"coze_api_key",
"COZE_API_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Coze API."""
return {
"bot_id": self.bot_id,
"conversation_id": self.conversation_id,
"user": self.user,
"streaming": self.streaming,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
r = self._chat(messages, **kwargs)
res = r.json()
if res["code"] != 0:
raise ValueError(
f"Error from Coze api response: {res['code']}: {res['msg']}, "
f"logid: {r.headers.get('X-Tt-Logid')}"
)
return self._create_chat_result(res.get("messages") or [])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
for chunk in res.iter_lines():
chunk = chunk.decode("utf-8").strip("\r\n")
parts = chunk.split("data:", 1)
chunk = parts[1] if len(parts) > 1 else None
if chunk is None:
continue
response = json.loads(chunk)
if response["event"] == "done":
break
elif (
response["event"] != "message"
or response["message"]["type"] != "answer"
):
continue
chunk = _convert_delta_to_message_chunk(response["message"])
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}
query = ""
chat_history = []
for msg in messages:
if isinstance(msg, HumanMessage):
query = f"{msg.content}" # overwrite, to get last user message as query
chat_history.append(_convert_message_to_dict(msg))
conversation_id = parameters.pop("conversation_id")
bot_id = parameters.pop("bot_id")
user = parameters.pop("user")
streaming = parameters.pop("streaming")
payload = {
"conversation_id": conversation_id,
"bot_id": bot_id,
"user": user,
"query": query,
"stream": streaming,
}
if chat_history:
payload["chat_history"] = chat_history
url = self.coze_api_base + "/open_api/v2/chat"
api_key = ""
if self.coze_api_key:
api_key = self.coze_api_key.get_secret_value()
res = requests.post(
url=url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
json=payload,
stream=streaming,
)
if res.status_code != 200:
logid = res.headers.get("X-Tt-Logid")
raise ValueError(f"Error from Coze api response: {res}, logid: {logid}")
return res
def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult:
generations = []
for c in messages:
msg = _convert_dict_to_message(c)
if msg:
generations.append(ChatGeneration(message=msg))
llm_output = {"token_usage": "", "model": ""}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "coze-chat"

View File

@@ -0,0 +1,161 @@
from typing import Any, Dict, List, Optional, Union
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import (
AIMessage,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from langchain_community.utilities.requests import Requests
def _format_dappier_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
formatted_messages = []
for message in messages:
if message.type == "human":
formatted_messages.append({"role": "user", "content": message.content})
elif message.type == "system":
formatted_messages.append({"role": "system", "content": message.content})
return formatted_messages
class ChatDappierAI(BaseChatModel):
"""`Dappier` chat large language models.
`Dappier` is a platform enabling access to diverse, real-time data models.
Enhance your AI applications with Dappier's pre-trained, LLM-ready data models
and ensure accurate, current responses with reduced inaccuracies.
To use one of our Dappier AI Data Models, you will need an API key.
Please visit Dappier Platform (https://platform.dappier.com/) to log in
and create an API key in your profile.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatDappierAI
from langchain_core.messages import HumanMessage
# Initialize `ChatDappierAI` with the desired configuration
chat = ChatDappierAI(
dappier_endpoint="https://api.dappier.com/app/datamodel/dm_01hpsxyfm2fwdt2zet9cg6fdxt",
dappier_api_key="<YOUR_KEY>")
# Create a list of messages to interact with the model
messages = [HumanMessage(content="hello")]
# Invoke the model with the provided messages
chat.invoke(messages)
you can find more details here : https://docs.dappier.com/introduction"""
dappier_endpoint: str = "https://api.dappier.com/app/datamodelconversation"
dappier_model: str = "dm_01hpsxyfm2fwdt2zet9cg6fdxt"
dappier_api_key: Optional[SecretStr] = Field(None, description="Dappier API Token")
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key exists in environment."""
values["dappier_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "dappier_api_key", "DAPPIER_API_KEY")
)
return values
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "dappier-realtimesearch-chat"
@property
def _api_key(self) -> str:
if self.dappier_api_key:
return self.dappier_api_key.get_secret_value()
return ""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
url = f"{self.dappier_endpoint}"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
user_query = _format_dappier_messages(messages=messages)
payload: Dict[str, Any] = {
"model": self.dappier_model,
"conversation": user_query,
}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
response.raise_for_status()
data = response.json()
message_response = data["message"]
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=message_response))]
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
url = f"{self.dappier_endpoint}"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
user_query = _format_dappier_messages(messages=messages)
payload: Dict[str, Any] = {
"model": self.dappier_model,
"conversation": user_query,
}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
response.raise_for_status()
data = await response.json()
message_response = data["message"]
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=message_response))
]
)

View File

@@ -0,0 +1,60 @@
import logging
from urllib.parse import urlparse
from langchain_core._api import deprecated
from langchain_community.chat_models.mlflow import ChatMlflow
logger = logging.getLogger(__name__)
@deprecated(
since="0.3.3",
removal="1.0",
alternative_import="databricks_langchain.ChatDatabricks",
)
class ChatDatabricks(ChatMlflow):
"""`Databricks` chat models API.
To use, you should have the ``mlflow`` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatDatabricks
chat_model = ChatDatabricks(
target_uri="databricks",
endpoint="databricks-llama-2-70b-chat",
temperature=0.1,
)
# single input invocation
print(chat_model.invoke("What is MLflow?").content)
# single input invocation with streaming response
for chunk in chat_model.stream("What is MLflow?"):
print(chunk.content, end="|")
"""
target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "databricks-chat"
@property
def _mlflow_extras(self) -> str:
return ""
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)

View File

@@ -0,0 +1,546 @@
"""deepinfra.com chat models wrapper"""
from __future__ import annotations
import json
import logging
from json import JSONDecodeError
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import aiohttp
import requests
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
class ChatDeepInfraException(Exception):
"""Exception raised when the DeepInfra API returns an error."""
pass
def _create_retry_decorator(
llm: ChatDeepInfra,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions."""
return create_base_retry_decorator(
error_types=[requests.exceptions.ConnectTimeout, ChatDeepInfraException],
max_retries=llm.max_retries,
run_manager=run_manager,
)
def _parse_tool_calling(tool_call: dict) -> ToolCall:
"""
Convert a tool calling response from server to a ToolCall object.
Args:
tool_call:
Returns:
"""
name = tool_call["function"].get("name", "")
try:
args = json.loads(tool_call["function"]["arguments"])
except (JSONDecodeError, TypeError):
args = {}
id = tool_call.get("id")
return create_tool_call(name=name, args=args, id=id)
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
"""
Convert a ToolCall object to a tool calling request for server.
Args:
tool_call:
Returns:
"""
return {
"type": "function",
"function": {
"arguments": json.dumps(tool_call["args"]),
"name": tool_call["name"],
},
"id": tool_call.get("id"),
}
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
content = _dict.get("content", "") or ""
tool_calls_content = _dict.get("tool_calls", []) or []
tool_calls = [
_parse_tool_calling(tool_call) for tool_call in tool_calls_content
]
return AIMessage(content=content, tool_calls=tool_calls)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
tool_calls = _dict.get("tool_calls") or []
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
tool_calls = [_parse_tool_calling(tool_call) for tool_call in tool_calls]
return AIMessageChunk(content=content, tool_calls=tool_calls)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
tool_calls = [
_convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
]
message_dict = {
"role": "assistant",
"content": message.content,
"tool_calls": tool_calls, # type: ignore[dict-item]
}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"name": message.name, # type: ignore[dict-item]
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class ChatDeepInfra(BaseChatModel):
"""A chat model that uses the DeepInfra API."""
# client: Any #: :meta private:
model_name: str = Field(default="meta-llama/Llama-2-70b-chat-hf", alias="model")
"""Model name to use."""
url: str = "https://api.deepinfra.com/v1/openai/chat/completions"
"""URL to use for the API call."""
deepinfra_api_token: Optional[str] = None
request_timeout: Optional[float] = Field(default=None, alias="timeout")
temperature: Optional[float] = 1
"""Run inference with this temperature. Must be in the closed
interval [0.0, 1.0]."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_tokens: int = 256
streaming: bool = False
max_retries: int = 1
model_config = ConfigDict(
populate_by_name=True,
)
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model_name,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
"request_timeout": self.request_timeout,
**self.model_kwargs,
}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
return {**self._default_params}
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
try:
request_timeout = kwargs.pop("request_timeout")
request = Requests(headers=self._headers())
response = request.post(
url=self._url(), data=self._body(kwargs), timeout=request_timeout
)
self._handle_status(response.status_code, response.text)
return response
except Exception as e:
print("EX", e) # noqa: T201
raise
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
self,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
try:
request_timeout = kwargs.pop("request_timeout")
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(kwargs), timeout=request_timeout
) as response:
self._handle_status(response.status, await response.text())
return await response.json()
except Exception as e:
print("EX", e) # noqa: T201
raise
return await _completion_with_retry(**kwargs)
@model_validator(mode="before")
@classmethod
def init_defaults(cls, values: Dict) -> Any:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
# For compatibility with LiteLLM
api_key = get_from_dict_or_env(
values,
"deepinfra_api_key",
"DEEPINFRA_API_KEY",
default="",
)
values["deepinfra_api_token"] = get_from_dict_or_env(
values,
"deepinfra_api_token",
"DEEPINFRA_API_TOKEN",
default=api_key,
)
return values
@model_validator(mode="after")
def validate_environment(self) -> Self:
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if self.top_k is not None and self.top_k <= 0:
raise ValueError("top_k must be positive")
return self
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response.json())
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model": self.model_name}
res = ChatResult(generations=generations, llm_output=llm_output)
return res
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {"messages": message_dicts, "stream": True, **params, **kwargs}
request_timeout = params.pop("request_timeout")
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(params), timeout=request_timeout
) as response:
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager:
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk
)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {"messages": message_dicts, **params, **kwargs}
res = await self.acompletion_with_retry(run_manager=run_manager, **params)
return self._create_chat_result(res)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "deepinfra-chat"
def _handle_status(self, code: int, text: Any) -> None:
if code >= 500:
raise ChatDeepInfraException(
f"DeepInfra Server error status {code}: {text}"
)
elif code >= 400:
raise ValueError(f"DeepInfra received an invalid payload: {text}")
elif code != 200:
raise Exception(
f"DeepInfra returned an unexpected response with status {code}: {text}"
)
def _url(self) -> str:
return self.url
def _headers(self) -> Dict:
return {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
def _body(self, kwargs: Any) -> Dict:
return kwargs
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
async def _parse_stream_async(rbody: aiohttp.StreamReader) -> AsyncIterator[str]:
async for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
def _parse_stream_helper(line: bytes) -> Optional[str]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
else:
return line.decode("utf-8")
return None
def _handle_sse_line(line: str) -> Optional[BaseMessageChunk]:
try:
obj = json.loads(line)
default_chunk_class = AIMessageChunk
delta = obj.get("choices", [{}])[0].get("delta", {})
return _convert_delta_to_message_chunk(delta, default_chunk_class)
except Exception:
return None

View File

@@ -0,0 +1,627 @@
import json
import warnings
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
from langchain_core.messages.tool import tool_call as create_tool_call
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
)
from langchain_community.utilities.requests import Requests
def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk:
message = generated_result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
create_tool_call_chunk(
name=tool_call["name"],
args=json.dumps(tool_call["args"]),
id=tool_call["id"],
index=idx,
)
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks,
)
return ChatGenerationChunk(message=message_chunk)
else:
return cast(ChatGenerationChunk, generated_result.generations[0])
def _message_role(type: str) -> str:
role_mapping = {
"ai": "assistant",
"human": "user",
"chat": "user",
"AIMessageChunk": "assistant",
}
if type in role_mapping:
return role_mapping[type]
else:
raise ValueError(f"Unknown type: {type}")
def _extract_edenai_tool_results_from_messages(
messages: List[BaseMessage],
) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]:
"""
Get the last langchain tools messages to transform them into edenai tool_results
Returns tool_results and messages without the extracted tool messages
"""
tool_results: List[Dict[str, Any]] = []
other_messages = messages[:]
for msg in reversed(messages):
if isinstance(msg, ToolMessage):
tool_results = [
{"id": msg.tool_call_id, "result": msg.content},
*tool_results,
]
other_messages.pop()
else:
break
return tool_results, other_messages
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = None
formatted_messages = []
human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
last_human_message = human_messages[-1] if human_messages else ""
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
for i, message in enumerate(other_messages):
if isinstance(message, SystemMessage):
if i != 0:
raise ValueError("System message must be at beginning of message list.")
system = message.content
elif isinstance(message, ToolMessage):
formatted_messages.append({"role": "tool", "message": message.content})
elif message != last_human_message:
formatted_messages.append(
{
"role": _message_role(message.type),
"message": message.content,
"tool_calls": _format_tool_calls_to_edenai_tool_calls(message),
}
)
return {
"text": getattr(last_human_message, "content", ""),
"previous_history": formatted_messages,
"chatbot_global_action": system,
"tool_results": tool_results,
}
def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List:
tool_calls = getattr(message, "tool_calls", [])
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
edenai_tool_calls = []
for invalid_tool_call in invalid_tool_calls:
edenai_tool_calls.append(
{
"arguments": invalid_tool_call.get("args"),
"id": invalid_tool_call.get("id"),
"name": invalid_tool_call.get("name"),
}
)
for tool_call in tool_calls:
tool_args = tool_call.get("args", {})
try:
arguments = json.dumps(tool_args)
except TypeError:
arguments = str(tool_args)
edenai_tool_calls.append(
{
"arguments": arguments,
"id": tool_call["id"],
"name": tool_call["name"],
}
)
return edenai_tool_calls
def _extract_tool_calls_from_edenai_response(
provider_response: Dict[str, Any],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
tool_calls = []
invalid_tool_calls = []
message = provider_response.get("message", {})[1]
if raw_tool_calls := message.get("tool_calls"):
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(
create_tool_call(
name=raw_tool_call["name"],
args=json.loads(raw_tool_call["arguments"]),
id=raw_tool_call["id"],
)
)
except json.JSONDecodeError as exc:
invalid_tool_calls.append(
create_invalid_tool_call(
name=raw_tool_call.get("name"),
args=raw_tool_call.get("arguments"),
id=raw_tool_call.get("id"),
error=f"Received JSONDecodeError {exc}",
)
)
return tool_calls, invalid_tool_calls
class ChatEdenAI(BaseChatModel):
"""`EdenAI` chat large language models.
`EdenAI` is a versatile platform that allows you to access various language models
from different providers such as Google, OpenAI, Cohere, Mistral and more.
To get started, make sure you have the environment variable ``EDENAI_API_KEY``
set with your API key, or pass it as a named parameter to the constructor.
Additionally, `EdenAI` provides the flexibility to choose from a variety of models,
including the ones like "gpt-4".
Example:
.. code-block:: python
from langchain_community.chat_models import ChatEdenAI
from langchain_core.messages import HumanMessage
# Initialize `ChatEdenAI` with the desired configuration
chat = ChatEdenAI(
provider="openai",
model="gpt-4",
max_tokens=256,
temperature=0.75)
# Create a list of messages to interact with the model
messages = [HumanMessage(content="hello")]
# Invoke the model with the provided messages
chat.invoke(messages)
`EdenAI` goes beyond mere model invocation. It empowers you with advanced features :
- **Multiple Providers**: access to a diverse range of llms offered by various
providers giving you the freedom to choose the best-suited model for your use case.
- **Fallback Mechanism**: Set a fallback mechanism to ensure seamless operations
even if the primary provider is unavailable, you can easily switches to an
alternative provider.
- **Usage Statistics**: Track usage statistics on a per-project
and per-API key basis.
This feature allows you to monitor and manage resource consumption effectively.
- **Monitoring and Observability**: `EdenAI` provides comprehensive monitoring
and observability tools on the platform.
Example of setting up a fallback mechanism:
.. code-block:: python
# Initialize `ChatEdenAI` with a fallback provider
chat_with_fallback = ChatEdenAI(
provider="openai",
model="gpt-4",
max_tokens=256,
temperature=0.75,
fallback_provider="google")
you can find more details here : https://docs.edenai.co/reference/text_chat_create
"""
provider: str = "openai"
"""chat provider to use (eg: openai,google etc.)"""
model: Optional[str] = None
"""
model name for above provider (eg: 'gpt-4' for openai)
available models are shown on https://docs.edenai.co/ under 'available providers'
"""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = 0
"""A non-negative float that tunes the degree of randomness in generation."""
streaming: bool = False
"""Whether to stream the results."""
fallback_providers: Optional[str] = None
"""Providers in this will be used as fallback if the call to provider fails."""
edenai_api_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
model_config = ConfigDict(
extra="forbid",
)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["edenai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY")
)
return values
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "edenai-chat"
@property
def _api_key(self) -> str:
if self.edenai_api_key:
return self.edenai_api_key.get_secret_value()
return ""
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Call out to EdenAI's chat endpoint."""
if "available_tools" in kwargs:
yield self._stream_with_tools_as_generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload, stream=True)
response.raise_for_status()
for chunk_response in response.iter_lines():
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
if run_manager:
run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if "available_tools" in kwargs:
yield await self._astream_with_tools_as_agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
response.raise_for_status()
async for chunk_response in response.content:
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk["text"], chunk=cg_chunk
)
yield cg_chunk
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
formatted_tool_choice = "required" if tool_choice == "any" else tool_choice
return super().bind(
available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs
)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
llm = self.bind_tools([schema], tool_choice="required")
if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to EdenAI's chat endpoint."""
if self.streaming:
if "available_tools" in kwargs:
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
response.raise_for_status()
data = response.json()
provider_response = data[self.provider]
if self.fallback_providers:
fallback_response = data.get(self.fallback_providers)
if fallback_response:
provider_response = fallback_response
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response(
provider_response
)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=provider_response["generated_text"] or "",
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
)
],
llm_output=data,
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
if "available_tools" in kwargs:
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
response.raise_for_status()
data = await response.json()
provider_response = data[self.provider]
if self.fallback_providers:
fallback_response = data.get(self.fallback_providers)
if fallback_response:
provider_response = fallback_response
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=provider_response["generated_text"]
)
)
],
llm_output=data,
)
def _stream_with_tools_as_generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
run_manager: Optional[CallbackManagerForLLMRun],
**kwargs: Any,
) -> ChatGenerationChunk:
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
return _result_to_chunked_message(result)
async def _astream_with_tools_as_agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
run_manager: Optional[AsyncCallbackManagerForLLMRun],
**kwargs: Any,
) -> ChatGenerationChunk:
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _result_to_chunked_message(result)

View File

@@ -0,0 +1,229 @@
import logging
import threading
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator
logger = logging.getLogger(__name__)
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
@deprecated(
since="0.0.13",
alternative="langchain_community.chat_models.QianfanChatEndpoint",
)
class ErnieBotChat(BaseChatModel):
"""`ERNIE-Bot` large language model.
ERNIE-Bot is a large language model developed by Baidu,
covering a huge amount of Chinese data.
To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
Note:
access_token will be automatically generated based on client_id and client_secret,
and will be regenerated after expiration (30 days).
Default model is `ERNIE-Bot-turbo`,
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`,
`ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`.
Example:
.. code-block:: python
from langchain_community.chat_models import ErnieBotChat
chat = ErnieBotChat(model_name='ERNIE-Bot')
Deprecated Note:
Please use `QianfanChatEndpoint` instead of this class.
`QianfanChatEndpoint` is a more suitable choice for production.
Always test your code after changing to `QianfanChatEndpoint`.
Example of `QianfanChatEndpoint`:
.. code-block:: python
from langchain_community.chat_models import QianfanChatEndpoint
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
ernie_api_base: Optional[str] = None
"""Baidu application custom endpoints"""
ernie_client_id: Optional[str] = None
"""Baidu application client id"""
ernie_client_secret: Optional[str] = None
"""Baidu application client secret"""
access_token: Optional[str] = None
"""access token is generated by client id and client secret,
setting this value directly will cause an error"""
model_name: str = "ERNIE-Bot-turbo"
"""model name of ernie, default is `ERNIE-Bot-turbo`.
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
system: Optional[str] = None
"""system is mainly used for model character design,
for example, you are an AI assistant produced by xxx company.
The length of the system is limiting of 1024 characters."""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
streaming: Optional[bool] = False
"""streaming mode. not supported yet."""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
_lock = threading.Lock()
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["ernie_api_base"] = get_from_dict_or_env(
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
)
values["ernie_client_id"] = get_from_dict_or_env(
values,
"ernie_client_id",
"ERNIE_CLIENT_ID",
)
values["ernie_client_secret"] = get_from_dict_or_env(
values,
"ernie_client_secret",
"ERNIE_CLIENT_SECRET",
)
return values
def _chat(self, payload: object) -> dict:
base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
model_paths = {
"ERNIE-Bot-turbo": "eb-instant",
"ERNIE-Bot": "completions",
"ERNIE-Bot-8K": "ernie_bot_8k",
"ERNIE-Bot-4": "completions_pro",
"ERNIE-Bot-turbo-AI": "ai_apaas",
"BLOOMZ-7B": "bloomz_7b1",
"Llama-2-7b-chat": "llama_2_7b",
"Llama-2-13b-chat": "llama_2_13b",
"Llama-2-70b-chat": "llama_2_70b",
}
if self.model_name in model_paths:
url = f"{base_url}/{model_paths[self.model_name]}"
else:
raise ValueError(f"Got unknown model_name {self.model_name}")
resp = requests.post(
url,
timeout=self.request_timeout,
headers={
"Content-Type": "application/json",
},
params={"access_token": self.access_token},
json=payload,
)
return resp.json()
def _refresh_access_token_with_lock(self) -> None:
with self._lock:
logger.debug("Refreshing access token")
base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
resp = requests.post(
base_url,
timeout=10,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
params={
"grant_type": "client_credentials",
"client_id": self.ernie_client_id,
"client_secret": self.ernie_client_secret,
},
)
self.access_token = str(resp.json().get("access_token"))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
raise ValueError("`streaming` option currently unsupported.")
if not self.access_token:
self._refresh_access_token_with_lock()
payload = {
"messages": [_convert_message_to_dict(m) for m in messages],
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
"system": self.system,
**kwargs,
}
logger.debug(f"Payload for ernie api is {payload}")
resp = self._chat(payload)
if resp.get("error_code"):
if resp.get("error_code") == 111:
logger.debug("access_token expired, refresh it")
self._refresh_access_token_with_lock()
resp = self._chat(payload)
else:
raise ValueError(f"Error from ErnieChat api response: {resp}")
return self._create_chat_result(resp)
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
if "function_call" in response:
additional_kwargs = {
"function_call": dict(response.get("function_call", {}))
}
else:
additional_kwargs = {}
generations = [
ChatGeneration(
message=AIMessage(
content=response.get("result", ""),
additional_kwargs={**additional_kwargs},
)
)
]
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "ernie-bot-chat"

View File

@@ -0,0 +1,185 @@
"""EverlyAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from __future__ import annotations
import logging
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Set,
Type,
Union,
)
from langchain_core.messages import BaseMessage
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, model_validator
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_community.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
)
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://everlyai.xyz/hosted"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
class ChatEverlyAI(ChatOpenAI):
"""`EverlyAI` Chat large language models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``EVERLYAI_API_KEY`` set with your API key.
Alternatively, you can use the everlyai_api_key keyword argument.
Any parameters that are valid to be passed to the `openai.create` call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatEverlyAI
chat = ChatEverlyAI(model_name="meta-llama/Llama-2-7b-chat-hf")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "everlyai-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"everlyai_api_key": "EVERLYAI_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return False
everlyai_api_key: Optional[str] = None
"""EverlyAI Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
everlyai_api_base: str = DEFAULT_API_BASE
"""Base URL path for API requests."""
available_models: Optional[Set[str]] = None
"""Available models from EverlyAI API."""
@staticmethod
def get_available_models() -> Set[str]:
"""Get available models from EverlyAI API."""
# EverlyAI doesn't yet support dynamically query for available models.
return set(
[
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-chat-hf-quantized",
]
)
@model_validator(mode="before")
@classmethod
def validate_environment_override(cls, values: dict) -> Any:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"everlyai_api_key",
"EVERLYAI_API_KEY",
)
)
values["openai_api_base"] = DEFAULT_API_BASE
try:
import openai
except ImportError as e:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`.",
) from e
try:
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`.",
) from exc
if "model_name" not in values.keys():
values["model_name"] = DEFAULT_MODEL
model_name = values["model_name"]
available_models = cls.get_available_models()
if model_name not in available_models:
raise ValueError(
f"Model name {model_name} not found in available models: "
f"{available_models}.",
)
values["available_models"] = available_models
return values
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
)
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens

View File

@@ -0,0 +1,105 @@
"""Fake ChatModel for testing purposes."""
import asyncio
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[BaseMessage]
sleep: Optional[float] = None
i: int = 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-messages-list-chat-model"
class FakeListChatModel(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
responses: List
sleep: Optional[float] = None
i: int = 0
@property
def _llm_type(self) -> str:
return "fake-list-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
return response
def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
if self.sleep is not None:
time.sleep(self.sleep)
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
if self.sleep is not None:
await asyncio.sleep(self.sleep)
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}

View File

@@ -0,0 +1,372 @@
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Type,
Union,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils.env import get_from_dict_or_env
from pydantic import Field, SecretStr, model_validator
from langchain_community.adapters.openai import convert_message_to_dict
def _convert_delta_to_message_chunk(
_dict: Any, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
"""Convert a delta response to a message chunk."""
role = _dict.role
content = _dict.content or ""
additional_kwargs: Dict = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict.name)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore[call-arg]
def convert_dict_to_message(_dict: Any) -> BaseMessage:
"""Convert a dict response to a message."""
role = _dict.role
content = _dict.content or ""
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
content = _dict.content
additional_kwargs: Dict = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=content)
elif role == "function":
return FunctionMessage(content=content, name=_dict.name)
else:
return ChatMessage(content=content, role=role)
@deprecated(
since="0.0.26",
removal="1.0",
alternative_import="langchain_fireworks.ChatFireworks",
)
class ChatFireworks(BaseChatModel):
"""Fireworks Chat models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: dict = Field(
default_factory=lambda: {
"temperature": 0.7,
"max_tokens": 512,
"top_p": 1,
}.copy()
)
fireworks_api_key: Optional[SecretStr] = None
max_retries: int = 20
use_retry: bool = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key in environment."""
try:
import fireworks.client
except ImportError as e:
raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
fireworks.client.api_key = fireworks_api_key.get_secret_value()
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks-chat"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
**kwargs,
}
response = completion_with_retry(
self,
self.use_retry,
run_manager=run_manager,
stop=stop,
**params,
)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
**kwargs,
}
response = await acompletion_with_retry(
self, self.use_retry, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
if llm_outputs[0] is None:
return {}
return llm_outputs[0]
def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(res.message)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
)
generations.append(gen)
llm_output = {"model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
**kwargs,
}
for chunk in completion_with_retry(
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
params = {
"model": self.model,
"messages": message_dicts,
"stream": True,
**self.model_kwargs,
**kwargs,
}
async for chunk in await acompletion_with_retry_streaming(
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
choice = chunk.choices[0]
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
finish_reason = choice.finish_reason
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]:
"""Define conditional decorator.
Args:
condition: The condition.
decorator: The decorator.
Returns:
The decorated function.
"""
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition:
return decorator(func)
return func
return actual_decorator
def completion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
return fireworks.client.ChatCompletion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
async def acompletion_with_retry_streaming(
llm: ChatFireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call for streaming."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.ChatCompletion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: ChatFireworks,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
import fireworks.client
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.InternalServerError,
fireworks.client.error.BadGatewayError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.friendli import BaseFriendli
def get_role(message: BaseMessage) -> str:
"""Get role of the message.
Args:
message (BaseMessage): The message object.
Raises:
ValueError: Raised when the message is of an unknown type.
Returns:
str: The role of the message.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "user"
if isinstance(message, AIMessage):
return "assistant"
if isinstance(message, SystemMessage):
return "system"
raise ValueError(f"Got unknown type {message}")
def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
"""Get a request of the Friendli chat API.
Args:
messages (List[BaseMessage]): Messages comprising the conversation so far.
Returns:
Dict[str, Any]: The request for the Friendli chat API.
"""
return {
"messages": [
{"role": get_role(message), "content": message.content}
for message in messages
]
}
class ChatFriendli(BaseChatModel, BaseFriendli):
"""Friendli LLM for chat.
``friendli-client`` package should be installed with `pip install friendli-client`.
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
personal access token for the ``friendli_token`` argument.
Example:
.. code-block:: python
from langchain_community.chat_models import FriendliChat
chat = Friendli(
model="meta-llama-3.1-8b-instruct", friendli_token="YOUR FRIENDLI TOKEN"
)
chat.invoke("What is generative AI?")
"""
model: str = "meta-llama-3.1-8b-instruct"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"friendli_token": "FRIENDLI_TOKEN"}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Friendli completions API."""
return {
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"max_tokens": self.max_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {"model": self.model, **self._default_params}
@property
def _llm_type(self) -> str:
return "friendli-chat"
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop"] = self.stop
else:
params["stop"] = stop
return {**params, **kwargs}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = self.client.chat.completions.create(
**get_chat_request(messages), stream=True, model=self.model, **params
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
if run_manager:
run_manager.on_llm_new_token(delta)
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = await self.async_client.chat.completions.create(
**get_chat_request(messages), stream=True, model=self.model, **params
)
async for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
if run_manager:
await run_manager.on_llm_new_token(delta)
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._get_invocation_params(stop=stop, **kwargs)
response = self.client.chat.completions.create(
messages=[
{
"role": get_role(message),
"content": message.content,
}
for message in messages
],
stream=False,
model=self.model,
**params,
)
message = AIMessage(content=response.choices[0].message.content)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
params = self._get_invocation_params(stop=stop, **kwargs)
response = await self.async_client.chat.completions.create(
messages=[
{
"role": get_role(message),
"content": message.content,
}
for message in messages
],
stream=False,
model=self.model,
**params,
)
message = AIMessage(content=response.choices[0].message.content)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@@ -0,0 +1,280 @@
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.gigachat import _BaseGigaChat
if TYPE_CHECKING:
import gigachat.models as gm
logger = logging.getLogger(__name__)
def _convert_dict_to_message(message: gm.Messages) -> BaseMessage:
from gigachat.models import FunctionCall, MessagesRole
additional_kwargs: Dict = {}
if function_call := message.function_call:
if isinstance(function_call, FunctionCall):
additional_kwargs["function_call"] = dict(function_call)
elif isinstance(function_call, dict):
additional_kwargs["function_call"] = function_call
if message.role == MessagesRole.SYSTEM:
return SystemMessage(content=message.content)
elif message.role == MessagesRole.USER:
return HumanMessage(content=message.content)
elif message.role == MessagesRole.ASSISTANT:
return AIMessage(content=message.content, additional_kwargs=additional_kwargs)
else:
raise TypeError(f"Got unknown role {message.role} {message}")
def _convert_message_to_dict(message: gm.BaseMessage) -> gm.Messages:
from gigachat.models import Messages, MessagesRole
if isinstance(message, SystemMessage):
return Messages(role=MessagesRole.SYSTEM, content=message.content)
elif isinstance(message, HumanMessage):
return Messages(role=MessagesRole.USER, content=message.content)
elif isinstance(message, AIMessage):
return Messages(
role=MessagesRole.ASSISTANT,
content=message.content,
function_call=message.additional_kwargs.get("function_call", None),
)
elif isinstance(message, ChatMessage):
return Messages(role=MessagesRole(message.role), content=message.content)
elif isinstance(message, FunctionMessage):
return Messages(role=MessagesRole.FUNCTION, content=message.content)
else:
raise TypeError(f"Got unknown type {message}")
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
@deprecated(
since="0.3.5",
removal="1.0",
alternative_import="langchain_gigachat.GigaChat",
)
class GigaChat(_BaseGigaChat, BaseChatModel):
"""`GigaChat` large language models API.
To use, you should pass login and password to access GigaChat API or use token.
Example:
.. code-block:: python
from langchain_community.chat_models import GigaChat
giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=...)
"""
def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat:
from gigachat.models import Chat
payload = Chat(
messages=[_convert_message_to_dict(m) for m in messages],
)
payload.functions = kwargs.get("functions", None)
payload.model = self.model
if self.profanity_check is not None:
payload.profanity_check = self.profanity_check
if self.temperature is not None:
payload.temperature = self.temperature
if self.top_p is not None:
payload.top_p = self.top_p
if self.max_tokens is not None:
payload.max_tokens = self.max_tokens
if self.repetition_penalty is not None:
payload.repetition_penalty = self.repetition_penalty
if self.update_interval is not None:
payload.update_interval = self.update_interval
if self.verbose:
logger.warning("Giga request: %s", payload.dict())
return payload
def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for res in response.choices:
message = _convert_dict_to_message(res.message)
finish_reason = res.finish_reason
gen = ChatGeneration(
message=message,
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
if finish_reason != "stop":
logger.warning(
"Giga generation stopped with reason: %s",
finish_reason,
)
if self.verbose:
logger.warning("Giga response: %s", message.content)
llm_output = {"token_usage": response.usage, "model_name": response.model}
return ChatResult(generations=generations, llm_output=llm_output)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
payload = self._build_payload(messages, **kwargs)
response = self._client.chat(payload)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
payload = self._build_payload(messages, **kwargs)
response = await self._client.achat(payload)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
payload = self._build_payload(messages, **kwargs)
for chunk in self._client.stream(payload):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
content = choice.get("delta", {}).get("content", {})
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
if run_manager:
run_manager.on_llm_new_token(content)
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
payload = self._build_payload(messages, **kwargs)
async for chunk in self._client.astream(payload):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
content = choice.get("delta", {}).get("content", {})
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
if run_manager:
await run_manager.on_llm_new_token(content)
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)

View File

@@ -0,0 +1,355 @@
"""Wrapper around Google's PaLM Chat API."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import BaseModel, SecretStr
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
if TYPE_CHECKING:
import google.generativeai as genai
logger = logging.getLogger(__name__)
class ChatGooglePalmError(Exception):
"""Error with the `Google PaLM` API."""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(
response: genai.types.ChatResponse,
stop: Optional[List[str]],
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
if not response.candidates:
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
generations: List[ChatGeneration] = []
for candidate in response.candidates:
author = candidate.get("author")
if author is None:
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
if content is None:
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
if author == "ai":
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
)
elif author == "human":
generations.append(
ChatGeneration(
text=content,
message=HumanMessage(content=content),
)
)
else:
generations.append(
ChatGeneration(
text=content,
message=ChatMessage(role=author, content=content),
)
)
return ChatResult(generations=generations)
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> genai.types.MessagePromptDict:
"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
import google.generativeai as genai
context: str = ""
examples: List[genai.types.MessageDict] = []
messages: List[genai.types.MessageDict] = []
remaining = list(enumerate(input_messages))
while remaining:
index, input_message = remaining.pop(0)
if isinstance(input_message, SystemMessage):
if index != 0:
raise ChatGooglePalmError("System message must be first input message.")
context = cast(str, input_message.content)
elif isinstance(
input_message, HumanMessage
) and input_message.additional_kwargs.get("example"):
if messages:
raise ChatGooglePalmError(
"Message examples must come before other messages."
)
_, next_input_message = remaining.pop(0)
if isinstance(
next_input_message, AIMessage
) and next_input_message.additional_kwargs.get("example"):
examples.extend(
[
genai.types.MessageDict(
author="human", content=input_message.content
),
genai.types.MessageDict(
author="ai", content=next_input_message.content
),
]
)
else:
raise ChatGooglePalmError(
"Human example message must be immediately followed by an "
" AI example response."
)
elif isinstance(
input_message, AIMessage
) and input_message.additional_kwargs.get("example"):
raise ChatGooglePalmError(
"AI example message must be immediately preceded by a Human "
"example message."
)
elif isinstance(input_message, AIMessage):
messages.append(
genai.types.MessageDict(author="ai", content=input_message.content)
)
elif isinstance(input_message, HumanMessage):
messages.append(
genai.types.MessageDict(author="human", content=input_message.content)
)
elif isinstance(input_message, ChatMessage):
messages.append(
genai.types.MessageDict(
author=input_message.role, content=input_message.content
)
)
else:
raise ChatGooglePalmError(
"Messages without an explicit role not supported by PaLM API."
)
return genai.types.MessagePromptDict(
context=context,
examples=examples,
messages=messages,
)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return llm.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
class ChatGooglePalm(BaseChatModel, BaseModel):
"""`Google PaLM` Chat models API.
To use you must have the google.generativeai Python package installed and
either:
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatGooglePalm
chat = ChatGooglePalm()
"""
client: Any #: :meta private:
model_name: str = "models/chat-bison-001"
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must be in the closed
interval [0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@classmethod
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "google_palm"]
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
google_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key.get_secret_value())
except ImportError:
raise ChatGooglePalmError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`"
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await achat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"

View File

@@ -0,0 +1,401 @@
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import BaseModel, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.adapters.openai import (
convert_dict_to_message,
convert_message_to_dict,
)
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
if TYPE_CHECKING:
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
logger = logging.getLogger(__name__)
DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com"
class GPTRouterException(Exception):
"""Error with the `GPTRouter APIs`"""
class GPTRouterModel(BaseModel):
"""GPTRouter model."""
name: str
provider_name: str
def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs: Any
) -> List:
"""
Return the body for the model router input.
"""
from gpt_router.models import GenerationParams, ModelGenerationRequest
return [
ModelGenerationRequest(
model_name=model.name,
provider_name=model.provider_name,
order=index + 1,
prompt_params=GenerationParams(**kwargs),
)
for index, model in enumerate(models_priority_list)
]
def _create_retry_decorator(
llm: GPTRouter,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
from gpt_router import exceptions
errors = [
exceptions.GPTRouterApiTimeoutError,
exceptions.GPTRouterInternalServerError,
exceptions.GPTRouterNotAvailableError,
exceptions.GPTRouterTooManyRequestsError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def completion_with_retry(
llm: GPTRouter,
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
ordered_generation_requests = get_ordered_generation_requests(
models_priority_list, **kwargs
)
return llm.client.generate(
ordered_generation_requests=ordered_generation_requests,
is_stream=kwargs.get("stream", False),
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: GPTRouter,
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
ordered_generation_requests = get_ordered_generation_requests(
models_priority_list, **kwargs
)
return await llm.client.agenerate(
ordered_generation_requests=ordered_generation_requests,
is_stream=kwargs.get("stream", False),
)
return await _completion_with_retry(**kwargs)
class GPTRouter(BaseChatModel):
"""GPTRouter by Writesonic Inc.
For more information, see https://gpt-router.writesonic.com/docs
"""
client: Any = Field(default=None, exclude=True) #: :meta private:
models_priority_list: List[GPTRouterModel] = Field(min_length=1)
gpt_router_api_base: str = Field(default="")
"""WriteSonic GPTRouter custom endpoint"""
gpt_router_api_key: Optional[SecretStr] = None
"""WriteSonic GPTRouter API Key"""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
max_retries: int = 4
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: int = 256
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
values["gpt_router_api_base"] = get_from_dict_or_env(
values,
"gpt_router_api_base",
"GPT_ROUTER_API_BASE",
DEFAULT_API_BASE_URL,
)
values["gpt_router_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"gpt_router_api_key",
"GPT_ROUTER_API_KEY",
)
)
return values
@model_validator(mode="after")
def post_init(self) -> Self:
try:
from gpt_router.client import GPTRouterClient
except ImportError:
raise GPTRouterException(
"Could not import GPTRouter python package. "
"Please install it with `pip install GPTRouter`."
)
gpt_router_client = GPTRouterClient(
self.gpt_router_api_base,
self.gpt_router_api_key.get_secret_value()
if self.gpt_router_api_key
else None,
)
self.client = gpt_router_client
return self
@property
def lc_secrets(self) -> Dict[str, str]:
return {"gpt_router_api_key": "GPT_ROUTER_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "gpt-router-chat"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
**{"models_priority_list": self.models_priority_list},
**self._default_params,
}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling GPTRouter API."""
return {
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": False}
response = completion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": False}
response = await acompletion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
return self._create_chat_result(response)
def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class
)
finish_reason = data.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return gen_chunk, default_chunk_class
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = completion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
for chunk in generator_response:
if chunk.event != "update":
continue
chunk, default_chunk_class = self._create_chat_generation_chunk(
chunk.data, default_chunk_class
)
if run_manager:
run_manager.on_llm_new_token(
token=str(chunk.message.content), chunk=chunk
)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = acompletion_with_retry(
self,
messages=message_dicts,
models_priority_list=self.models_priority_list,
run_manager=run_manager,
**params,
)
async for chunk in await generator_response:
if chunk.event != "update":
continue
chunk, default_chunk_class = self._create_chat_generation_chunk(
chunk.data, default_chunk_class
)
if run_manager:
await run_manager.on_llm_new_token(
token=str(chunk.message.content), chunk=chunk
)
yield chunk
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: GenerationResponse) -> ChatResult:
generations = []
for res in response.choices:
message = convert_dict_to_message(
{
"role": "assistant",
"content": res.text,
}
)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.finish_reason),
)
generations.append(gen)
llm_output = {"token_usage": response.meta, "model": response.model}
return ChatResult(generations=generations, llm_output=llm_output)

View File

@@ -0,0 +1,235 @@
"""Hugging Face Chat Wrapper."""
from typing import Any, AsyncIterator, Iterator, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from pydantic import model_validator
from typing_extensions import Self
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_community.llms.huggingface_hub import HuggingFaceHub
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
@deprecated(
since="0.0.37",
removal="1.0",
alternative_import="langchain_huggingface.ChatHuggingFace",
)
class ChatHuggingFace(BaseChatModel):
"""
Wrapper for using Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
and `HuggingFaceHub` LLMs.
Upon instantiating this class, the model_id is resolved from the url
provided to the LLM, and the appropriate tokenizer is loaded from
the HuggingFace Hub.
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
"""
llm: Any
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
HuggingFaceHub."""
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: Optional[str] = None
streaming: bool = False
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from transformers import AutoTokenizer
self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
@model_validator(mode="after")
def validate_llm(self) -> Self:
if not isinstance(
self.llm,
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
):
raise TypeError(
"Expected llm to be one of HuggingFaceTextGenInference, "
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(self.llm)}"
)
return self
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)
for data in self.llm.stream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = self._to_chat_prompt(messages)
async for data in self.llm.astream(request, **kwargs):
delta = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _to_chat_prompt(
self,
messages: List[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")
if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")
messages_dicts = [self._to_chatml_format(m) for m in messages]
return self.tokenizer.apply_chat_template(
messages_dicts, tokenize=False, add_generation_prompt=True
)
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
return {"role": role, "content": message.content}
@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []
for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)
return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)
def _resolve_model_id(self) -> None:
"""Resolve the model_id from the LLM's inference_server_url"""
from huggingface_hub import list_inference_endpoints
available_endpoints = list_inference_endpoints("*")
if isinstance(self.llm, HuggingFaceHub) or (
hasattr(self.llm, "repo_id") and self.llm.repo_id
):
self.model_id = self.llm.repo_id
return
elif isinstance(self.llm, HuggingFaceTextGenInference):
endpoint_url: Optional[str] = self.llm.inference_server_url
else:
endpoint_url = self.llm.endpoint_url
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
self.model_id = endpoint.repository
if not self.model_id:
raise ValueError(
"Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)
@property
def _llm_type(self) -> str:
return "huggingface-chat-wrapper"

View File

@@ -0,0 +1,111 @@
"""ChatModel wrapper which returns user input as the response.."""
from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional
import yaml
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
_message_from_dict,
messages_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from pydantic import Field
from langchain_community.llms.utils import enforce_stop_tokens
def _display_messages(messages: List[BaseMessage]) -> None:
dict_messages = messages_to_dict(messages)
for message in dict_messages:
yaml_string = yaml.dump(
message,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
width=10000,
line_break=None,
)
print("\n", "======= start of message =======", "\n\n") # noqa: T201
print(yaml_string) # noqa: T201
print("======= end of message =======", "\n\n") # noqa: T201
def _collect_yaml_input(
messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> BaseMessage:
"""Collects and returns user input as a single string."""
lines = []
while True:
line = input()
if not line.strip():
break
if stop and any(seq in line for seq in stop):
break
lines.append(line)
yaml_string = "\n".join(lines)
# Try to parse the input string as YAML
try:
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
if message is None:
return HumanMessage(content="")
if stop:
if isinstance(message.content, str):
message.content = enforce_stop_tokens(message.content, stop)
else:
raise ValueError("Cannot use when output is not a string.")
return message
except yaml.YAMLError:
raise ValueError("Invalid YAML string entered.")
except ValueError:
raise ValueError("Invalid message entered.")
class HumanInputChatModel(BaseChatModel):
"""ChatModel which returns user input as the response."""
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
message_func: Callable = Field(default_factory=lambda: _display_messages)
separator: str = "\n"
input_kwargs: Mapping[str, Any] = {}
message_kwargs: Mapping[str, Any] = {}
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"input_func": self.input_func.__name__,
"message_func": self.message_func.__name__,
}
@property
def _llm_type(self) -> str:
"""Returns the type of LLM."""
return "human-input-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
Displays the messages to the user and returns their input as a response.
Args:
messages (List[BaseMessage]): The messages to be displayed to the user.
stop (Optional[List[str]]): A list of stop strings.
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
Returns:
ChatResult: The user's input as a response.
"""
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])

View File

@@ -0,0 +1,280 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from pydantic import ConfigDict, Field, SecretStr, model_validator
logger = logging.getLogger(__name__)
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"Role": message.role, "Content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"Role": "system", "Content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"Role": "user", "Content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"Role": "assistant", "Content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["Role"]
if role == "system":
return SystemMessage(content=_dict.get("Content", "") or "")
elif role == "user":
return HumanMessage(content=_dict["Content"])
elif role == "assistant":
return AIMessage(content=_dict.get("Content", "") or "")
else:
return ChatMessage(content=_dict["Content"], role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("Role")
content = _dict.get("Content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["Choices"]:
message = _convert_dict_to_message(choice["Message"])
message.id = response.get("Id", "")
generations.append(ChatGeneration(message=message))
token_usage = response["Usage"]
llm_output = {"token_usage": token_usage}
return ChatResult(generations=generations, llm_output=llm_output)
class ChatHunyuan(BaseChatModel):
"""Tencent Hunyuan chat models API by Tencent.
For more information, see https://cloud.tencent.com/document/product/1729
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {
"hunyuan_app_id": "HUNYUAN_APP_ID",
"hunyuan_secret_id": "HUNYUAN_SECRET_ID",
"hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
}
@property
def lc_serializable(self) -> bool:
return True
hunyuan_app_id: Optional[int] = None
"""Hunyuan App ID"""
hunyuan_secret_id: Optional[str] = None
"""Hunyuan Secret ID"""
hunyuan_secret_key: Optional[SecretStr] = None
"""Hunyuan Secret Key"""
streaming: bool = False
"""Whether to stream the results or not."""
request_timeout: int = 60
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
temperature: float = 1.0
"""What sampling temperature to use."""
top_p: float = 1.0
"""What probability mass to use."""
model: str = "hunyuan-lite"
"""What Model to use.
Optional model:
- hunyuan-lite
- hunyuan-standard
- hunyuan-standard-256K
- hunyuan-pro
- hunyuan-code
- hunyuan-role
- hunyuan-functioncall
- hunyuan-vision
"""
stream_moderation: bool = False
"""Whether to review the results or not when streaming is true."""
enable_enhancement: bool = True
"""Whether to enhancement the results or not."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
values["hunyuan_app_id"] = get_from_dict_or_env(
values,
"hunyuan_app_id",
"HUNYUAN_APP_ID",
)
values["hunyuan_secret_id"] = get_from_dict_or_env(
values,
"hunyuan_secret_id",
"HUNYUAN_SECRET_ID",
)
values["hunyuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"hunyuan_secret_key",
"HUNYUAN_SECRET_KEY",
)
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Hunyuan API."""
normal_params = {
"Temperature": self.temperature,
"TopP": self.top_p,
"Model": self.model,
"Stream": self.streaming,
"StreamModeration": self.stream_moderation,
"EnableEnhancement": self.enable_enhancement,
}
return {**normal_params, **self.model_kwargs}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
return _create_chat_result(json.loads(res.to_json_string()))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in res:
chunk = chunk.get("data", "")
if len(chunk) == 0:
continue
response = json.loads(chunk)
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
for choice in response["Choices"]:
chunk = _convert_delta_to_message_chunk(
choice["Delta"], default_chunk_class
)
chunk.id = response.get("Id", "")
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
if self.hunyuan_secret_key is None:
raise ValueError("Hunyuan secret key is not set.")
try:
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
except ImportError:
raise ImportError(
"Could not import tencentcloud python package. "
"Please install it with `pip install tencentcloud-sdk-python`."
)
parameters = {**self._default_params, **kwargs}
cred = credential.Credential(
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
)
client = hunyuan_client.HunyuanClient(cred, "")
req = models.ChatCompletionsRequest()
params = {
"Messages": [_convert_message_to_dict(m) for m in messages],
**parameters,
}
req.from_json_string(json.dumps(params))
resp = client.ChatCompletions(req)
return resp
@property
def _llm_type(self) -> str:
return "hunyuan-chat"

View File

@@ -0,0 +1,228 @@
import logging
from typing import Any, Dict, List, Mapping, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr
logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra="allow"):
"""Parameters for the `Javelin AI Gateway` LLM."""
temperature: float = 0.0
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
class ChatJavelinAIGateway(BaseChatModel):
"""`Javelin AI Gateway` chat models API.
To use, you should have the ``javelin_sdk`` python package installed.
For more information, see https://docs.getjavelin.io
Example:
.. code-block:: python
from langchain_community.chat_models import ChatJavelinAIGateway
chat = ChatJavelinAIGateway(
gateway_uri="<javelin-ai-gateway-uri>",
route="<javelin-ai-gateway-chat-route>",
params={
"temperature": 0.1
}
)
"""
route: str
"""The route to use for the Javelin AI Gateway API."""
gateway_uri: Optional[str] = None
"""The URI for the Javelin AI Gateway API."""
params: Optional[ChatParams] = None
"""Parameters for the Javelin AI Gateway LLM."""
client: Any = None
"""javelin client."""
javelin_api_key: Optional[SecretStr] = Field(None, alias="api_key")
"""The API key for the Javelin AI Gateway."""
model_config = ConfigDict(
populate_by_name=True,
)
def __init__(self, **kwargs: Any):
try:
from javelin_sdk import (
JavelinClient,
UnauthorizedError,
)
except ImportError:
raise ImportError(
"Could not import javelin_sdk python package. "
"Please install it with `pip install javelin_sdk`."
)
super().__init__(**kwargs)
if self.gateway_uri:
try:
self.client = JavelinClient(
base_url=self.gateway_uri,
api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(),
)
except UnauthorizedError as e:
raise ValueError("Javelin: Incorrect API Key.") from e
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"gateway_uri": self.gateway_uri,
"javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(),
"route": self.route,
**(self.params.dict() if self.params else {}),
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = self.client.query_route(self.route, query_body=data)
return ChatJavelinAIGateway._create_chat_result(resp.dict())
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = await self.client.aquery_route(self.route, query_body=data)
return ChatJavelinAIGateway._create_chat_result(resp.dict())
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "javelin-ai-gateway-chat"
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)
@staticmethod
def _raise_functions_not_supported() -> None:
raise ValueError(
"Function messages are not supported by the Javelin AI Gateway. Please"
" create a feature request at https://docs.getjavelin.io"
)
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by the Javelin AI Gateway. Please"
" create a feature request at https://docs.getjavelin.io"
)
else:
raise ValueError(f"Got unknown message type: {message}")
if "function_call" in message.additional_kwargs:
ChatJavelinAIGateway._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by Javelin AI Gateway "
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict
@staticmethod
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for candidate in response["llm_response"]["choices"]:
message = ChatJavelinAIGateway._convert_dict_to_message(
candidate["message"]
)
message_metadata = candidate.get("metadata", {})
gen = ChatGeneration(
message=message,
generation_info=dict(message_metadata),
)
generations.append(gen)
response_metadata = response.get("metadata", {})
return ChatResult(generations=generations, llm_output=response_metadata)

View File

@@ -0,0 +1,414 @@
"""JinaChat wrapper."""
from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from pydantic import ConfigDict, Field, SecretStr, model_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
content = _dict["content"] or ""
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"name": message.name,
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class JinaChat(BaseChatModel):
"""`Jina AI` Chat models API.
To use, you should have the ``openai`` python package installed, and the
environment variable ``JINACHAT_API_KEY`` set to your API key, which you
can generate at https://chat.jina.ai/api.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import JinaChat
chat = JinaChat()
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"jinachat_api_key": "JINACHAT_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False
client: Any = None #: :meta private:
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
jinachat_api_key: Optional[SecretStr] = None
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to JinaChat completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY")
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling JinaChat API."""
return {
"request_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"temperature": self.temperature,
**self.model_kwargs,
}
def _create_retry_decorator(self) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
jinachat_creds: Dict[str, Any] = {
"api_key": self.jinachat_api_key
and self.jinachat_api_key.get_secret_value(),
"api_base": "https://api.chat.jina.ai/v1",
"model": "jinachat",
}
return {**jinachat_creds, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "jinachat"

View File

@@ -0,0 +1,603 @@
##
# Copyright (c) 2024, Chad Juliano, Kinetica DB Inc.
##
"""Kinetica SQL generation LLM API."""
import json
import logging
import os
import re
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
from langchain_core.utils import pre_init
if TYPE_CHECKING:
import gpudb
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers.transform import BaseOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
from pydantic import BaseModel, ConfigDict, Field
LOG = logging.getLogger(__name__)
# Kinetica pydantic API datatypes
class _KdtSuggestContext(BaseModel):
"""pydantic API request type"""
table: Optional[str] = Field(default=None, title="Name of table")
description: Optional[str] = Field(default=None, title="Table description")
columns: List[str] = Field(default=[], title="Table columns list")
rules: Optional[List[str]] = Field(
default=None, title="Rules that apply to the table."
)
samples: Optional[Dict] = Field(
default=None, title="Samples that apply to the entire context."
)
def to_system_str(self) -> str:
lines = []
lines.append(f"CREATE TABLE {self.table} AS")
lines.append("(")
if not self.columns or len(self.columns) == 0:
ValueError("columns list can't be null.")
columns = []
for column in self.columns:
column = column.replace('"', "").strip()
columns.append(f" {column}")
lines.append(",\n".join(columns))
lines.append(");")
if self.description:
lines.append(f"COMMENT ON TABLE {self.table} IS '{self.description}';")
if self.rules and len(self.rules) > 0:
lines.append(
f"-- When querying table {self.table} the following rules apply:"
)
for rule in self.rules:
lines.append(f"-- * {rule}")
result = "\n".join(lines)
return result
class _KdtSuggestPayload(BaseModel):
"""pydantic API request type"""
question: Optional[str] = None
context: List[_KdtSuggestContext]
def get_system_str(self) -> str:
lines = []
for table_context in self.context:
if table_context.table is None:
continue
context_str = table_context.to_system_str()
lines.append(context_str)
return "\n\n".join(lines)
def get_messages(self) -> List[Dict]:
messages = []
for context in self.context:
if context.samples is None:
continue
for question, answer in context.samples.items():
# unescape double quotes
answer = answer.replace("''", "'")
messages.append(dict(role="user", content=question or ""))
messages.append(dict(role="assistant", content=answer))
return messages
def to_completion(self) -> Dict:
messages = []
messages.append(dict(role="system", content=self.get_system_str()))
messages.extend(self.get_messages())
messages.append(dict(role="user", content=self.question or ""))
response = dict(messages=messages)
return response
class _KdtoSuggestRequest(BaseModel):
"""pydantic API request type"""
payload: _KdtSuggestPayload
class _KdtMessage(BaseModel):
"""pydantic API response type"""
role: str = Field(default="", title="One of [user|assistant|system]")
content: str
class _KdtChoice(BaseModel):
"""pydantic API response type"""
index: int
message: Optional[_KdtMessage] = Field(default=None, title="The generated SQL")
finish_reason: str
class _KdtUsage(BaseModel):
"""pydantic API response type"""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class _KdtSqlResponse(BaseModel):
"""pydantic API response type"""
id: str
object: str
created: int
model: str
choices: List[_KdtChoice]
usage: _KdtUsage
prompt: str = Field(default="", title="The input question")
class _KdtCompletionResponse(BaseModel):
"""pydantic API response type"""
status: str
data: _KdtSqlResponse
class _KineticaLlmFileContextParser:
"""Parser for Kinetica LLM context datafiles."""
# parse line into a dict containing role and content
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
@classmethod
def _removesuffix(cls, text: str, suffix: str) -> str:
if suffix and text.endswith(suffix):
return text[: -len(suffix)]
return text
@classmethod
def parse_dialogue_file(cls, input_file: os.PathLike) -> Dict:
path = Path(input_file)
# schema = path.name.removesuffix(".txt") python 3.9
schema = cls._removesuffix(path.name, ".txt")
lines = open(input_file).read()
return cls.parse_dialogue(lines, schema)
@classmethod
def parse_dialogue(cls, text: str, schema: str) -> Dict:
messages = []
system = None
lines = text.split("<|end|>")
user_message = None
for idx, line in enumerate(lines):
line = line.strip()
if len(line) == 0:
continue
match = cls.PARSER.match(line)
if match is None:
raise ValueError(f"Could not find starting token in: {line}")
groupdict = match.groupdict()
role = groupdict["role"]
if role == "system":
if system is not None:
raise ValueError(f"Only one system token allowed in: {line}")
system = groupdict["content"]
elif role == "user":
if user_message is not None:
raise ValueError(
f"Found user token without assistant token: {line}"
)
user_message = groupdict
elif role == "assistant":
if user_message is None:
raise Exception(f"Found assistant token without user token: {line}")
messages.append(user_message)
messages.append(groupdict)
user_message = None
else:
raise ValueError(f"Unknown token: {role}")
return {"schema": schema, "system": system, "messages": messages}
class KineticaUtil:
"""Kinetica utility functions."""
@classmethod
def create_kdbc(
cls,
url: Optional[str] = None,
user: Optional[str] = None,
passwd: Optional[str] = None,
) -> "gpudb.GPUdb":
"""Create a connectica connection object and verify connectivity.
If None is passed for one or more of the parameters then an attempt will be made
to retrieve the value from the related environment variable.
Args:
url: The Kinetica URL or ``KINETICA_URL`` if None.
user: The Kinetica user or ``KINETICA_USER`` if None.
passwd: The Kinetica password or ``KINETICA_PASSWD`` if None.
Returns:
The Kinetica connection object.
"""
try:
import gpudb
except ModuleNotFoundError:
raise ImportError(
"Could not import Kinetica python package. "
"Please install it with `pip install gpudb`."
)
url = cls._get_env("KINETICA_URL", url)
user = cls._get_env("KINETICA_USER", user)
passwd = cls._get_env("KINETICA_PASSWD", passwd)
options = gpudb.GPUdb.Options()
options.username = user
options.password = passwd
options.skip_ssl_cert_verification = True
options.disable_failover = True
options.logging_level = "INFO"
kdbc = gpudb.GPUdb(host=url, options=options)
LOG.info(
"Connected to Kinetica: {}. (api={}, server={})".format(
kdbc.get_url(), version("gpudb"), kdbc.server_version
)
)
return kdbc
@classmethod
def _get_env(cls, name: str, default: Optional[str]) -> str:
"""Get an environment variable or use a default."""
if default is not None:
return default
result = os.getenv(name)
if result is not None:
return result
raise ValueError(
f"Parameter was not passed and not found in the environment: {name}"
)
class ChatKinetica(BaseChatModel):
"""Kinetica LLM Chat Model API.
Prerequisites for using this API:
* The ``gpudb`` and ``typeguard`` packages installed.
* A Kinetica DB instance.
* Kinetica host specified in ``KINETICA_URL``
* Kinetica login specified ``KINETICA_USER``, and ``KINETICA_PASSWD``.
* An LLM context that specifies the tables and samples to use for inferencing.
This API is intended to interact with the Kinetica SqlAssist LLM that supports
generation of SQL from natural language.
In the Kinetica LLM workflow you create an LLM context in the database that provides
information needed for infefencing that includes tables, annotations, rules, and
samples. Invoking ``load_messages_from_context()`` will retrieve the contxt
information from the database so that it can be used to create a chat prompt.
The chat prompt consists of a ``SystemMessage`` and pairs of
``HumanMessage``/``AIMessage`` that contain the samples which are question/SQL
pairs. You can append pairs samples to this list but it is not intended to
facilitate a typical natural language conversation.
When you create a chain from the chat prompt and execute it, the Kinetica LLM will
generate SQL from the input. Optionally you can use ``KineticaSqlOutputParser`` to
execute the SQL and return the result as a dataframe.
The following example creates an LLM using the environment variables for the
Kinetica connection. This will fail if the API is unable to connect to the database.
Example:
.. code-block:: python
from langchain_community.chat_models.kinetica import KineticaChatLLM
kinetica_llm = KineticaChatLLM()
If you prefer to pass connection information directly then you can create a
connection using ``KineticaUtil.create_kdbc()``.
Example:
.. code-block:: python
from langchain_community.chat_models.kinetica import (
KineticaChatLLM, KineticaUtil)
kdbc = KineticaUtil._create_kdbc(url=url, user=user, passwd=passwd)
kinetica_llm = KineticaChatLLM(kdbc=kdbc)
"""
kdbc: Any = Field(exclude=True)
""" Kinetica DB connection. """
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Pydantic object validator."""
kdbc = values.get("kdbc", None)
if kdbc is None:
kdbc = KineticaUtil.create_kdbc()
values["kdbc"] = kdbc
return values
@property
def _llm_type(self) -> str:
return "kinetica-sqlassist"
@property
def _identifying_params(self) -> Dict[str, Any]:
return dict(
kinetica_version=str(self.kdbc.server_version), api_version=version("gpudb")
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
dict_messages = [self._convert_message_to_dict(m) for m in messages]
sql_response = self._submit_completion(dict_messages)
response_message = cast(_KdtMessage, sql_response.choices[0].message)
generated_dict = response_message.model_dump()
generated_message = self._convert_message_from_dict(generated_dict)
llm_output = dict(
input_tokens=sql_response.usage.prompt_tokens,
output_tokens=sql_response.usage.completion_tokens,
model_name=sql_response.model,
)
return ChatResult(
generations=[ChatGeneration(message=generated_message)],
llm_output=llm_output,
)
def load_messages_from_context(self, context_name: str) -> List:
"""Load a lanchain prompt from a Kinetica context.
A Kinetica Context is an object created with the Kinetica Workbench UI or with
SQL syntax. This function will convert the data in the context to a list of
messages that can be used as a prompt. The messages will contain a
``SystemMessage`` followed by pairs of ``HumanMessage``/``AIMessage`` that
contain the samples.
Args:
context_name: The name of an LLM context in the database.
Returns:
A list of messages containing the information from the context.
"""
# query kinetica for the prompt
sql = f"GENERATE PROMPT WITH OPTIONS (CONTEXT_NAMES = '{context_name}')"
result = self._execute_sql(sql)
prompt = result["Prompt"]
prompt_json = json.loads(prompt)
# convert the prompt to messages
# request = SuggestRequest.model_validate(prompt_json) # pydantic v2
request = _KdtoSuggestRequest.model_validate(prompt_json)
payload = request.payload
dict_messages = []
dict_messages.append(dict(role="system", content=payload.get_system_str()))
dict_messages.extend(payload.get_messages())
messages = [self._convert_message_from_dict(m) for m in dict_messages]
return messages
def _submit_completion(self, messages: List[Dict]) -> _KdtSqlResponse:
"""Submit a /chat/completions request to Kinetica."""
request = dict(messages=messages)
request_json = json.dumps(request)
response_raw = self.kdbc._GPUdb__submit_request_json(
"/chat/completions", request_json
)
response_json = json.loads(response_raw)
status = response_json["status"]
if status != "OK":
message = response_json["message"]
match_resp = re.compile(r"response:({.*})")
result = match_resp.search(message)
if result is not None:
response = result.group(1)
response_json = json.loads(response)
message = response_json["message"]
raise ValueError(message)
data = response_json["data"]
# response = CompletionResponse.model_validate(data) # pydantic v2
response = _KdtCompletionResponse.model_validate(data)
if response.status != "OK":
raise ValueError("SQL Generation failed")
return response.data
def _execute_sql(self, sql: str) -> Dict:
"""Execute an SQL query and return the result."""
response = self.kdbc.execute_sql_and_decode(
sql, limit=1, get_column_major=False
)
status_info = response["status_info"]
if status_info["status"] != "OK":
message = status_info["message"]
raise ValueError(message)
records = response["records"]
if len(records) != 1:
raise ValueError("No records returned.")
record = records[0]
response_dict = {}
for col, val in record.items():
response_dict[col] = val
return response_dict
@classmethod
def load_messages_from_datafile(cls, sa_datafile: Path) -> List[BaseMessage]:
"""Load a lanchain prompt from a Kinetica context datafile."""
datafile_dict = _KineticaLlmFileContextParser.parse_dialogue_file(sa_datafile)
messages = cls._convert_dict_to_messages(datafile_dict)
return messages
@classmethod
def _convert_message_to_dict(cls, message: BaseMessage) -> Dict:
"""Convert a single message to a BaseMessage."""
content = cast(str, message.content)
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError(f"Got unsupported message type: {message}")
result_message = dict(role=role, content=content)
return result_message
@classmethod
def _convert_message_from_dict(cls, message: Dict) -> BaseMessage:
"""Convert a single message from a BaseMessage."""
role = message["role"]
content = message["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
raise ValueError(f"Got unsupported role: {role}")
@classmethod
def _convert_dict_to_messages(cls, sa_data: Dict) -> List[BaseMessage]:
"""Convert a dict to a list of BaseMessages."""
schema = sa_data["schema"]
system = sa_data["system"]
messages = sa_data["messages"]
LOG.info(f"Importing prompt for schema: {schema}")
result_list: List[BaseMessage] = []
result_list.append(SystemMessage(content=system))
result_list.extend([cls._convert_message_from_dict(m) for m in messages])
return result_list
class KineticaSqlResponse(BaseModel):
"""Response containing SQL and the fetched data.
This object is returned by a chain with ``KineticaSqlOutputParser`` and it contains
the generated SQL and related Pandas Dataframe fetched from the database.
"""
sql: str = Field(default="")
"""The generated SQL."""
# dataframe: "pd.DataFrame" = Field(default=None)
dataframe: Any = Field(default=None)
"""The Pandas dataframe containing the fetched data."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
"""Fetch and return data from the Kinetica LLM.
This object is used as the last element of a chain to execute generated SQL and it
will output a ``KineticaSqlResponse`` containing the SQL and a pandas dataframe with
the fetched data.
Example:
.. code-block:: python
from langchain_community.chat_models.kinetica import (
KineticaChatLLM, KineticaSqlOutputParser)
kinetica_llm = KineticaChatLLM()
# create chain
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
ctx_messages.append(("human", "{input}"))
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
chain = (
prompt_template
| kinetica_llm
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
)
sql_response: KineticaSqlResponse = chain.invoke(
{"input": "What are the female users ordered by username?"}
)
assert isinstance(sql_response, KineticaSqlResponse)
LOG.info(f"SQL Response: {sql_response.sql}")
assert isinstance(sql_response.dataframe, pd.DataFrame)
"""
kdbc: Any = Field(exclude=True)
""" Kinetica DB connection. """
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def parse(self, text: str) -> KineticaSqlResponse:
df = self.kdbc.to_df(text)
return KineticaSqlResponse(sql=text, dataframe=df)
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> KineticaSqlResponse:
return self.parse(result[0].text)
@property
def _type(self) -> str:
return "kinetica_sql_output_parser"

View File

@@ -0,0 +1,285 @@
"""KonkoAI chat wrapper."""
from __future__ import annotations
import logging
import os
import warnings
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
Union,
cast,
)
import requests
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from pydantic import Field, SecretStr
from langchain_community.adapters.openai import (
convert_message_to_dict,
)
from langchain_community.chat_models.openai import (
ChatOpenAI,
_convert_delta_to_message_chunk,
generate_from_stream,
)
from langchain_community.utils.openai import is_openai_v1
DEFAULT_API_BASE = "https://api.konko.ai/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
logger = logging.getLogger(__name__)
class ChatKonko(ChatOpenAI):
"""`ChatKonko` Chat large language models API.
To use, you should have the ``konko`` python package installed, and the
environment variable ``KONKO_API_KEY`` and ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the konko.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatKonko
llm = ChatKonko(model="meta-llama/Llama-2-13b-chat-hf")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"konko_api_key": "KONKO_API_KEY", "openai_api_key": "OPENAI_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False
client: Any = None #: :meta private:
model: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
konko_api_key: Optional[str] = None
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
max_tokens: int = 20
"""Maximum number of tokens to generate."""
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
)
try:
import konko
except ImportError:
raise ImportError(
"Could not import konko python package. "
"Please install it with `pip install konko`."
)
try:
if is_openai_v1():
values["client"] = konko.chat.completions
else:
values["client"] = konko.ChatCompletion
except AttributeError:
raise ValueError(
"`konko` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the konko package. Try upgrading it "
"with `pip install --upgrade konko`."
)
if not hasattr(konko, "_is_legacy_openai"):
warnings.warn(
"You are using an older version of the 'konko' package. "
"Please consider upgrading to access new features."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Konko API."""
return {
"model": self.model,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
@staticmethod
def get_available_models(
konko_api_key: Union[str, SecretStr, None] = None,
openai_api_key: Union[str, SecretStr, None] = None,
konko_api_base: str = DEFAULT_API_BASE,
) -> Set[str]:
"""Get available models from Konko API."""
# Try to retrieve the OpenAI API key if it's not passed as an argument
if not openai_api_key:
try:
openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"])
except KeyError:
pass # It's okay if it's not set, we just won't use it
elif isinstance(openai_api_key, str):
openai_api_key = convert_to_secret_str(openai_api_key)
# Try to retrieve the Konko API key if it's not passed as an argument
if not konko_api_key:
try:
konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"])
except KeyError:
raise ValueError(
"Konko API key must be passed as keyword argument or "
"set in environment variable KONKO_API_KEY."
)
elif isinstance(konko_api_key, str):
konko_api_key = convert_to_secret_str(konko_api_key)
models_url = f"{konko_api_base}/models"
headers = {
"Authorization": f"Bearer {konko_api_key.get_secret_value()}",
}
if openai_api_key:
headers["X-OpenAI-Api-Key"] = cast(
SecretStr, openai_api_key
).get_secret_value()
models_response = requests.get(models_url, headers=headers)
if models_response.status_code != 200:
raise ValueError(
f"Error getting models from {models_url}: {models_response.status_code}"
)
return {model["id"] for model in models_response.json()["data"]}
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model}, **self._default_params}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the konko client."""
return {**self._default_params}
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
return {
"model": self.model,
**super()._get_invocation_params(stop=stop),
**self._default_params,
**kwargs,
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "konko-chat"

View File

@@ -0,0 +1,632 @@
"""
Deprecated LiteLLM wrapper.
⭐ Use `pip install langchain-litellm` and import
`from langchain_litellm import ChatLiteLLM` instead.
"""
from __future__ import annotations
import json
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library"""
def _create_retry_decorator(
llm: ChatLiteLLM,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
import litellm
errors = [
litellm.Timeout,
litellm.APIError,
litellm.APIConnectionError,
litellm.RateLimitError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs = {}
if _dict.get("function_call"):
additional_kwargs["function_call"] = dict(_dict["function_call"])
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
else:
return ChatMessage(content=_dict["content"], role=role)
async def acompletion_with_retry(
llm: ChatLiteLLM,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
elif _dict.get("reasoning_content"):
additional_kwargs = {"reasoning_content": _dict["reasoning_content"]}
else:
additional_kwargs = {}
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
ToolCallChunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any] = {"content": message.content}
if isinstance(message, ChatMessage):
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
if message.tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict["role"] = "function"
message_dict["name"] = message.name
elif isinstance(message, ToolMessage):
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
_OPENAI_MODELS = [
"o1-mini",
"o1-preview",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
]
@deprecated(
since="0.3.24",
removal="1.0",
alternative_import="langchain_litellm.ChatLiteLLM",
)
class ChatLiteLLM(BaseChatModel):
"""DEPRECATED use `langchain_litellm.ChatLiteLLM` instead."""
client: Any = None #: :meta private:
model: str = "gpt-3.5-turbo"
model_name: Optional[str] = None
"""Model name to use."""
openai_api_key: Optional[str] = None
azure_api_key: Optional[str] = None
anthropic_api_key: Optional[str] = None
replicate_api_key: Optional[str] = None
cohere_api_key: Optional[str] = None
openrouter_api_key: Optional[str] = None
api_key: Optional[str] = None
streaming: bool = False
api_base: Optional[str] = None
organization: Optional[str] = None
custom_llm_provider: Optional[str] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must be in the closed
interval [0.0, 1.0]."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
n: Optional[int] = None
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_tokens: Optional[int] = None
max_retries: int = 1
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
return {
"model": set_model_value,
"force_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
"custom_llm_provider": self.custom_llm_provider,
**self.model_kwargs,
}
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
self.client.api_base = self.api_base
self.client.api_key = self.api_key
for named_api_key in [
"openai_api_key",
"azure_api_key",
"anthropic_api_key",
"replicate_api_key",
"cohere_api_key",
"openrouter_api_key",
]:
if api_key_value := getattr(self, named_api_key):
setattr(
self.client,
named_api_key.replace("_api_key", "_key"),
api_key_value,
)
self.client.organization = self.organization
creds: Dict[str, Any] = {
"model": set_model_value,
"force_timeout": self.request_timeout,
"api_base": self.api_base,
}
return {**self._default_params, **creds}
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.completion(**kwargs)
return _completion_with_retry(**kwargs)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
try:
import litellm
except ImportError:
raise ChatLiteLLMException(
"Could not import litellm python package. "
"Please install it with `pip install litellm`"
)
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY", default=""
)
values["azure_api_key"] = get_from_dict_or_env(
values, "azure_api_key", "AZURE_API_KEY", default=""
)
values["anthropic_api_key"] = get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY", default=""
)
values["replicate_api_key"] = get_from_dict_or_env(
values, "replicate_api_key", "REPLICATE_API_KEY", default=""
)
values["openrouter_api_key"] = get_from_dict_or_env(
values, "openrouter_api_key", "OPENROUTER_API_KEY", default=""
)
values["cohere_api_key"] = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY", default=""
)
values["huggingface_api_key"] = get_from_dict_or_env(
values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default=""
)
values["together_ai_api_key"] = get_from_dict_or_env(
values, "together_ai_api_key", "TOGETHERAI_API_KEY", default=""
)
values["client"] = litellm
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if isinstance(message, AIMessage):
message.response_metadata = {
"model_name": self.model_name or self.model
}
message.usage_metadata = _create_usage_metadata(token_usage)
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
llm_output = {"token_usage": token_usage, "model": set_model_value}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
added_model_name = False
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
usage = chunk.get("usage", {})
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
if isinstance(chunk, AIMessageChunk):
if not added_model_name:
chunk.response_metadata = {
"model_name": self.model_name or self.model
}
added_model_name = True
chunk.usage_metadata = _create_usage_metadata(usage)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
added_model_name = False
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
usage = chunk.get("usage", {})
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
if isinstance(chunk, AIMessageChunk):
if not added_model_name:
chunk.response_metadata = {
"model_name": self.model_name or self.model
}
added_model_name = True
chunk.usage_metadata = _create_usage_metadata(usage)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.
LiteLLM expects tools argument in OpenAI format.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call. Options are:
- str of the form ``"<<tool_name>>"``: calls <<tool_name>> tool.
- ``"auto"``:
automatically selects a tool (including no tool).
- ``"none"``:
does not call a tool.
- ``"any"`` or ``"required"`` or ``True``:
forces least one tool to be called.
- dict of the form:
``{"type": "function", "function": {"name": <<tool_name>>}}``
- ``False`` or ``None``: no effect
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
# In case of openai if tool_choice is `any` or if bool has been provided we
# change it to `required` as that is suppored by openai.
if (
(self.model is not None and "azure" in self.model)
or (self.model_name is not None and "azure" in self.model_name)
or (self.model is not None and self.model in _OPENAI_MODELS)
or (self.model_name is not None and self.model_name in _OPENAI_MODELS)
) and (tool_choice == "any" or isinstance(tool_choice, bool)):
tool_choice = "required"
# If tool_choice is bool apart from openai we make it `any`
elif isinstance(tool_choice, bool):
tool_choice = "any"
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"] for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"] for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
set_model_value = self.model
if self.model_name is not None:
set_model_value = self.model_name
return {
"model": set_model_value,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "litellm-chat"
def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)

View File

@@ -0,0 +1,230 @@
"""
Deprecated LiteLLM wrapper.
⭐ Use `pip install langchain-litellm` and import
`from langchain_litellm import ChatLiteLLMRouter` instead.
"""
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional, Type
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.chat_models.litellm import (
ChatLiteLLM,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
)
token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
def get_llm_output(usage: Any, **params: Any) -> dict:
"""Get llm output from usage and params."""
llm_output = {token_usage_key_name: usage}
# copy over metadata (metadata came from router completion call)
metadata = params["metadata"]
for key in metadata:
if key not in llm_output:
# if token usage in metadata, prefer metadata's copy of it
llm_output[key] = metadata[key]
return llm_output
@deprecated(
since="0.3.24",
removal="1.0",
alternative_import="langchain_litellm.ChatLiteLLMRouter",
)
class ChatLiteLLMRouter(ChatLiteLLM):
"""DEPRECATED use `langchain_litellm.ChatLiteLLMRouter` instead."""
router: Any
def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router."""
super().__init__(router=router, **kwargs) # type: ignore[call-arg]
self.router = router
@property
def _llm_type(self) -> str:
return "LiteLLMRouter"
def _prepare_params_for_router(self, params: Any) -> None:
# allow the router to set api_base based on its model choice
api_base_key_name = "api_base"
if api_base_key_name in params and params[api_base_key_name] is None:
del params[api_base_key_name]
# add metadata so router can fill it below
params.setdefault("metadata", {})
def set_default_model(self, model_name: str) -> None:
"""Set the default model to use for completion calls.
Sets `self.model` to `model_name` if it is in the litellm router's
(`self.router`) model list. This provides the default model to use
for completion calls if no `model` kwarg is provided.
"""
model_list = self.router.model_list
if not model_list:
raise ValueError("model_list is None or empty.")
for entry in model_list:
if entry["model_name"] == model_name:
self.model = model_name
return
raise ValueError(f"Model {model_name} not found in model_list.")
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._prepare_params_for_router(params)
response = self.router.completion(
messages=message_dicts,
**params,
)
return self._create_chat_result(response, **params)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._prepare_params_for_router(params)
for chunk in self.router.completion(messages=message_dicts, **params):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk, **params
)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._prepare_params_for_router(params)
async for chunk in await self.router.acompletion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk, **params
)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._prepare_params_for_router(params)
response = await self.router.acompletion(
messages=message_dicts,
**params,
)
return self._create_chat_result(response, **params)
# from
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
# but modified to handle LiteLLM Usage class
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
# get dict from LiteLLM Usage class
for k, v in token_usage.model_dump().items():
if k in overall_token_usage and overall_token_usage[k] is not None:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
def _create_chat_result(
self, response: Mapping[str, Any], **params: Any
) -> ChatResult:
from litellm.utils import Usage
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(
message=message,
generation_info=dict(finish_reason=res.get("finish_reason")),
)
generations.append(gen)
token_usage = response.get("usage", Usage(prompt_tokens=0, total_tokens=0))
llm_output = get_llm_output(token_usage, **params)
return ChatResult(generations=generations, llm_output=llm_output)

View File

@@ -0,0 +1,241 @@
import json
import logging
import re
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import get_pydantic_field_names
from pydantic import ConfigDict, model_validator
logger = logging.getLogger(__name__)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
class LlamaEdgeChatService(BaseChatModel):
"""Chat with LLMs via `llama-api-server`
For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge
"""
request_timeout: int = 60
"""request timeout for chat http requests"""
service_url: Optional[str] = None
"""URL of WasmChat service"""
model: str = "NA"
"""model name, default is `NA`."""
streaming: bool = False
"""Whether to stream the results or not."""
model_config = ConfigDict(
populate_by_name=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
if res.status_code != 200:
raise ValueError(f"Error code: {res.status_code}, reason: {res.reason}")
response = res.json()
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
substring = '"object":"chat.completion.chunk"}'
for line in res.iter_lines():
chunks = []
if line:
json_string = line.decode("utf-8")
# Find all positions of the substring
positions = [m.start() for m in re.finditer(substring, json_string)]
positions = [-1 * len(substring)] + positions
for i in range(len(positions) - 1):
chunk = json.loads(
json_string[
positions[i] + len(substring) : positions[i + 1]
+ len(substring)
]
)
chunks.append(chunk)
for chunk in chunks:
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if (
choice.get("finish_reason") is not None
and choice.get("finish_reason") == "stop"
):
break
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason)
if finish_reason is not None
else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None:
res = requests.models.Response()
res.status_code = 503
res.reason = "The IP address or port of the chat service is incorrect."
return res
service_url = f"{self.service_url}/v1/chat/completions"
if self.streaming:
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
"stream": self.streaming,
}
else:
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
}
res = requests.post(
url=service_url,
timeout=self.request_timeout,
headers={
"accept": "application/json",
"Content-Type": "application/json",
},
data=json.dumps(payload),
)
return res
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
message = _convert_dict_to_message(response["choices"][0].get("message"))
generations = [ChatGeneration(message=message)]
token_usage = response["usage"]
llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
@property
def _llm_type(self) -> str:
return "wasm-chat"

View File

@@ -0,0 +1,818 @@
import json
from operator import itemgetter
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall, ToolCallChunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import (
BaseModel,
Field,
model_validator,
)
from typing_extensions import Self
class ChatLlamaCpp(BaseChatModel):
"""llama.cpp model.
To use, you should have the llama-cpp-python library installed, and provide the
path to the Llama model as a named parameter to the constructor.
Check out: https://github.com/abetlen/llama-cpp-python
"""
client: Any = None #: :meta private:
model_path: str
"""The path to the Llama model file."""
lora_base: Optional[str] = None
"""The path to the Llama LoRA base model."""
lora_path: Optional[str] = None
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
n_ctx: int = 512
"""Token context window."""
n_parts: int = -1
"""Number of parts to split the model into.
If -1, the number of parts is automatically determined."""
seed: int = -1
"""Seed. If -1, a random seed is used."""
f16_kv: bool = True
"""Use half-precision for key/value cache."""
logits_all: bool = False
"""Return logits for all tokens, not just the last token."""
vocab_only: bool = False
"""Only load the vocabulary, no weights."""
use_mlock: bool = False
"""Force system to keep model in RAM."""
n_threads: Optional[int] = None
"""Number of threads to use.
If None, the number of threads is automatically determined."""
n_batch: int = 8
"""Number of tokens to process in parallel.
Should be a number between 1 and n_ctx."""
n_gpu_layers: Optional[int] = None
"""Number of layers to be loaded into gpu memory. Default None."""
suffix: Optional[str] = None
"""A suffix to append to the generated text. If None, no suffix is appended."""
max_tokens: int = 256
"""The maximum number of tokens to generate."""
temperature: float = 0.8
"""The temperature to use for sampling."""
top_p: float = 0.95
"""The top-p value to use for sampling."""
logprobs: Optional[int] = None
"""The number of logprobs to return. If None, no logprobs are returned."""
echo: bool = False
"""Whether to echo the prompt."""
stop: Optional[List[str]] = None
"""A list of strings to stop generation when encountered."""
repeat_penalty: float = 1.1
"""The penalty to apply to repeated tokens."""
top_k: int = 40
"""The top-k value to use for sampling."""
last_n_tokens_size: int = 64
"""The number of tokens to look back when applying the repeat_penalty."""
use_mmap: bool = True
"""Whether to keep the model loaded in RAM"""
rope_freq_scale: float = 1.0
"""Scale factor for rope sampling."""
rope_freq_base: float = 10000.0
"""Base frequency for rope sampling."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Any additional parameters to pass to llama_cpp.Llama."""
streaming: bool = True
"""Whether to stream the results, token by token."""
grammar_path: Optional[Union[str, Path]] = None
"""
grammar_path: Path to the .gbnf file that defines formal grammars
for constraining model outputs. For instance, the grammar can be used
to force the model to generate valid JSON or to speak exclusively in emojis. At most
one of grammar_path and grammar should be passed in.
"""
grammar: Any = None
"""
grammar: formal grammar for constraining model outputs. For instance, the grammar
can be used to force the model to generate valid JSON or to speak exclusively in
emojis. At most one of grammar_path and grammar should be passed in.
"""
verbose: bool = True
"""Print verbose output to stderr."""
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that llama-cpp-python library is installed."""
try:
from llama_cpp import Llama, LlamaGrammar
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
model_path = self.model_path
model_param_names = [
"rope_freq_scale",
"rope_freq_base",
"lora_path",
"lora_base",
"n_ctx",
"n_parts",
"seed",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"n_threads",
"n_batch",
"use_mmap",
"last_n_tokens_size",
"verbose",
]
model_params = {k: getattr(self, k) for k in model_param_names}
# For backwards compatibility, only include if non-null.
if self.n_gpu_layers is not None:
model_params["n_gpu_layers"] = self.n_gpu_layers
model_params.update(self.model_kwargs)
try:
self.client = Llama(model_path, **model_params)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
if self.grammar and self.grammar_path:
grammar = self.grammar
grammar_path = self.grammar_path
raise ValueError(
"Can only pass in one of grammar and grammar_path. Received "
f"{grammar=} and {grammar_path=}."
)
elif isinstance(self.grammar, str):
self.grammar = LlamaGrammar.from_string(self.grammar)
elif self.grammar_path:
self.grammar = LlamaGrammar.from_file(self.grammar_path)
else:
pass
return self
def _get_parameters(self, stop: Optional[List[str]]) -> Dict[str, Any]:
"""
Performs sanity check, preparing parameters in format needed by llama_cpp.
Returns:
Dictionary containing the combined parameters.
"""
params = self._default_params
# llama_cpp expects the "stop" key not this, so we remove it:
stop_sequences = params.pop("stop_sequences")
# then sets it as configured, or default to an empty list:
params["stop"] = stop or stop_sequences or self.stop or []
return params
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
# "system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {**self._get_parameters(stop), **kwargs}
# Check tool_choice is whether available, if yes then run no stream with tool
# calling
if self.streaming and not params.get("tool_choice"):
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
return generate_from_stream(stream_iter)
message_dicts = self._create_message_dicts(messages)
response = self.client.create_chat_completion(messages=message_dicts, **params)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = {**self._get_parameters(stop), **kwargs}
message_dicts = self._create_message_dicts(messages)
result = self.client.create_chat_completion(
messages=message_dicts, stream=True, **params
)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
count = 0
for chunk in result:
count += 1
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[dict, bool, str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model
tool_choice: does not currently support "any", "auto" choices like OpenAI
tool-calling API. should be a dict of the form to force this tool
{"type": "function", "function": {"name": <<tool_name>>}}.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
tool_names = [ft["function"]["name"] for ft in formatted_tools]
if tool_choice:
if isinstance(tool_choice, dict):
if not any(
tool_choice["function"]["name"] == name for name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only "
f"provided tools were {tool_names}."
)
elif isinstance(tool_choice, str):
chosen = [
f for f in formatted_tools if f["function"]["name"] == tool_choice
]
if not chosen:
raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only "
f"provided tools were {tool_names}."
)
elif isinstance(tool_choice, bool):
if len(formatted_tools) > 1:
raise ValueError(
"tool_choice=True can only be specified when a single tool is "
f"passed in. Received {len(tools)} tools."
)
tool_choice = formatted_tools[0]
else:
raise ValueError(
"""Unrecognized tool_choice type. Expected dict having format like
this {"type": "function", "function": {"name": <<tool_name>>}}"""
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec or be a valid JSON schema
with top level 'title' and 'description' keys specified.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Any other args to bind to model, ``self.bind(..., **kwargs)``.
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Pydantic schema (include_raw=False):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Pydantic schema (include_raw=True):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: dict schema (include_raw=False):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from pydantic import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
tool_choice = {"type": "function", "function": {"name": tool_name}}
llm = self.bind_tools([schema], tool_choice=tool_choice)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(Type, schema)], first_tool_only=True
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
**{"model_path": self.model_path},
**self._default_params,
}
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "llama-cpp-python"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling create_chat_completion."""
params: Dict = {
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"logprobs": self.logprobs,
"stop_sequences": self.stop, # key here is convention among LLM classes
"repeat_penalty": self.repeat_penalty,
}
if self.grammar:
params["grammar"] = self.grammar
return params
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tc = make_invalid_tool_call(raw_tool_call, str(e))
invalid_tool_calls.append(invalid_tc)
else:
if not tc:
continue
else:
tool_calls.append(tc)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
name=name,
id=id_,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs,
name=name,
id=id_,
)
else:
return ChatMessage(
content=_dict.get("content", ""), role=cast(str, role), id=id_
)
def _format_message_content(content: Any) -> Any:
"""Format message content."""
if content and isinstance(content, list):
# Remove unexpected block types
formatted_content = []
for block in content:
if (
isinstance(block, dict)
and "type" in block
and block["type"] == "tool_use"
):
continue
else:
formatted_content.append(block)
else:
formatted_content = content
return formatted_content
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any] = {
"content": _format_message_content(message.content),
}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
# populate role and additional message data
if isinstance(message, ChatMessage):
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id
supported_props = {"content", "role", "tool_call_id"}
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
tool_call = ToolCallChunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
tool_call_chunks.append(tool_call)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content, id=id_)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore[call-arg]

Some files were not shown because too many files have changed in this diff Show More