initial commit
This commit is contained in:
@@ -0,0 +1,335 @@
|
||||
"""**Chat Models** are a variation on language models.
|
||||
|
||||
While Chat Models use language models under the hood, the interface they expose
|
||||
is a bit different. Rather than expose a "text in, text out" API, they expose
|
||||
an interface where "chat messages" are the inputs and outputs.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
AIMessage, BaseMessage, HumanMessage
|
||||
""" # noqa: E501
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.anthropic import (
|
||||
ChatAnthropic,
|
||||
)
|
||||
from langchain_community.chat_models.anyscale import (
|
||||
ChatAnyscale,
|
||||
)
|
||||
from langchain_community.chat_models.azure_openai import (
|
||||
AzureChatOpenAI,
|
||||
)
|
||||
from langchain_community.chat_models.baichuan import (
|
||||
ChatBaichuan,
|
||||
)
|
||||
from langchain_community.chat_models.baidu_qianfan_endpoint import (
|
||||
QianfanChatEndpoint,
|
||||
)
|
||||
from langchain_community.chat_models.bedrock import (
|
||||
BedrockChat,
|
||||
)
|
||||
from langchain_community.chat_models.cohere import (
|
||||
ChatCohere,
|
||||
)
|
||||
from langchain_community.chat_models.coze import (
|
||||
ChatCoze,
|
||||
)
|
||||
from langchain_community.chat_models.databricks import (
|
||||
ChatDatabricks,
|
||||
)
|
||||
from langchain_community.chat_models.deepinfra import (
|
||||
ChatDeepInfra,
|
||||
)
|
||||
from langchain_community.chat_models.edenai import ChatEdenAI
|
||||
from langchain_community.chat_models.ernie import (
|
||||
ErnieBotChat,
|
||||
)
|
||||
from langchain_community.chat_models.everlyai import (
|
||||
ChatEverlyAI,
|
||||
)
|
||||
from langchain_community.chat_models.fake import (
|
||||
FakeListChatModel,
|
||||
)
|
||||
from langchain_community.chat_models.fireworks import (
|
||||
ChatFireworks,
|
||||
)
|
||||
from langchain_community.chat_models.friendli import (
|
||||
ChatFriendli,
|
||||
)
|
||||
from langchain_community.chat_models.gigachat import (
|
||||
GigaChat,
|
||||
)
|
||||
from langchain_community.chat_models.google_palm import (
|
||||
ChatGooglePalm,
|
||||
)
|
||||
from langchain_community.chat_models.gpt_router import (
|
||||
GPTRouter,
|
||||
)
|
||||
from langchain_community.chat_models.huggingface import (
|
||||
ChatHuggingFace,
|
||||
)
|
||||
from langchain_community.chat_models.human import (
|
||||
HumanInputChatModel,
|
||||
)
|
||||
from langchain_community.chat_models.hunyuan import (
|
||||
ChatHunyuan,
|
||||
)
|
||||
from langchain_community.chat_models.javelin_ai_gateway import (
|
||||
ChatJavelinAIGateway,
|
||||
)
|
||||
from langchain_community.chat_models.jinachat import (
|
||||
JinaChat,
|
||||
)
|
||||
from langchain_community.chat_models.kinetica import (
|
||||
ChatKinetica,
|
||||
)
|
||||
from langchain_community.chat_models.konko import (
|
||||
ChatKonko,
|
||||
)
|
||||
from langchain_community.chat_models.litellm import (
|
||||
ChatLiteLLM,
|
||||
)
|
||||
from langchain_community.chat_models.litellm_router import (
|
||||
ChatLiteLLMRouter,
|
||||
)
|
||||
from langchain_community.chat_models.llama_edge import (
|
||||
LlamaEdgeChatService,
|
||||
)
|
||||
from langchain_community.chat_models.llamacpp import ChatLlamaCpp
|
||||
from langchain_community.chat_models.maritalk import (
|
||||
ChatMaritalk,
|
||||
)
|
||||
from langchain_community.chat_models.minimax import (
|
||||
MiniMaxChat,
|
||||
)
|
||||
from langchain_community.chat_models.mlflow import (
|
||||
ChatMlflow,
|
||||
)
|
||||
from langchain_community.chat_models.mlflow_ai_gateway import (
|
||||
ChatMLflowAIGateway,
|
||||
)
|
||||
from langchain_community.chat_models.mlx import (
|
||||
ChatMLX,
|
||||
)
|
||||
from langchain_community.chat_models.moonshot import (
|
||||
MoonshotChat,
|
||||
)
|
||||
from langchain_community.chat_models.naver import (
|
||||
ChatClovaX,
|
||||
)
|
||||
from langchain_community.chat_models.oci_data_science import (
|
||||
ChatOCIModelDeployment,
|
||||
ChatOCIModelDeploymentTGI,
|
||||
ChatOCIModelDeploymentVLLM,
|
||||
)
|
||||
from langchain_community.chat_models.oci_generative_ai import (
|
||||
ChatOCIGenAI, # noqa: F401
|
||||
)
|
||||
from langchain_community.chat_models.octoai import ChatOctoAI
|
||||
from langchain_community.chat_models.ollama import (
|
||||
ChatOllama,
|
||||
)
|
||||
from langchain_community.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
)
|
||||
from langchain_community.chat_models.outlines import ChatOutlines
|
||||
from langchain_community.chat_models.pai_eas_endpoint import (
|
||||
PaiEasChatEndpoint,
|
||||
)
|
||||
from langchain_community.chat_models.perplexity import (
|
||||
ChatPerplexity,
|
||||
)
|
||||
from langchain_community.chat_models.premai import (
|
||||
ChatPremAI,
|
||||
)
|
||||
from langchain_community.chat_models.promptlayer_openai import (
|
||||
PromptLayerChatOpenAI,
|
||||
)
|
||||
from langchain_community.chat_models.reka import (
|
||||
ChatReka,
|
||||
)
|
||||
from langchain_community.chat_models.sambanova import (
|
||||
ChatSambaNovaCloud,
|
||||
ChatSambaStudio,
|
||||
)
|
||||
from langchain_community.chat_models.snowflake import (
|
||||
ChatSnowflakeCortex,
|
||||
)
|
||||
from langchain_community.chat_models.solar import (
|
||||
SolarChat,
|
||||
)
|
||||
from langchain_community.chat_models.sparkllm import (
|
||||
ChatSparkLLM,
|
||||
)
|
||||
from langchain_community.chat_models.symblai_nebula import ChatNebula
|
||||
from langchain_community.chat_models.tongyi import (
|
||||
ChatTongyi,
|
||||
)
|
||||
from langchain_community.chat_models.vertexai import (
|
||||
ChatVertexAI,
|
||||
)
|
||||
from langchain_community.chat_models.volcengine_maas import (
|
||||
VolcEngineMaasChat,
|
||||
)
|
||||
from langchain_community.chat_models.yandex import (
|
||||
ChatYandexGPT,
|
||||
)
|
||||
from langchain_community.chat_models.yi import (
|
||||
ChatYi,
|
||||
)
|
||||
from langchain_community.chat_models.yuan2 import (
|
||||
ChatYuan2,
|
||||
)
|
||||
from langchain_community.chat_models.zhipuai import (
|
||||
ChatZhipuAI,
|
||||
)
|
||||
__all__ = [
|
||||
"AzureChatOpenAI",
|
||||
"BedrockChat",
|
||||
"ChatAnthropic",
|
||||
"ChatAnyscale",
|
||||
"ChatBaichuan",
|
||||
"ChatClovaX",
|
||||
"ChatCohere",
|
||||
"ChatCoze",
|
||||
"ChatOctoAI",
|
||||
"ChatDatabricks",
|
||||
"ChatDeepInfra",
|
||||
"ChatEdenAI",
|
||||
"ChatEverlyAI",
|
||||
"ChatFireworks",
|
||||
"ChatFriendli",
|
||||
"ChatGooglePalm",
|
||||
"ChatHuggingFace",
|
||||
"ChatHunyuan",
|
||||
"ChatJavelinAIGateway",
|
||||
"ChatKinetica",
|
||||
"ChatKonko",
|
||||
"ChatLiteLLM",
|
||||
"ChatLiteLLMRouter",
|
||||
"ChatMLX",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatMaritalk",
|
||||
"ChatMlflow",
|
||||
"ChatNebula",
|
||||
"ChatOCIGenAI",
|
||||
"ChatOCIModelDeployment",
|
||||
"ChatOCIModelDeploymentVLLM",
|
||||
"ChatOCIModelDeploymentTGI",
|
||||
"ChatOllama",
|
||||
"ChatOpenAI",
|
||||
"ChatOutlines",
|
||||
"ChatPerplexity",
|
||||
"ChatReka",
|
||||
"ChatPremAI",
|
||||
"ChatSambaNovaCloud",
|
||||
"ChatSambaStudio",
|
||||
"ChatSparkLLM",
|
||||
"ChatSnowflakeCortex",
|
||||
"ChatTongyi",
|
||||
"ChatVertexAI",
|
||||
"ChatYandexGPT",
|
||||
"ChatYuan2",
|
||||
"ChatZhipuAI",
|
||||
"ChatLlamaCpp",
|
||||
"ErnieBotChat",
|
||||
"FakeListChatModel",
|
||||
"GPTRouter",
|
||||
"GigaChat",
|
||||
"HumanInputChatModel",
|
||||
"JinaChat",
|
||||
"LlamaEdgeChatService",
|
||||
"MiniMaxChat",
|
||||
"MoonshotChat",
|
||||
"PaiEasChatEndpoint",
|
||||
"PromptLayerChatOpenAI",
|
||||
"QianfanChatEndpoint",
|
||||
"SolarChat",
|
||||
"VolcEngineMaasChat",
|
||||
"ChatYi",
|
||||
]
|
||||
|
||||
|
||||
_module_lookup = {
|
||||
"AzureChatOpenAI": "langchain_community.chat_models.azure_openai",
|
||||
"BedrockChat": "langchain_community.chat_models.bedrock",
|
||||
"ChatAnthropic": "langchain_community.chat_models.anthropic",
|
||||
"ChatAnyscale": "langchain_community.chat_models.anyscale",
|
||||
"ChatBaichuan": "langchain_community.chat_models.baichuan",
|
||||
"ChatClovaX": "langchain_community.chat_models.naver",
|
||||
"ChatCohere": "langchain_community.chat_models.cohere",
|
||||
"ChatCoze": "langchain_community.chat_models.coze",
|
||||
"ChatDatabricks": "langchain_community.chat_models.databricks",
|
||||
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
|
||||
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
|
||||
"ChatEdenAI": "langchain_community.chat_models.edenai",
|
||||
"ChatFireworks": "langchain_community.chat_models.fireworks",
|
||||
"ChatFriendli": "langchain_community.chat_models.friendli",
|
||||
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
|
||||
"ChatHuggingFace": "langchain_community.chat_models.huggingface",
|
||||
"ChatHunyuan": "langchain_community.chat_models.hunyuan",
|
||||
"ChatJavelinAIGateway": "langchain_community.chat_models.javelin_ai_gateway",
|
||||
"ChatKinetica": "langchain_community.chat_models.kinetica",
|
||||
"ChatKonko": "langchain_community.chat_models.konko",
|
||||
"ChatLiteLLM": "langchain_community.chat_models.litellm",
|
||||
"ChatLiteLLMRouter": "langchain_community.chat_models.litellm_router",
|
||||
"ChatMLflowAIGateway": "langchain_community.chat_models.mlflow_ai_gateway",
|
||||
"ChatMLX": "langchain_community.chat_models.mlx",
|
||||
"ChatMaritalk": "langchain_community.chat_models.maritalk",
|
||||
"ChatMlflow": "langchain_community.chat_models.mlflow",
|
||||
"ChatNebula": "langchain_community.chat_models.symblai_nebula",
|
||||
"ChatOctoAI": "langchain_community.chat_models.octoai",
|
||||
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
|
||||
"ChatOCIModelDeployment": "langchain_community.chat_models.oci_data_science",
|
||||
"ChatOCIModelDeploymentVLLM": "langchain_community.chat_models.oci_data_science",
|
||||
"ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science",
|
||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||
"ChatOpenAI": "langchain_community.chat_models.openai",
|
||||
"ChatOutlines": "langchain_community.chat_models.outlines",
|
||||
"ChatReka": "langchain_community.chat_models.reka",
|
||||
"ChatPerplexity": "langchain_community.chat_models.perplexity",
|
||||
"ChatSambaNovaCloud": "langchain_community.chat_models.sambanova",
|
||||
"ChatSambaStudio": "langchain_community.chat_models.sambanova",
|
||||
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
|
||||
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
|
||||
"ChatTongyi": "langchain_community.chat_models.tongyi",
|
||||
"ChatVertexAI": "langchain_community.chat_models.vertexai",
|
||||
"ChatYandexGPT": "langchain_community.chat_models.yandex",
|
||||
"ChatYuan2": "langchain_community.chat_models.yuan2",
|
||||
"ChatZhipuAI": "langchain_community.chat_models.zhipuai",
|
||||
"ErnieBotChat": "langchain_community.chat_models.ernie",
|
||||
"FakeListChatModel": "langchain_community.chat_models.fake",
|
||||
"GPTRouter": "langchain_community.chat_models.gpt_router",
|
||||
"GigaChat": "langchain_community.chat_models.gigachat",
|
||||
"HumanInputChatModel": "langchain_community.chat_models.human",
|
||||
"JinaChat": "langchain_community.chat_models.jinachat",
|
||||
"LlamaEdgeChatService": "langchain_community.chat_models.llama_edge",
|
||||
"MiniMaxChat": "langchain_community.chat_models.minimax",
|
||||
"MoonshotChat": "langchain_community.chat_models.moonshot",
|
||||
"PaiEasChatEndpoint": "langchain_community.chat_models.pai_eas_endpoint",
|
||||
"PromptLayerChatOpenAI": "langchain_community.chat_models.promptlayer_openai",
|
||||
"SolarChat": "langchain_community.chat_models.solar",
|
||||
"QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint",
|
||||
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
|
||||
"ChatPremAI": "langchain_community.chat_models.premai",
|
||||
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
|
||||
"ChatYi": "langchain_community.chat_models.yi",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,234 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.llms.anthropic import _AnthropicCommon
|
||||
|
||||
|
||||
def _convert_one_message_to_text(
|
||||
message: BaseMessage,
|
||||
human_prompt: str,
|
||||
ai_prompt: str,
|
||||
) -> str:
|
||||
content = cast(str, message.content)
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_anthropic(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
human_prompt: str = "\n\nHuman:",
|
||||
ai_prompt: str = "\n\nAssistant:",
|
||||
) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
|
||||
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
|
||||
Returns:
|
||||
str: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
if not isinstance(messages[-1], AIMessage):
|
||||
messages.append(AIMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
_convert_one_message_to_text(message, human_prompt, ai_prompt)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.28",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_anthropic.ChatAnthropic",
|
||||
)
|
||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
"""`Anthropic` chat large language models.
|
||||
|
||||
To use, you should have the ``anthropic`` python package installed, and the
|
||||
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import anthropic
|
||||
from langchain_community.chat_models import ChatAnthropic
|
||||
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "anthropic-chat"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "anthropic"]
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
Returns:
|
||||
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
||||
"""
|
||||
prompt_params = {}
|
||||
if self.HUMAN_PROMPT:
|
||||
prompt_params["human_prompt"] = self.HUMAN_PROMPT
|
||||
if self.AI_PROMPT:
|
||||
prompt_params["ai_prompt"] = self.AI_PROMPT
|
||||
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||
|
||||
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||
return self._convert_messages_to_prompt(prompt.to_messages())
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
stream_resp = self.client.completions.create(**params, stream=True)
|
||||
for data in stream_resp:
|
||||
delta = data.completion
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
stream_resp = await self.async_client.completions.create(**params, stream=True)
|
||||
async for data in stream_resp:
|
||||
delta = data.completion
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = await self.async_client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
if not self.count_tokens:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
return self.count_tokens(text)
|
||||
@@ -0,0 +1,243 @@
|
||||
"""Anyscale Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
from langchain_community.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
)
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
|
||||
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
|
||||
class ChatAnyscale(ChatOpenAI):
|
||||
"""`Anyscale` Chat large language models.
|
||||
|
||||
See https://www.anyscale.com/ for information about Anyscale.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``ANYSCALE_API_KEY`` set with your API key.
|
||||
Alternatively, you can use the anyscale_api_key keyword argument.
|
||||
|
||||
Any parameters that are valid to be passed to the `openai.create` call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatAnyscale
|
||||
chat = ChatAnyscale(model_name="meta-llama/Llama-2-7b-chat-hf")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "anyscale-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
anyscale_api_key: SecretStr = Field(default=SecretStr(""))
|
||||
"""AnyScale Endpoints API keys."""
|
||||
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
|
||||
"""Model name to use."""
|
||||
anyscale_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
anyscale_proxy: Optional[str] = None
|
||||
"""To support explicit proxy for Anyscale."""
|
||||
available_models: Optional[Set[str]] = None
|
||||
"""Available models from Anyscale API."""
|
||||
|
||||
@staticmethod
|
||||
def get_available_models(
|
||||
anyscale_api_key: Optional[str] = None,
|
||||
anyscale_api_base: str = DEFAULT_API_BASE,
|
||||
) -> Set[str]:
|
||||
"""Get available models from Anyscale API."""
|
||||
try:
|
||||
anyscale_api_key = anyscale_api_key or os.environ["ANYSCALE_API_KEY"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Anyscale API key must be passed as keyword argument or "
|
||||
"set in environment variable ANYSCALE_API_KEY.",
|
||||
) from e
|
||||
|
||||
models_url = f"{anyscale_api_base}/models"
|
||||
models_response = requests.get(
|
||||
models_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {anyscale_api_key}",
|
||||
},
|
||||
)
|
||||
|
||||
if models_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error getting models from {models_url}: "
|
||||
f"{models_response.status_code}",
|
||||
)
|
||||
|
||||
return {model["id"] for model in models_response.json()["data"]}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anyscale_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_key",
|
||||
"ANYSCALE_API_KEY",
|
||||
)
|
||||
)
|
||||
values["anyscale_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_base",
|
||||
"ANYSCALE_API_BASE",
|
||||
default=DEFAULT_API_BASE,
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_proxy",
|
||||
"ANYSCALE_PROXY",
|
||||
default="",
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`.",
|
||||
) from e
|
||||
try:
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": values["anyscale_api_key"].get_secret_value(),
|
||||
"base_url": values["anyscale_api_base"],
|
||||
# To do: future support
|
||||
# "organization": values["openai_organization"],
|
||||
# "timeout": values["request_timeout"],
|
||||
# "max_retries": values["max_retries"],
|
||||
# "default_headers": values["default_headers"],
|
||||
# "default_query": values["default_query"],
|
||||
# "http_client": values["http_client"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
values["client"] = openai.OpenAI(**client_params).chat.completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params
|
||||
).chat.completions
|
||||
else:
|
||||
values["openai_api_base"] = values["anyscale_api_base"]
|
||||
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError as exc:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`.",
|
||||
) from exc
|
||||
|
||||
if "model_name" not in values.keys():
|
||||
values["model_name"] = DEFAULT_MODEL
|
||||
|
||||
model_name = values["model_name"]
|
||||
available_models = cls.get_available_models(
|
||||
values["anyscale_api_key"].get_secret_value(),
|
||||
values["anyscale_api_base"],
|
||||
)
|
||||
|
||||
if model_name not in available_models:
|
||||
raise ValueError(
|
||||
f"Model name {model_name} not found in available models: "
|
||||
f"{available_models}.",
|
||||
)
|
||||
|
||||
values["available_models"] = available_models
|
||||
|
||||
return values
|
||||
|
||||
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
||||
tiktoken_ = _import_tiktoken()
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
model = self.model_name
|
||||
# Returns the number of tokens used by a list of messages.
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
return model, encoding
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens with tiktoken package.
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||
"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||
)
|
||||
if sys.version_info[1] <= 7:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
model, encoding = self._get_encoding_model()
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
@@ -0,0 +1,293 @@
|
||||
"""Azure OpenAI chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.10",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_openai.AzureChatOpenAI",
|
||||
)
|
||||
class AzureChatOpenAI(ChatOpenAI):
|
||||
"""`Azure OpenAI` Chat Completion API.
|
||||
|
||||
To use this class you
|
||||
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
|
||||
constructor to refer to the "Model deployment name" in the Azure portal.
|
||||
|
||||
In addition, you should have the ``openai`` python package installed, and the
|
||||
following environment variables set or passed in constructor in lower case:
|
||||
- ``AZURE_OPENAI_API_KEY``
|
||||
- ``AZURE_OPENAI_ENDPOINT``
|
||||
- ``AZURE_OPENAI_AD_TOKEN``
|
||||
- ``OPENAI_API_VERSION``
|
||||
- ``OPENAI_PROXY``
|
||||
|
||||
For example, if you have `gpt-35-turbo` deployed, with the deployment name
|
||||
`35-turbo-dev`, the constructor should look like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AzureChatOpenAI(
|
||||
azure_deployment="35-turbo-dev",
|
||||
openai_api_version="2023-05-15",
|
||||
)
|
||||
|
||||
Be aware the API version may change.
|
||||
|
||||
You can also specify the version of the model using ``model_version`` constructor
|
||||
parameter, as Azure OpenAI doesn't return model version with the response.
|
||||
|
||||
Default is empty. When you specify the version, it will be appended to the
|
||||
model name in the response. Setting correct version will help you to calculate the
|
||||
cost properly. Model version is not validated, so make sure you set it correctly
|
||||
to get the correct cost.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
"""
|
||||
|
||||
azure_endpoint: Union[str, None] = None
|
||||
"""Your Azure endpoint, including the resource.
|
||||
|
||||
Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
|
||||
|
||||
Example: `https://example-resource.azure.openai.com/`
|
||||
"""
|
||||
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
|
||||
"""A model deployment.
|
||||
|
||||
If given sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Note: this means you won't be able to use non-deployment endpoints.
|
||||
"""
|
||||
openai_api_version: str = Field(default="", alias="api_version")
|
||||
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
|
||||
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
|
||||
azure_ad_token: Union[str, None] = None
|
||||
"""Your Azure Active Directory token.
|
||||
|
||||
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
|
||||
|
||||
For more:
|
||||
https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id.
|
||||
"""
|
||||
azure_ad_token_provider: Union[Callable[[], str], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
Will be invoked on every sync request. For async requests,
|
||||
will be invoked if `azure_ad_async_token_provider` is not provided.
|
||||
"""
|
||||
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
|
||||
"""A function that returns an Azure Active Directory token.
|
||||
|
||||
Will be invoked on every async request.
|
||||
"""
|
||||
model_version: str = ""
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
openai_api_type: str = ""
|
||||
"""Legacy, for openai<1.0.0 support."""
|
||||
validate_base_url: bool = True
|
||||
"""For backwards compatibility. If legacy val openai_api_base is passed in, try to
|
||||
infer if it is a base_url or azure_endpoint and update accordingly.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "azure_openai"]
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
# Check OPENAI_KEY for backwards compatibility.
|
||||
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
|
||||
# other forms of azure credentials.
|
||||
values["openai_api_key"] = (
|
||||
values["openai_api_key"]
|
||||
or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
|
||||
"OPENAI_API_BASE"
|
||||
)
|
||||
values["openai_api_version"] = values["openai_api_version"] or os.getenv(
|
||||
"OPENAI_API_VERSION"
|
||||
)
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
values["openai_organization"] = (
|
||||
values["openai_organization"]
|
||||
or os.getenv("OPENAI_ORG_ID")
|
||||
or os.getenv("OPENAI_ORGANIZATION")
|
||||
)
|
||||
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
|
||||
"AZURE_OPENAI_ENDPOINT"
|
||||
)
|
||||
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
|
||||
"AZURE_OPENAI_AD_TOKEN"
|
||||
)
|
||||
|
||||
values["openai_api_type"] = get_from_dict_or_env(
|
||||
values, "openai_api_type", "OPENAI_API_TYPE", default="azure"
|
||||
)
|
||||
values["openai_proxy"] = get_from_dict_or_env(
|
||||
values, "openai_proxy", "OPENAI_PROXY", default=""
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if is_openai_v1():
|
||||
# For backwards compatibility. Before openai v1, no distinction was made
|
||||
# between azure_endpoint and base_url (openai_api_base).
|
||||
openai_api_base = values["openai_api_base"]
|
||||
if openai_api_base and values["validate_base_url"]:
|
||||
if "/openai" not in openai_api_base:
|
||||
values["openai_api_base"] = (
|
||||
values["openai_api_base"].rstrip("/") + "/openai"
|
||||
)
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, Azure endpoints should be specified via "
|
||||
f"the `azure_endpoint` param not `openai_api_base` "
|
||||
f"(or alias `base_url`). Updating `openai_api_base` from "
|
||||
f"{openai_api_base} to {values['openai_api_base']}."
|
||||
)
|
||||
if values["deployment_name"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `deployment_name` (or alias "
|
||||
"`azure_deployment`) is specified then "
|
||||
"`openai_api_base` (or alias `base_url`) should not be. "
|
||||
"Instead use `deployment_name` (or alias `azure_deployment`) "
|
||||
"and `azure_endpoint`."
|
||||
)
|
||||
if values["deployment_name"] not in values["openai_api_base"]:
|
||||
warnings.warn(
|
||||
"As of openai>=1.0.0, if `openai_api_base` "
|
||||
"(or alias `base_url`) is specified it is expected to be "
|
||||
"of the form "
|
||||
"https://example-resource.azure.openai.com/openai/deployments/example-deployment. " # noqa: E501
|
||||
f"Updating {openai_api_base} to "
|
||||
f"{values['openai_api_base']}."
|
||||
)
|
||||
values["openai_api_base"] += (
|
||||
"/deployments/" + values["deployment_name"]
|
||||
)
|
||||
values["deployment_name"] = None
|
||||
client_params = {
|
||||
"api_version": values["openai_api_version"],
|
||||
"azure_endpoint": values["azure_endpoint"],
|
||||
"azure_deployment": values["deployment_name"],
|
||||
"api_key": values["openai_api_key"],
|
||||
"azure_ad_token": values["azure_ad_token"],
|
||||
"azure_ad_token_provider": values["azure_ad_token_provider"],
|
||||
"organization": values["openai_organization"],
|
||||
"base_url": values["openai_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": {
|
||||
**(values["default_headers"] or {}),
|
||||
"User-Agent": "langchain-comm-python-azure-openai",
|
||||
},
|
||||
"default_query": values["default_query"],
|
||||
"http_client": values["http_client"],
|
||||
}
|
||||
values["client"] = openai.AzureOpenAI(**client_params).chat.completions
|
||||
|
||||
azure_ad_async_token_provider = values["azure_ad_async_token_provider"]
|
||||
|
||||
if azure_ad_async_token_provider:
|
||||
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider
|
||||
|
||||
values["async_client"] = openai.AsyncAzureOpenAI(
|
||||
**client_params
|
||||
).chat.completions
|
||||
else:
|
||||
values["client"] = openai.ChatCompletion
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
if is_openai_v1():
|
||||
return super()._default_params
|
||||
else:
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**self._default_params}
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the config params used for the openai client."""
|
||||
if is_openai_v1():
|
||||
return super()._client_params
|
||||
else:
|
||||
return {
|
||||
**super()._client_params,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "azure-openai-chat"
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_api_type": self.openai_api_type,
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
for res in response["choices"]:
|
||||
if res.get("finish_reason", None) == "content_filter":
|
||||
raise ValueError(
|
||||
"Azure has not provided the response due to a content filter "
|
||||
"being triggered"
|
||||
)
|
||||
chat_result = super()._create_chat_result(response)
|
||||
|
||||
if "model" in response:
|
||||
model = response["model"]
|
||||
if self.model_version:
|
||||
model = f"{model}-{self.model_version}"
|
||||
|
||||
if chat_result.llm_output is not None and isinstance(
|
||||
chat_result.llm_output, dict
|
||||
):
|
||||
chat_result.llm_output["model_name"] = model
|
||||
|
||||
return chat_result
|
||||
@@ -0,0 +1,426 @@
|
||||
import json
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLBaseEndpoint,
|
||||
AzureMLEndpointApiType,
|
||||
ContentFormatterBase,
|
||||
)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for `LLaMA`."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||
"`CustomOpenAIContentFormatter` instead."
|
||||
)
|
||||
|
||||
|
||||
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
||||
"""Chat Content formatter for models with OpenAI like API scheme."""
|
||||
|
||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict:
|
||||
"""Converts a message to a dict according to a role"""
|
||||
content = cast(str, message.content)
|
||||
if isinstance(message, HumanMessage):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": ContentFormatterBase.escape_special_characters(content),
|
||||
}
|
||||
elif isinstance(message, AIMessage):
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": ContentFormatterBase.escape_special_characters(content),
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return {
|
||||
"role": "system",
|
||||
"content": ContentFormatterBase.escape_special_characters(content),
|
||||
}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": ContentFormatterBase.escape_special_characters(content),
|
||||
}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_messages_request_payload(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": chat_messages,
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||
) -> ChatGeneration:
|
||||
"""Formats response"""
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
try:
|
||||
choice = json.loads(output)["output"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=choice.strip(),
|
||||
),
|
||||
generation_info=None,
|
||||
)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=AIMessage(content=choice["message"]["content"].strip())
|
||||
if choice["message"]["role"] == "assistant"
|
||||
else BaseMessage(
|
||||
content=choice["message"]["content"].strip(),
|
||||
type=choice["message"]["role"],
|
||||
),
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
)
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
Chat Content formatter for Llama."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"""`LlamaChatContentFormatter` will be deprecated in the future.
|
||||
Please use `CustomOpenAIChatContentFormatter` instead.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
||||
"""Content formatter for `Mistral`."""
|
||||
|
||||
def format_messages_request_payload(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [self._convert_message_to_dict(message) for message in messages]
|
||||
|
||||
if chat_messages and chat_messages[0]["role"] == "system":
|
||||
# Mistral OSS models do not explicitly support system prompts, so we have to
|
||||
# stash in the first user prompt
|
||||
chat_messages[1]["content"] = (
|
||||
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
|
||||
)
|
||||
del chat_messages[0]
|
||||
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": chat_messages,
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
|
||||
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
"""Azure ML Online Endpoint chat models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
|
||||
endpoint_api_type=AzureMLApiType.serverless,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=chat_content_formatter,
|
||||
)
|
||||
"""
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "azureml_chat_endpoint"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
messages: The messages in the conversation with the chat model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = azureml_model.invoke("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs.update(kwargs)
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
|
||||
request_payload = self.content_formatter.format_messages_request_payload(
|
||||
messages, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(
|
||||
body=request_payload, run_manager=run_manager
|
||||
)
|
||||
generations = self.content_formatter.format_response_payload(
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
return ChatResult(generations=[generations])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||
|
||||
import openai
|
||||
|
||||
params = {}
|
||||
client_params = {
|
||||
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||
"base_url": self.endpoint_url,
|
||||
"timeout": timeout,
|
||||
"default_headers": None,
|
||||
"default_query": None,
|
||||
"http_client": None,
|
||||
}
|
||||
|
||||
client = openai.OpenAI(**client_params)
|
||||
message_dicts = [
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||
for m in messages
|
||||
]
|
||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in client.chat.completions.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"],
|
||||
default_chunk_class,
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk,
|
||||
generation_info=generation_info or None,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||
|
||||
import openai
|
||||
|
||||
params = {}
|
||||
client_params = {
|
||||
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||
"base_url": self.endpoint_url,
|
||||
"timeout": timeout,
|
||||
"default_headers": None,
|
||||
"default_query": None,
|
||||
"http_client": None,
|
||||
}
|
||||
|
||||
async_client = openai.AsyncOpenAI(**client_params)
|
||||
message_dicts = [
|
||||
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||
for m in messages
|
||||
]
|
||||
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await async_client.chat.completions.create(
|
||||
messages=message_dicts,
|
||||
**params,
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
@@ -0,0 +1,650 @@
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from langchain_community.chat_models.llamacpp import (
|
||||
_lc_invalid_tool_call_to_openai_tool_call,
|
||||
_lc_tool_call_to_openai_tool_call,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
content = message.content
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": content}
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
|
||||
elif message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
] + [
|
||||
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"content": content,
|
||||
"name": message.name or message.additional_kwargs.get("name"),
|
||||
}
|
||||
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict.get("content", "")
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
additional_kwargs = {}
|
||||
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=_dict.get("tool_call_id"),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def aconnect_httpx_sse(
|
||||
client: Any, method: str, url: str, **kwargs: Any
|
||||
) -> AsyncIterator:
|
||||
"""Async context manager for connecting to an SSE stream.
|
||||
|
||||
Args:
|
||||
client: The httpx client.
|
||||
method: The HTTP method.
|
||||
url: The URL to connect to.
|
||||
kwargs: Additional keyword arguments to pass to the client.
|
||||
|
||||
Yields:
|
||||
An EventSource object.
|
||||
"""
|
||||
from httpx_sse import EventSource
|
||||
|
||||
async with client.stream(method, url, **kwargs) as response:
|
||||
yield EventSource(response)
|
||||
|
||||
|
||||
class ChatBaichuan(BaseChatModel):
|
||||
"""Baichuan chat model integration.
|
||||
|
||||
Setup:
|
||||
To use, you should have the environment variable``BAICHUAN_API_KEY`` set with
|
||||
your API KEY.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export BAICHUAN_API_KEY="your-api-key"
|
||||
|
||||
Key init args — completion params:
|
||||
model: Optional[str]
|
||||
Name of Baichuan model to use.
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
streaming: Optional[bool]
|
||||
Whether to stream the results or not.
|
||||
temperature: Optional[float]
|
||||
Sampling temperature.
|
||||
top_p: Optional[float]
|
||||
What probability mass to use.
|
||||
top_k: Optional[int]
|
||||
What search sampling control to use.
|
||||
|
||||
Key init args — client params:
|
||||
api_key: Optional[str]
|
||||
Baichuan API key. If not passed in will be read from env var BAICHUAN_API_KEY.
|
||||
base_url: Optional[str]
|
||||
Base URL for API requests.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatBaichuan
|
||||
|
||||
chat = ChatBaichuan(
|
||||
api_key=api_key,
|
||||
model='Baichuan4',
|
||||
# temperature=...,
|
||||
# other params...
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
|
||||
("human", "我喜欢编程。"),
|
||||
]
|
||||
chat.invoke(messages)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content='I enjoy programming.',
|
||||
response_metadata={
|
||||
'token_usage': {
|
||||
'prompt_tokens': 93,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 98
|
||||
},
|
||||
'model': 'Baichuan4'
|
||||
},
|
||||
id='run-944ff552-6a93-44cf-a861-4e4d849746f9-0'
|
||||
)
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in chat.stream(messages):
|
||||
print(chunk)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
content='I' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8'
|
||||
content=' enjoy programming.' id='run-f99fcd6f-dd31-46d5-be8f-0b6a22bf77d8
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = chat.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
full
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessageChunk(
|
||||
content='I like programming.',
|
||||
id='run-74689970-dc31-461d-b729-3b6aa93508d2'
|
||||
)
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
await chat.ainvoke(messages)
|
||||
|
||||
# stream
|
||||
# async for chunk in chat.astream(messages):
|
||||
# print(chunk)
|
||||
|
||||
# batch
|
||||
# await chat.abatch([messages])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content='I enjoy programming.',
|
||||
response_metadata={
|
||||
'token_usage': {
|
||||
'prompt_tokens': 93,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 98
|
||||
},
|
||||
'model': 'Baichuan4'
|
||||
},
|
||||
id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0'
|
||||
)
|
||||
Tool calling:
|
||||
|
||||
.. code-block:: python
|
||||
class get_current_weather(BaseModel):
|
||||
'''Get current weather.'''
|
||||
|
||||
location: str = Field('City or province, such as Shanghai')
|
||||
|
||||
|
||||
llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather])
|
||||
llm_with_tools.invoke('How is the weather today?')
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[{'name': 'get_current_weather',
|
||||
'args': {'location': 'New York'},
|
||||
'id': '3951017OF8doB0A',
|
||||
'type': 'tool_call'}]
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
ai_msg = chat.invoke(messages)
|
||||
ai_msg.response_metadata
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'token_usage': {
|
||||
'prompt_tokens': 93,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 98
|
||||
},
|
||||
'model': 'Baichuan4'
|
||||
}
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"baichuan_api_key": "BAICHUAN_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
baichuan_api_base: str = Field(default=DEFAULT_API_BASE, alias="base_url")
|
||||
"""Baichuan custom endpoints"""
|
||||
baichuan_api_key: SecretStr = Field(alias="api_key")
|
||||
"""Baichuan API Key"""
|
||||
baichuan_secret_key: Optional[SecretStr] = None
|
||||
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
request_timeout: int = Field(default=60, alias="timeout")
|
||||
"""request timeout for chat http requests"""
|
||||
model: str = "Baichuan2-Turbo-192K"
|
||||
"""model name of Baichuan, default is `Baichuan2-Turbo-192K`,
|
||||
other options include `Baichuan2-Turbo`"""
|
||||
temperature: Optional[float] = Field(default=0.3)
|
||||
"""What sampling temperature to use."""
|
||||
top_k: int = 5
|
||||
"""What search sampling control to use."""
|
||||
top_p: float = 0.85
|
||||
"""What probability mass to use."""
|
||||
with_search_enhance: bool = False
|
||||
"""[DEPRECATED, keeping it for for backward compatibility],
|
||||
Whether to use search enhance, default is False."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["baichuan_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"baichuan_api_base",
|
||||
"BAICHUAN_API_BASE",
|
||||
DEFAULT_API_BASE,
|
||||
)
|
||||
values["baichuan_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
["baichuan_api_key", "api_key"],
|
||||
"BAICHUAN_API_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Baichuan API."""
|
||||
normal_params = {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"stream": self.streaming,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from Baichuan api response: {res}")
|
||||
response = res.json()
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, stream=True, **kwargs)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from Baichuan api response: {res}")
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in res.iter_lines():
|
||||
chunk = chunk.decode("utf-8").strip("\r\n")
|
||||
parts = chunk.split("data: ", 1)
|
||||
chunk = parts[1] if len(parts) > 1 else None
|
||||
if chunk is None:
|
||||
continue
|
||||
if chunk == "[DONE]":
|
||||
break
|
||||
response = json.loads(chunk)
|
||||
for m in response.get("choices"):
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
m.get("delta"), default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
response = await client.post(self.baichuan_api_base, json=payload)
|
||||
response.raise_for_status()
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
payload = self._create_payload_parameters(messages, stream=True, **kwargs)
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers=headers, timeout=self.request_timeout
|
||||
) as client:
|
||||
async with aconnect_httpx_sse(
|
||||
client, "POST", self.baichuan_api_base, json=payload
|
||||
) as event_source:
|
||||
async for sse in event_source.aiter_sse():
|
||||
chunk = json.loads(sse.data)
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], AIMessageChunk
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
|
||||
generation_info = (
|
||||
{"finish_reason": finish_reason}
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
if finish_reason is not None:
|
||||
break
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
payload = self._create_payload_parameters(messages, **kwargs)
|
||||
url = self.baichuan_api_base
|
||||
headers = self._create_headers_parameters(**kwargs)
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
timeout=self.request_timeout,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=self.streaming,
|
||||
)
|
||||
return res
|
||||
|
||||
def _create_payload_parameters(
|
||||
self, messages: List[BaseMessage], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
temperature = parameters.pop("temperature", 0.3)
|
||||
top_k = parameters.pop("top_k", 5)
|
||||
top_p = parameters.pop("top_p", 0.85)
|
||||
model = parameters.pop("model")
|
||||
with_search_enhance = parameters.pop("with_search_enhance", False)
|
||||
stream = parameters.pop("stream", False)
|
||||
tools = parameters.pop("tools", [])
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"with_search_enhance": with_search_enhance,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def _create_headers_parameters(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
default_headers = parameters.pop("headers", {})
|
||||
api_key = ""
|
||||
if self.baichuan_api_key:
|
||||
api_key = self.baichuan_api_key.get_secret_value()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
**default_headers,
|
||||
}
|
||||
return headers
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for c in response["choices"]:
|
||||
message = _convert_dict_to_message(c["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
|
||||
token_usage = response["usage"]
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "baichuan-chat"
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool.
|
||||
Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
@@ -0,0 +1,841 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import get_fields, is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dictionary that can be passed to the API."""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if len(message.tool_calls) != 0:
|
||||
tool_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||
}
|
||||
# If function call only, content is None not empty string
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, (FunctionMessage, ToolMessage)):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": _create_tool_content(message.content),
|
||||
"name": message.name or message.additional_kwargs.get("name"),
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _create_tool_content(content: Union[str, List[Union[str, Dict[Any, Any]]]]) -> str:
|
||||
"""Convert tool content to dict scheme."""
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
if isinstance(json.loads(content), dict):
|
||||
return content
|
||||
else:
|
||||
return json.dumps({"tool_result": content})
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"tool_result": content})
|
||||
else:
|
||||
return json.dumps({"tool_result": content})
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("result", "") or ""
|
||||
additional_kwargs: Mapping[str, Any] = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
if "thoughts" in additional_kwargs["function_call"]:
|
||||
# align to api sample, which affects the llm function_call output
|
||||
additional_kwargs["function_call"].pop("thoughts")
|
||||
|
||||
# DO NOT ADD ANY NUMERIC OBJECT TO `msg_additional_kwargs` AND `additional_kwargs`
|
||||
# ALONG WITH THEIRS SUB-CONTAINERS !!!
|
||||
# OR IT WILL RAISE A DEADLY EXCEPTION FROM `merge_dict`
|
||||
# 不要往 `msg_additional_kwargs` 和 `additional_kwargs` 里面加任何数值类对象!
|
||||
# 子容器也不行!
|
||||
# 不然 `merge_dict` 会报错导致代码无法运行
|
||||
additional_kwargs = {**_dict.get("body", {}), **additional_kwargs}
|
||||
msg_additional_kwargs = dict(
|
||||
finish_reason=additional_kwargs.get("finish_reason", ""),
|
||||
request_id=additional_kwargs["id"],
|
||||
object=additional_kwargs.get("object", ""),
|
||||
search_info=additional_kwargs.get("search_info", []),
|
||||
)
|
||||
|
||||
if additional_kwargs.get("function_call", {}):
|
||||
msg_additional_kwargs["function_call"] = additional_kwargs.get(
|
||||
"function_call", {}
|
||||
)
|
||||
msg_additional_kwargs["tool_calls"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": additional_kwargs.get("function_call", {}),
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
ret = AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=msg_additional_kwargs,
|
||||
)
|
||||
|
||||
if usage := additional_kwargs.get("usage", None):
|
||||
ret.usage_metadata = UsageMetadata(
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class QianfanChatEndpoint(BaseChatModel):
|
||||
"""Baidu Qianfan chat model integration.
|
||||
|
||||
Setup:
|
||||
Install ``qianfan`` and set environment variables ``QIANFAN_AK``, ``QIANFAN_SK``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install qianfan
|
||||
export QIANFAN_AK="your-api-key"
|
||||
export QIANFAN_SK="your-secret_key"
|
||||
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
Name of Qianfan model to use.
|
||||
temperature: Optional[float]
|
||||
Sampling temperature.
|
||||
endpoint: Optional[str]
|
||||
Endpoint of the Qianfan LLM
|
||||
top_p: Optional[float]
|
||||
What probability mass to use.
|
||||
|
||||
Key init args — client params:
|
||||
timeout: Optional[int]
|
||||
Timeout for requests.
|
||||
api_key: Optional[str]
|
||||
Qianfan API KEY. If not passed in will be read from env var QIANFAN_AK.
|
||||
secret_key: Optional[str]
|
||||
Qianfan SECRET KEY. If not passed in will be read from env var QIANFAN_SK.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
qianfan_chat = QianfanChatEndpoint(
|
||||
model="ERNIE-3.5-8K",
|
||||
temperature=0.2,
|
||||
timeout=30,
|
||||
# api_key="...",
|
||||
# secret_key="...",
|
||||
# top_p="...",
|
||||
# other params...
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"),
|
||||
("human", "我喜欢编程。"),
|
||||
]
|
||||
qianfan_chat.invoke(messages)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(content='I enjoy programming.', additional_kwargs={'finish_reason': 'normal', 'request_id': 'as-7848zeqn1c', 'object': 'chat.completion', 'search_info': []}, response_metadata={'token_usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}, 'model_name': 'ERNIE-3.5-8K', 'finish_reason': 'normal', 'id': 'as-7848zeqn1c', 'object': 'chat.completion', 'created': 1719153606, 'result': 'I enjoy programming.', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}}, id='run-4bca0c10-5043-456b-a5be-2f62a980f3f0-0')
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in qianfan_chat.stream(messages):
|
||||
print(chunk)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
content='I enjoy' response_metadata={'finish_reason': 'normal', 'request_id': 'as-yz0yz1w1rq', 'object': 'chat.completion', 'search_info': []} id='run-0fa9da50-003e-4a26-ba16-dbfe96249b8b' role='assistant'
|
||||
content=' programming.' response_metadata={'finish_reason': 'normal', 'request_id': 'as-yz0yz1w1rq', 'object': 'chat.completion', 'search_info': []} id='run-0fa9da50-003e-4a26-ba16-dbfe96249b8b' role='assistant'
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = chat.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
full
|
||||
|
||||
.. code-block::
|
||||
|
||||
AIMessageChunk(content='I enjoy programming.', response_metadata={'finish_reason': 'normalnormal', 'request_id': 'as-p63cnn3ppnas-p63cnn3ppn', 'object': 'chat.completionchat.completion', 'search_info': []}, id='run-09a8cbbd-5ded-4529-981d-5bc9d1206404')
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
await qianfan_chat.ainvoke(messages)
|
||||
|
||||
# stream:
|
||||
# async for chunk in qianfan_chat.astream(messages):
|
||||
# print(chunk)
|
||||
|
||||
# batch:
|
||||
# await qianfan_chat.abatch([messages])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[AIMessage(content='I enjoy programming.', additional_kwargs={'finish_reason': 'normal', 'request_id': 'as-mpqa8qa1qb', 'object': 'chat.completion', 'search_info': []}, response_metadata={'token_usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}, 'model_name': 'ERNIE-3.5-8K', 'finish_reason': 'normal', 'id': 'as-mpqa8qa1qb', 'object': 'chat.completion', 'created': 1719155120, 'result': 'I enjoy programming.', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 16, 'completion_tokens': 4, 'total_tokens': 20}}, id='run-443b2231-08f9-4725-b807-b77d0507ad44-0')]
|
||||
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
chat_with_tools = qianfan_chat.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = chat_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
ai_msg.tool_calls
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
{
|
||||
'name': 'GetWeather',
|
||||
'args': {'location': 'Los Angeles, CA'},
|
||||
'id': '533e5f63-a3dc-40f2-9d9c-22b1feee62e0'
|
||||
}
|
||||
]
|
||||
|
||||
Structured output:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
'''Joke to tell user.'''
|
||||
|
||||
setup: str = Field(description="The setup of the joke")
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
|
||||
|
||||
|
||||
structured_chat = qianfan_chat.with_structured_output(Joke)
|
||||
structured_chat.invoke("Tell me a joke about cats")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Joke(
|
||||
setup='A cat is sitting in front of a mirror and sees another cat. What does the cat think?',
|
||||
punchline="The cat doesn't think it's another cat, it thinks it's another mirror.",
|
||||
rating=None
|
||||
)
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
ai_msg = qianfan_chat.invoke(messages)
|
||||
ai_msg.response_metadata
|
||||
|
||||
.. code-block:: python
|
||||
{
|
||||
'token_usage': {
|
||||
'prompt_tokens': 16,
|
||||
'completion_tokens': 4,
|
||||
'total_tokens': 20},
|
||||
'model_name': 'ERNIE-3.5-8K',
|
||||
'finish_reason': 'normal',
|
||||
'id': 'as-qbzwtydqmi',
|
||||
'object': 'chat.completion',
|
||||
'created': 1719158153,
|
||||
'result': 'I enjoy programming.',
|
||||
'is_truncated': False,
|
||||
'need_clear_history': False,
|
||||
'usage': {
|
||||
'prompt_tokens': 16,
|
||||
'completion_tokens': 4,
|
||||
'total_tokens': 20
|
||||
}
|
||||
}
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra params for model invoke using with `do`."""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
# It could be empty due to the use of Console API
|
||||
# And they're not list here
|
||||
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Qianfan API KEY"""
|
||||
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
|
||||
"""Qianfan SECRET KEY"""
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
request_timeout: Optional[int] = Field(60, alias="timeout")
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
"""What probability mass to use."""
|
||||
temperature: Optional[float] = 0.95
|
||||
"""What sampling temperature to use."""
|
||||
penalty_score: Optional[float] = 1
|
||||
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
|
||||
In the case of other model, passing these params will not affect the result.
|
||||
"""
|
||||
|
||||
model: Optional[str] = Field(default=None)
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set.
|
||||
Default is set by `qianfan` SDK, not here
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["qianfan_ak"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, ["qianfan_ak", "api_key"], "QIANFAN_AK", default=""
|
||||
)
|
||||
)
|
||||
values["qianfan_sk"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, ["qianfan_sk", "secret_key"], "QIANFAN_SK", default=""
|
||||
)
|
||||
)
|
||||
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in get_fields(cls).items()
|
||||
if field.default is not None
|
||||
}
|
||||
default_values.update(values)
|
||||
params = {
|
||||
**values.get("init_kwargs", {}),
|
||||
"model": default_values.get("model"),
|
||||
"stream": default_values.get("streaming"),
|
||||
}
|
||||
if values["qianfan_ak"].get_secret_value() != "":
|
||||
params["ak"] = values["qianfan_ak"].get_secret_value()
|
||||
if values["qianfan_sk"].get_secret_value() != "":
|
||||
params["sk"] = values["qianfan_sk"].get_secret_value()
|
||||
if (
|
||||
default_values.get("endpoint") is not None
|
||||
and default_values["endpoint"] != ""
|
||||
):
|
||||
params["endpoint"] = default_values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
values["client"] = qianfan.ChatCompletion(**params)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat_model."""
|
||||
return "baidu-qianfan-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Qianfan API."""
|
||||
normal_params = {
|
||||
"model": self.model,
|
||||
"endpoint": self.endpoint,
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a list of messages into a dictionary containing the message content
|
||||
and default parameters.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The list of messages.
|
||||
**kwargs (Any): Optional arguments to add additional parameters to the
|
||||
resulting dictionary.
|
||||
|
||||
Returns:
|
||||
`dict` containing the message content and default parameters.
|
||||
|
||||
"""
|
||||
messages_dict: Dict[str, Any] = {
|
||||
"messages": [
|
||||
convert_message_to_dict(m)
|
||||
for m in messages
|
||||
if not isinstance(m, SystemMessage)
|
||||
]
|
||||
}
|
||||
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
|
||||
if "system" not in messages_dict:
|
||||
messages_dict["system"] = ""
|
||||
messages_dict["system"] += cast(str, messages[i].content) + "\n"
|
||||
|
||||
return {
|
||||
**messages_dict,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
messages: The messages to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = qianfan_model.invoke("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
chat_generation_info: Dict = {}
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
chat_generation_info = (
|
||||
chunk.generation_info
|
||||
if chunk.generation_info is not None
|
||||
else chat_generation_info
|
||||
)
|
||||
completion += chunk.text
|
||||
if isinstance(chunk.message, AIMessageChunk):
|
||||
usage_metadata = chunk.message.usage_metadata
|
||||
|
||||
lc_msg = AIMessage(
|
||||
content=completion,
|
||||
additional_kwargs={},
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={
|
||||
"token_usage": usage_metadata or {},
|
||||
"model_name": self.model,
|
||||
},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stop"] = stop
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=[gen], llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
chat_generation_info: Dict = {}
|
||||
usage_metadata: Optional[UsageMetadata] = None
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
chat_generation_info = (
|
||||
chunk.generation_info
|
||||
if chunk.generation_info is not None
|
||||
else chat_generation_info
|
||||
)
|
||||
completion += chunk.text
|
||||
|
||||
if isinstance(chunk.message, AIMessageChunk):
|
||||
usage_metadata = chunk.message.usage_metadata
|
||||
|
||||
lc_msg = AIMessage(
|
||||
content=completion,
|
||||
additional_kwargs={},
|
||||
usage_metadata=usage_metadata,
|
||||
)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={
|
||||
"token_usage": usage_metadata or {},
|
||||
"model_name": self.model,
|
||||
},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stop"] = stop
|
||||
response_payload = await self.client.ado(**params)
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
generations = []
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stop"] = stop
|
||||
params["stream"] = True
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk( # type: ignore[call-arg]
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=msg.usage_metadata,
|
||||
tool_call_chunks=[
|
||||
tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
],
|
||||
),
|
||||
generation_info=msg.additional_kwargs,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
params["stop"] = stop
|
||||
params["stream"] = True
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
additional_kwargs = msg.additional_kwargs.get("function_call", {})
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=AIMessageChunk( # type: ignore[call-arg]
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=msg.usage_metadata,
|
||||
tool_call_chunks=[
|
||||
tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
index=None,
|
||||
)
|
||||
for tc in msg.tool_calls
|
||||
],
|
||||
),
|
||||
generation_info=msg.additional_kwargs,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
|
||||
return super().bind(functions=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||
then the model output will be an object of that class. If a dict then
|
||||
the model output will be a dict. With a Pydantic class the returned
|
||||
attributes will be validated, whereas with a dict they will not be. If
|
||||
`method` is "function_calling" and `schema` is a dict, then the dict
|
||||
must match the OpenAI function-calling spec.
|
||||
include_raw: If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
Returns:
|
||||
A Runnable that takes any ChatModel input and returns as output:
|
||||
|
||||
If include_raw is True then a dict with keys:
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
If include_raw is False then just _DictOrPydantic is returned,
|
||||
where _DictOrPydantic depends on the schema:
|
||||
|
||||
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||
class.
|
||||
|
||||
If schema is a dict then _DictOrPydantic is a dict.
|
||||
|
||||
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: Function-calling, dict schema (method="function_calling", include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_mistralai import QianfanChatEndpoint
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||
llm = QianfanChatEndpoint(endpoint="ernie-3.5-8k-0329")
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
llm = self.bind_tools([schema])
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True,
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
@@ -0,0 +1,337 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.chat_models.anthropic import (
|
||||
convert_messages_to_prompt_anthropic,
|
||||
)
|
||||
from langchain_community.chat_models.meta import convert_messages_to_prompt_llama
|
||||
from langchain_community.llms.bedrock import BedrockBase
|
||||
from langchain_community.utilities.anthropic import (
|
||||
get_num_tokens_anthropic,
|
||||
get_token_ids_anthropic,
|
||||
)
|
||||
|
||||
|
||||
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
|
||||
"""Convert a list of messages to a prompt for mistral."""
|
||||
return "\n".join(
|
||||
[_convert_one_message_to_text_mistral(message) for message in messages]
|
||||
)
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
to a dict for anthropic api
|
||||
|
||||
{
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRg...",
|
||||
}
|
||||
|
||||
And throws an error if it's not a b64 image
|
||||
"""
|
||||
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
||||
match = re.match(regex, image_url)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
"Anthropic only supports base64-encoded images currently."
|
||||
" Example: data:image/png;base64,'/9j/4AAQSk'..."
|
||||
)
|
||||
return {
|
||||
"type": "base64",
|
||||
"media_type": match.group("media_type"),
|
||||
"data": match.group("data"),
|
||||
}
|
||||
|
||||
|
||||
def _format_anthropic_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> Tuple[Optional[str], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
"""
|
||||
[
|
||||
{
|
||||
"role": _message_type_lookups[m.type],
|
||||
"content": [_AnthropicMessageContent(text=m.content).dict()],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
"""
|
||||
system: Optional[str] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
for i, message in enumerate(messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
if not isinstance(message.content, str):
|
||||
raise ValueError(
|
||||
"System message must be a string, "
|
||||
f"instead was: {type(message.content)}"
|
||||
)
|
||||
system = message.content
|
||||
continue
|
||||
|
||||
role = _message_type_lookups[message.type]
|
||||
content: Union[str, List[Dict]]
|
||||
|
||||
if not isinstance(message.content, str):
|
||||
# parse as dict
|
||||
assert isinstance(message.content, list), (
|
||||
"Anthropic message content must be str or list of dicts"
|
||||
)
|
||||
|
||||
# populate content
|
||||
content = []
|
||||
for item in message.content:
|
||||
if isinstance(item, str):
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": item,
|
||||
}
|
||||
)
|
||||
elif isinstance(item, dict):
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
if item["type"] == "image_url":
|
||||
# convert format
|
||||
source = _format_image(item["image_url"]["url"])
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content.append(item)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Content items must be str or dict, instead was: {type(item)}"
|
||||
)
|
||||
else:
|
||||
content = message.content
|
||||
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
return system, formatted_messages
|
||||
|
||||
|
||||
class ChatPromptAdapter:
|
||||
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||
that Chat model expects.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def convert_messages_to_prompt(
|
||||
cls, provider: str, messages: List[BaseMessage]
|
||||
) -> str:
|
||||
if provider == "anthropic":
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
elif provider == "meta":
|
||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||
elif provider == "mistral":
|
||||
prompt = convert_messages_to_prompt_mistral(messages=messages)
|
||||
elif provider == "amazon":
|
||||
prompt = convert_messages_to_prompt_anthropic(
|
||||
messages=messages,
|
||||
human_prompt="\n\nUser:",
|
||||
ai_prompt="\n\nBot:",
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} model does not support chat."
|
||||
)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def format_messages(
|
||||
cls, provider: str, messages: List[BaseMessage]
|
||||
) -> Tuple[Optional[str], List[Dict]]:
|
||||
if provider == "anthropic":
|
||||
return _format_anthropic_messages(messages)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} not supported for format_messages"
|
||||
)
|
||||
|
||||
|
||||
_message_type_lookups = {
|
||||
"human": "user",
|
||||
"ai": "assistant",
|
||||
"AIMessageChunk": "assistant",
|
||||
"HumanMessageChunk": "user",
|
||||
"function": "user",
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.34", removal="1.0", alternative_import="langchain_aws.ChatBedrock"
|
||||
)
|
||||
class BedrockChat(BaseChatModel, BedrockBase):
|
||||
"""Chat model that uses the Bedrock API."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "amazon_bedrock_chat"
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "bedrock"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.region_name:
|
||||
attributes["region_name"] = self.region_name
|
||||
|
||||
return attributes
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
provider = self._get_provider()
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
|
||||
if provider == "anthropic":
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
else:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
for chunk in self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
delta = chunk.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
completion = ""
|
||||
llm_output: Dict[str, Any] = {"model_id": self.model_id}
|
||||
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
|
||||
if provider == "anthropic":
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
else:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion, usage_info = self._prepare_input_and_invoke(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
system=system,
|
||||
messages=formatted_messages,
|
||||
**params,
|
||||
)
|
||||
|
||||
llm_output["usage"] = usage_info
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
final_usage: Dict[str, int] = defaultdict(int)
|
||||
final_output = {}
|
||||
for output in llm_outputs:
|
||||
output = output or {}
|
||||
usage = output.get("usage", {})
|
||||
for token_type, token_count in usage.items():
|
||||
final_usage[token_type] += token_count
|
||||
final_output.update(output)
|
||||
final_output["usage"] = final_usage
|
||||
return final_output
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
else:
|
||||
return super().get_num_tokens(text)
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
if self._model_is_anthropic:
|
||||
return get_token_ids_anthropic(text)
|
||||
else:
|
||||
return super().get_token_ids(text)
|
||||
@@ -0,0 +1,256 @@
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from langchain_classic.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Initialize logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
def _convert_messages_to_cloudflare_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert LangChain messages to Cloudflare Workers AI format."""
|
||||
cloudflare_messages = []
|
||||
msg: Dict[str, Any]
|
||||
for message in messages:
|
||||
# Base structure for each message
|
||||
msg = {
|
||||
"role": "",
|
||||
"content": message.content if isinstance(message.content, str) else "",
|
||||
}
|
||||
|
||||
# Determine role and additional fields based on message type
|
||||
if isinstance(message, HumanMessage):
|
||||
msg["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
msg["role"] = "assistant"
|
||||
# If the AIMessage includes tool calls, format them as needed
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
{"name": tool_call["name"], "arguments": tool_call["args"]}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
msg["tool_calls"] = tool_calls
|
||||
elif isinstance(message, SystemMessage):
|
||||
msg["role"] = "system"
|
||||
elif isinstance(message, ToolMessage):
|
||||
msg["role"] = "tool"
|
||||
msg["tool_call_id"] = (
|
||||
message.tool_call_id
|
||||
) # Use tool_call_id if it's a ToolMessage
|
||||
|
||||
# Add the formatted message to the list
|
||||
cloudflare_messages.append(msg)
|
||||
|
||||
return cloudflare_messages
|
||||
|
||||
|
||||
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
|
||||
"""Get tool calls from ollama response."""
|
||||
tool_calls = []
|
||||
if "tool_calls" in response.json()["result"]:
|
||||
for tc in response.json()["result"]["tool_calls"]:
|
||||
tool_calls.append(
|
||||
tool_call(
|
||||
id=str(uuid4()),
|
||||
name=tc["name"],
|
||||
args=tc["arguments"],
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.23",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_cloudflare.ChatCloudflareWorkersAI",
|
||||
)
|
||||
class ChatCloudflareWorkersAI(BaseChatModel):
|
||||
"""Custom chat model for Cloudflare Workers AI"""
|
||||
|
||||
account_id: str = Field(...)
|
||||
api_token: str = Field(...)
|
||||
model: str = Field(...)
|
||||
ai_gateway: str = ""
|
||||
url: str = ""
|
||||
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
||||
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize with necessary credentials."""
|
||||
super().__init__(**kwargs)
|
||||
if self.ai_gateway:
|
||||
self.url = (
|
||||
f"{self.gateway_url}/{self.account_id}/"
|
||||
f"{self.ai_gateway}/workers-ai/run/{self.model}"
|
||||
)
|
||||
else:
|
||||
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response based on the messages provided."""
|
||||
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
prompt = "\n".join(
|
||||
f"role: {msg['role']}, content: {msg['content']}"
|
||||
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
|
||||
+ (
|
||||
f", tool_call_id: {msg['tool_call_id']}"
|
||||
if "tool_call_id" in msg
|
||||
else ""
|
||||
)
|
||||
for msg in formatted_messages
|
||||
)
|
||||
|
||||
# Initialize `data` with `prompt`
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"tools": kwargs["tools"] if "tools" in kwargs else None,
|
||||
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
|
||||
}
|
||||
|
||||
# Ensure `tools` is a list if it's included in `kwargs`
|
||||
if data["tools"] is not None and not isinstance(data["tools"], list):
|
||||
data["tools"] = [data["tools"]]
|
||||
|
||||
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
|
||||
|
||||
response = requests.post(self.url, headers=headers, json=data)
|
||||
tool_calls = _get_tool_calls_from_response(response)
|
||||
ai_message = AIMessage(
|
||||
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
|
||||
)
|
||||
chat_generation = ChatGeneration(message=ai_message)
|
||||
return ChatResult(generations=[chat_generation])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tools for use in model generation."""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema."""
|
||||
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "json_schema":
|
||||
# Some applications require that incompatible parameters (e.g., unsupported
|
||||
# methods) be handled.
|
||||
method = "function_calling"
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True,
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return the type of the LLM (for Langchain compatibility)."""
|
||||
return "cloudflare-workers-ai"
|
||||
251
venv/Lib/site-packages/langchain_community/chat_models/cohere.py
Normal file
251
venv/Lib/site-packages/langchain_community/chat_models/cohere.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_community.llms.cohere import BaseCohere
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
"""Get the role of the message.
|
||||
|
||||
Args:
|
||||
message: The message.
|
||||
|
||||
Returns:
|
||||
The role of the message.
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is of an unknown type.
|
||||
"""
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "User"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "Chatbot"
|
||||
elif isinstance(message, SystemMessage):
|
||||
return "System"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def get_cohere_chat_request(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
connectors: Optional[List[Dict[str, str]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the request for the Cohere chat API.
|
||||
|
||||
Args:
|
||||
messages: The messages.
|
||||
connectors: The connectors.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
The request for the Cohere chat API.
|
||||
"""
|
||||
documents = (
|
||||
None
|
||||
if "source_documents" not in kwargs
|
||||
else [
|
||||
{
|
||||
"snippet": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(kwargs["source_documents"])
|
||||
]
|
||||
)
|
||||
kwargs.pop("source_documents", None)
|
||||
maybe_connectors = connectors if documents is None else None
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
prompt_truncation = (
|
||||
"AUTO" if documents is not None or connectors is not None else None
|
||||
)
|
||||
|
||||
req = {
|
||||
"message": messages[-1].content,
|
||||
"chat_history": [
|
||||
{"role": get_role(x), "message": x.content} for x in messages[:-1]
|
||||
],
|
||||
"documents": documents,
|
||||
"connectors": maybe_connectors,
|
||||
"prompt_truncation": prompt_truncation,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
return {k: v for k, v in req.items() if v is not None}
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.30", removal="1.0", alternative_import="langchain_cohere.ChatCohere"
|
||||
)
|
||||
class ChatCohere(BaseChatModel, BaseCohere):
|
||||
"""`Cohere` chat large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
environment variable ``COHERE_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatCohere
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
chat = ChatCohere(max_tokens=256, temperature=0.75)
|
||||
|
||||
messages = [HumanMessage(content="knock knock")]
|
||||
chat.invoke(messages)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "cohere-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
|
||||
if hasattr(self.client, "chat_stream"): # detect and support sdk v5
|
||||
stream = self.client.chat_stream(**request)
|
||||
else:
|
||||
stream = self.client.chat(**request, stream=True)
|
||||
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
|
||||
if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
|
||||
stream = await self.async_client.chat_stream(**request)
|
||||
else:
|
||||
stream = await self.async_client.chat(**request, stream=True)
|
||||
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||
"""Get the generation info from cohere API response."""
|
||||
return {
|
||||
"documents": response.documents,
|
||||
"citations": response.citations,
|
||||
"search_results": response.search_results,
|
||||
"search_queries": response.search_queries,
|
||||
"token_count": response.token_count,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
return len(self.client.tokenize(text=text).tokens)
|
||||
255
venv/Lib/site-packages/langchain_community/chat_models/coze.py
Normal file
255
venv/Lib/site-packages/langchain_community/chat_models/coze.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE = "https://api.coze.com"
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": message.content,
|
||||
"content_type": "text",
|
||||
}
|
||||
else:
|
||||
message_dict = {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"content_type": "text",
|
||||
}
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
|
||||
msg_type = _dict["type"]
|
||||
if msg_type != "answer":
|
||||
return None
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("content", "") or "")
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessageChunk(content=content)
|
||||
else:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class ChatCoze(BaseChatModel):
|
||||
"""ChatCoze chat models API by coze.com
|
||||
|
||||
For more information, see https://www.coze.com/open/docs/chat
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"coze_api_key": "COZE_API_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
coze_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||
"""Coze custom endpoints"""
|
||||
coze_api_key: Optional[SecretStr] = None
|
||||
"""Coze API Key"""
|
||||
request_timeout: int = Field(default=60, alias="timeout")
|
||||
"""request timeout for chat http requests"""
|
||||
bot_id: str = Field(default="")
|
||||
"""The ID of the bot that the API interacts with."""
|
||||
conversation_id: str = Field(default="")
|
||||
"""Indicate which conversation the dialog is taking place in. If there is no need to
|
||||
distinguish the context of the conversation(just a question and answer), skip this
|
||||
parameter. It will be generated by the system."""
|
||||
user: str = Field(default="")
|
||||
"""The user who calls the API to chat with the bot."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response to the client.
|
||||
false: if no value is specified or set to false, a non-streaming response is
|
||||
returned. "Non-streaming response" means that all responses will be returned at once
|
||||
after they are all ready, and the client does not need to concatenate the content.
|
||||
true: set to true, partial message deltas will be sent .
|
||||
"Streaming response" will provide real-time response of the model to the client, and
|
||||
the client needs to assemble the final reply based on the type of message. """
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["coze_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"coze_api_base",
|
||||
"COZE_API_BASE",
|
||||
DEFAULT_API_BASE,
|
||||
)
|
||||
values["coze_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"coze_api_key",
|
||||
"COZE_API_KEY",
|
||||
)
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Coze API."""
|
||||
return {
|
||||
"bot_id": self.bot_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"user": self.user,
|
||||
"streaming": self.streaming,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
r = self._chat(messages, **kwargs)
|
||||
res = r.json()
|
||||
if res["code"] != 0:
|
||||
raise ValueError(
|
||||
f"Error from Coze api response: {res['code']}: {res['msg']}, "
|
||||
f"logid: {r.headers.get('X-Tt-Logid')}"
|
||||
)
|
||||
|
||||
return self._create_chat_result(res.get("messages") or [])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, **kwargs)
|
||||
for chunk in res.iter_lines():
|
||||
chunk = chunk.decode("utf-8").strip("\r\n")
|
||||
parts = chunk.split("data:", 1)
|
||||
chunk = parts[1] if len(parts) > 1 else None
|
||||
if chunk is None:
|
||||
continue
|
||||
response = json.loads(chunk)
|
||||
if response["event"] == "done":
|
||||
break
|
||||
elif (
|
||||
response["event"] != "message"
|
||||
or response["message"]["type"] != "answer"
|
||||
):
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(response["message"])
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
|
||||
query = ""
|
||||
chat_history = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
query = f"{msg.content}" # overwrite, to get last user message as query
|
||||
chat_history.append(_convert_message_to_dict(msg))
|
||||
|
||||
conversation_id = parameters.pop("conversation_id")
|
||||
bot_id = parameters.pop("bot_id")
|
||||
user = parameters.pop("user")
|
||||
streaming = parameters.pop("streaming")
|
||||
|
||||
payload = {
|
||||
"conversation_id": conversation_id,
|
||||
"bot_id": bot_id,
|
||||
"user": user,
|
||||
"query": query,
|
||||
"stream": streaming,
|
||||
}
|
||||
if chat_history:
|
||||
payload["chat_history"] = chat_history
|
||||
|
||||
url = self.coze_api_base + "/open_api/v2/chat"
|
||||
api_key = ""
|
||||
if self.coze_api_key:
|
||||
api_key = self.coze_api_key.get_secret_value()
|
||||
|
||||
res = requests.post(
|
||||
url=url,
|
||||
timeout=self.request_timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
json=payload,
|
||||
stream=streaming,
|
||||
)
|
||||
if res.status_code != 200:
|
||||
logid = res.headers.get("X-Tt-Logid")
|
||||
raise ValueError(f"Error from Coze api response: {res}, logid: {logid}")
|
||||
return res
|
||||
|
||||
def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult:
|
||||
generations = []
|
||||
for c in messages:
|
||||
msg = _convert_dict_to_message(c)
|
||||
if msg:
|
||||
generations.append(ChatGeneration(message=msg))
|
||||
|
||||
llm_output = {"token_usage": "", "model": ""}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "coze-chat"
|
||||
@@ -0,0 +1,161 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
|
||||
def _format_dappier_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == "human":
|
||||
formatted_messages.append({"role": "user", "content": message.content})
|
||||
elif message.type == "system":
|
||||
formatted_messages.append({"role": "system", "content": message.content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
class ChatDappierAI(BaseChatModel):
|
||||
"""`Dappier` chat large language models.
|
||||
|
||||
`Dappier` is a platform enabling access to diverse, real-time data models.
|
||||
Enhance your AI applications with Dappier's pre-trained, LLM-ready data models
|
||||
and ensure accurate, current responses with reduced inaccuracies.
|
||||
|
||||
To use one of our Dappier AI Data Models, you will need an API key.
|
||||
Please visit Dappier Platform (https://platform.dappier.com/) to log in
|
||||
and create an API key in your profile.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatDappierAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
# Initialize `ChatDappierAI` with the desired configuration
|
||||
chat = ChatDappierAI(
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodel/dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
dappier_api_key="<YOUR_KEY>")
|
||||
|
||||
# Create a list of messages to interact with the model
|
||||
messages = [HumanMessage(content="hello")]
|
||||
|
||||
# Invoke the model with the provided messages
|
||||
chat.invoke(messages)
|
||||
|
||||
|
||||
you can find more details here : https://docs.dappier.com/introduction"""
|
||||
|
||||
dappier_endpoint: str = "https://api.dappier.com/app/datamodelconversation"
|
||||
|
||||
dappier_model: str = "dm_01hpsxyfm2fwdt2zet9cg6fdxt"
|
||||
|
||||
dappier_api_key: Optional[SecretStr] = Field(None, description="Dappier API Token")
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["dappier_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "dappier_api_key", "DAPPIER_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_community import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "dappier-realtimesearch-chat"
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.dappier_api_key:
|
||||
return self.dappier_api_key.get_secret_value()
|
||||
return ""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
url = f"{self.dappier_endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
user_query = _format_dappier_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.dappier_model,
|
||||
"conversation": user_query,
|
||||
}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
message_response = data["message"]
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=message_response))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
url = f"{self.dappier_endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
user_query = _format_dappier_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.dappier_model,
|
||||
"conversation": user_query,
|
||||
}
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
message_response = data["message"]
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage(content=message_response))
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
from langchain_community.chat_models.mlflow import ChatMlflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.3",
|
||||
removal="1.0",
|
||||
alternative_import="databricks_langchain.ChatDatabricks",
|
||||
)
|
||||
class ChatDatabricks(ChatMlflow):
|
||||
"""`Databricks` chat models API.
|
||||
|
||||
To use, you should have the ``mlflow`` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatDatabricks
|
||||
|
||||
chat_model = ChatDatabricks(
|
||||
target_uri="databricks",
|
||||
endpoint="databricks-llama-2-70b-chat",
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# single input invocation
|
||||
print(chat_model.invoke("What is MLflow?").content)
|
||||
|
||||
# single input invocation with streaming response
|
||||
for chunk in chat_model.stream("What is MLflow?"):
|
||||
print(chunk.content, end="|")
|
||||
"""
|
||||
|
||||
target_uri: str = "databricks"
|
||||
"""The target URI to use. Defaults to ``databricks``."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "databricks-chat"
|
||||
|
||||
@property
|
||||
def _mlflow_extras(self) -> str:
|
||||
return ""
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
|
||||
if urlparse(self.target_uri).scheme != "databricks":
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
||||
@@ -0,0 +1,546 @@
|
||||
"""deepinfra.com chat models wrapper"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDeepInfraException(Exception):
|
||||
"""Exception raised when the DeepInfra API returns an error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatDeepInfra,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions."""
|
||||
return create_base_retry_decorator(
|
||||
error_types=[requests.exceptions.ConnectTimeout, ChatDeepInfraException],
|
||||
max_retries=llm.max_retries,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
|
||||
def _parse_tool_calling(tool_call: dict) -> ToolCall:
|
||||
"""
|
||||
Convert a tool calling response from server to a ToolCall object.
|
||||
Args:
|
||||
tool_call:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
name = tool_call["function"].get("name", "")
|
||||
try:
|
||||
args = json.loads(tool_call["function"]["arguments"])
|
||||
except (JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
id = tool_call.get("id")
|
||||
return create_tool_call(name=name, args=args, id=id)
|
||||
|
||||
|
||||
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a ToolCall object to a tool calling request for server.
|
||||
Args:
|
||||
tool_call:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
"name": tool_call["name"],
|
||||
},
|
||||
"id": tool_call.get("id"),
|
||||
}
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
content = _dict.get("content", "") or ""
|
||||
tool_calls_content = _dict.get("tool_calls", []) or []
|
||||
tool_calls = [
|
||||
_parse_tool_calling(tool_call) for tool_call in tool_calls_content
|
||||
]
|
||||
return AIMessage(content=content, tool_calls=tool_calls)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
tool_calls = _dict.get("tool_calls") or []
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
tool_calls = [_parse_tool_calling(tool_call) for tool_call in tool_calls]
|
||||
return AIMessageChunk(content=content, tool_calls=tool_calls)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
tool_calls = [
|
||||
_convert_to_tool_calling(tool_call) for tool_call in message.tool_calls
|
||||
]
|
||||
message_dict = {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": tool_calls, # type: ignore[dict-item]
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"name": message.name, # type: ignore[dict-item]
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class ChatDeepInfra(BaseChatModel):
|
||||
"""A chat model that uses the DeepInfra API."""
|
||||
|
||||
# client: Any #: :meta private:
|
||||
model_name: str = Field(default="meta-llama/Llama-2-70b-chat-hf", alias="model")
|
||||
"""Model name to use."""
|
||||
|
||||
url: str = "https://api.deepinfra.com/v1/openai/chat/completions"
|
||||
"""URL to use for the API call."""
|
||||
|
||||
deepinfra_api_token: Optional[str] = None
|
||||
request_timeout: Optional[float] = Field(default=None, alias="timeout")
|
||||
temperature: Optional[float] = 1
|
||||
"""Run inference with this temperature. Must be in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
max_tokens: int = 256
|
||||
streaming: bool = False
|
||||
max_retries: int = 1
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"request_timeout": self.request_timeout,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the openai client."""
|
||||
return {**self._default_params}
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
try:
|
||||
request_timeout = kwargs.pop("request_timeout")
|
||||
request = Requests(headers=self._headers())
|
||||
response = request.post(
|
||||
url=self._url(), data=self._body(kwargs), timeout=request_timeout
|
||||
)
|
||||
self._handle_status(response.status_code, response.text)
|
||||
return response
|
||||
except Exception as e:
|
||||
print("EX", e) # noqa: T201
|
||||
raise
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
async def acompletion_with_retry(
|
||||
self,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
try:
|
||||
request_timeout = kwargs.pop("request_timeout")
|
||||
request = Requests(headers=self._headers())
|
||||
async with request.apost(
|
||||
url=self._url(), data=self._body(kwargs), timeout=request_timeout
|
||||
) as response:
|
||||
self._handle_status(response.status, await response.text())
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
print("EX", e) # noqa: T201
|
||||
raise
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def init_defaults(cls, values: Dict) -> Any:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
# For compatibility with LiteLLM
|
||||
api_key = get_from_dict_or_env(
|
||||
values,
|
||||
"deepinfra_api_key",
|
||||
"DEEPINFRA_API_KEY",
|
||||
default="",
|
||||
)
|
||||
values["deepinfra_api_token"] = get_from_dict_or_env(
|
||||
values,
|
||||
"deepinfra_api_token",
|
||||
"DEEPINFRA_API_TOKEN",
|
||||
default=api_key,
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if self.top_k is not None and self.top_k <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
return self
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response.json())
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model": self.model_name}
|
||||
res = ChatResult(generations=generations, llm_output=llm_output)
|
||||
return res
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
for line in _parse_stream(response.iter_lines()):
|
||||
chunk = _handle_sse_line(line)
|
||||
if chunk:
|
||||
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {"messages": message_dicts, "stream": True, **params, **kwargs}
|
||||
|
||||
request_timeout = params.pop("request_timeout")
|
||||
request = Requests(headers=self._headers())
|
||||
async with request.apost(
|
||||
url=self._url(), data=self._body(params), timeout=request_timeout
|
||||
) as response:
|
||||
async for line in _parse_stream_async(response.content):
|
||||
chunk = _handle_sse_line(line)
|
||||
if chunk:
|
||||
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
str(chunk.content), chunk=cg_chunk
|
||||
)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {"messages": message_dicts, **params, **kwargs}
|
||||
|
||||
res = await self.acompletion_with_retry(run_manager=run_manager, **params)
|
||||
return self._create_chat_result(res)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "deepinfra-chat"
|
||||
|
||||
def _handle_status(self, code: int, text: Any) -> None:
|
||||
if code >= 500:
|
||||
raise ChatDeepInfraException(
|
||||
f"DeepInfra Server error status {code}: {text}"
|
||||
)
|
||||
elif code >= 400:
|
||||
raise ValueError(f"DeepInfra received an invalid payload: {text}")
|
||||
elif code != 200:
|
||||
raise Exception(
|
||||
f"DeepInfra returned an unexpected response with status {code}: {text}"
|
||||
)
|
||||
|
||||
def _url(self) -> str:
|
||||
return self.url
|
||||
|
||||
def _headers(self) -> Dict:
|
||||
return {
|
||||
"Authorization": f"bearer {self.deepinfra_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _body(self, kwargs: Any) -> Dict:
|
||||
return kwargs
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
|
||||
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||
for line in rbody:
|
||||
_line = _parse_stream_helper(line)
|
||||
if _line is not None:
|
||||
yield _line
|
||||
|
||||
|
||||
async def _parse_stream_async(rbody: aiohttp.StreamReader) -> AsyncIterator[str]:
|
||||
async for line in rbody:
|
||||
_line = _parse_stream_helper(line)
|
||||
if _line is not None:
|
||||
yield _line
|
||||
|
||||
|
||||
def _parse_stream_helper(line: bytes) -> Optional[str]:
|
||||
if line and line.startswith(b"data:"):
|
||||
if line.startswith(b"data: "):
|
||||
# SSE event may be valid when it contain whitespace
|
||||
line = line[len(b"data: ") :]
|
||||
else:
|
||||
line = line[len(b"data:") :]
|
||||
if line.strip() == b"[DONE]":
|
||||
# return here will cause GeneratorExit exception in urllib3
|
||||
# and it will close http connection with TCP Reset
|
||||
return None
|
||||
else:
|
||||
return line.decode("utf-8")
|
||||
return None
|
||||
|
||||
|
||||
def _handle_sse_line(line: str) -> Optional[BaseMessageChunk]:
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
default_chunk_class = AIMessageChunk
|
||||
delta = obj.get("choices", [{}])[0].get("delta", {})
|
||||
return _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
except Exception:
|
||||
return None
|
||||
627
venv/Lib/site-packages/langchain_community/chat_models/edenai.py
Normal file
627
venv/Lib/site-packages/langchain_community/chat_models/edenai.py
Normal file
@@ -0,0 +1,627 @@
|
||||
import json
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
|
||||
def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk:
|
||||
message = generated_result.generations[0].message
|
||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||
tool_call_chunks = [
|
||||
create_tool_call_chunk(
|
||||
name=tool_call["name"],
|
||||
args=json.dumps(tool_call["args"]),
|
||||
id=tool_call["id"],
|
||||
index=idx,
|
||||
)
|
||||
for idx, tool_call in enumerate(message.tool_calls)
|
||||
]
|
||||
message_chunk = AIMessageChunk(
|
||||
content=message.content,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
return ChatGenerationChunk(message=message_chunk)
|
||||
else:
|
||||
return cast(ChatGenerationChunk, generated_result.generations[0])
|
||||
|
||||
|
||||
def _message_role(type: str) -> str:
|
||||
role_mapping = {
|
||||
"ai": "assistant",
|
||||
"human": "user",
|
||||
"chat": "user",
|
||||
"AIMessageChunk": "assistant",
|
||||
}
|
||||
|
||||
if type in role_mapping:
|
||||
return role_mapping[type]
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type}")
|
||||
|
||||
|
||||
def _extract_edenai_tool_results_from_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]:
|
||||
"""
|
||||
Get the last langchain tools messages to transform them into edenai tool_results
|
||||
Returns tool_results and messages without the extracted tool messages
|
||||
"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
other_messages = messages[:]
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_results = [
|
||||
{"id": msg.tool_call_id, "result": msg.content},
|
||||
*tool_results,
|
||||
]
|
||||
other_messages.pop()
|
||||
else:
|
||||
break
|
||||
return tool_results, other_messages
|
||||
|
||||
|
||||
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
system = None
|
||||
formatted_messages = []
|
||||
|
||||
human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
|
||||
last_human_message = human_messages[-1] if human_messages else ""
|
||||
|
||||
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
||||
for i, message in enumerate(other_messages):
|
||||
if isinstance(message, SystemMessage):
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
system = message.content
|
||||
elif isinstance(message, ToolMessage):
|
||||
formatted_messages.append({"role": "tool", "message": message.content})
|
||||
elif message != last_human_message:
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": _message_role(message.type),
|
||||
"message": message.content,
|
||||
"tool_calls": _format_tool_calls_to_edenai_tool_calls(message),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"text": getattr(last_human_message, "content", ""),
|
||||
"previous_history": formatted_messages,
|
||||
"chatbot_global_action": system,
|
||||
"tool_results": tool_results,
|
||||
}
|
||||
|
||||
|
||||
def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List:
|
||||
tool_calls = getattr(message, "tool_calls", [])
|
||||
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
|
||||
edenai_tool_calls = []
|
||||
|
||||
for invalid_tool_call in invalid_tool_calls:
|
||||
edenai_tool_calls.append(
|
||||
{
|
||||
"arguments": invalid_tool_call.get("args"),
|
||||
"id": invalid_tool_call.get("id"),
|
||||
"name": invalid_tool_call.get("name"),
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_args = tool_call.get("args", {})
|
||||
try:
|
||||
arguments = json.dumps(tool_args)
|
||||
except TypeError:
|
||||
arguments = str(tool_args)
|
||||
edenai_tool_calls.append(
|
||||
{
|
||||
"arguments": arguments,
|
||||
"id": tool_call["id"],
|
||||
"name": tool_call["name"],
|
||||
}
|
||||
)
|
||||
return edenai_tool_calls
|
||||
|
||||
|
||||
def _extract_tool_calls_from_edenai_response(
|
||||
provider_response: Dict[str, Any],
|
||||
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
|
||||
message = provider_response.get("message", {})[1]
|
||||
|
||||
if raw_tool_calls := message.get("tool_calls"):
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(
|
||||
create_tool_call(
|
||||
name=raw_tool_call["name"],
|
||||
args=json.loads(raw_tool_call["arguments"]),
|
||||
id=raw_tool_call["id"],
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
invalid_tool_calls.append(
|
||||
create_invalid_tool_call(
|
||||
name=raw_tool_call.get("name"),
|
||||
args=raw_tool_call.get("arguments"),
|
||||
id=raw_tool_call.get("id"),
|
||||
error=f"Received JSONDecodeError {exc}",
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls, invalid_tool_calls
|
||||
|
||||
|
||||
class ChatEdenAI(BaseChatModel):
|
||||
"""`EdenAI` chat large language models.
|
||||
|
||||
`EdenAI` is a versatile platform that allows you to access various language models
|
||||
from different providers such as Google, OpenAI, Cohere, Mistral and more.
|
||||
|
||||
To get started, make sure you have the environment variable ``EDENAI_API_KEY``
|
||||
set with your API key, or pass it as a named parameter to the constructor.
|
||||
|
||||
Additionally, `EdenAI` provides the flexibility to choose from a variety of models,
|
||||
including the ones like "gpt-4".
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatEdenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
# Initialize `ChatEdenAI` with the desired configuration
|
||||
chat = ChatEdenAI(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
max_tokens=256,
|
||||
temperature=0.75)
|
||||
|
||||
# Create a list of messages to interact with the model
|
||||
messages = [HumanMessage(content="hello")]
|
||||
|
||||
# Invoke the model with the provided messages
|
||||
chat.invoke(messages)
|
||||
|
||||
`EdenAI` goes beyond mere model invocation. It empowers you with advanced features :
|
||||
|
||||
- **Multiple Providers**: access to a diverse range of llms offered by various
|
||||
providers giving you the freedom to choose the best-suited model for your use case.
|
||||
|
||||
- **Fallback Mechanism**: Set a fallback mechanism to ensure seamless operations
|
||||
even if the primary provider is unavailable, you can easily switches to an
|
||||
alternative provider.
|
||||
|
||||
- **Usage Statistics**: Track usage statistics on a per-project
|
||||
and per-API key basis.
|
||||
This feature allows you to monitor and manage resource consumption effectively.
|
||||
|
||||
- **Monitoring and Observability**: `EdenAI` provides comprehensive monitoring
|
||||
and observability tools on the platform.
|
||||
|
||||
Example of setting up a fallback mechanism:
|
||||
.. code-block:: python
|
||||
|
||||
# Initialize `ChatEdenAI` with a fallback provider
|
||||
chat_with_fallback = ChatEdenAI(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
max_tokens=256,
|
||||
temperature=0.75,
|
||||
fallback_provider="google")
|
||||
|
||||
you can find more details here : https://docs.edenai.co/reference/text_chat_create
|
||||
"""
|
||||
|
||||
provider: str = "openai"
|
||||
"""chat provider to use (eg: openai,google etc.)"""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""
|
||||
model name for above provider (eg: 'gpt-4' for openai)
|
||||
available models are shown on https://docs.edenai.co/ under 'available providers'
|
||||
"""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = 0
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
fallback_providers: Optional[str] = None
|
||||
"""Providers in this will be used as fallback if the call to provider fails."""
|
||||
|
||||
edenai_api_url: str = "https://api.edenai.run/v2"
|
||||
|
||||
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_community import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "edenai-chat"
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.edenai_api_key:
|
||||
return self.edenai_api_key.get_secret_value()
|
||||
return ""
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Call out to EdenAI's chat endpoint."""
|
||||
if "available_tools" in kwargs:
|
||||
yield self._stream_with_tools_as_generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return
|
||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"fallback_providers": self.fallback_providers,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
for chunk_response in response.iter_lines():
|
||||
chunk = json.loads(chunk_response.decode())
|
||||
token = chunk["text"]
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
if "available_tools" in kwargs:
|
||||
yield await self._astream_with_tools_as_agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return
|
||||
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"fallback_providers": self.fallback_providers,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk_response in response.content:
|
||||
chunk = json.loads(chunk_response.decode())
|
||||
token = chunk["text"]
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=token)
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk["text"], chunk=cg_chunk
|
||||
)
|
||||
yield cg_chunk
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
|
||||
formatted_tool_choice = "required" if tool_choice == "any" else tool_choice
|
||||
return super().bind(
|
||||
available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs
|
||||
)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
llm = self.bind_tools([schema], tool_choice="required")
|
||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], first_tool_only=True
|
||||
)
|
||||
else:
|
||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to EdenAI's chat endpoint."""
|
||||
if self.streaming:
|
||||
if "available_tools" in kwargs:
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
url = f"{self.edenai_api_url}/text/chat"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"fallback_providers": self.fallback_providers,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
provider_response = data[self.provider]
|
||||
|
||||
if self.fallback_providers:
|
||||
fallback_response = data.get(self.fallback_providers)
|
||||
if fallback_response:
|
||||
provider_response = fallback_response
|
||||
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response(
|
||||
provider_response
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=provider_response["generated_text"] or "",
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
)
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
if "available_tools" in kwargs:
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
url = f"{self.edenai_api_url}/text/chat"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
formatted_data = _format_edenai_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"providers": self.provider,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"fallback_providers": self.fallback_providers,
|
||||
**formatted_data,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if self.model is not None:
|
||||
payload["settings"] = {self.provider: self.model}
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
provider_response = data[self.provider]
|
||||
|
||||
if self.fallback_providers:
|
||||
fallback_response = data.get(self.fallback_providers)
|
||||
if fallback_response:
|
||||
provider_response = fallback_response
|
||||
|
||||
if provider_response.get("status") == "fail":
|
||||
err_msg = provider_response.get("error", {}).get("message")
|
||||
raise Exception(err_msg)
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=provider_response["generated_text"]
|
||||
)
|
||||
)
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
||||
|
||||
def _stream_with_tools_as_generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[CallbackManagerForLLMRun],
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
return _result_to_chunked_message(result)
|
||||
|
||||
async def _astream_with_tools_as_agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun],
|
||||
**kwargs: Any,
|
||||
) -> ChatGenerationChunk:
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _result_to_chunked_message(result)
|
||||
229
venv/Lib/site-packages/langchain_community/chat_models/ernie.py
Normal file
229
venv/Lib/site-packages/langchain_community/chat_models/ernie.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.13",
|
||||
alternative="langchain_community.chat_models.QianfanChatEndpoint",
|
||||
)
|
||||
class ErnieBotChat(BaseChatModel):
|
||||
"""`ERNIE-Bot` large language model.
|
||||
|
||||
ERNIE-Bot is a large language model developed by Baidu,
|
||||
covering a huge amount of Chinese data.
|
||||
|
||||
To use, you should have the `ernie_client_id` and `ernie_client_secret` set,
|
||||
or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`.
|
||||
|
||||
Note:
|
||||
access_token will be automatically generated based on client_id and client_secret,
|
||||
and will be regenerated after expiration (30 days).
|
||||
|
||||
Default model is `ERNIE-Bot-turbo`,
|
||||
currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot`, `ERNIE-Bot-8K`,
|
||||
`ERNIE-Bot-4`, `ERNIE-Bot-turbo-AI`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
chat = ErnieBotChat(model_name='ERNIE-Bot')
|
||||
|
||||
|
||||
Deprecated Note:
|
||||
Please use `QianfanChatEndpoint` instead of this class.
|
||||
`QianfanChatEndpoint` is a more suitable choice for production.
|
||||
|
||||
Always test your code after changing to `QianfanChatEndpoint`.
|
||||
|
||||
Example of `QianfanChatEndpoint`:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
|
||||
|
||||
"""
|
||||
|
||||
ernie_api_base: Optional[str] = None
|
||||
"""Baidu application custom endpoints"""
|
||||
|
||||
ernie_client_id: Optional[str] = None
|
||||
"""Baidu application client id"""
|
||||
|
||||
ernie_client_secret: Optional[str] = None
|
||||
"""Baidu application client secret"""
|
||||
|
||||
access_token: Optional[str] = None
|
||||
"""access token is generated by client id and client secret,
|
||||
setting this value directly will cause an error"""
|
||||
|
||||
model_name: str = "ERNIE-Bot-turbo"
|
||||
"""model name of ernie, default is `ERNIE-Bot-turbo`.
|
||||
Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`"""
|
||||
|
||||
system: Optional[str] = None
|
||||
"""system is mainly used for model character design,
|
||||
for example, you are an AI assistant produced by xxx company.
|
||||
The length of the system is limiting of 1024 characters."""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""streaming mode. not supported yet."""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["ernie_api_base"] = get_from_dict_or_env(
|
||||
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
||||
)
|
||||
values["ernie_client_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_id",
|
||||
"ERNIE_CLIENT_ID",
|
||||
)
|
||||
values["ernie_client_secret"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_secret",
|
||||
"ERNIE_CLIENT_SECRET",
|
||||
)
|
||||
return values
|
||||
|
||||
def _chat(self, payload: object) -> dict:
|
||||
base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
|
||||
model_paths = {
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-8K": "ernie_bot_8k",
|
||||
"ERNIE-Bot-4": "completions_pro",
|
||||
"ERNIE-Bot-turbo-AI": "ai_apaas",
|
||||
"BLOOMZ-7B": "bloomz_7b1",
|
||||
"Llama-2-7b-chat": "llama_2_7b",
|
||||
"Llama-2-13b-chat": "llama_2_13b",
|
||||
"Llama-2-70b-chat": "llama_2_70b",
|
||||
}
|
||||
if self.model_name in model_paths:
|
||||
url = f"{base_url}/{model_paths[self.model_name]}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown model_name {self.model_name}")
|
||||
|
||||
resp = requests.post(
|
||||
url,
|
||||
timeout=self.request_timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
params={"access_token": self.access_token},
|
||||
json=payload,
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
def _refresh_access_token_with_lock(self) -> None:
|
||||
with self._lock:
|
||||
logger.debug("Refreshing access token")
|
||||
base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
|
||||
resp = requests.post(
|
||||
base_url,
|
||||
timeout=10,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.ernie_client_id,
|
||||
"client_secret": self.ernie_client_secret,
|
||||
},
|
||||
)
|
||||
self.access_token = str(resp.json().get("access_token"))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
raise ValueError("`streaming` option currently unsupported.")
|
||||
|
||||
if not self.access_token:
|
||||
self._refresh_access_token_with_lock()
|
||||
payload = {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
"system": self.system,
|
||||
**kwargs,
|
||||
}
|
||||
logger.debug(f"Payload for ernie api is {payload}")
|
||||
resp = self._chat(payload)
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
logger.debug("access_token expired, refresh it")
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._chat(payload)
|
||||
else:
|
||||
raise ValueError(f"Error from ErnieChat api response: {resp}")
|
||||
return self._create_chat_result(resp)
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
if "function_call" in response:
|
||||
additional_kwargs = {
|
||||
"function_call": dict(response.get("function_call", {}))
|
||||
}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
generations = [
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content=response.get("result", ""),
|
||||
additional_kwargs={**additional_kwargs},
|
||||
)
|
||||
)
|
||||
]
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ernie-bot-chat"
|
||||
@@ -0,0 +1,185 @@
|
||||
"""EverlyAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
from langchain_community.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_API_BASE = "https://everlyai.xyz/hosted"
|
||||
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
||||
|
||||
|
||||
class ChatEverlyAI(ChatOpenAI):
|
||||
"""`EverlyAI` Chat large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``EVERLYAI_API_KEY`` set with your API key.
|
||||
Alternatively, you can use the everlyai_api_key keyword argument.
|
||||
|
||||
Any parameters that are valid to be passed to the `openai.create` call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatEverlyAI
|
||||
chat = ChatEverlyAI(model_name="meta-llama/Llama-2-7b-chat-hf")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "everlyai-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"everlyai_api_key": "EVERLYAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
everlyai_api_key: Optional[str] = None
|
||||
"""EverlyAI Endpoints API keys."""
|
||||
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
|
||||
"""Model name to use."""
|
||||
everlyai_api_base: str = DEFAULT_API_BASE
|
||||
"""Base URL path for API requests."""
|
||||
available_models: Optional[Set[str]] = None
|
||||
"""Available models from EverlyAI API."""
|
||||
|
||||
@staticmethod
|
||||
def get_available_models() -> Set[str]:
|
||||
"""Get available models from EverlyAI API."""
|
||||
# EverlyAI doesn't yet support dynamically query for available models.
|
||||
return set(
|
||||
[
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-13b-chat-hf-quantized",
|
||||
]
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment_override(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"everlyai_api_key",
|
||||
"EVERLYAI_API_KEY",
|
||||
)
|
||||
)
|
||||
values["openai_api_base"] = DEFAULT_API_BASE
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`.",
|
||||
) from e
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError as exc:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`.",
|
||||
) from exc
|
||||
|
||||
if "model_name" not in values.keys():
|
||||
values["model_name"] = DEFAULT_MODEL
|
||||
|
||||
model_name = values["model_name"]
|
||||
|
||||
available_models = cls.get_available_models()
|
||||
|
||||
if model_name not in available_models:
|
||||
raise ValueError(
|
||||
f"Model name {model_name} not found in available models: "
|
||||
f"{available_models}.",
|
||||
)
|
||||
|
||||
values["available_models"] = available_models
|
||||
|
||||
return values
|
||||
|
||||
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
|
||||
tiktoken_ = _import_tiktoken()
|
||||
if self.tiktoken_model_name is not None:
|
||||
model = self.tiktoken_model_name
|
||||
else:
|
||||
model = self.model_name
|
||||
# Returns the number of tokens used by a list of messages.
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
return model, encoding
|
||||
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Calculate num tokens with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if tools is not None:
|
||||
warnings.warn(
|
||||
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||
)
|
||||
if sys.version_info[1] <= 7:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
model, encoding = self._get_encoding_model()
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
105
venv/Lib/site-packages/langchain_community/chat_models/fake.py
Normal file
105
venv/Lib/site-packages/langchain_community/chat_models/fake.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List[BaseMessage]
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
generation = ChatGeneration(message=response)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-messages-list-chat-model"
|
||||
|
||||
|
||||
class FakeListChatModel(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
|
||||
responses: List
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-list-chat-model"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
return response
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[CallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
if self.sleep is not None:
|
||||
time.sleep(self.sleep)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
if self.sleep is not None:
|
||||
await asyncio.sleep(self.sleep)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
@@ -0,0 +1,372 @@
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.utils.env import get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr, model_validator
|
||||
|
||||
from langchain_community.adapters.openai import convert_message_to_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Any, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
"""Convert a delta response to a message chunk."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
additional_kwargs: Dict = {}
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict.name)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Any) -> BaseMessage:
|
||||
"""Convert a dict response to a message."""
|
||||
role = _dict.role
|
||||
content = _dict.content or ""
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
content = _dict.content
|
||||
additional_kwargs: Dict = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=content, name=_dict.name)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.26",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_fireworks.ChatFireworks",
|
||||
)
|
||||
class ChatFireworks(BaseChatModel):
|
||||
"""Fireworks Chat models."""
|
||||
|
||||
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
|
||||
model_kwargs: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512,
|
||||
"top_p": 1,
|
||||
}.copy()
|
||||
)
|
||||
fireworks_api_key: Optional[SecretStr] = None
|
||||
max_retries: int = 20
|
||||
use_retry: bool = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key in environment."""
|
||||
try:
|
||||
import fireworks.client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import fireworks-ai python package. "
|
||||
"Please install it with `pip install fireworks-ai`."
|
||||
) from e
|
||||
fireworks_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
fireworks.client.api_key = fireworks_api_key.get_secret_value()
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "fireworks-chat"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
self,
|
||||
self.use_retry,
|
||||
run_manager=run_manager,
|
||||
stop=stop,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
response = await acompletion_with_retry(
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
if llm_outputs[0] is None:
|
||||
return {}
|
||||
return llm_outputs[0]
|
||||
|
||||
def _create_chat_result(self, response: Any) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(res.message)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
)
|
||||
generations.append(gen)
|
||||
llm_output = {"model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
for chunk in completion_with_retry(
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
"stream": True,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
async for chunk in await acompletion_with_retry_streaming(
|
||||
self, self.use_retry, run_manager=run_manager, stop=stop, **params
|
||||
):
|
||||
choice = chunk.choices[0]
|
||||
chunk = _convert_delta_to_message_chunk(choice.delta, default_chunk_class)
|
||||
finish_reason = choice.finish_reason
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
|
||||
def conditional_decorator(
|
||||
condition: bool, decorator: Callable[[Any], Any]
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define conditional decorator.
|
||||
|
||||
Args:
|
||||
condition: The condition.
|
||||
decorator: The decorator.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
if condition:
|
||||
return decorator(func)
|
||||
return func
|
||||
|
||||
return actual_decorator
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return await fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry_streaming(
|
||||
llm: ChatFireworks,
|
||||
use_retry: bool,
|
||||
*,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call for streaming."""
|
||||
import fireworks.client
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return fireworks.client.ChatCompletion.acreate(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatFireworks,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define retry mechanism."""
|
||||
import fireworks.client
|
||||
|
||||
errors = [
|
||||
fireworks.client.error.RateLimitError,
|
||||
fireworks.client.error.InternalServerError,
|
||||
fireworks.client.error.BadGatewayError,
|
||||
fireworks.client.error.ServiceUnavailableError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.friendli import BaseFriendli
|
||||
|
||||
|
||||
def get_role(message: BaseMessage) -> str:
|
||||
"""Get role of the message.
|
||||
|
||||
Args:
|
||||
message (BaseMessage): The message object.
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when the message is of an unknown type.
|
||||
|
||||
Returns:
|
||||
str: The role of the message.
|
||||
"""
|
||||
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(message, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(message, SystemMessage):
|
||||
return "system"
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
"""Get a request of the Friendli chat API.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): Messages comprising the conversation so far.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The request for the Friendli chat API.
|
||||
"""
|
||||
return {
|
||||
"messages": [
|
||||
{"role": get_role(message), "content": message.content}
|
||||
for message in messages
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ChatFriendli(BaseChatModel, BaseFriendli):
|
||||
"""Friendli LLM for chat.
|
||||
|
||||
``friendli-client`` package should be installed with `pip install friendli-client`.
|
||||
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
|
||||
personal access token for the ``friendli_token`` argument.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import FriendliChat
|
||||
|
||||
chat = Friendli(
|
||||
model="meta-llama-3.1-8b-instruct", friendli_token="YOUR FRIENDLI TOKEN"
|
||||
)
|
||||
chat.invoke("What is generative AI?")
|
||||
"""
|
||||
|
||||
model: str = "meta-llama-3.1-8b-instruct"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"friendli_token": "FRIENDLI_TOKEN"}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Friendli completions API."""
|
||||
return {
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model": self.model, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "friendli-chat"
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = self._default_params
|
||||
if self.stop is not None and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop is not None:
|
||||
params["stop"] = self.stop
|
||||
else:
|
||||
params["stop"] = stop
|
||||
return {**params, **kwargs}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = self.client.chat.completions.create(
|
||||
**get_chat_request(messages), stream=True, model=self.model, **params
|
||||
)
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
stream = await self.async_client.chat.completions.create(
|
||||
**get_chat_request(messages), stream=True, model=self.model, **params
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
response = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": get_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
stream=False,
|
||||
model=self.model,
|
||||
**params,
|
||||
)
|
||||
|
||||
message = AIMessage(content=response.choices[0].message.content)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
response = await self.async_client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": get_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
for message in messages
|
||||
],
|
||||
stream=False,
|
||||
model=self.model,
|
||||
**params,
|
||||
)
|
||||
|
||||
message = AIMessage(content=response.choices[0].message.content)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.gigachat import _BaseGigaChat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gigachat.models as gm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_dict_to_message(message: gm.Messages) -> BaseMessage:
|
||||
from gigachat.models import FunctionCall, MessagesRole
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
if function_call := message.function_call:
|
||||
if isinstance(function_call, FunctionCall):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
elif isinstance(function_call, dict):
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
if message.role == MessagesRole.SYSTEM:
|
||||
return SystemMessage(content=message.content)
|
||||
elif message.role == MessagesRole.USER:
|
||||
return HumanMessage(content=message.content)
|
||||
elif message.role == MessagesRole.ASSISTANT:
|
||||
return AIMessage(content=message.content, additional_kwargs=additional_kwargs)
|
||||
else:
|
||||
raise TypeError(f"Got unknown role {message.role} {message}")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: gm.BaseMessage) -> gm.Messages:
|
||||
from gigachat.models import Messages, MessagesRole
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
return Messages(role=MessagesRole.SYSTEM, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return Messages(role=MessagesRole.USER, content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
return Messages(
|
||||
role=MessagesRole.ASSISTANT,
|
||||
content=message.content,
|
||||
function_call=message.additional_kwargs.get("function_call", None),
|
||||
)
|
||||
elif isinstance(message, ChatMessage):
|
||||
return Messages(role=MessagesRole(message.role), content=message.content)
|
||||
elif isinstance(message, FunctionMessage):
|
||||
return Messages(role=MessagesRole.FUNCTION, content=message.content)
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.5",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_gigachat.GigaChat",
|
||||
)
|
||||
class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
"""`GigaChat` large language models API.
|
||||
|
||||
To use, you should pass login and password to access GigaChat API or use token.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import GigaChat
|
||||
giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=...)
|
||||
"""
|
||||
|
||||
def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat:
|
||||
from gigachat.models import Chat
|
||||
|
||||
payload = Chat(
|
||||
messages=[_convert_message_to_dict(m) for m in messages],
|
||||
)
|
||||
|
||||
payload.functions = kwargs.get("functions", None)
|
||||
payload.model = self.model
|
||||
|
||||
if self.profanity_check is not None:
|
||||
payload.profanity_check = self.profanity_check
|
||||
if self.temperature is not None:
|
||||
payload.temperature = self.temperature
|
||||
if self.top_p is not None:
|
||||
payload.top_p = self.top_p
|
||||
if self.max_tokens is not None:
|
||||
payload.max_tokens = self.max_tokens
|
||||
if self.repetition_penalty is not None:
|
||||
payload.repetition_penalty = self.repetition_penalty
|
||||
if self.update_interval is not None:
|
||||
payload.update_interval = self.update_interval
|
||||
|
||||
if self.verbose:
|
||||
logger.warning("Giga request: %s", payload.dict())
|
||||
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: Any) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = _convert_dict_to_message(res.message)
|
||||
finish_reason = res.finish_reason
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append(gen)
|
||||
if finish_reason != "stop":
|
||||
logger.warning(
|
||||
"Giga generation stopped with reason: %s",
|
||||
finish_reason,
|
||||
)
|
||||
if self.verbose:
|
||||
logger.warning("Giga response: %s", message.content)
|
||||
llm_output = {"token_usage": response.usage, "model_name": response.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
response = self._client.chat(payload)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
response = await self._client.achat(payload)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
|
||||
for chunk in self._client.stream(payload):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
content = choice.get("delta", {}).get("content", {})
|
||||
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
|
||||
async for chunk in self._client.astream(payload):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
content = choice.get("delta", {}).get("content", {})
|
||||
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content)
|
||||
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""Wrapper around Google's PaLM Chat API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import google.generativeai as genai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGooglePalmError(Exception):
|
||||
"""Error with the `Google PaLM` API."""
|
||||
|
||||
|
||||
def _truncate_at_stop_tokens(
|
||||
text: str,
|
||||
stop: Optional[List[str]],
|
||||
) -> str:
|
||||
"""Truncates text at the earliest stop token found."""
|
||||
if stop is None:
|
||||
return text
|
||||
|
||||
for stop_token in stop:
|
||||
stop_token_idx = text.find(stop_token)
|
||||
if stop_token_idx != -1:
|
||||
text = text[:stop_token_idx]
|
||||
return text
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
response: genai.types.ChatResponse,
|
||||
stop: Optional[List[str]],
|
||||
) -> ChatResult:
|
||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||
if not response.candidates:
|
||||
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
|
||||
|
||||
generations: List[ChatGeneration] = []
|
||||
for candidate in response.candidates:
|
||||
author = candidate.get("author")
|
||||
if author is None:
|
||||
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
|
||||
|
||||
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
|
||||
if content is None:
|
||||
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
|
||||
|
||||
if author == "ai":
|
||||
generations.append(
|
||||
ChatGeneration(text=content, message=AIMessage(content=content))
|
||||
)
|
||||
elif author == "human":
|
||||
generations.append(
|
||||
ChatGeneration(
|
||||
text=content,
|
||||
message=HumanMessage(content=content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
generations.append(
|
||||
ChatGeneration(
|
||||
text=content,
|
||||
message=ChatMessage(role=author, content=content),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
|
||||
def _messages_to_prompt_dict(
|
||||
input_messages: List[BaseMessage],
|
||||
) -> genai.types.MessagePromptDict:
|
||||
"""Converts a list of LangChain messages into a PaLM API MessagePrompt structure."""
|
||||
import google.generativeai as genai
|
||||
|
||||
context: str = ""
|
||||
examples: List[genai.types.MessageDict] = []
|
||||
messages: List[genai.types.MessageDict] = []
|
||||
|
||||
remaining = list(enumerate(input_messages))
|
||||
|
||||
while remaining:
|
||||
index, input_message = remaining.pop(0)
|
||||
|
||||
if isinstance(input_message, SystemMessage):
|
||||
if index != 0:
|
||||
raise ChatGooglePalmError("System message must be first input message.")
|
||||
context = cast(str, input_message.content)
|
||||
elif isinstance(
|
||||
input_message, HumanMessage
|
||||
) and input_message.additional_kwargs.get("example"):
|
||||
if messages:
|
||||
raise ChatGooglePalmError(
|
||||
"Message examples must come before other messages."
|
||||
)
|
||||
_, next_input_message = remaining.pop(0)
|
||||
if isinstance(
|
||||
next_input_message, AIMessage
|
||||
) and next_input_message.additional_kwargs.get("example"):
|
||||
examples.extend(
|
||||
[
|
||||
genai.types.MessageDict(
|
||||
author="human", content=input_message.content
|
||||
),
|
||||
genai.types.MessageDict(
|
||||
author="ai", content=next_input_message.content
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ChatGooglePalmError(
|
||||
"Human example message must be immediately followed by an "
|
||||
" AI example response."
|
||||
)
|
||||
elif isinstance(
|
||||
input_message, AIMessage
|
||||
) and input_message.additional_kwargs.get("example"):
|
||||
raise ChatGooglePalmError(
|
||||
"AI example message must be immediately preceded by a Human "
|
||||
"example message."
|
||||
)
|
||||
elif isinstance(input_message, AIMessage):
|
||||
messages.append(
|
||||
genai.types.MessageDict(author="ai", content=input_message.content)
|
||||
)
|
||||
elif isinstance(input_message, HumanMessage):
|
||||
messages.append(
|
||||
genai.types.MessageDict(author="human", content=input_message.content)
|
||||
)
|
||||
elif isinstance(input_message, ChatMessage):
|
||||
messages.append(
|
||||
genai.types.MessageDict(
|
||||
author=input_message.role, content=input_message.content
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ChatGooglePalmError(
|
||||
"Messages without an explicit role not supported by PaLM API."
|
||||
)
|
||||
|
||||
return genai.types.MessagePromptDict(
|
||||
context=context,
|
||||
examples=examples,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
import google.api_core.exceptions
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.chat(**kwargs)
|
||||
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
return await llm.client.chat_async(**kwargs)
|
||||
|
||||
return await _achat_with_retry(**kwargs)
|
||||
|
||||
|
||||
class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
"""`Google PaLM` Chat models API.
|
||||
|
||||
To use you must have the google.generativeai Python package installed and
|
||||
either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatGooglePalm
|
||||
chat = ChatGooglePalm()
|
||||
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "models/chat-bison-001"
|
||||
"""Model name to use."""
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must be in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "google_palm"]
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
google_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key.get_secret_value())
|
||||
except ImportError:
|
||||
raise ChatGooglePalmError(
|
||||
"Could not import google.generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = chat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = await achat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "google-palm-chat"
|
||||
@@ -0,0 +1,401 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_dict_to_message,
|
||||
convert_message_to_dict,
|
||||
)
|
||||
from langchain_community.chat_models.openai import _convert_delta_to_message_chunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gpt_router.models import ChunkedGenerationResponse, GenerationResponse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com"
|
||||
|
||||
|
||||
class GPTRouterException(Exception):
|
||||
"""Error with the `GPTRouter APIs`"""
|
||||
|
||||
|
||||
class GPTRouterModel(BaseModel):
|
||||
"""GPTRouter model."""
|
||||
|
||||
name: str
|
||||
provider_name: str
|
||||
|
||||
|
||||
def get_ordered_generation_requests(
|
||||
models_priority_list: List[GPTRouterModel], **kwargs: Any
|
||||
) -> List:
|
||||
"""
|
||||
Return the body for the model router input.
|
||||
"""
|
||||
|
||||
from gpt_router.models import GenerationParams, ModelGenerationRequest
|
||||
|
||||
return [
|
||||
ModelGenerationRequest(
|
||||
model_name=model.name,
|
||||
provider_name=model.provider_name,
|
||||
order=index + 1,
|
||||
prompt_params=GenerationParams(**kwargs),
|
||||
)
|
||||
for index, model in enumerate(models_priority_list)
|
||||
]
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: GPTRouter,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
from gpt_router import exceptions
|
||||
|
||||
errors = [
|
||||
exceptions.GPTRouterApiTimeoutError,
|
||||
exceptions.GPTRouterInternalServerError,
|
||||
exceptions.GPTRouterNotAvailableError,
|
||||
exceptions.GPTRouterTooManyRequestsError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: GPTRouter,
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
ordered_generation_requests = get_ordered_generation_requests(
|
||||
models_priority_list, **kwargs
|
||||
)
|
||||
return llm.client.generate(
|
||||
ordered_generation_requests=ordered_generation_requests,
|
||||
is_stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: GPTRouter,
|
||||
models_priority_list: List[GPTRouterModel],
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
ordered_generation_requests = get_ordered_generation_requests(
|
||||
models_priority_list, **kwargs
|
||||
)
|
||||
return await llm.client.agenerate(
|
||||
ordered_generation_requests=ordered_generation_requests,
|
||||
is_stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
class GPTRouter(BaseChatModel):
|
||||
"""GPTRouter by Writesonic Inc.
|
||||
|
||||
For more information, see https://gpt-router.writesonic.com/docs
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
models_priority_list: List[GPTRouterModel] = Field(min_length=1)
|
||||
gpt_router_api_base: str = Field(default="")
|
||||
"""WriteSonic GPTRouter custom endpoint"""
|
||||
gpt_router_api_key: Optional[SecretStr] = None
|
||||
"""WriteSonic GPTRouter API Key"""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
max_retries: int = 4
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 256
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
values["gpt_router_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"gpt_router_api_base",
|
||||
"GPT_ROUTER_API_BASE",
|
||||
DEFAULT_API_BASE_URL,
|
||||
)
|
||||
|
||||
values["gpt_router_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"gpt_router_api_key",
|
||||
"GPT_ROUTER_API_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init(self) -> Self:
|
||||
try:
|
||||
from gpt_router.client import GPTRouterClient
|
||||
|
||||
except ImportError:
|
||||
raise GPTRouterException(
|
||||
"Could not import GPTRouter python package. "
|
||||
"Please install it with `pip install GPTRouter`."
|
||||
)
|
||||
|
||||
gpt_router_client = GPTRouterClient(
|
||||
self.gpt_router_api_base,
|
||||
self.gpt_router_api_key.get_secret_value()
|
||||
if self.gpt_router_api_key
|
||||
else None,
|
||||
)
|
||||
self.client = gpt_router_client
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"gpt_router_api_key": "GPT_ROUTER_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "gpt-router-chat"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"models_priority_list": self.models_priority_list},
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling GPTRouter API."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": False}
|
||||
response = completion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": False}
|
||||
response = await acompletion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_generation_chunk(
|
||||
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
|
||||
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
{"content": data.get("text", "")}, default_chunk_class
|
||||
)
|
||||
finish_reason = data.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
return gen_chunk, default_chunk_class
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
generator_response = completion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
for chunk in generator_response:
|
||||
if chunk.event != "update":
|
||||
continue
|
||||
|
||||
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||
chunk.data, default_chunk_class
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=str(chunk.message.content), chunk=chunk
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
generator_response = acompletion_with_retry(
|
||||
self,
|
||||
messages=message_dicts,
|
||||
models_priority_list=self.models_priority_list,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
async for chunk in await generator_response:
|
||||
if chunk.event != "update":
|
||||
continue
|
||||
|
||||
chunk, default_chunk_class = self._create_chat_generation_chunk(
|
||||
chunk.data, default_chunk_class
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=str(chunk.message.content), chunk=chunk
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: GenerationResponse) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
message = convert_dict_to_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": res.text,
|
||||
}
|
||||
)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.finish_reason),
|
||||
)
|
||||
generations.append(gen)
|
||||
llm_output = {"token_usage": response.meta, "model": response.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
@@ -0,0 +1,235 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_community.llms.huggingface_hub import HuggingFaceHub
|
||||
from langchain_community.llms.huggingface_text_gen_inference import (
|
||||
HuggingFaceTextGenInference,
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.0.37",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_huggingface.ChatHuggingFace",
|
||||
)
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""
|
||||
Wrapper for using Hugging Face LLM's as ChatModels.
|
||||
|
||||
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
||||
and `HuggingFaceHub` LLMs.
|
||||
|
||||
Upon instantiating this class, the model_id is resolved from the url
|
||||
provided to the LLM, and the appropriate tokenizer is loaded from
|
||||
the HuggingFace Hub.
|
||||
|
||||
Adapted from: https://python.langchain.com/docs/integrations/chat/llama2_chat
|
||||
"""
|
||||
|
||||
llm: Any
|
||||
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, or
|
||||
HuggingFaceHub."""
|
||||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
|
||||
tokenizer: Any = None
|
||||
model_id: Optional[str] = None
|
||||
streaming: bool = False
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self._resolve_model_id()
|
||||
|
||||
self.tokenizer = (
|
||||
AutoTokenizer.from_pretrained(self.model_id)
|
||||
if self.tokenizer is None
|
||||
else self.tokenizer
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_llm(self) -> Self:
|
||||
if not isinstance(
|
||||
self.llm,
|
||||
(HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub),
|
||||
):
|
||||
raise TypeError(
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
f"HuggingFaceEndpoint, HuggingFaceHub, received {type(self.llm)}"
|
||||
)
|
||||
return self
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
request = self._to_chat_prompt(messages)
|
||||
|
||||
for data in self.llm.stream(request, **kwargs):
|
||||
delta = data
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
request = self._to_chat_prompt(messages)
|
||||
async for data in self.llm.astream(request, **kwargs):
|
||||
delta = data
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages_dicts, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||
"""Convert LangChain message to ChatML format."""
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
|
||||
return {"role": role, "content": message.content}
|
||||
|
||||
@staticmethod
|
||||
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
||||
chat_generations = []
|
||||
|
||||
for g in llm_result.generations[0]:
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=g.text), generation_info=g.generation_info
|
||||
)
|
||||
chat_generations.append(chat_generation)
|
||||
|
||||
return ChatResult(
|
||||
generations=chat_generations, llm_output=llm_result.llm_output
|
||||
)
|
||||
|
||||
def _resolve_model_id(self) -> None:
|
||||
"""Resolve the model_id from the LLM's inference_server_url"""
|
||||
|
||||
from huggingface_hub import list_inference_endpoints
|
||||
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
if isinstance(self.llm, HuggingFaceHub) or (
|
||||
hasattr(self.llm, "repo_id") and self.llm.repo_id
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif isinstance(self.llm, HuggingFaceTextGenInference):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
|
||||
for endpoint in available_endpoints:
|
||||
if endpoint.url == endpoint_url:
|
||||
self.model_id = endpoint.repository
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
"Failed to resolve model_id:"
|
||||
f"Could not find model id for inference server: {endpoint_url}"
|
||||
"Make sure that your Hugging Face token has access to the endpoint."
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface-chat-wrapper"
|
||||
111
venv/Lib/site-packages/langchain_community/chat_models/human.py
Normal file
111
venv/Lib/site-packages/langchain_community/chat_models/human.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""ChatModel wrapper which returns user input as the response.."""
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
def _display_messages(messages: List[BaseMessage]) -> None:
|
||||
dict_messages = messages_to_dict(messages)
|
||||
for message in dict_messages:
|
||||
yaml_string = yaml.dump(
|
||||
message,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
width=10000,
|
||||
line_break=None,
|
||||
)
|
||||
print("\n", "======= start of message =======", "\n\n") # noqa: T201
|
||||
print(yaml_string) # noqa: T201
|
||||
print("======= end of message =======", "\n\n") # noqa: T201
|
||||
|
||||
|
||||
def _collect_yaml_input(
|
||||
messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> BaseMessage:
|
||||
"""Collects and returns user input as a single string."""
|
||||
lines = []
|
||||
while True:
|
||||
line = input()
|
||||
if not line.strip():
|
||||
break
|
||||
if stop and any(seq in line for seq in stop):
|
||||
break
|
||||
lines.append(line)
|
||||
yaml_string = "\n".join(lines)
|
||||
|
||||
# Try to parse the input string as YAML
|
||||
try:
|
||||
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
|
||||
if message is None:
|
||||
return HumanMessage(content="")
|
||||
if stop:
|
||||
if isinstance(message.content, str):
|
||||
message.content = enforce_stop_tokens(message.content, stop)
|
||||
else:
|
||||
raise ValueError("Cannot use when output is not a string.")
|
||||
return message
|
||||
except yaml.YAMLError:
|
||||
raise ValueError("Invalid YAML string entered.")
|
||||
except ValueError:
|
||||
raise ValueError("Invalid message entered.")
|
||||
|
||||
|
||||
class HumanInputChatModel(BaseChatModel):
|
||||
"""ChatModel which returns user input as the response."""
|
||||
|
||||
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
|
||||
message_func: Callable = Field(default_factory=lambda: _display_messages)
|
||||
separator: str = "\n"
|
||||
input_kwargs: Mapping[str, Any] = {}
|
||||
message_kwargs: Mapping[str, Any] = {}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"input_func": self.input_func.__name__,
|
||||
"message_func": self.message_func.__name__,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Returns the type of LLM."""
|
||||
return "human-input-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Displays the messages to the user and returns their input as a response.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages to be displayed to the user.
|
||||
stop (Optional[List[str]]): A list of stop strings.
|
||||
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
|
||||
|
||||
Returns:
|
||||
ChatResult: The user's input as a response.
|
||||
"""
|
||||
self.message_func(messages, **self.message_kwargs)
|
||||
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
||||
@@ -0,0 +1,280 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"Role": message.role, "Content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"Role": "system", "Content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"Role": "user", "Content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"Role": "assistant", "Content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["Role"]
|
||||
if role == "system":
|
||||
return SystemMessage(content=_dict.get("Content", "") or "")
|
||||
elif role == "user":
|
||||
return HumanMessage(content=_dict["Content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("Content", "") or "")
|
||||
else:
|
||||
return ChatMessage(content=_dict["Content"], role=role)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("Role")
|
||||
content = _dict.get("Content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for choice in response["Choices"]:
|
||||
message = _convert_dict_to_message(choice["Message"])
|
||||
message.id = response.get("Id", "")
|
||||
generations.append(ChatGeneration(message=message))
|
||||
|
||||
token_usage = response["Usage"]
|
||||
llm_output = {"token_usage": token_usage}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
class ChatHunyuan(BaseChatModel):
|
||||
"""Tencent Hunyuan chat models API by Tencent.
|
||||
|
||||
For more information, see https://cloud.tencent.com/document/product/1729
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {
|
||||
"hunyuan_app_id": "HUNYUAN_APP_ID",
|
||||
"hunyuan_secret_id": "HUNYUAN_SECRET_ID",
|
||||
"hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
|
||||
}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
hunyuan_app_id: Optional[int] = None
|
||||
"""Hunyuan App ID"""
|
||||
hunyuan_secret_id: Optional[str] = None
|
||||
"""Hunyuan Secret ID"""
|
||||
hunyuan_secret_key: Optional[SecretStr] = None
|
||||
"""Hunyuan Secret Key"""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
request_timeout: int = 60
|
||||
"""Timeout for requests to Hunyuan API. Default is 60 seconds."""
|
||||
temperature: float = 1.0
|
||||
"""What sampling temperature to use."""
|
||||
top_p: float = 1.0
|
||||
"""What probability mass to use."""
|
||||
model: str = "hunyuan-lite"
|
||||
"""What Model to use.
|
||||
Optional model:
|
||||
- hunyuan-lite
|
||||
- hunyuan-standard
|
||||
- hunyuan-standard-256K
|
||||
- hunyuan-pro
|
||||
- hunyuan-code
|
||||
- hunyuan-role
|
||||
- hunyuan-functioncall
|
||||
- hunyuan-vision
|
||||
"""
|
||||
stream_moderation: bool = False
|
||||
"""Whether to review the results or not when streaming is true."""
|
||||
enable_enhancement: bool = True
|
||||
"""Whether to enhancement the results or not."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["hunyuan_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_app_id",
|
||||
"HUNYUAN_APP_ID",
|
||||
)
|
||||
values["hunyuan_secret_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_id",
|
||||
"HUNYUAN_SECRET_ID",
|
||||
)
|
||||
values["hunyuan_secret_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_key",
|
||||
"HUNYUAN_SECRET_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Hunyuan API."""
|
||||
normal_params = {
|
||||
"Temperature": self.temperature,
|
||||
"TopP": self.top_p,
|
||||
"Model": self.model,
|
||||
"Stream": self.streaming,
|
||||
"StreamModeration": self.stream_moderation,
|
||||
"EnableEnhancement": self.enable_enhancement,
|
||||
}
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
return _create_chat_result(json.loads(res.to_json_string()))
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, **kwargs)
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in res:
|
||||
chunk = chunk.get("data", "")
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
response = json.loads(chunk)
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error from Hunyuan api response: {response}")
|
||||
|
||||
for choice in response["Choices"]:
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["Delta"], default_chunk_class
|
||||
)
|
||||
chunk.id = response.get("Id", "")
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
|
||||
if self.hunyuan_secret_key is None:
|
||||
raise ValueError("Hunyuan secret key is not set.")
|
||||
|
||||
try:
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tencentcloud python package. "
|
||||
"Please install it with `pip install tencentcloud-sdk-python`."
|
||||
)
|
||||
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
cred = credential.Credential(
|
||||
self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
|
||||
)
|
||||
client = hunyuan_client.HunyuanClient(cred, "")
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Messages": [_convert_message_to_dict(m) for m in messages],
|
||||
**parameters,
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
resp = client.ChatCompletions(req)
|
||||
return resp
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "hunyuan-chat"
|
||||
@@ -0,0 +1,228 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
|
||||
class ChatParams(BaseModel, extra="allow"):
|
||||
"""Parameters for the `Javelin AI Gateway` LLM."""
|
||||
|
||||
temperature: float = 0.0
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class ChatJavelinAIGateway(BaseChatModel):
|
||||
"""`Javelin AI Gateway` chat models API.
|
||||
|
||||
To use, you should have the ``javelin_sdk`` python package installed.
|
||||
For more information, see https://docs.getjavelin.io
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatJavelinAIGateway
|
||||
|
||||
chat = ChatJavelinAIGateway(
|
||||
gateway_uri="<javelin-ai-gateway-uri>",
|
||||
route="<javelin-ai-gateway-chat-route>",
|
||||
params={
|
||||
"temperature": 0.1
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
route: str
|
||||
"""The route to use for the Javelin AI Gateway API."""
|
||||
|
||||
gateway_uri: Optional[str] = None
|
||||
"""The URI for the Javelin AI Gateway API."""
|
||||
|
||||
params: Optional[ChatParams] = None
|
||||
"""Parameters for the Javelin AI Gateway LLM."""
|
||||
|
||||
client: Any = None
|
||||
"""javelin client."""
|
||||
|
||||
javelin_api_key: Optional[SecretStr] = Field(None, alias="api_key")
|
||||
"""The API key for the Javelin AI Gateway."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
try:
|
||||
from javelin_sdk import (
|
||||
JavelinClient,
|
||||
UnauthorizedError,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import javelin_sdk python package. "
|
||||
"Please install it with `pip install javelin_sdk`."
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if self.gateway_uri:
|
||||
try:
|
||||
self.client = JavelinClient(
|
||||
base_url=self.gateway_uri,
|
||||
api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(),
|
||||
)
|
||||
except UnauthorizedError as e:
|
||||
raise ValueError("Javelin: Incorrect API Key.") from e
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"gateway_uri": self.gateway_uri,
|
||||
"javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(),
|
||||
"route": self.route,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = self.client.query_route(self.route, query_body=data)
|
||||
|
||||
return ChatJavelinAIGateway._create_chat_result(resp.dict())
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = await self.client.aquery_route(self.route, query_body=data)
|
||||
|
||||
return ChatJavelinAIGateway._create_chat_result(resp.dict())
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||
return {
|
||||
**self._default_params,
|
||||
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "javelin-ai-gateway-chat"
|
||||
|
||||
@staticmethod
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict["content"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
@staticmethod
|
||||
def _raise_functions_not_supported() -> None:
|
||||
raise ValueError(
|
||||
"Function messages are not supported by the Javelin AI Gateway. Please"
|
||||
" create a feature request at https://docs.getjavelin.io"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by the Javelin AI Gateway. Please"
|
||||
" create a feature request at https://docs.getjavelin.io"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown message type: {message}")
|
||||
|
||||
if "function_call" in message.additional_kwargs:
|
||||
ChatJavelinAIGateway._raise_functions_not_supported()
|
||||
if message.additional_kwargs:
|
||||
logger.warning(
|
||||
"Additional message arguments are unsupported by Javelin AI Gateway "
|
||||
" and will be ignored: %s",
|
||||
message.additional_kwargs,
|
||||
)
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for candidate in response["llm_response"]["choices"]:
|
||||
message = ChatJavelinAIGateway._convert_dict_to_message(
|
||||
candidate["message"]
|
||||
)
|
||||
message_metadata = candidate.get("metadata", {})
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(message_metadata),
|
||||
)
|
||||
generations.append(gen)
|
||||
|
||||
response_metadata = response.get("metadata", {})
|
||||
return ChatResult(generations=generations, llm_output=response_metadata)
|
||||
@@ -0,0 +1,414 @@
|
||||
"""JinaChat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
pre_init,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(llm.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
return await llm.client.acreate(**kwargs)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
content = _dict["content"] or ""
|
||||
return AIMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"name": message.name,
|
||||
"content": message.content,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class JinaChat(BaseChatModel):
|
||||
"""`Jina AI` Chat models API.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
environment variable ``JINACHAT_API_KEY`` set to your API key, which you
|
||||
can generate at https://chat.jina.ai/api.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import JinaChat
|
||||
chat = JinaChat()
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"jinachat_api_key": "JINACHAT_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return False
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
jinachat_api_key: Optional[SecretStr] = None
|
||||
"""Base URL path for API requests,
|
||||
leave blank if not using a proxy or service emulator."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout for requests to JinaChat completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["jinachat_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY")
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling JinaChat API."""
|
||||
return {
|
||||
"request_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||
import openai
|
||||
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(messages=message_dicts, **params):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = dict(self._invocation_params)
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
llm_output = {"token_usage": response["usage"]}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, **params
|
||||
):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(self, messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Mapping[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
jinachat_creds: Dict[str, Any] = {
|
||||
"api_key": self.jinachat_api_key
|
||||
and self.jinachat_api_key.get_secret_value(),
|
||||
"api_base": "https://api.chat.jina.ai/v1",
|
||||
"model": "jinachat",
|
||||
}
|
||||
return {**jinachat_creds, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "jinachat"
|
||||
@@ -0,0 +1,603 @@
|
||||
##
|
||||
# Copyright (c) 2024, Chad Juliano, Kinetica DB Inc.
|
||||
##
|
||||
"""Kinetica SQL generation LLM API."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
|
||||
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gpudb
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Kinetica pydantic API datatypes
|
||||
|
||||
|
||||
class _KdtSuggestContext(BaseModel):
|
||||
"""pydantic API request type"""
|
||||
|
||||
table: Optional[str] = Field(default=None, title="Name of table")
|
||||
description: Optional[str] = Field(default=None, title="Table description")
|
||||
columns: List[str] = Field(default=[], title="Table columns list")
|
||||
rules: Optional[List[str]] = Field(
|
||||
default=None, title="Rules that apply to the table."
|
||||
)
|
||||
samples: Optional[Dict] = Field(
|
||||
default=None, title="Samples that apply to the entire context."
|
||||
)
|
||||
|
||||
def to_system_str(self) -> str:
|
||||
lines = []
|
||||
lines.append(f"CREATE TABLE {self.table} AS")
|
||||
lines.append("(")
|
||||
|
||||
if not self.columns or len(self.columns) == 0:
|
||||
ValueError("columns list can't be null.")
|
||||
|
||||
columns = []
|
||||
for column in self.columns:
|
||||
column = column.replace('"', "").strip()
|
||||
columns.append(f" {column}")
|
||||
lines.append(",\n".join(columns))
|
||||
lines.append(");")
|
||||
|
||||
if self.description:
|
||||
lines.append(f"COMMENT ON TABLE {self.table} IS '{self.description}';")
|
||||
|
||||
if self.rules and len(self.rules) > 0:
|
||||
lines.append(
|
||||
f"-- When querying table {self.table} the following rules apply:"
|
||||
)
|
||||
for rule in self.rules:
|
||||
lines.append(f"-- * {rule}")
|
||||
|
||||
result = "\n".join(lines)
|
||||
return result
|
||||
|
||||
|
||||
class _KdtSuggestPayload(BaseModel):
|
||||
"""pydantic API request type"""
|
||||
|
||||
question: Optional[str] = None
|
||||
context: List[_KdtSuggestContext]
|
||||
|
||||
def get_system_str(self) -> str:
|
||||
lines = []
|
||||
for table_context in self.context:
|
||||
if table_context.table is None:
|
||||
continue
|
||||
context_str = table_context.to_system_str()
|
||||
lines.append(context_str)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
def get_messages(self) -> List[Dict]:
|
||||
messages = []
|
||||
for context in self.context:
|
||||
if context.samples is None:
|
||||
continue
|
||||
for question, answer in context.samples.items():
|
||||
# unescape double quotes
|
||||
answer = answer.replace("''", "'")
|
||||
|
||||
messages.append(dict(role="user", content=question or ""))
|
||||
messages.append(dict(role="assistant", content=answer))
|
||||
return messages
|
||||
|
||||
def to_completion(self) -> Dict:
|
||||
messages = []
|
||||
messages.append(dict(role="system", content=self.get_system_str()))
|
||||
messages.extend(self.get_messages())
|
||||
messages.append(dict(role="user", content=self.question or ""))
|
||||
response = dict(messages=messages)
|
||||
return response
|
||||
|
||||
|
||||
class _KdtoSuggestRequest(BaseModel):
|
||||
"""pydantic API request type"""
|
||||
|
||||
payload: _KdtSuggestPayload
|
||||
|
||||
|
||||
class _KdtMessage(BaseModel):
|
||||
"""pydantic API response type"""
|
||||
|
||||
role: str = Field(default="", title="One of [user|assistant|system]")
|
||||
content: str
|
||||
|
||||
|
||||
class _KdtChoice(BaseModel):
|
||||
"""pydantic API response type"""
|
||||
|
||||
index: int
|
||||
message: Optional[_KdtMessage] = Field(default=None, title="The generated SQL")
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class _KdtUsage(BaseModel):
|
||||
"""pydantic API response type"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class _KdtSqlResponse(BaseModel):
|
||||
"""pydantic API response type"""
|
||||
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[_KdtChoice]
|
||||
usage: _KdtUsage
|
||||
prompt: str = Field(default="", title="The input question")
|
||||
|
||||
|
||||
class _KdtCompletionResponse(BaseModel):
|
||||
"""pydantic API response type"""
|
||||
|
||||
status: str
|
||||
data: _KdtSqlResponse
|
||||
|
||||
|
||||
class _KineticaLlmFileContextParser:
|
||||
"""Parser for Kinetica LLM context datafiles."""
|
||||
|
||||
# parse line into a dict containing role and content
|
||||
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
||||
|
||||
@classmethod
|
||||
def _removesuffix(cls, text: str, suffix: str) -> str:
|
||||
if suffix and text.endswith(suffix):
|
||||
return text[: -len(suffix)]
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def parse_dialogue_file(cls, input_file: os.PathLike) -> Dict:
|
||||
path = Path(input_file)
|
||||
# schema = path.name.removesuffix(".txt") python 3.9
|
||||
schema = cls._removesuffix(path.name, ".txt")
|
||||
|
||||
lines = open(input_file).read()
|
||||
return cls.parse_dialogue(lines, schema)
|
||||
|
||||
@classmethod
|
||||
def parse_dialogue(cls, text: str, schema: str) -> Dict:
|
||||
messages = []
|
||||
system = None
|
||||
|
||||
lines = text.split("<|end|>")
|
||||
user_message = None
|
||||
|
||||
for idx, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
if len(line) == 0:
|
||||
continue
|
||||
|
||||
match = cls.PARSER.match(line)
|
||||
if match is None:
|
||||
raise ValueError(f"Could not find starting token in: {line}")
|
||||
|
||||
groupdict = match.groupdict()
|
||||
role = groupdict["role"]
|
||||
|
||||
if role == "system":
|
||||
if system is not None:
|
||||
raise ValueError(f"Only one system token allowed in: {line}")
|
||||
system = groupdict["content"]
|
||||
elif role == "user":
|
||||
if user_message is not None:
|
||||
raise ValueError(
|
||||
f"Found user token without assistant token: {line}"
|
||||
)
|
||||
user_message = groupdict
|
||||
elif role == "assistant":
|
||||
if user_message is None:
|
||||
raise Exception(f"Found assistant token without user token: {line}")
|
||||
messages.append(user_message)
|
||||
messages.append(groupdict)
|
||||
user_message = None
|
||||
else:
|
||||
raise ValueError(f"Unknown token: {role}")
|
||||
|
||||
return {"schema": schema, "system": system, "messages": messages}
|
||||
|
||||
|
||||
class KineticaUtil:
|
||||
"""Kinetica utility functions."""
|
||||
|
||||
@classmethod
|
||||
def create_kdbc(
|
||||
cls,
|
||||
url: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
passwd: Optional[str] = None,
|
||||
) -> "gpudb.GPUdb":
|
||||
"""Create a connectica connection object and verify connectivity.
|
||||
|
||||
If None is passed for one or more of the parameters then an attempt will be made
|
||||
to retrieve the value from the related environment variable.
|
||||
|
||||
Args:
|
||||
url: The Kinetica URL or ``KINETICA_URL`` if None.
|
||||
user: The Kinetica user or ``KINETICA_USER`` if None.
|
||||
passwd: The Kinetica password or ``KINETICA_PASSWD`` if None.
|
||||
|
||||
Returns:
|
||||
The Kinetica connection object.
|
||||
"""
|
||||
|
||||
try:
|
||||
import gpudb
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"Could not import Kinetica python package. "
|
||||
"Please install it with `pip install gpudb`."
|
||||
)
|
||||
|
||||
url = cls._get_env("KINETICA_URL", url)
|
||||
user = cls._get_env("KINETICA_USER", user)
|
||||
passwd = cls._get_env("KINETICA_PASSWD", passwd)
|
||||
|
||||
options = gpudb.GPUdb.Options()
|
||||
options.username = user
|
||||
options.password = passwd
|
||||
options.skip_ssl_cert_verification = True
|
||||
options.disable_failover = True
|
||||
options.logging_level = "INFO"
|
||||
kdbc = gpudb.GPUdb(host=url, options=options)
|
||||
|
||||
LOG.info(
|
||||
"Connected to Kinetica: {}. (api={}, server={})".format(
|
||||
kdbc.get_url(), version("gpudb"), kdbc.server_version
|
||||
)
|
||||
)
|
||||
|
||||
return kdbc
|
||||
|
||||
@classmethod
|
||||
def _get_env(cls, name: str, default: Optional[str]) -> str:
|
||||
"""Get an environment variable or use a default."""
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
result = os.getenv(name)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
raise ValueError(
|
||||
f"Parameter was not passed and not found in the environment: {name}"
|
||||
)
|
||||
|
||||
|
||||
class ChatKinetica(BaseChatModel):
|
||||
"""Kinetica LLM Chat Model API.
|
||||
|
||||
Prerequisites for using this API:
|
||||
|
||||
* The ``gpudb`` and ``typeguard`` packages installed.
|
||||
* A Kinetica DB instance.
|
||||
* Kinetica host specified in ``KINETICA_URL``
|
||||
* Kinetica login specified ``KINETICA_USER``, and ``KINETICA_PASSWD``.
|
||||
* An LLM context that specifies the tables and samples to use for inferencing.
|
||||
|
||||
This API is intended to interact with the Kinetica SqlAssist LLM that supports
|
||||
generation of SQL from natural language.
|
||||
|
||||
In the Kinetica LLM workflow you create an LLM context in the database that provides
|
||||
information needed for infefencing that includes tables, annotations, rules, and
|
||||
samples. Invoking ``load_messages_from_context()`` will retrieve the contxt
|
||||
information from the database so that it can be used to create a chat prompt.
|
||||
|
||||
The chat prompt consists of a ``SystemMessage`` and pairs of
|
||||
``HumanMessage``/``AIMessage`` that contain the samples which are question/SQL
|
||||
pairs. You can append pairs samples to this list but it is not intended to
|
||||
facilitate a typical natural language conversation.
|
||||
|
||||
When you create a chain from the chat prompt and execute it, the Kinetica LLM will
|
||||
generate SQL from the input. Optionally you can use ``KineticaSqlOutputParser`` to
|
||||
execute the SQL and return the result as a dataframe.
|
||||
|
||||
The following example creates an LLM using the environment variables for the
|
||||
Kinetica connection. This will fail if the API is unable to connect to the database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models.kinetica import KineticaChatLLM
|
||||
kinetica_llm = KineticaChatLLM()
|
||||
|
||||
If you prefer to pass connection information directly then you can create a
|
||||
connection using ``KineticaUtil.create_kdbc()``.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models.kinetica import (
|
||||
KineticaChatLLM, KineticaUtil)
|
||||
kdbc = KineticaUtil._create_kdbc(url=url, user=user, passwd=passwd)
|
||||
kinetica_llm = KineticaChatLLM(kdbc=kdbc)
|
||||
"""
|
||||
|
||||
kdbc: Any = Field(exclude=True)
|
||||
""" Kinetica DB connection. """
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Pydantic object validator."""
|
||||
|
||||
kdbc = values.get("kdbc", None)
|
||||
if kdbc is None:
|
||||
kdbc = KineticaUtil.create_kdbc()
|
||||
values["kdbc"] = kdbc
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "kinetica-sqlassist"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return dict(
|
||||
kinetica_version=str(self.kdbc.server_version), api_version=version("gpudb")
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if stop is not None:
|
||||
raise ValueError("stop kwargs are not permitted.")
|
||||
|
||||
dict_messages = [self._convert_message_to_dict(m) for m in messages]
|
||||
sql_response = self._submit_completion(dict_messages)
|
||||
|
||||
response_message = cast(_KdtMessage, sql_response.choices[0].message)
|
||||
generated_dict = response_message.model_dump()
|
||||
|
||||
generated_message = self._convert_message_from_dict(generated_dict)
|
||||
|
||||
llm_output = dict(
|
||||
input_tokens=sql_response.usage.prompt_tokens,
|
||||
output_tokens=sql_response.usage.completion_tokens,
|
||||
model_name=sql_response.model,
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=generated_message)],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
def load_messages_from_context(self, context_name: str) -> List:
|
||||
"""Load a lanchain prompt from a Kinetica context.
|
||||
|
||||
A Kinetica Context is an object created with the Kinetica Workbench UI or with
|
||||
SQL syntax. This function will convert the data in the context to a list of
|
||||
messages that can be used as a prompt. The messages will contain a
|
||||
``SystemMessage`` followed by pairs of ``HumanMessage``/``AIMessage`` that
|
||||
contain the samples.
|
||||
|
||||
Args:
|
||||
context_name: The name of an LLM context in the database.
|
||||
|
||||
Returns:
|
||||
A list of messages containing the information from the context.
|
||||
"""
|
||||
|
||||
# query kinetica for the prompt
|
||||
sql = f"GENERATE PROMPT WITH OPTIONS (CONTEXT_NAMES = '{context_name}')"
|
||||
|
||||
result = self._execute_sql(sql)
|
||||
prompt = result["Prompt"]
|
||||
prompt_json = json.loads(prompt)
|
||||
|
||||
# convert the prompt to messages
|
||||
# request = SuggestRequest.model_validate(prompt_json) # pydantic v2
|
||||
|
||||
request = _KdtoSuggestRequest.model_validate(prompt_json)
|
||||
payload = request.payload
|
||||
|
||||
dict_messages = []
|
||||
dict_messages.append(dict(role="system", content=payload.get_system_str()))
|
||||
|
||||
dict_messages.extend(payload.get_messages())
|
||||
messages = [self._convert_message_from_dict(m) for m in dict_messages]
|
||||
return messages
|
||||
|
||||
def _submit_completion(self, messages: List[Dict]) -> _KdtSqlResponse:
|
||||
"""Submit a /chat/completions request to Kinetica."""
|
||||
|
||||
request = dict(messages=messages)
|
||||
request_json = json.dumps(request)
|
||||
response_raw = self.kdbc._GPUdb__submit_request_json(
|
||||
"/chat/completions", request_json
|
||||
)
|
||||
response_json = json.loads(response_raw)
|
||||
|
||||
status = response_json["status"]
|
||||
if status != "OK":
|
||||
message = response_json["message"]
|
||||
match_resp = re.compile(r"response:({.*})")
|
||||
result = match_resp.search(message)
|
||||
if result is not None:
|
||||
response = result.group(1)
|
||||
response_json = json.loads(response)
|
||||
message = response_json["message"]
|
||||
raise ValueError(message)
|
||||
|
||||
data = response_json["data"]
|
||||
# response = CompletionResponse.model_validate(data) # pydantic v2
|
||||
response = _KdtCompletionResponse.model_validate(data)
|
||||
if response.status != "OK":
|
||||
raise ValueError("SQL Generation failed")
|
||||
return response.data
|
||||
|
||||
def _execute_sql(self, sql: str) -> Dict:
|
||||
"""Execute an SQL query and return the result."""
|
||||
|
||||
response = self.kdbc.execute_sql_and_decode(
|
||||
sql, limit=1, get_column_major=False
|
||||
)
|
||||
|
||||
status_info = response["status_info"]
|
||||
if status_info["status"] != "OK":
|
||||
message = status_info["message"]
|
||||
raise ValueError(message)
|
||||
|
||||
records = response["records"]
|
||||
if len(records) != 1:
|
||||
raise ValueError("No records returned.")
|
||||
|
||||
record = records[0]
|
||||
response_dict = {}
|
||||
for col, val in record.items():
|
||||
response_dict[col] = val
|
||||
return response_dict
|
||||
|
||||
@classmethod
|
||||
def load_messages_from_datafile(cls, sa_datafile: Path) -> List[BaseMessage]:
|
||||
"""Load a lanchain prompt from a Kinetica context datafile."""
|
||||
datafile_dict = _KineticaLlmFileContextParser.parse_dialogue_file(sa_datafile)
|
||||
messages = cls._convert_dict_to_messages(datafile_dict)
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def _convert_message_to_dict(cls, message: BaseMessage) -> Dict:
|
||||
"""Convert a single message to a BaseMessage."""
|
||||
|
||||
content = cast(str, message.content)
|
||||
if isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "assistant"
|
||||
elif isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {message}")
|
||||
|
||||
result_message = dict(role=role, content=content)
|
||||
return result_message
|
||||
|
||||
@classmethod
|
||||
def _convert_message_from_dict(cls, message: Dict) -> BaseMessage:
|
||||
"""Convert a single message from a BaseMessage."""
|
||||
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
raise ValueError(f"Got unsupported role: {role}")
|
||||
|
||||
@classmethod
|
||||
def _convert_dict_to_messages(cls, sa_data: Dict) -> List[BaseMessage]:
|
||||
"""Convert a dict to a list of BaseMessages."""
|
||||
|
||||
schema = sa_data["schema"]
|
||||
system = sa_data["system"]
|
||||
messages = sa_data["messages"]
|
||||
LOG.info(f"Importing prompt for schema: {schema}")
|
||||
|
||||
result_list: List[BaseMessage] = []
|
||||
result_list.append(SystemMessage(content=system))
|
||||
result_list.extend([cls._convert_message_from_dict(m) for m in messages])
|
||||
return result_list
|
||||
|
||||
|
||||
class KineticaSqlResponse(BaseModel):
|
||||
"""Response containing SQL and the fetched data.
|
||||
|
||||
This object is returned by a chain with ``KineticaSqlOutputParser`` and it contains
|
||||
the generated SQL and related Pandas Dataframe fetched from the database.
|
||||
"""
|
||||
|
||||
sql: str = Field(default="")
|
||||
"""The generated SQL."""
|
||||
|
||||
# dataframe: "pd.DataFrame" = Field(default=None)
|
||||
dataframe: Any = Field(default=None)
|
||||
"""The Pandas dataframe containing the fetched data."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class KineticaSqlOutputParser(BaseOutputParser[KineticaSqlResponse]):
|
||||
"""Fetch and return data from the Kinetica LLM.
|
||||
|
||||
This object is used as the last element of a chain to execute generated SQL and it
|
||||
will output a ``KineticaSqlResponse`` containing the SQL and a pandas dataframe with
|
||||
the fetched data.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models.kinetica import (
|
||||
KineticaChatLLM, KineticaSqlOutputParser)
|
||||
kinetica_llm = KineticaChatLLM()
|
||||
|
||||
# create chain
|
||||
ctx_messages = kinetica_llm.load_messages_from_context(self.context_name)
|
||||
ctx_messages.append(("human", "{input}"))
|
||||
prompt_template = ChatPromptTemplate.from_messages(ctx_messages)
|
||||
chain = (
|
||||
prompt_template
|
||||
| kinetica_llm
|
||||
| KineticaSqlOutputParser(kdbc=kinetica_llm.kdbc)
|
||||
)
|
||||
sql_response: KineticaSqlResponse = chain.invoke(
|
||||
{"input": "What are the female users ordered by username?"}
|
||||
)
|
||||
|
||||
assert isinstance(sql_response, KineticaSqlResponse)
|
||||
LOG.info(f"SQL Response: {sql_response.sql}")
|
||||
assert isinstance(sql_response.dataframe, pd.DataFrame)
|
||||
"""
|
||||
|
||||
kdbc: Any = Field(exclude=True)
|
||||
""" Kinetica DB connection. """
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> KineticaSqlResponse:
|
||||
df = self.kdbc.to_df(text)
|
||||
return KineticaSqlResponse(sql=text, dataframe=df)
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> KineticaSqlResponse:
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "kinetica_sql_output_parser"
|
||||
285
venv/Lib/site-packages/langchain_community/chat_models/konko.py
Normal file
285
venv/Lib/site-packages/langchain_community/chat_models/konko.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""KonkoAI chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from langchain_community.adapters.openai import (
|
||||
convert_message_to_dict,
|
||||
)
|
||||
from langchain_community.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_convert_delta_to_message_chunk,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
DEFAULT_API_BASE = "https://api.konko.ai/v1"
|
||||
DEFAULT_MODEL = "meta-llama/Llama-2-13b-chat-hf"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatKonko(ChatOpenAI):
|
||||
"""`ChatKonko` Chat large language models API.
|
||||
|
||||
To use, you should have the ``konko`` python package installed, and the
|
||||
environment variable ``KONKO_API_KEY`` and ``OPENAI_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the konko.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatKonko
|
||||
llm = ChatKonko(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"konko_api_key": "KONKO_API_KEY", "openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return False
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = Field(default=DEFAULT_MODEL, alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
openai_api_key: Optional[str] = None
|
||||
konko_api_key: Optional[str] = None
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: int = 20
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["konko_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
|
||||
)
|
||||
try:
|
||||
import konko
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import konko python package. "
|
||||
"Please install it with `pip install konko`."
|
||||
)
|
||||
try:
|
||||
if is_openai_v1():
|
||||
values["client"] = konko.chat.completions
|
||||
else:
|
||||
values["client"] = konko.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`konko` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the konko package. Try upgrading it "
|
||||
"with `pip install --upgrade konko`."
|
||||
)
|
||||
|
||||
if not hasattr(konko, "_is_legacy_openai"):
|
||||
warnings.warn(
|
||||
"You are using an older version of the 'konko' package. "
|
||||
"Please consider upgrading to access new features."
|
||||
)
|
||||
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Konko API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_available_models(
|
||||
konko_api_key: Union[str, SecretStr, None] = None,
|
||||
openai_api_key: Union[str, SecretStr, None] = None,
|
||||
konko_api_base: str = DEFAULT_API_BASE,
|
||||
) -> Set[str]:
|
||||
"""Get available models from Konko API."""
|
||||
|
||||
# Try to retrieve the OpenAI API key if it's not passed as an argument
|
||||
if not openai_api_key:
|
||||
try:
|
||||
openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"])
|
||||
except KeyError:
|
||||
pass # It's okay if it's not set, we just won't use it
|
||||
elif isinstance(openai_api_key, str):
|
||||
openai_api_key = convert_to_secret_str(openai_api_key)
|
||||
|
||||
# Try to retrieve the Konko API key if it's not passed as an argument
|
||||
if not konko_api_key:
|
||||
try:
|
||||
konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Konko API key must be passed as keyword argument or "
|
||||
"set in environment variable KONKO_API_KEY."
|
||||
)
|
||||
elif isinstance(konko_api_key, str):
|
||||
konko_api_key = convert_to_secret_str(konko_api_key)
|
||||
|
||||
models_url = f"{konko_api_base}/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {konko_api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
if openai_api_key:
|
||||
headers["X-OpenAI-Api-Key"] = cast(
|
||||
SecretStr, openai_api_key
|
||||
).get_secret_value()
|
||||
|
||||
models_response = requests.get(models_url, headers=headers)
|
||||
|
||||
if models_response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error getting models from {models_url}: {models_response.status_code}"
|
||||
)
|
||||
|
||||
return {model["id"] for model in models_response.json()["data"]}
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
finish_reason = choice.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the konko client."""
|
||||
return {**self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model,
|
||||
**super()._get_invocation_params(stop=stop),
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "konko-chat"
|
||||
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
Deprecated LiteLLM wrapper.
|
||||
|
||||
⭐ Use `pip install langchain-litellm` and import
|
||||
`from langchain_litellm import ChatLiteLLM` instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatLiteLLMException(Exception):
|
||||
"""Error with the `LiteLLM I/O` library"""
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatLiteLLM,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
import litellm
|
||||
|
||||
errors = [
|
||||
litellm.Timeout,
|
||||
litellm.APIError,
|
||||
litellm.APIConnectionError,
|
||||
litellm.RateLimitError,
|
||||
]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(_dict["function_call"])
|
||||
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatLiteLLM,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
return await llm.client.acreate(**kwargs)
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
elif _dict.get("reasoning_content"):
|
||||
additional_kwargs = {"reasoning_content": _dict["reasoning_content"]}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any] = {"content": message.content}
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict["role"] = message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict["role"] = "assistant"
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
]
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict["role"] = "system"
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict["role"] = "function"
|
||||
message_dict["name"] = message.name
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict["role"] = "tool"
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
_OPENAI_MODELS = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o-mini-2024-07-18",
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
]
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.24",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_litellm.ChatLiteLLM",
|
||||
)
|
||||
class ChatLiteLLM(BaseChatModel):
|
||||
"""DEPRECATED – use `langchain_litellm.ChatLiteLLM` instead."""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "gpt-3.5-turbo"
|
||||
model_name: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
openai_api_key: Optional[str] = None
|
||||
azure_api_key: Optional[str] = None
|
||||
anthropic_api_key: Optional[str] = None
|
||||
replicate_api_key: Optional[str] = None
|
||||
cohere_api_key: Optional[str] = None
|
||||
openrouter_api_key: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
streaming: bool = False
|
||||
api_base: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
custom_llm_provider: Optional[str] = None
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must be in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for API call not explicitly specified."""
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
n: Optional[int] = None
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
max_retries: int = 1
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
return {
|
||||
"model": set_model_value,
|
||||
"force_timeout": self.request_timeout,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"custom_llm_provider": self.custom_llm_provider,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used for the openai client."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
self.client.api_base = self.api_base
|
||||
self.client.api_key = self.api_key
|
||||
for named_api_key in [
|
||||
"openai_api_key",
|
||||
"azure_api_key",
|
||||
"anthropic_api_key",
|
||||
"replicate_api_key",
|
||||
"cohere_api_key",
|
||||
"openrouter_api_key",
|
||||
]:
|
||||
if api_key_value := getattr(self, named_api_key):
|
||||
setattr(
|
||||
self.client,
|
||||
named_api_key.replace("_api_key", "_key"),
|
||||
api_key_value,
|
||||
)
|
||||
self.client.organization = self.organization
|
||||
creds: Dict[str, Any] = {
|
||||
"model": set_model_value,
|
||||
"force_timeout": self.request_timeout,
|
||||
"api_base": self.api_base,
|
||||
}
|
||||
return {**self._default_params, **creds}
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.completion(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
raise ChatLiteLLMException(
|
||||
"Could not import litellm python package. "
|
||||
"Please install it with `pip install litellm`"
|
||||
)
|
||||
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY", default=""
|
||||
)
|
||||
values["azure_api_key"] = get_from_dict_or_env(
|
||||
values, "azure_api_key", "AZURE_API_KEY", default=""
|
||||
)
|
||||
values["anthropic_api_key"] = get_from_dict_or_env(
|
||||
values, "anthropic_api_key", "ANTHROPIC_API_KEY", default=""
|
||||
)
|
||||
values["replicate_api_key"] = get_from_dict_or_env(
|
||||
values, "replicate_api_key", "REPLICATE_API_KEY", default=""
|
||||
)
|
||||
values["openrouter_api_key"] = get_from_dict_or_env(
|
||||
values, "openrouter_api_key", "OPENROUTER_API_KEY", default=""
|
||||
)
|
||||
values["cohere_api_key"] = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY", default=""
|
||||
)
|
||||
values["huggingface_api_key"] = get_from_dict_or_env(
|
||||
values, "huggingface_api_key", "HUGGINGFACE_API_KEY", default=""
|
||||
)
|
||||
values["together_ai_api_key"] = get_from_dict_or_env(
|
||||
values, "together_ai_api_key", "TOGETHERAI_API_KEY", default=""
|
||||
)
|
||||
values["client"] = litellm
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_k"] is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
if isinstance(message, AIMessage):
|
||||
message.response_metadata = {
|
||||
"model_name": self.model_name or self.model
|
||||
}
|
||||
message.usage_metadata = _create_usage_metadata(token_usage)
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
)
|
||||
generations.append(gen)
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
llm_output = {"token_usage": token_usage, "model": set_model_value}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
added_model_name = False
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
usage = chunk.get("usage", {})
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
if isinstance(chunk, AIMessageChunk):
|
||||
if not added_model_name:
|
||||
chunk.response_metadata = {
|
||||
"model_name": self.model_name or self.model
|
||||
}
|
||||
added_model_name = True
|
||||
chunk.usage_metadata = _create_usage_metadata(usage)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
added_model_name = False
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
usage = chunk.get("usage", {})
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
if isinstance(chunk, AIMessageChunk):
|
||||
if not added_model_name:
|
||||
chunk.response_metadata = {
|
||||
"model_name": self.model_name or self.model
|
||||
}
|
||||
added_model_name = True
|
||||
chunk.usage_metadata = _create_usage_metadata(usage)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
LiteLLM expects tools argument in OpenAI format.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call. Options are:
|
||||
- str of the form ``"<<tool_name>>"``: calls <<tool_name>> tool.
|
||||
- ``"auto"``:
|
||||
automatically selects a tool (including no tool).
|
||||
- ``"none"``:
|
||||
does not call a tool.
|
||||
- ``"any"`` or ``"required"`` or ``True``:
|
||||
forces least one tool to be called.
|
||||
- dict of the form:
|
||||
``{"type": "function", "function": {"name": <<tool_name>>}}``
|
||||
- ``False`` or ``None``: no effect
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
|
||||
# In case of openai if tool_choice is `any` or if bool has been provided we
|
||||
# change it to `required` as that is suppored by openai.
|
||||
if (
|
||||
(self.model is not None and "azure" in self.model)
|
||||
or (self.model_name is not None and "azure" in self.model_name)
|
||||
or (self.model is not None and self.model in _OPENAI_MODELS)
|
||||
or (self.model_name is not None and self.model_name in _OPENAI_MODELS)
|
||||
) and (tool_choice == "any" or isinstance(tool_choice, bool)):
|
||||
tool_choice = "required"
|
||||
# If tool_choice is bool apart from openai we make it `any`
|
||||
elif isinstance(tool_choice, bool):
|
||||
tool_choice = "any"
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"] for formatted_tool in formatted_tools
|
||||
]
|
||||
if not any(
|
||||
tool_name == tool_choice["function"]["name"] for tool_name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
set_model_value = self.model
|
||||
if self.model_name is not None:
|
||||
set_model_value = self.model_name
|
||||
return {
|
||||
"model": set_model_value,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "litellm-chat"
|
||||
|
||||
|
||||
def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
|
||||
input_tokens = token_usage.get("prompt_tokens", 0)
|
||||
output_tokens = token_usage.get("completion_tokens", 0)
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
)
|
||||
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Deprecated LiteLLM wrapper.
|
||||
|
||||
⭐ Use `pip install langchain-litellm` and import
|
||||
`from langchain_litellm import ChatLiteLLMRouter` instead.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.chat_models.litellm import (
|
||||
ChatLiteLLM,
|
||||
_convert_delta_to_message_chunk,
|
||||
_convert_dict_to_message,
|
||||
)
|
||||
|
||||
token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
|
||||
model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password
|
||||
|
||||
|
||||
def get_llm_output(usage: Any, **params: Any) -> dict:
|
||||
"""Get llm output from usage and params."""
|
||||
llm_output = {token_usage_key_name: usage}
|
||||
# copy over metadata (metadata came from router completion call)
|
||||
metadata = params["metadata"]
|
||||
for key in metadata:
|
||||
if key not in llm_output:
|
||||
# if token usage in metadata, prefer metadata's copy of it
|
||||
llm_output[key] = metadata[key]
|
||||
return llm_output
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.24",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_litellm.ChatLiteLLMRouter",
|
||||
)
|
||||
class ChatLiteLLMRouter(ChatLiteLLM):
|
||||
"""DEPRECATED – use `langchain_litellm.ChatLiteLLMRouter` instead."""
|
||||
|
||||
router: Any
|
||||
|
||||
def __init__(self, *, router: Any, **kwargs: Any) -> None:
|
||||
"""Construct Chat LiteLLM Router."""
|
||||
super().__init__(router=router, **kwargs) # type: ignore[call-arg]
|
||||
self.router = router
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "LiteLLMRouter"
|
||||
|
||||
def _prepare_params_for_router(self, params: Any) -> None:
|
||||
# allow the router to set api_base based on its model choice
|
||||
api_base_key_name = "api_base"
|
||||
if api_base_key_name in params and params[api_base_key_name] is None:
|
||||
del params[api_base_key_name]
|
||||
|
||||
# add metadata so router can fill it below
|
||||
params.setdefault("metadata", {})
|
||||
|
||||
def set_default_model(self, model_name: str) -> None:
|
||||
"""Set the default model to use for completion calls.
|
||||
|
||||
Sets `self.model` to `model_name` if it is in the litellm router's
|
||||
(`self.router`) model list. This provides the default model to use
|
||||
for completion calls if no `model` kwarg is provided.
|
||||
"""
|
||||
model_list = self.router.model_list
|
||||
if not model_list:
|
||||
raise ValueError("model_list is None or empty.")
|
||||
for entry in model_list:
|
||||
if entry["model_name"] == model_name:
|
||||
self.model = model_name
|
||||
return
|
||||
raise ValueError(f"Model {model_name} not found in model_list.")
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
self._prepare_params_for_router(params)
|
||||
|
||||
response = self.router.completion(
|
||||
messages=message_dicts,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response, **params)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
self._prepare_params_for_router(params)
|
||||
|
||||
for chunk in self.router.completion(messages=message_dicts, **params):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
str(chunk.content), chunk=cg_chunk, **params
|
||||
)
|
||||
yield cg_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
self._prepare_params_for_router(params)
|
||||
|
||||
async for chunk in await self.router.acompletion(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
str(chunk.content), chunk=cg_chunk, **params
|
||||
)
|
||||
yield cg_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
self._prepare_params_for_router(params)
|
||||
|
||||
response = await self.router.acompletion(
|
||||
messages=message_dicts,
|
||||
**params,
|
||||
)
|
||||
return self._create_chat_result(response, **params)
|
||||
|
||||
# from
|
||||
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
|
||||
# but modified to handle LiteLLM Usage class
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
if token_usage is not None:
|
||||
# get dict from LiteLLM Usage class
|
||||
for k, v in token_usage.model_dump().items():
|
||||
if k in overall_token_usage and overall_token_usage[k] is not None:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
if system_fingerprint is None:
|
||||
system_fingerprint = output.get("system_fingerprint")
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
if system_fingerprint:
|
||||
combined["system_fingerprint"] = system_fingerprint
|
||||
return combined
|
||||
|
||||
def _create_chat_result(
|
||||
self, response: Mapping[str, Any], **params: Any
|
||||
) -> ChatResult:
|
||||
from litellm.utils import Usage
|
||||
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", Usage(prompt_tokens=0, total_tokens=0))
|
||||
llm_output = get_llm_output(token_usage, **params)
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
@@ -0,0 +1,241 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict.get("content", "") or "")
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||
else:
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
class LlamaEdgeChatService(BaseChatModel):
|
||||
"""Chat with LLMs via `llama-api-server`
|
||||
|
||||
For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge
|
||||
"""
|
||||
|
||||
request_timeout: int = 60
|
||||
"""request timeout for chat http requests"""
|
||||
service_url: Optional[str] = None
|
||||
"""URL of WasmChat service"""
|
||||
model: str = "NA"
|
||||
"""model name, default is `NA`."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error code: {res.status_code}, reason: {res.reason}")
|
||||
|
||||
response = res.json()
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
res = self._chat(messages, **kwargs)
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
substring = '"object":"chat.completion.chunk"}'
|
||||
for line in res.iter_lines():
|
||||
chunks = []
|
||||
if line:
|
||||
json_string = line.decode("utf-8")
|
||||
|
||||
# Find all positions of the substring
|
||||
positions = [m.start() for m in re.finditer(substring, json_string)]
|
||||
positions = [-1 * len(substring)] + positions
|
||||
|
||||
for i in range(len(positions) - 1):
|
||||
chunk = json.loads(
|
||||
json_string[
|
||||
positions[i] + len(substring) : positions[i + 1]
|
||||
+ len(substring)
|
||||
]
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
for chunk in chunks:
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
if (
|
||||
choice.get("finish_reason") is not None
|
||||
and choice.get("finish_reason") == "stop"
|
||||
):
|
||||
break
|
||||
finish_reason = choice.get("finish_reason")
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason)
|
||||
if finish_reason is not None
|
||||
else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
if self.service_url is None:
|
||||
res = requests.models.Response()
|
||||
res.status_code = 503
|
||||
res.reason = "The IP address or port of the chat service is incorrect."
|
||||
return res
|
||||
|
||||
service_url = f"{self.service_url}/v1/chat/completions"
|
||||
|
||||
if self.streaming:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"stream": self.streaming,
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
url=service_url,
|
||||
timeout=self.request_timeout,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
message = _convert_dict_to_message(response["choices"][0].get("message"))
|
||||
generations = [ChatGeneration(message=message)]
|
||||
|
||||
token_usage = response["usage"]
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "wasm-chat"
|
||||
@@ -0,0 +1,818 @@
|
||||
import json
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import InvalidToolCall, ToolCall, ToolCallChunk
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatLlamaCpp(BaseChatModel):
|
||||
"""llama.cpp model.
|
||||
|
||||
To use, you should have the llama-cpp-python library installed, and provide the
|
||||
path to the Llama model as a named parameter to the constructor.
|
||||
Check out: https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
model_path: str
|
||||
"""The path to the Llama model file."""
|
||||
|
||||
lora_base: Optional[str] = None
|
||||
"""The path to the Llama LoRA base model."""
|
||||
|
||||
lora_path: Optional[str] = None
|
||||
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
|
||||
|
||||
n_ctx: int = 512
|
||||
"""Token context window."""
|
||||
|
||||
n_parts: int = -1
|
||||
"""Number of parts to split the model into.
|
||||
If -1, the number of parts is automatically determined."""
|
||||
|
||||
seed: int = -1
|
||||
"""Seed. If -1, a random seed is used."""
|
||||
|
||||
f16_kv: bool = True
|
||||
"""Use half-precision for key/value cache."""
|
||||
|
||||
logits_all: bool = False
|
||||
"""Return logits for all tokens, not just the last token."""
|
||||
|
||||
vocab_only: bool = False
|
||||
"""Only load the vocabulary, no weights."""
|
||||
|
||||
use_mlock: bool = False
|
||||
"""Force system to keep model in RAM."""
|
||||
|
||||
n_threads: Optional[int] = None
|
||||
"""Number of threads to use.
|
||||
If None, the number of threads is automatically determined."""
|
||||
|
||||
n_batch: int = 8
|
||||
"""Number of tokens to process in parallel.
|
||||
Should be a number between 1 and n_ctx."""
|
||||
|
||||
n_gpu_layers: Optional[int] = None
|
||||
"""Number of layers to be loaded into gpu memory. Default None."""
|
||||
|
||||
suffix: Optional[str] = None
|
||||
"""A suffix to append to the generated text. If None, no suffix is appended."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
temperature: float = 0.8
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: float = 0.95
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
logprobs: Optional[int] = None
|
||||
"""The number of logprobs to return. If None, no logprobs are returned."""
|
||||
|
||||
echo: bool = False
|
||||
"""Whether to echo the prompt."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
"""A list of strings to stop generation when encountered."""
|
||||
|
||||
repeat_penalty: float = 1.1
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
top_k: int = 40
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
last_n_tokens_size: int = 64
|
||||
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||
|
||||
use_mmap: bool = True
|
||||
"""Whether to keep the model loaded in RAM"""
|
||||
|
||||
rope_freq_scale: float = 1.0
|
||||
"""Scale factor for rope sampling."""
|
||||
|
||||
rope_freq_base: float = 10000.0
|
||||
"""Base frequency for rope sampling."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Any additional parameters to pass to llama_cpp.Llama."""
|
||||
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
grammar_path: Optional[Union[str, Path]] = None
|
||||
"""
|
||||
grammar_path: Path to the .gbnf file that defines formal grammars
|
||||
for constraining model outputs. For instance, the grammar can be used
|
||||
to force the model to generate valid JSON or to speak exclusively in emojis. At most
|
||||
one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
grammar: Any = None
|
||||
"""
|
||||
grammar: formal grammar for constraining model outputs. For instance, the grammar
|
||||
can be used to force the model to generate valid JSON or to speak exclusively in
|
||||
emojis. At most one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
|
||||
verbose: bool = True
|
||||
"""Print verbose output to stderr."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
try:
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
|
||||
model_path = self.model_path
|
||||
model_param_names = [
|
||||
"rope_freq_scale",
|
||||
"rope_freq_base",
|
||||
"lora_path",
|
||||
"lora_base",
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"use_mmap",
|
||||
"last_n_tokens_size",
|
||||
"verbose",
|
||||
]
|
||||
model_params = {k: getattr(self, k) for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if self.n_gpu_layers is not None:
|
||||
model_params["n_gpu_layers"] = self.n_gpu_layers
|
||||
|
||||
model_params.update(self.model_kwargs)
|
||||
|
||||
try:
|
||||
self.client = Llama(model_path, **model_params)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
if self.grammar and self.grammar_path:
|
||||
grammar = self.grammar
|
||||
grammar_path = self.grammar_path
|
||||
raise ValueError(
|
||||
"Can only pass in one of grammar and grammar_path. Received "
|
||||
f"{grammar=} and {grammar_path=}."
|
||||
)
|
||||
elif isinstance(self.grammar, str):
|
||||
self.grammar = LlamaGrammar.from_string(self.grammar)
|
||||
elif self.grammar_path:
|
||||
self.grammar = LlamaGrammar.from_file(self.grammar_path)
|
||||
else:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _get_parameters(self, stop: Optional[List[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Performs sanity check, preparing parameters in format needed by llama_cpp.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the combined parameters.
|
||||
"""
|
||||
|
||||
params = self._default_params
|
||||
|
||||
# llama_cpp expects the "stop" key not this, so we remove it:
|
||||
stop_sequences = params.pop("stop_sequences")
|
||||
|
||||
# then sets it as configured, or default to an empty list:
|
||||
params["stop"] = stop or stop_sequences or self.stop or []
|
||||
|
||||
return params
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
|
||||
return message_dicts
|
||||
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(message=message, generation_info=generation_info)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
# "system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = {**self._get_parameters(stop), **kwargs}
|
||||
|
||||
# Check tool_choice is whether available, if yes then run no stream with tool
|
||||
# calling
|
||||
if self.streaming and not params.get("tool_choice"):
|
||||
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
|
||||
response = self.client.create_chat_completion(messages=message_dicts, **params)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = {**self._get_parameters(stop), **kwargs}
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
|
||||
result = self.client.create_chat_completion(
|
||||
messages=message_dicts, stream=True, **params
|
||||
)
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
count = 0
|
||||
for chunk in result:
|
||||
count += 1
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, bool, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model
|
||||
|
||||
tool_choice: does not currently support "any", "auto" choices like OpenAI
|
||||
tool-calling API. should be a dict of the form to force this tool
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
tool_names = [ft["function"]["name"] for ft in formatted_tools]
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, dict):
|
||||
if not any(
|
||||
tool_choice["function"]["name"] == name for name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice=} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
elif isinstance(tool_choice, str):
|
||||
chosen = [
|
||||
f for f in formatted_tools if f["function"]["name"] == tool_choice
|
||||
]
|
||||
if not chosen:
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice=} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
elif isinstance(tool_choice, bool):
|
||||
if len(formatted_tools) > 1:
|
||||
raise ValueError(
|
||||
"tool_choice=True can only be specified when a single tool is "
|
||||
f"passed in. Received {len(tools)} tools."
|
||||
)
|
||||
tool_choice = formatted_tools[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"""Unrecognized tool_choice type. Expected dict having format like
|
||||
this {"type": "function", "function": {"name": <<tool_name>>}}"""
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
|
||||
then the model output will be an object of that class. If a dict then
|
||||
the model output will be a dict. With a Pydantic class the returned
|
||||
attributes will be validated, whereas with a dict they will not be. If
|
||||
`method` is "function_calling" and `schema` is a dict, then the dict
|
||||
must match the OpenAI function-calling spec or be a valid JSON schema
|
||||
with top level 'title' and 'description' keys specified.
|
||||
include_raw: If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
kwargs: Any other args to bind to model, ``self.bind(..., **kwargs)``.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes any ChatModel input and returns as output:
|
||||
|
||||
If include_raw is True then a dict with keys:
|
||||
raw: BaseMessage
|
||||
parsed: Optional[_DictOrPydantic]
|
||||
parsing_error: Optional[BaseException]
|
||||
|
||||
If include_raw is False then just _DictOrPydantic is returned,
|
||||
where _DictOrPydantic depends on the schema:
|
||||
|
||||
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
|
||||
class.
|
||||
|
||||
If schema is a dict then _DictOrPydantic is a dict.
|
||||
|
||||
Example: Pydantic schema (include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = ChatLlamaCpp(
|
||||
temperature=0.,
|
||||
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
|
||||
n_ctx=10000,
|
||||
n_gpu_layers=4,
|
||||
n_batch=200,
|
||||
max_tokens=512,
|
||||
n_threads=multiprocessing.cpu_count() - 1,
|
||||
repeat_penalty=1.5,
|
||||
top_p=0.5,
|
||||
stop=["<|end_of_text|>", "<|eot_id|>"],
|
||||
)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
|
||||
# -> AnswerWithJustification(
|
||||
# answer='They weigh the same',
|
||||
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
|
||||
# )
|
||||
|
||||
Example: Pydantic schema (include_raw=True):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = ChatLlamaCpp(
|
||||
temperature=0.,
|
||||
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
|
||||
n_ctx=10000,
|
||||
n_gpu_layers=4,
|
||||
n_batch=200,
|
||||
max_tokens=512,
|
||||
n_threads=multiprocessing.cpu_count() - 1,
|
||||
repeat_penalty=1.5,
|
||||
top_p=0.5,
|
||||
stop=["<|end_of_text|>", "<|eot_id|>"],
|
||||
)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
|
||||
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
Example: dict schema (include_raw=False):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatLlamaCpp
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
class AnswerWithJustification(BaseModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
dict_schema = convert_to_openai_tool(AnswerWithJustification)
|
||||
llm = ChatLlamaCpp(
|
||||
temperature=0.,
|
||||
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
|
||||
n_ctx=10000,
|
||||
n_gpu_layers=4,
|
||||
n_batch=200,
|
||||
max_tokens=512,
|
||||
n_threads=multiprocessing.cpu_count() - 1,
|
||||
repeat_penalty=1.5,
|
||||
top_p=0.5,
|
||||
stop=["<|end_of_text|>", "<|eot_id|>"],
|
||||
)
|
||||
structured_llm = llm.with_structured_output(dict_schema)
|
||||
|
||||
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
|
||||
# -> {
|
||||
# 'answer': 'They weigh the same',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
tool_choice = {"type": "function", "function": {"name": tool_name}}
|
||||
llm = self.bind_tools([schema], tool_choice=tool_choice)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[cast(Type, schema)], first_tool_only=True
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary of identifying parameters.
|
||||
|
||||
This information is used by the LangChain callback system, which
|
||||
is used for tracing purposes make it possible to monitor LLMs.
|
||||
"""
|
||||
return {
|
||||
# The model name allows users to specify custom token counting
|
||||
# rules in LLM monitoring applications (e.g., in LangSmith users
|
||||
# can provide per token pricing for their model and monitor
|
||||
# costs for the given LLM.)
|
||||
**{"model_path": self.model_path},
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Get the type of language model used by this chat model."""
|
||||
return "llama-cpp-python"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling create_chat_completion."""
|
||||
params: Dict = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"logprobs": self.logprobs,
|
||||
"stop_sequences": self.stop, # key here is convention among LLM classes
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
}
|
||||
if self.grammar:
|
||||
params["grammar"] = self.grammar
|
||||
return params
|
||||
|
||||
|
||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": tool_call["id"],
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _lc_invalid_tool_call_to_openai_tool_call(
|
||||
invalid_tool_call: InvalidToolCall,
|
||||
) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"id": invalid_tool_call["id"],
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
Args:
|
||||
_dict: The dictionary.
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
name = _dict.get("name")
|
||||
id_ = _dict.get("id")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tc = parse_tool_call(raw_tool_call, return_id=True)
|
||||
except Exception as e:
|
||||
invalid_tc = make_invalid_tool_call(raw_tool_call, str(e))
|
||||
invalid_tool_calls.append(invalid_tc)
|
||||
else:
|
||||
if not tc:
|
||||
continue
|
||||
else:
|
||||
tool_calls.append(tc)
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
name=name,
|
||||
id=id_,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
|
||||
elif role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=cast(str, _dict.get("tool_call_id")),
|
||||
additional_kwargs=additional_kwargs,
|
||||
name=name,
|
||||
id=id_,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(
|
||||
content=_dict.get("content", ""), role=cast(str, role), id=id_
|
||||
)
|
||||
|
||||
|
||||
def _format_message_content(content: Any) -> Any:
|
||||
"""Format message content."""
|
||||
if content and isinstance(content, list):
|
||||
# Remove unexpected block types
|
||||
formatted_content = []
|
||||
for block in content:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and "type" in block
|
||||
and block["type"] == "tool_use"
|
||||
):
|
||||
continue
|
||||
else:
|
||||
formatted_content.append(block)
|
||||
else:
|
||||
formatted_content = content
|
||||
|
||||
return formatted_content
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any] = {
|
||||
"content": _format_message_content(message.content),
|
||||
}
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
||||
# populate role and additional message data
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict["role"] = message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict["role"] = "assistant"
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
] + [
|
||||
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
elif "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
tool_call_supported_props = {"id", "type", "function"}
|
||||
message_dict["tool_calls"] = [
|
||||
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
|
||||
for tool_call in message_dict["tool_calls"]
|
||||
]
|
||||
else:
|
||||
pass
|
||||
# If tool calls present, content null value should be None not empty string.
|
||||
if "function_call" in message_dict or "tool_calls" in message_dict:
|
||||
message_dict["content"] = message_dict["content"] or None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict["role"] = "system"
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict["role"] = "function"
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict["role"] = "tool"
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
|
||||
supported_props = {"content", "role", "tool_call_id"}
|
||||
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
id_ = _dict.get("id")
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
tool_call = ToolCallChunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
tool_call_chunks.append(tool_call)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content, id=id_)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id_,
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content, id=id_)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=id_
|
||||
)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_)
|
||||
else:
|
||||
return default_class(content=content, id=id_) # type: ignore[call-arg]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user