initial commit
This commit is contained in:
157
venv/Lib/site-packages/langchain_community/callbacks/__init__.py
Normal file
157
venv/Lib/site-packages/langchain_community/callbacks/__init__.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""**Callback handlers** allow listening to events in LangChain.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.callbacks.aim_callback import (
|
||||
AimCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.argilla_callback import (
|
||||
ArgillaCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.arize_callback import (
|
||||
ArizeCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.arthur_callback import (
|
||||
ArthurCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.clearml_callback import (
|
||||
ClearMLCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.comet_ml_callback import (
|
||||
CometCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.context_callback import (
|
||||
ContextCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.fiddler_callback import (
|
||||
FiddlerCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.flyte_callback import (
|
||||
FlyteCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.human import (
|
||||
HumanApprovalCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.infino_callback import (
|
||||
InfinoCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.labelstudio_callback import (
|
||||
LabelStudioCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.llmonitor_callback import (
|
||||
LLMonitorCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
wandb_tracing_enabled,
|
||||
)
|
||||
from langchain_community.callbacks.mlflow_callback import (
|
||||
MlflowCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.openai_info import (
|
||||
OpenAICallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.promptlayer_callback import (
|
||||
PromptLayerCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.sagemaker_callback import (
|
||||
SageMakerCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.streamlit import (
|
||||
LLMThoughtLabeler,
|
||||
StreamlitCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.trubrics_callback import (
|
||||
TrubricsCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.upstash_ratelimit_callback import (
|
||||
UpstashRatelimitError,
|
||||
UpstashRatelimitHandler, # noqa: F401
|
||||
)
|
||||
from langchain_community.callbacks.uptrain_callback import (
|
||||
UpTrainCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.wandb_callback import (
|
||||
WandbCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.whylabs_callback import (
|
||||
WhyLabsCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
_module_lookup = {
|
||||
"AimCallbackHandler": "langchain_community.callbacks.aim_callback",
|
||||
"ArgillaCallbackHandler": "langchain_community.callbacks.argilla_callback",
|
||||
"ArizeCallbackHandler": "langchain_community.callbacks.arize_callback",
|
||||
"ArthurCallbackHandler": "langchain_community.callbacks.arthur_callback",
|
||||
"ClearMLCallbackHandler": "langchain_community.callbacks.clearml_callback",
|
||||
"CometCallbackHandler": "langchain_community.callbacks.comet_ml_callback",
|
||||
"ContextCallbackHandler": "langchain_community.callbacks.context_callback",
|
||||
"FiddlerCallbackHandler": "langchain_community.callbacks.fiddler_callback",
|
||||
"FlyteCallbackHandler": "langchain_community.callbacks.flyte_callback",
|
||||
"HumanApprovalCallbackHandler": "langchain_community.callbacks.human",
|
||||
"InfinoCallbackHandler": "langchain_community.callbacks.infino_callback",
|
||||
"LLMThoughtLabeler": "langchain_community.callbacks.streamlit",
|
||||
"LLMonitorCallbackHandler": "langchain_community.callbacks.llmonitor_callback",
|
||||
"LabelStudioCallbackHandler": "langchain_community.callbacks.labelstudio_callback",
|
||||
"MlflowCallbackHandler": "langchain_community.callbacks.mlflow_callback",
|
||||
"OpenAICallbackHandler": "langchain_community.callbacks.openai_info",
|
||||
"PromptLayerCallbackHandler": "langchain_community.callbacks.promptlayer_callback",
|
||||
"SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback",
|
||||
"StreamlitCallbackHandler": "langchain_community.callbacks.streamlit",
|
||||
"TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback",
|
||||
"UpstashRatelimitError": "langchain_community.callbacks.upstash_ratelimit_callback",
|
||||
"UpstashRatelimitHandler": "langchain_community.callbacks.upstash_ratelimit_callback", # noqa
|
||||
"UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback",
|
||||
"WandbCallbackHandler": "langchain_community.callbacks.wandb_callback",
|
||||
"WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback",
|
||||
"get_openai_callback": "langchain_community.callbacks.manager",
|
||||
"wandb_tracing_enabled": "langchain_community.callbacks.manager",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"FiddlerCallbackHandler",
|
||||
"FlyteCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"StreamlitCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
"UpstashRatelimitError",
|
||||
"UpstashRatelimitHandler",
|
||||
"UpTrainCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"wandb_tracing_enabled",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,434 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
|
||||
def import_aim() -> Any:
|
||||
"""Import the aim python package and raise an error if it is not installed."""
|
||||
return guard_import("aim")
|
||||
|
||||
|
||||
class BaseMetadataCallbackHandler:
|
||||
"""Callback handler for the metadata and associated function states for callbacks.
|
||||
|
||||
Attributes:
|
||||
step (int): The current step.
|
||||
starts (int): The number of times the start method has been called.
|
||||
ends (int): The number of times the end method has been called.
|
||||
errors (int): The number of times the error method has been called.
|
||||
text_ctr (int): The number of times the text method has been called.
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
llm_starts (int): The number of times the llm start method has been called.
|
||||
llm_ends (int): The number of times the llm end method has been called.
|
||||
llm_streams (int): The number of times the text method has been called.
|
||||
tool_starts (int): The number of times the tool start method has been called.
|
||||
tool_ends (int): The number of times the tool end method has been called.
|
||||
agent_ends (int): The number of times the agent end method has been called.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"errors": self.errors,
|
||||
"text_ctr": self.text_ctr,
|
||||
"chain_starts": self.chain_starts,
|
||||
"chain_ends": self.chain_ends,
|
||||
"llm_starts": self.llm_starts,
|
||||
"llm_ends": self.llm_ends,
|
||||
"llm_streams": self.llm_streams,
|
||||
"tool_starts": self.tool_starts,
|
||||
"tool_ends": self.tool_ends,
|
||||
"agent_ends": self.agent_ends,
|
||||
}
|
||||
|
||||
def reset_callback_meta(self) -> None:
|
||||
"""Reset the callback metadata."""
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Aim.
|
||||
|
||||
Parameters:
|
||||
repo (:obj:`str`, optional): Aim repository path or Repo object to which
|
||||
Run object is bound. If skipped, default Repo is used.
|
||||
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
|
||||
'default' if not specified. Can be used later to query runs/sequences.
|
||||
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
|
||||
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
|
||||
to disable system metrics tracking.
|
||||
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
|
||||
params such as installed packages, git info, environment variables, etc.
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run
|
||||
and then logs the response to Aim.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
system_tracking_interval: Optional[int] = 10,
|
||||
log_system_params: bool = True,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
super().__init__()
|
||||
|
||||
aim = import_aim()
|
||||
self.repo = repo
|
||||
self.experiment_name = experiment_name
|
||||
self.system_tracking_interval = system_tracking_interval
|
||||
self.log_system_params = log_system_params
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
self.action_records: list = []
|
||||
|
||||
def setup(self, **kwargs: Any) -> None:
|
||||
aim = import_aim()
|
||||
|
||||
if not self._run:
|
||||
if self._run_hash:
|
||||
self._run = aim.Run(
|
||||
self._run_hash,
|
||||
repo=self.repo,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
)
|
||||
else:
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
self._run.set(key, value, strict=False)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
aim = import_aim()
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_llm_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
prompts_res = deepcopy(prompts)
|
||||
|
||||
self._run.track(
|
||||
[aim.Text(prompt) for prompt in prompts_res],
|
||||
name="on_llm_start",
|
||||
context=resp,
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_llm_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
response_res = deepcopy(response)
|
||||
|
||||
generated = [
|
||||
aim.Text(generation.text)
|
||||
for generations in response_res.generations
|
||||
for generation in generations
|
||||
]
|
||||
self._run.track(
|
||||
generated,
|
||||
name="on_llm_end",
|
||||
context=resp,
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_chain_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
inputs_res = deepcopy(inputs)
|
||||
|
||||
self._run.track(
|
||||
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_chain_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
outputs_res = deepcopy(outputs)
|
||||
|
||||
self._run.track(
|
||||
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {"action": "on_tool_start"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_tool_end"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = {"action": "on_agent_finish"}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
finish_res = deepcopy(finish)
|
||||
|
||||
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
|
||||
finish_res.return_values["output"], finish_res.log
|
||||
)
|
||||
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
aim = import_aim()
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = {
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
}
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
action_res = deepcopy(action)
|
||||
|
||||
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
|
||||
action_res.tool_input, action_res.log
|
||||
)
|
||||
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
repo: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
system_tracking_interval: Optional[int] = 10,
|
||||
log_system_params: bool = True,
|
||||
langchain_asset: Any = None,
|
||||
reset: bool = True,
|
||||
finish: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and reset the session.
|
||||
|
||||
Args:
|
||||
repo (:obj:`str`, optional): Aim repository path or Repo object to which
|
||||
Run object is bound. If skipped, default Repo is used.
|
||||
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
|
||||
'default' if not specified. Can be used later to query runs/sequences.
|
||||
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
|
||||
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
|
||||
to disable system metrics tracking.
|
||||
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
|
||||
params such as installed packages, git info, environment variables, etc.
|
||||
langchain_asset: The langchain asset to save.
|
||||
reset: Whether to reset the session.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if langchain_asset:
|
||||
try:
|
||||
for key, value in langchain_asset.dict().items():
|
||||
self._run.set(key, value, strict=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if finish or reset:
|
||||
self._run.close()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
aim = import_aim()
|
||||
self.repo = repo if repo else self.repo
|
||||
self.experiment_name = (
|
||||
experiment_name if experiment_name else self.experiment_name
|
||||
)
|
||||
self.system_tracking_interval = (
|
||||
system_tracking_interval
|
||||
if system_tracking_interval
|
||||
else self.system_tracking_interval
|
||||
)
|
||||
self.log_system_params = (
|
||||
log_system_params if log_system_params else self.log_system_params
|
||||
)
|
||||
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
self.action_records = []
|
||||
@@ -0,0 +1,349 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from packaging.version import parse
|
||||
|
||||
|
||||
class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs into Argilla.
|
||||
|
||||
Args:
|
||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||
exist in advance. If you need help on how to create a `FeedbackDataset` in
|
||||
Argilla, please visit
|
||||
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
|
||||
workspace_name: name of the workspace in Argilla where the specified
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||
default workspace will be used.
|
||||
api_url: URL of the Argilla Server that we want to use, and where the
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||
`ARGILLA_API_URL` environment variable or the default will be used.
|
||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `argilla` package is not installed.
|
||||
ConnectionError: if the connection to Argilla fails.
|
||||
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import ArgillaCallbackHandler
|
||||
>>> argilla_callback = ArgillaCallbackHandler(
|
||||
... dataset_name="my-dataset",
|
||||
... workspace_name="my-workspace",
|
||||
... api_url="http://localhost:6900",
|
||||
... api_key="argilla.apikey",
|
||||
... )
|
||||
>>> llm = OpenAI(
|
||||
... temperature=0,
|
||||
... callbacks=[argilla_callback],
|
||||
... verbose=True,
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> llm.generate([
|
||||
... "What is the best NLP-annotation tool out there? (no bias at all)",
|
||||
... ])
|
||||
"Argilla, no doubt about it."
|
||||
"""
|
||||
|
||||
REPO_URL: str = "https://github.com/argilla-io/argilla"
|
||||
ISSUES_URL: str = f"{REPO_URL}/issues"
|
||||
BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html"
|
||||
|
||||
DEFAULT_API_URL: str = "http://localhost:6900"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
workspace_name: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initializes the `ArgillaCallbackHandler`.
|
||||
|
||||
Args:
|
||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||
exist in advance. If you need help on how to create a `FeedbackDataset`
|
||||
in Argilla, please visit
|
||||
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
|
||||
workspace_name: name of the workspace in Argilla where the specified
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||
default workspace will be used.
|
||||
api_url: URL of the Argilla Server that we want to use, and where the
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||
`ARGILLA_API_URL` environment variable or the default will be used.
|
||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `argilla` package is not installed.
|
||||
ConnectionError: if the connection to Argilla fails.
|
||||
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
||||
try:
|
||||
import argilla as rg
|
||||
|
||||
self.ARGILLA_VERSION = rg.__version__
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the Argilla callback manager you need to have the `argilla` "
|
||||
"Python package installed. Please install it with `pip install argilla`"
|
||||
)
|
||||
|
||||
# Check whether the Argilla version is compatible
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
|
||||
raise ImportError(
|
||||
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
|
||||
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
|
||||
"upgrade `argilla` with `pip install --upgrade argilla`."
|
||||
)
|
||||
|
||||
# Show a warning message if Argilla will assume the default values will be used
|
||||
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
||||
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
|
||||
" default API URL in Argilla Quickstart."
|
||||
),
|
||||
)
|
||||
api_url = self.DEFAULT_API_URL
|
||||
|
||||
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
||||
self.DEFAULT_API_KEY = (
|
||||
"admin.apikey"
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
|
||||
else "owner.apikey"
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
||||
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
|
||||
" default API key in Argilla Quickstart."
|
||||
),
|
||||
)
|
||||
api_key = self.DEFAULT_API_KEY
|
||||
|
||||
# Connect to Argilla with the provided credentials, if applicable
|
||||
try:
|
||||
rg.init(api_key=api_key, api_url=api_url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Argilla with exception: '{e}'.\n"
|
||||
"Please check your `api_key` and `api_url`, and make sure that "
|
||||
"the Argilla server is up and running. If the problem persists "
|
||||
f"please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
# Set the Argilla variables
|
||||
self.dataset_name = dataset_name
|
||||
self.workspace_name = workspace_name or rg.get_workspace()
|
||||
|
||||
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
||||
try:
|
||||
extra_args = {}
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
warnings.warn(
|
||||
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
|
||||
" higher is recommended.",
|
||||
UserWarning,
|
||||
)
|
||||
extra_args = {"with_records": False}
|
||||
self.dataset = rg.FeedbackDataset.from_argilla(
|
||||
name=self.dataset_name,
|
||||
workspace=self.workspace_name,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
|
||||
f"\nPlease check that the dataset with name={self.dataset_name} in the"
|
||||
f" workspace={self.workspace_name} exists in advance. If you need help"
|
||||
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
||||
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
|
||||
f" please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
supported_fields = ["prompt", "response"]
|
||||
if supported_fields != [field.name for field in self.dataset.fields]:
|
||||
raise ValueError(
|
||||
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
|
||||
f"{self.workspace_name} had fields that are not supported yet for the"
|
||||
f"`langchain` integration. Supported fields are: {supported_fields},"
|
||||
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
|
||||
" For more information on how to create a `langchain`-compatible"
|
||||
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
|
||||
)
|
||||
|
||||
self.prompts: Dict[str, List[str]] = {}
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
|
||||
" change based on updates to `langchain`. Please report any issues to"
|
||||
f" {self.ISSUES_URL} as an `integration` issue."
|
||||
),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log records to Argilla when an LLM ends."""
|
||||
# Do nothing if there's a parent_run_id, since we will log the records when
|
||||
# the chain ends
|
||||
if kwargs["parent_run_id"]:
|
||||
return
|
||||
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
prompts = self.prompts[str(kwargs["run_id"])]
|
||||
for prompt, generations in zip(prompts, response.generations):
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": prompt,
|
||||
"response": generation.text.strip(),
|
||||
},
|
||||
}
|
||||
for generation in generations
|
||||
]
|
||||
)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
|
||||
either the `parent_run_id` or the `run_id` as the key. This is done so that
|
||||
we don't log the same input prompt twice, once when the LLM starts and once
|
||||
when the chain starts.
|
||||
"""
|
||||
if "input" in inputs:
|
||||
self.prompts.update(
|
||||
{
|
||||
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
|
||||
inputs["input"]
|
||||
if isinstance(inputs["input"], list)
|
||||
else [inputs["input"]]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
|
||||
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
|
||||
differs if the output is a list or not.
|
||||
"""
|
||||
if not any(
|
||||
key in self.prompts
|
||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||
):
|
||||
return
|
||||
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
|
||||
List, self.prompts.get(str(kwargs["run_id"]), [])
|
||||
)
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, list):
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": prompt,
|
||||
"response": output["text"].strip(),
|
||||
},
|
||||
}
|
||||
for prompt, output in zip(prompts, chain_output_val)
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": " ".join(prompts),
|
||||
"response": chain_output_val.strip(),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
if str(kwargs["parent_run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["parent_run_id"]))
|
||||
if str(kwargs["run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
@@ -0,0 +1,213 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import import_pandas
|
||||
|
||||
|
||||
class ArizeCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Arize."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
SPACE_KEY: Optional[str] = None,
|
||||
API_KEY: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.space_key = SPACE_KEY
|
||||
self.api_key = API_KEY
|
||||
self.prompt_records: List[str] = []
|
||||
self.response_records: List[str] = []
|
||||
self.prediction_ids: List[str] = []
|
||||
self.pred_timestamps: List[int] = []
|
||||
self.response_embeddings: List[float] = []
|
||||
self.prompt_embeddings: List[float] = []
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.step = 0
|
||||
|
||||
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
||||
from arize.pandas.logger import Client
|
||||
|
||||
self.generator = EmbeddingGenerator.from_use_case(
|
||||
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
||||
model_name="distilbert-base-uncased",
|
||||
tokenizer_max_length=512,
|
||||
batch_size=256,
|
||||
)
|
||||
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
|
||||
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
|
||||
raise ValueError("❌ CHANGE SPACE AND API KEYS")
|
||||
else:
|
||||
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
for prompt in prompts:
|
||||
self.prompt_records.append(prompt.replace("\n", ""))
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
pd = import_pandas()
|
||||
from arize.utils.types import (
|
||||
EmbeddingColumnNames,
|
||||
Environments,
|
||||
ModelTypes,
|
||||
Schema,
|
||||
)
|
||||
|
||||
# Safe check if 'llm_output' and 'token_usage' exist
|
||||
if response.llm_output and "token_usage" in response.llm_output:
|
||||
self.prompt_tokens = response.llm_output["token_usage"].get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
self.total_tokens = response.llm_output["token_usage"].get(
|
||||
"total_tokens", 0
|
||||
)
|
||||
self.completion_tokens = response.llm_output["token_usage"].get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
else:
|
||||
self.prompt_tokens = self.total_tokens = self.completion_tokens = (
|
||||
0 # assign default value
|
||||
)
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
prompt = self.prompt_records[self.step]
|
||||
self.step = self.step + 1
|
||||
prompt_embedding = pd.Series(
|
||||
self.generator.generate_embeddings(
|
||||
text_col=pd.Series(prompt.replace("\n", " "))
|
||||
).reset_index(drop=True)
|
||||
)
|
||||
|
||||
# Assigning text to response_text instead of response
|
||||
response_text = generation.text.replace("\n", " ")
|
||||
response_embedding = pd.Series(
|
||||
self.generator.generate_embeddings(
|
||||
text_col=pd.Series(generation.text.replace("\n", " "))
|
||||
).reset_index(drop=True)
|
||||
)
|
||||
pred_timestamp = datetime.now().timestamp()
|
||||
|
||||
# Define the columns and data
|
||||
columns = [
|
||||
"prediction_ts",
|
||||
"response",
|
||||
"prompt",
|
||||
"response_vector",
|
||||
"prompt_vector",
|
||||
"prompt_token",
|
||||
"completion_token",
|
||||
"total_token",
|
||||
]
|
||||
data = [
|
||||
[
|
||||
pred_timestamp,
|
||||
response_text,
|
||||
prompt,
|
||||
response_embedding[0],
|
||||
prompt_embedding[0],
|
||||
self.prompt_tokens,
|
||||
self.total_tokens,
|
||||
self.completion_tokens,
|
||||
]
|
||||
]
|
||||
|
||||
# Create the DataFrame
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
# Declare prompt and response columns
|
||||
prompt_columns = EmbeddingColumnNames(
|
||||
vector_column_name="prompt_vector", data_column_name="prompt"
|
||||
)
|
||||
|
||||
response_columns = EmbeddingColumnNames(
|
||||
vector_column_name="response_vector", data_column_name="response"
|
||||
)
|
||||
|
||||
schema = Schema(
|
||||
timestamp_column_name="prediction_ts",
|
||||
tag_column_names=[
|
||||
"prompt_token",
|
||||
"completion_token",
|
||||
"total_token",
|
||||
],
|
||||
prompt_column_names=prompt_columns,
|
||||
response_column_names=response_columns,
|
||||
)
|
||||
|
||||
response_from_arize = self.arize_client.log(
|
||||
dataframe=df,
|
||||
schema=schema,
|
||||
model_id=self.model_id,
|
||||
model_version=self.model_version,
|
||||
model_type=ModelTypes.GENERATIVE_LLM,
|
||||
environment=Environments.PRODUCTION,
|
||||
)
|
||||
if response_from_arize.status_code == 200:
|
||||
print("✅ Successfully logged data to Arize!") # noqa: T201
|
||||
else:
|
||||
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,297 @@
|
||||
"""ArthurAI's Callback Handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import arthurai
|
||||
from arthurai.core.models import ArthurModel
|
||||
|
||||
PROMPT_TOKENS = "prompt_tokens"
|
||||
COMPLETION_TOKENS = "completion_tokens"
|
||||
TOKEN_USAGE = "token_usage"
|
||||
FINISH_REASON = "finish_reason"
|
||||
DURATION = "duration"
|
||||
|
||||
|
||||
def _lazy_load_arthur() -> arthurai:
|
||||
"""Lazy load Arthur."""
|
||||
try:
|
||||
import arthurai
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"To use the ArthurCallbackHandler you need the"
|
||||
" `arthurai` package. Please install it with"
|
||||
" `pip install arthurai`.",
|
||||
e,
|
||||
)
|
||||
|
||||
return arthurai
|
||||
|
||||
|
||||
class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Arthur platform.
|
||||
|
||||
Arthur helps enterprise teams optimize model operations
|
||||
and performance at scale. The Arthur API tracks model
|
||||
performance, explainability, and fairness across tabular,
|
||||
NLP, and CV models. Our API is model- and platform-agnostic,
|
||||
and continuously scales with complex and dynamic enterprise needs.
|
||||
To learn more about Arthur, visit our website at
|
||||
https://www.arthur.ai/ or read the Arthur docs at
|
||||
https://docs.arthur.ai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arthur_model: ArthurModel,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
super().__init__()
|
||||
arthurai = _lazy_load_arthur()
|
||||
Stage = arthurai.common.constants.Stage
|
||||
ValueType = arthurai.common.constants.ValueType
|
||||
self.arthur_model = arthur_model
|
||||
# save the attributes of this model to be used when preparing
|
||||
# inferences to log to Arthur in on_llm_end()
|
||||
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
|
||||
self.input_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.stage == Stage.ModelPipelineInput
|
||||
and x.value_type == ValueType.Unstructured_Text
|
||||
][0].name
|
||||
self.output_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.stage == Stage.PredictedValue
|
||||
and x.value_type == ValueType.Unstructured_Text
|
||||
][0].name
|
||||
self.token_likelihood_attr = None
|
||||
if (
|
||||
len(
|
||||
[
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.value_type == ValueType.TokenLikelihoods
|
||||
]
|
||||
)
|
||||
> 0
|
||||
):
|
||||
self.token_likelihood_attr = [
|
||||
x
|
||||
for x in self.arthur_model.get_attributes()
|
||||
if x.value_type == ValueType.TokenLikelihoods
|
||||
][0].name
|
||||
|
||||
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
|
||||
|
||||
@classmethod
|
||||
def from_credentials(
|
||||
cls,
|
||||
model_id: str,
|
||||
arthur_url: Optional[str] = "https://app.arthur.ai",
|
||||
arthur_login: Optional[str] = None,
|
||||
arthur_password: Optional[str] = None,
|
||||
) -> ArthurCallbackHandler:
|
||||
"""Initialize callback handler from Arthur credentials.
|
||||
|
||||
Args:
|
||||
model_id (str): The ID of the arthur model to log to.
|
||||
arthur_url (str, optional): The URL of the Arthur instance to log to.
|
||||
Defaults to "https://app.arthur.ai".
|
||||
arthur_login (str, optional): The login to use to connect to Arthur.
|
||||
Defaults to None.
|
||||
arthur_password (str, optional): The password to use to connect to
|
||||
Arthur. Defaults to None.
|
||||
|
||||
Returns:
|
||||
ArthurCallbackHandler: The initialized callback handler.
|
||||
"""
|
||||
arthurai = _lazy_load_arthur()
|
||||
ArthurAI = arthurai.ArthurAI
|
||||
ResponseClientError = arthurai.common.exceptions.ResponseClientError
|
||||
|
||||
# connect to Arthur
|
||||
if arthur_login is None:
|
||||
try:
|
||||
arthur_api_key = os.environ["ARTHUR_API_KEY"]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"No Arthur authentication provided. Either give"
|
||||
" a login to the ArthurCallbackHandler"
|
||||
" or set an ARTHUR_API_KEY as an environment variable."
|
||||
)
|
||||
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
|
||||
else:
|
||||
if arthur_password is None:
|
||||
arthur = ArthurAI(url=arthur_url, login=arthur_login)
|
||||
else:
|
||||
arthur = ArthurAI(
|
||||
url=arthur_url, login=arthur_login, password=arthur_password
|
||||
)
|
||||
# get model from Arthur by the provided model ID
|
||||
try:
|
||||
arthur_model = arthur.get_model(model_id)
|
||||
except ResponseClientError:
|
||||
raise ValueError(
|
||||
f"Was unable to retrieve model with id {model_id} from Arthur."
|
||||
" Make sure the ID corresponds to a model that is currently"
|
||||
" registered with your Arthur account."
|
||||
)
|
||||
return cls(arthur_model)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""On LLM start, save the input prompts"""
|
||||
run_id = kwargs["run_id"]
|
||||
self.run_map[run_id]["input_texts"] = prompts
|
||||
self.run_map[run_id]["start_time"] = time()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""On LLM end, send data to Arthur."""
|
||||
try:
|
||||
import pytz
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import pytz. Please install it with 'pip install pytz'."
|
||||
) from e
|
||||
|
||||
run_id = kwargs["run_id"]
|
||||
|
||||
# get the run params from this run ID,
|
||||
# or raise an error if this run ID has no corresponding metadata in self.run_map
|
||||
try:
|
||||
run_map_data = self.run_map[run_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
"This function has been called with a run_id"
|
||||
" that was never registered in on_llm_start()."
|
||||
" Restart and try running the LLM again"
|
||||
) from e
|
||||
|
||||
# mark the duration time between on_llm_start() and on_llm_end()
|
||||
time_from_start_to_end = time() - run_map_data["start_time"]
|
||||
|
||||
# create inferences to log to Arthur
|
||||
inferences = []
|
||||
for i, generations in enumerate(response.generations):
|
||||
for generation in generations:
|
||||
inference = {
|
||||
"partner_inference_id": str(uuid.uuid4()),
|
||||
"inference_timestamp": datetime.now(tz=pytz.UTC),
|
||||
self.input_attr: run_map_data["input_texts"][i],
|
||||
self.output_attr: generation.text,
|
||||
}
|
||||
|
||||
if generation.generation_info is not None:
|
||||
# add finish reason to the inference
|
||||
# if generation info contains a finish reason and
|
||||
# if the ArthurModel was registered to monitor finish_reason
|
||||
if (
|
||||
FINISH_REASON in generation.generation_info
|
||||
and FINISH_REASON in self.attr_names
|
||||
):
|
||||
inference[FINISH_REASON] = generation.generation_info[
|
||||
FINISH_REASON
|
||||
]
|
||||
|
||||
# add token likelihoods data to the inference if the ArthurModel
|
||||
# was registered to monitor token likelihoods
|
||||
logprobs_data = generation.generation_info["logprobs"]
|
||||
if (
|
||||
logprobs_data is not None
|
||||
and self.token_likelihood_attr is not None
|
||||
):
|
||||
logprobs = logprobs_data["top_logprobs"]
|
||||
likelihoods = [
|
||||
{k: np.exp(v) for k, v in logprobs[i].items()}
|
||||
for i in range(len(logprobs))
|
||||
]
|
||||
inference[self.token_likelihood_attr] = likelihoods
|
||||
|
||||
# add token usage counts to the inference if the
|
||||
# ArthurModel was registered to monitor token usage
|
||||
if (
|
||||
isinstance(response.llm_output, dict)
|
||||
and TOKEN_USAGE in response.llm_output
|
||||
):
|
||||
token_usage = response.llm_output[TOKEN_USAGE]
|
||||
if (
|
||||
PROMPT_TOKENS in token_usage
|
||||
and PROMPT_TOKENS in self.attr_names
|
||||
):
|
||||
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
|
||||
if (
|
||||
COMPLETION_TOKENS in token_usage
|
||||
and COMPLETION_TOKENS in self.attr_names
|
||||
):
|
||||
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
|
||||
|
||||
# add inference duration to the inference if the ArthurModel
|
||||
# was registered to monitor inference duration
|
||||
if DURATION in self.attr_names:
|
||||
inference[DURATION] = time_from_start_to_end
|
||||
|
||||
inferences.append(inference)
|
||||
|
||||
# send inferences to arthur
|
||||
self.arthur_model.send_inferences(inferences)
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""On chain start, do nothing."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""On chain end, do nothing."""
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""On new token, pass."""
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
@@ -0,0 +1,135 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_INPUT_TOKENS = {
|
||||
"anthropic.claude-instant-v1": 0.0008,
|
||||
"anthropic.claude-v2": 0.008,
|
||||
"anthropic.claude-v2:1": 0.008,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.003,
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.003,
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.003,
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": 0.003,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 0.015,
|
||||
"anthropic.claude-opus-4-20250514-v1:0": 0.015,
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.0008,
|
||||
}
|
||||
|
||||
MODEL_COST_PER_1K_OUTPUT_TOKENS = {
|
||||
"anthropic.claude-instant-v1": 0.0024,
|
||||
"anthropic.claude-v2": 0.024,
|
||||
"anthropic.claude-v2:1": 0.024,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.015,
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.015,
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.015,
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0": 0.015,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 0.075,
|
||||
"anthropic.claude-opus-4-20250514-v1:0": 0.075,
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.004,
|
||||
}
|
||||
|
||||
|
||||
def _get_anthropic_claude_token_cost(
|
||||
prompt_tokens: int, completion_tokens: int, model_id: Union[str, None]
|
||||
) -> float:
|
||||
if model_id:
|
||||
# The model ID can be a cross-region (system-defined) inference profile ID,
|
||||
# which has a prefix indicating the region (e.g., 'us', 'eu') but
|
||||
# shares the same token costs as the "base model".
|
||||
# By extracting the "base model ID", by taking the last two segments
|
||||
# of the model ID, we can map cross-region inference profile IDs to
|
||||
# their corresponding cost entries.
|
||||
base_model_id = model_id.split(".")[-2] + "." + model_id.split(".")[-1]
|
||||
else:
|
||||
base_model_id = None
|
||||
"""Get the cost of tokens for the Claude model."""
|
||||
if base_model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_id}. Please provide a valid Anthropic model name."
|
||||
"Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())
|
||||
)
|
||||
return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[base_model_id] + (
|
||||
completion_tokens / 1000
|
||||
) * MODEL_COST_PER_1K_OUTPUT_TOKENS[base_model_id]
|
||||
|
||||
|
||||
class BedrockAnthropicTokenUsageCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks bedrock anthropic info."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Print out the token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
|
||||
if "usage" not in response.llm_output:
|
||||
with self._lock:
|
||||
self.successful_requests += 1
|
||||
return None
|
||||
|
||||
# compute tokens and cost for this request
|
||||
token_usage = response.llm_output["usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
total_tokens = token_usage.get("total_tokens", 0)
|
||||
model_id = response.llm_output.get("model_id", None)
|
||||
total_cost = _get_anthropic_claude_token_cost(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
# update shared state behind lock
|
||||
with self._lock:
|
||||
self.total_cost += total_cost
|
||||
self.total_tokens += total_tokens
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "BedrockAnthropicTokenUsageCallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "BedrockAnthropicTokenUsageCallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
||||
@@ -0,0 +1,518 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
load_json,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def import_clearml() -> Any:
|
||||
"""Import the clearml python package and raise an error if it is not installed."""
|
||||
return guard_import("clearml")
|
||||
|
||||
|
||||
class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to ClearML.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of clearml task such as "inference", "testing" or "qc"
|
||||
project_name (str): The clearml project name
|
||||
tags (list): Tags to add to the task
|
||||
task_name (str): Name of the clearml task
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics
|
||||
stream_logs (bool): Whether to stream callback actions to ClearML
|
||||
|
||||
This handler will utilize the associated callback method and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to the ClearML console.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: Optional[str] = "inference",
|
||||
project_name: Optional[str] = "langchain_callback_demo",
|
||||
tags: Optional[Sequence] = None,
|
||||
task_name: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
complexity_metrics: bool = False,
|
||||
stream_logs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
clearml = import_clearml()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
self.project_name = project_name
|
||||
self.tags = tags
|
||||
self.task_name = task_name
|
||||
self.visualize = visualize
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
# Check if ClearML task already exists (e.g. in pipeline)
|
||||
if clearml.Task.current_task():
|
||||
self.task = clearml.Task.current_task()
|
||||
else:
|
||||
self.task = clearml.Task.init(
|
||||
task_type=self.task_type,
|
||||
project_name=self.project_name,
|
||||
tags=self.tags,
|
||||
task_name=self.task_name,
|
||||
output_uri=True,
|
||||
)
|
||||
self.logger = self.task.get_logger()
|
||||
warning = (
|
||||
"The clearml callback is currently in beta and is subject to change "
|
||||
"based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
|
||||
)
|
||||
self.logger.report_text(warning, level=30, print_console=True)
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.visualize = visualize
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(prompt_resp)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_llm_token_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(self.analyze_text(generation.text))
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
self.action_records.append(generation_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs.get("input", inputs.get("human_input"))
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["input"] = chain_input
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(input_resp)
|
||||
elif isinstance(chain_input, list):
|
||||
for inp in chain_input:
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp.update(inp)
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(input_resp)
|
||||
else:
|
||||
raise ValueError("Unexpected data format provided!")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_chain_end",
|
||||
"outputs": outputs.get("output", outputs.get("text")),
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_start_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_text_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_agent_finish_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
self.on_agent_action_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(resp)
|
||||
|
||||
def analyze_text(self, text: str) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
|
||||
Returns:
|
||||
`dict` containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
spacy = import_spacy()
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(
|
||||
text
|
||||
),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(
|
||||
text
|
||||
),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if self.visualize and self.nlp and self.temp_dir.name is not None:
|
||||
doc = self.nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
dep_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
|
||||
)
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
ent_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
|
||||
)
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
self.logger.report_media(
|
||||
"Dependencies Plot", text, local_path=dep_output_path
|
||||
)
|
||||
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
|
||||
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def _build_llm_df(
|
||||
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
|
||||
) -> pd.DataFrame:
|
||||
base_df_fields = [field for field in base_df_fields if field in base_df]
|
||||
rename_map = {
|
||||
map_entry_k: map_entry_v
|
||||
for map_entry_k, map_entry_v in rename_map.items()
|
||||
if map_entry_k in base_df_fields
|
||||
}
|
||||
llm_df = base_df[base_df_fields].dropna(axis=1)
|
||||
if rename_map:
|
||||
llm_df = llm_df.rename(rename_map, axis=1)
|
||||
return llm_df
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
|
||||
base_df=on_llm_end_records_df,
|
||||
base_df_fields=["step", "prompts"]
|
||||
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
|
||||
rename_map={"step": "prompt_step"},
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns: List = []
|
||||
|
||||
if self.complexity_metrics:
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
"text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
|
||||
on_llm_end_records_df,
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns,
|
||||
{"step": "output_step", "text": "output"},
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
langchain_asset: Any = None,
|
||||
finish: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and setup the session.
|
||||
|
||||
Everything after this will be a new table.
|
||||
|
||||
Args:
|
||||
name: Name of the performed session so far so it is identifiable
|
||||
langchain_asset: The langchain asset to save.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
clearml = import_clearml()
|
||||
|
||||
# Log the action records
|
||||
self.logger.report_table(
|
||||
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
|
||||
)
|
||||
|
||||
# Session analysis
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
self.logger.report_table(
|
||||
"Session Analysis", name, table_plot=session_analysis_df
|
||||
)
|
||||
|
||||
if self.stream_logs:
|
||||
self.logger.report_text(
|
||||
{
|
||||
"action_records": pd.DataFrame(self.action_records),
|
||||
"session_analysis": session_analysis_df,
|
||||
}
|
||||
)
|
||||
|
||||
if langchain_asset:
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
# Create output model and connect it to the task
|
||||
output_model = clearml.OutputModel(
|
||||
task=self.task, config_text=load_json(langchain_asset_path)
|
||||
)
|
||||
output_model.update_weights(
|
||||
weights_filename=str(langchain_asset_path),
|
||||
auto_delete_file=False,
|
||||
target_filename=name,
|
||||
)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
output_model = clearml.OutputModel(
|
||||
task=self.task, config_text=load_json(langchain_asset_path)
|
||||
)
|
||||
output_model.update_weights(
|
||||
weights_filename=str(langchain_asset_path),
|
||||
auto_delete_file=False,
|
||||
target_filename=name,
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.") # noqa: T201
|
||||
print(repr(e)) # noqa: T201
|
||||
pass
|
||||
|
||||
# Cleanup after adding everything to ClearML
|
||||
self.task.flush(wait_for_uploads=True)
|
||||
self.temp_dir.cleanup()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.reset_callback_meta()
|
||||
|
||||
if finish:
|
||||
self.task.close()
|
||||
@@ -0,0 +1,639 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
import langchain_community
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
LANGCHAIN_MODEL_NAME = "langchain-model"
|
||||
|
||||
|
||||
def import_comet_ml() -> Any:
|
||||
"""Import comet_ml and raise an error if it is not installed."""
|
||||
return guard_import("comet_ml")
|
||||
|
||||
|
||||
def _get_experiment(
|
||||
workspace: Optional[str] = None, project_name: Optional[str] = None
|
||||
) -> Any:
|
||||
comet_ml = import_comet_ml()
|
||||
|
||||
experiment = comet_ml.Experiment(
|
||||
workspace=workspace,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
||||
return experiment
|
||||
|
||||
|
||||
def _fetch_text_complexity_metrics(text: str) -> dict:
|
||||
textstat = import_textstat()
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
return text_complexity_metrics
|
||||
|
||||
|
||||
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
|
||||
pd = import_pandas()
|
||||
metrics_df = pd.DataFrame(metrics)
|
||||
metrics_summary = metrics_df.describe()
|
||||
|
||||
return metrics_summary.to_dict()
|
||||
|
||||
|
||||
class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Comet.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of comet_ml task such as "inference",
|
||||
"testing" or "qc"
|
||||
project_name (str): The comet_ml project name
|
||||
tags (list): Tags to add to the task
|
||||
task_name (str): Name of the comet_ml task
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics
|
||||
stream_logs (bool): Whether to stream callback actions to Comet
|
||||
|
||||
This handler will utilize the associated callback method and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to Comet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: Optional[str] = "inference",
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
stream_logs: bool = True,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
self.comet_ml = import_comet_ml()
|
||||
super().__init__()
|
||||
|
||||
self.task_type = task_type
|
||||
self.workspace = workspace
|
||||
self.project_name = project_name
|
||||
self.tags = tags
|
||||
self.visualizations = visualizations
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.custom_metrics = custom_metrics
|
||||
self.stream_logs = stream_logs
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
self.experiment = _get_experiment(workspace, project_name)
|
||||
self.experiment.log_other("Created from", "langchain")
|
||||
if tags:
|
||||
self.experiment.add_tags(tags)
|
||||
self.name = name
|
||||
if self.name:
|
||||
self.experiment.set_name(self.name)
|
||||
|
||||
warning = (
|
||||
"The comet_ml callback is currently in beta and is subject to change "
|
||||
"based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/comet-ml/issue-tracking/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
self.comet_ml.LOGGER.warning(warning)
|
||||
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
if self.visualizations:
|
||||
spacy = import_spacy()
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
else:
|
||||
self.nlp = None
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
metadata = self._init_resp()
|
||||
metadata.update({"action": "on_llm_start"})
|
||||
metadata.update(flatten_dict(serialized))
|
||||
metadata.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(metadata)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
|
||||
if self.stream_logs:
|
||||
self._log_stream(prompt, metadata, self.step)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
metadata = self._init_resp()
|
||||
metadata.update({"action": "on_llm_end"})
|
||||
metadata.update(flatten_dict(response.llm_output or {}))
|
||||
metadata.update(self.get_custom_callback_meta())
|
||||
|
||||
output_complexity_metrics = []
|
||||
output_custom_metrics = []
|
||||
|
||||
for prompt_idx, generations in enumerate(response.generations):
|
||||
for gen_idx, generation in enumerate(generations):
|
||||
text = generation.text
|
||||
|
||||
generation_resp = deepcopy(metadata)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
|
||||
complexity_metrics = self._get_complexity_metrics(text)
|
||||
if complexity_metrics:
|
||||
output_complexity_metrics.append(complexity_metrics)
|
||||
generation_resp.update(complexity_metrics)
|
||||
|
||||
custom_metrics = self._get_custom_metrics(
|
||||
generation, prompt_idx, gen_idx
|
||||
)
|
||||
if custom_metrics:
|
||||
output_custom_metrics.append(custom_metrics)
|
||||
generation_resp.update(custom_metrics)
|
||||
|
||||
if self.stream_logs:
|
||||
self._log_stream(text, metadata, self.step)
|
||||
|
||||
self.action_records.append(generation_resp)
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
|
||||
self._log_text_metrics(output_complexity_metrics, step=self.step)
|
||||
self._log_text_metrics(output_custom_metrics, step=self.step)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for chain_input_key, chain_input_val in inputs.items():
|
||||
if isinstance(chain_input_val, str):
|
||||
input_resp = deepcopy(resp)
|
||||
if self.stream_logs:
|
||||
self._log_stream(chain_input_val, resp, self.step)
|
||||
input_resp.update({chain_input_key: chain_input_val})
|
||||
self.action_records.append(input_resp)
|
||||
|
||||
else:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
f"Unexpected data format provided! "
|
||||
f"Input Value for {chain_input_key} will not be logged"
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, str):
|
||||
output_resp = deepcopy(resp)
|
||||
if self.stream_logs:
|
||||
self._log_stream(chain_output_val, resp, self.step)
|
||||
output_resp.update({chain_output_key: chain_output_val})
|
||||
self.action_records.append(output_resp)
|
||||
else:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
f"Unexpected data format provided! "
|
||||
f"Output Value for {chain_output_key} will not be logged"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(input_str, resp, self.step)
|
||||
|
||||
resp.update({"input_str": input_str})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(output, resp, self.step)
|
||||
|
||||
resp.update({"output": output})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text"})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(text, resp, self.step)
|
||||
|
||||
resp.update({"text": text})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
output = finish.return_values["output"]
|
||||
log = finish.log
|
||||
|
||||
resp.update({"action": "on_agent_finish", "log": log})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(output, resp, self.step)
|
||||
|
||||
resp.update({"output": output})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
tool = action.tool
|
||||
tool_input = str(action.tool_input)
|
||||
log = action.log
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_agent_action", "log": log, "tool": tool})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
if self.stream_logs:
|
||||
self._log_stream(tool_input, resp, self.step)
|
||||
|
||||
resp.update({"tool_input": tool_input})
|
||||
self.action_records.append(resp)
|
||||
|
||||
def _get_complexity_metrics(self, text: str) -> dict:
|
||||
"""Compute text complexity metrics using textstat.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
|
||||
Returns:
|
||||
`dict` containing the complexity metrics.
|
||||
"""
|
||||
resp = {}
|
||||
if self.complexity_metrics:
|
||||
text_complexity_metrics = _fetch_text_complexity_metrics(text)
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
return resp
|
||||
|
||||
def _get_custom_metrics(
|
||||
self, generation: Generation, prompt_idx: int, gen_idx: int
|
||||
) -> dict:
|
||||
"""Compute Custom Metrics for an LLM Generated Output
|
||||
|
||||
Args:
|
||||
generation (LLMResult): Output generation from an LLM
|
||||
prompt_idx (int): List index of the input prompt
|
||||
gen_idx (int): List index of the generated output
|
||||
|
||||
Returns:
|
||||
dict: `dict` containing the custom metrics.
|
||||
"""
|
||||
|
||||
resp = {}
|
||||
if self.custom_metrics:
|
||||
custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
|
||||
resp.update(custom_metrics)
|
||||
|
||||
return resp
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
langchain_asset: Any = None,
|
||||
task_type: Optional[str] = "inference",
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = "comet-langchain-demo",
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
finish: bool = False,
|
||||
reset: bool = False,
|
||||
) -> None:
|
||||
"""Flush the tracker and setup the session.
|
||||
|
||||
Everything after this will be a new table.
|
||||
|
||||
Args:
|
||||
name: Name of the performed session so far so it is identifiable
|
||||
langchain_asset: The langchain asset to save.
|
||||
finish: Whether to finish the run.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._log_session(langchain_asset)
|
||||
|
||||
if langchain_asset:
|
||||
try:
|
||||
self._log_model(langchain_asset)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.error(
|
||||
"Failed to export agent or LLM to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
if finish:
|
||||
self.experiment.end()
|
||||
|
||||
if reset:
|
||||
self._reset(
|
||||
task_type,
|
||||
workspace,
|
||||
project_name,
|
||||
tags,
|
||||
name,
|
||||
visualizations,
|
||||
complexity_metrics,
|
||||
custom_metrics,
|
||||
)
|
||||
|
||||
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
|
||||
self.experiment.log_text(prompt, metadata=metadata, step=step)
|
||||
|
||||
def _log_model(self, langchain_asset: Any) -> None:
|
||||
model_parameters = self._get_llm_parameters(langchain_asset)
|
||||
self.experiment.log_parameters(model_parameters, prefix="model")
|
||||
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
|
||||
|
||||
try:
|
||||
if hasattr(langchain_asset, "save"):
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
self.experiment.log_model(model_name, str(langchain_asset_path))
|
||||
except (ValueError, AttributeError, NotImplementedError) as e:
|
||||
if hasattr(langchain_asset, "save_agent"):
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
self.experiment.log_model(model_name, str(langchain_asset_path))
|
||||
else:
|
||||
self.comet_ml.LOGGER.error(
|
||||
f"{e}"
|
||||
" Could not save Langchain Asset "
|
||||
f"for {langchain_asset.__class__.__name__}"
|
||||
)
|
||||
|
||||
def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
|
||||
try:
|
||||
llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
|
||||
# Log the cleaned dataframe as a table
|
||||
self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log session data to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
try:
|
||||
metadata = {"langchain_version": str(langchain_community.__version__)}
|
||||
# Log the langchain low-level records as a JSON file directly
|
||||
self.experiment.log_asset_data(
|
||||
self.action_records, "langchain-action_records.json", metadata=metadata
|
||||
)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log session data to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
try:
|
||||
self._log_visualizations(llm_session_df)
|
||||
except Exception:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
"Failed to log visualizations to Comet",
|
||||
exc_info=True,
|
||||
extra={"show_traceback": True},
|
||||
)
|
||||
|
||||
def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
|
||||
for key, value in metrics_summary.items():
|
||||
self.experiment.log_metrics(value, prefix=key, step=step)
|
||||
|
||||
def _log_visualizations(self, session_df: Any) -> None:
|
||||
if not (self.visualizations and self.nlp):
|
||||
return
|
||||
|
||||
spacy = import_spacy()
|
||||
|
||||
prompts = session_df["prompts"].tolist()
|
||||
outputs = session_df["text"].tolist()
|
||||
|
||||
for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
|
||||
doc = self.nlp(output)
|
||||
sentence_spans = list(doc.sents)
|
||||
|
||||
for visualization in self.visualizations:
|
||||
try:
|
||||
html = spacy.displacy.render(
|
||||
sentence_spans,
|
||||
style=visualization,
|
||||
options={"compact": True},
|
||||
jupyter=False,
|
||||
page=True,
|
||||
)
|
||||
self.experiment.log_asset_data(
|
||||
html,
|
||||
name=f"langchain-viz-{visualization}-{idx}.html",
|
||||
metadata={"prompt": prompt},
|
||||
step=idx,
|
||||
)
|
||||
except Exception as e:
|
||||
self.comet_ml.LOGGER.warning(
|
||||
e, exc_info=True, extra={"show_traceback": True}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def _reset(
|
||||
self,
|
||||
task_type: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
name: Optional[str] = None,
|
||||
visualizations: Optional[List[str]] = None,
|
||||
complexity_metrics: bool = False,
|
||||
custom_metrics: Optional[Callable] = None,
|
||||
) -> None:
|
||||
_task_type = task_type if task_type else self.task_type
|
||||
_workspace = workspace if workspace else self.workspace
|
||||
_project_name = project_name if project_name else self.project_name
|
||||
_tags = tags if tags else self.tags
|
||||
_name = name if name else self.name
|
||||
_visualizations = visualizations if visualizations else self.visualizations
|
||||
_complexity_metrics = (
|
||||
complexity_metrics if complexity_metrics else self.complexity_metrics
|
||||
)
|
||||
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
|
||||
|
||||
self.__init__( # type: ignore[misc]
|
||||
task_type=_task_type,
|
||||
workspace=_workspace,
|
||||
project_name=_project_name,
|
||||
tags=_tags,
|
||||
name=_name,
|
||||
visualizations=_visualizations,
|
||||
complexity_metrics=_complexity_metrics,
|
||||
custom_metrics=_custom_metrics,
|
||||
)
|
||||
|
||||
self.reset_callback_meta()
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
|
||||
pd = import_pandas()
|
||||
|
||||
llm_parameters = self._get_llm_parameters(langchain_asset)
|
||||
num_generations_per_prompt = llm_parameters.get("n", 1)
|
||||
|
||||
llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
# Repeat each input row based on the number of outputs generated per prompt
|
||||
llm_start_records_df = llm_start_records_df.loc[
|
||||
llm_start_records_df.index.repeat(num_generations_per_prompt)
|
||||
].reset_index(drop=True)
|
||||
llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_session_df = pd.merge(
|
||||
llm_start_records_df,
|
||||
llm_end_records_df,
|
||||
left_index=True,
|
||||
right_index=True,
|
||||
suffixes=["_llm_start", "_llm_end"],
|
||||
)
|
||||
|
||||
return llm_session_df
|
||||
|
||||
def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
|
||||
if not langchain_asset:
|
||||
return {}
|
||||
try:
|
||||
if hasattr(langchain_asset, "agent"):
|
||||
llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
|
||||
elif hasattr(langchain_asset, "llm_chain"):
|
||||
llm_parameters = langchain_asset.llm_chain.llm.dict()
|
||||
elif hasattr(langchain_asset, "llm"):
|
||||
llm_parameters = langchain_asset.llm.dict()
|
||||
else:
|
||||
llm_parameters = langchain_asset.dict()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
return llm_parameters
|
||||
@@ -0,0 +1,183 @@
|
||||
# flake8: noqa
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class DeepEvalCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs into deepeval.
|
||||
|
||||
Args:
|
||||
implementation_name: name of the `implementation` in deepeval
|
||||
metrics: A list of metrics
|
||||
|
||||
Raises:
|
||||
ImportError: if the `deepeval` package is not installed.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import DeepEvalCallbackHandler
|
||||
>>> from deepeval.metrics import AnswerRelevancy
|
||||
>>> metric = AnswerRelevancy(minimum_score=0.3)
|
||||
>>> deepeval_callback = DeepEvalCallbackHandler(
|
||||
... implementation_name="exampleImplementation",
|
||||
... metrics=[metric],
|
||||
... )
|
||||
>>> llm = OpenAI(
|
||||
... temperature=0,
|
||||
... callbacks=[deepeval_callback],
|
||||
... verbose=True,
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> llm.generate([
|
||||
... "What is the best evaluation tool out there? (no bias at all)",
|
||||
... ])
|
||||
"Deepeval, no doubt about it."
|
||||
"""
|
||||
|
||||
REPO_URL: str = "https://github.com/confident-ai/deepeval"
|
||||
ISSUES_URL: str = f"{REPO_URL}/issues"
|
||||
BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metrics: List[Any],
|
||||
implementation_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initializes the `deepevalCallbackHandler`.
|
||||
|
||||
Args:
|
||||
implementation_name: Name of the implementation you want.
|
||||
metrics: What metrics do you want to track?
|
||||
|
||||
Raises:
|
||||
ImportError: if the `deepeval` package is not installed.
|
||||
ConnectionError: if the connection to deepeval fails.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Import deepeval (not via `import_deepeval` to keep hints in IDEs)
|
||||
try:
|
||||
import deepeval # ignore: F401,I001
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""To use the deepeval callback manager you need to have the
|
||||
`deepeval` Python package installed. Please install it with
|
||||
`pip install deepeval`"""
|
||||
)
|
||||
|
||||
if os.path.exists(".deepeval"):
|
||||
warnings.warn(
|
||||
"""You are currently not logging anything to the dashboard, we
|
||||
recommend using `deepeval login`."""
|
||||
)
|
||||
|
||||
# Set the deepeval variables
|
||||
self.implementation_name = implementation_name
|
||||
self.metrics = metrics
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"The `DeepEvalCallbackHandler` is currently in beta and is subject to"
|
||||
" change based on updates to `langchain`. Please report any issues to"
|
||||
f" {self.ISSUES_URL} as an `integration` issue."
|
||||
),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Store the prompts"""
|
||||
self.prompts = prompts
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log records to deepeval when an LLM ends."""
|
||||
from deepeval.metrics.answer_relevancy import AnswerRelevancy
|
||||
from deepeval.metrics.bias_classifier import UnBiasedMetric
|
||||
from deepeval.metrics.metric import Metric
|
||||
from deepeval.metrics.toxic_classifier import NonToxicMetric
|
||||
|
||||
for metric in self.metrics:
|
||||
for i, generation in enumerate(response.generations):
|
||||
# Here, we only measure the first generation's output
|
||||
output = generation[0].text
|
||||
query = self.prompts[i]
|
||||
if isinstance(metric, AnswerRelevancy):
|
||||
result = metric.measure(
|
||||
output=output,
|
||||
query=query,
|
||||
)
|
||||
print(f"Answer Relevancy: {result}") # noqa: T201
|
||||
elif isinstance(metric, UnBiasedMetric):
|
||||
score = metric.measure(output)
|
||||
print(f"Bias Score: {score}") # noqa: T201
|
||||
elif isinstance(metric, NonToxicMetric):
|
||||
score = metric.measure(output)
|
||||
print(f"Toxic Score: {score}") # noqa: T201
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Metric {metric.__name__} is not supported by deepeval
|
||||
callbacks."""
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when chain starts"""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing when chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Callback handler for Context AI"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
|
||||
def import_context() -> Any:
|
||||
"""Import the `getcontext` package."""
|
||||
return (
|
||||
guard_import("getcontext", pip_name="python-context"),
|
||||
guard_import("getcontext.token", pip_name="python-context").Credential,
|
||||
guard_import(
|
||||
"getcontext.generated.models", pip_name="python-context"
|
||||
).Conversation,
|
||||
guard_import("getcontext.generated.models", pip_name="python-context").Message,
|
||||
guard_import(
|
||||
"getcontext.generated.models", pip_name="python-context"
|
||||
).MessageRole,
|
||||
guard_import("getcontext.generated.models", pip_name="python-context").Rating,
|
||||
)
|
||||
|
||||
|
||||
class ContextCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that records transcripts to the Context service.
|
||||
|
||||
(https://context.ai).
|
||||
|
||||
Keyword Args:
|
||||
token (optional): The token with which to authenticate requests to Context.
|
||||
Visit https://with.context.ai/settings to generate a token.
|
||||
If not provided, the value of the `CONTEXT_TOKEN` environment
|
||||
variable will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `context-python` package is not installed.
|
||||
|
||||
Chat Example:
|
||||
>>> from langchain_community.llms import ChatOpenAI
|
||||
>>> from langchain_community.callbacks import ContextCallbackHandler
|
||||
>>> context_callback = ContextCallbackHandler(
|
||||
... token="<CONTEXT_TOKEN_HERE>",
|
||||
... )
|
||||
>>> chat = ChatOpenAI(
|
||||
... temperature=0,
|
||||
... headers={"user_id": "123"},
|
||||
... callbacks=[context_callback],
|
||||
... openai_api_key="API_KEY_HERE",
|
||||
... )
|
||||
>>> messages = [
|
||||
... SystemMessage(content="You translate English to French."),
|
||||
... HumanMessage(content="I love programming with LangChain."),
|
||||
... ]
|
||||
>>> chat.invoke(messages)
|
||||
|
||||
Chain Example:
|
||||
>>> from langchain_classic.chains import LLMChain
|
||||
>>> from langchain_community.chat_models import ChatOpenAI
|
||||
>>> from langchain_community.callbacks import ContextCallbackHandler
|
||||
>>> context_callback = ContextCallbackHandler(
|
||||
... token="<CONTEXT_TOKEN_HERE>",
|
||||
... )
|
||||
>>> human_message_prompt = HumanMessagePromptTemplate(
|
||||
... prompt=PromptTemplate(
|
||||
... template="What is a good name for a company that makes {product}?",
|
||||
... input_variables=["product"],
|
||||
... ),
|
||||
... )
|
||||
>>> chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||
... [human_message_prompt]
|
||||
... )
|
||||
>>> callback = ContextCallbackHandler(token)
|
||||
>>> # Note: the same callback object must be shared between the
|
||||
... LLM and the chain.
|
||||
>>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback])
|
||||
>>> chain = LLMChain(
|
||||
... llm=chat,
|
||||
... prompt=chat_prompt_template,
|
||||
... callbacks=[callback]
|
||||
... )
|
||||
>>> chain.run("colorful socks")
|
||||
"""
|
||||
|
||||
def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None:
|
||||
(
|
||||
self.context,
|
||||
self.credential,
|
||||
self.conversation_model,
|
||||
self.message_model,
|
||||
self.message_role_model,
|
||||
self.rating_model,
|
||||
) = import_context()
|
||||
|
||||
token = token or os.environ.get("CONTEXT_TOKEN") or ""
|
||||
|
||||
self.client = self.context.ContextAPI(credential=self.credential(token))
|
||||
|
||||
self.chain_run_id = None
|
||||
|
||||
self.llm_model = None
|
||||
|
||||
self.messages: List[Any] = []
|
||||
self.metadata: Dict[str, str] = {}
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when the chat model is started."""
|
||||
llm_model = kwargs.get("invocation_params", {}).get("model", None)
|
||||
if llm_model is not None:
|
||||
self.metadata["model"] = llm_model
|
||||
|
||||
if len(messages) == 0:
|
||||
return
|
||||
|
||||
for message in messages[0]:
|
||||
role = self.message_role_model.SYSTEM
|
||||
if message.type == "human":
|
||||
role = self.message_role_model.USER
|
||||
elif message.type == "system":
|
||||
role = self.message_role_model.SYSTEM
|
||||
elif message.type == "ai":
|
||||
role = self.message_role_model.ASSISTANT
|
||||
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=message.content,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends."""
|
||||
if len(response.generations) == 0 or len(response.generations[0]) == 0:
|
||||
return
|
||||
|
||||
if not self.chain_run_id:
|
||||
generation = response.generations[0][0]
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=generation.text,
|
||||
role=self.message_role_model.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
self._log_conversation()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts."""
|
||||
self.chain_run_id = kwargs.get("run_id", None)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends."""
|
||||
self.messages.append(
|
||||
self.message_model(
|
||||
message=outputs["text"],
|
||||
role=self.message_role_model.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
self._log_conversation()
|
||||
|
||||
self.chain_run_id = None
|
||||
|
||||
def _log_conversation(self) -> None:
|
||||
"""Log the conversation to the context API."""
|
||||
if len(self.messages) == 0:
|
||||
return
|
||||
|
||||
self.client.log.conversation_upsert(
|
||||
body={
|
||||
"conversation": self.conversation_model(
|
||||
messages=self.messages,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.messages = []
|
||||
self.metadata = {}
|
||||
@@ -0,0 +1,335 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
from langchain_community.callbacks.utils import import_pandas
|
||||
|
||||
# Define constants
|
||||
|
||||
# LLMResult keys
|
||||
TOKEN_USAGE = "token_usage"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROMPT_TOKENS = "prompt_tokens"
|
||||
COMPLETION_TOKENS = "completion_tokens"
|
||||
RUN_ID = "run_id"
|
||||
MODEL_NAME = "model_name"
|
||||
GOOD = "good"
|
||||
BAD = "bad"
|
||||
NEUTRAL = "neutral"
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
|
||||
# Default values
|
||||
DEFAULT_MAX_TOKEN = 65536
|
||||
DEFAULT_MAX_DURATION = 120000
|
||||
|
||||
# Fiddler specific constants
|
||||
PROMPT = "prompt"
|
||||
RESPONSE = "response"
|
||||
CONTEXT = "context"
|
||||
DURATION = "duration"
|
||||
FEEDBACK = "feedback"
|
||||
LLM_STATUS = "llm_status"
|
||||
|
||||
FEEDBACK_POSSIBLE_VALUES = [GOOD, BAD, NEUTRAL]
|
||||
|
||||
# Define a dataset dictionary
|
||||
_dataset_dict = {
|
||||
PROMPT: ["fiddler"] * 10,
|
||||
RESPONSE: ["fiddler"] * 10,
|
||||
CONTEXT: ["fiddler"] * 10,
|
||||
FEEDBACK: ["good"] * 10,
|
||||
LLM_STATUS: ["success"] * 10,
|
||||
MODEL_NAME: ["fiddler"] * 10,
|
||||
RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10,
|
||||
TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
|
||||
PROMPT_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
|
||||
COMPLETION_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
|
||||
DURATION: [1, DEFAULT_MAX_DURATION] * 5,
|
||||
}
|
||||
|
||||
|
||||
def import_fiddler() -> Any:
|
||||
"""Import the fiddler python package and raise an error if it is not installed."""
|
||||
return guard_import("fiddler", pip_name="fiddler-client")
|
||||
|
||||
|
||||
# First, define custom callback handler implementations
|
||||
class FiddlerCallbackHandler(BaseCallbackHandler):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
org: str,
|
||||
project: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Fiddler callback handler.
|
||||
|
||||
Args:
|
||||
url: Fiddler URL (e.g. https://demo.fiddler.ai).
|
||||
Make sure to include the protocol (http/https).
|
||||
org: Fiddler organization id
|
||||
project: Fiddler project name to publish events to
|
||||
model: Fiddler model name to publish events to
|
||||
api_key: Fiddler authentication token
|
||||
"""
|
||||
super().__init__()
|
||||
# Initialize Fiddler client and other necessary properties
|
||||
self.fdl = import_fiddler()
|
||||
self.pd = import_pandas()
|
||||
|
||||
self.url = url
|
||||
self.org = org
|
||||
self.project = project
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self._df = self.pd.DataFrame(_dataset_dict)
|
||||
|
||||
self.run_id_prompts: Dict[UUID, List[str]] = {}
|
||||
self.run_id_response: Dict[UUID, List[str]] = {}
|
||||
self.run_id_starttime: Dict[UUID, int] = {}
|
||||
|
||||
# Initialize Fiddler client here
|
||||
self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key)
|
||||
|
||||
if self.project not in self.fiddler_client.get_project_names():
|
||||
print( # noqa: T201
|
||||
f"adding project {self.project}.This only has to be done once."
|
||||
)
|
||||
try:
|
||||
self.fiddler_client.add_project(self.project)
|
||||
except Exception as e:
|
||||
print( # noqa: T201
|
||||
f"Error adding project {self.project}:"
|
||||
"{e}. Fiddler integration will not work."
|
||||
)
|
||||
raise e
|
||||
|
||||
dataset_info = self.fdl.DatasetInfo.from_dataframe(
|
||||
self._df, max_inferred_cardinality=0
|
||||
)
|
||||
|
||||
# Set feedback column to categorical
|
||||
for i in range(len(dataset_info.columns)):
|
||||
if dataset_info.columns[i].name == FEEDBACK:
|
||||
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
|
||||
dataset_info.columns[i].possible_values = FEEDBACK_POSSIBLE_VALUES
|
||||
|
||||
elif dataset_info.columns[i].name == LLM_STATUS:
|
||||
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
|
||||
dataset_info.columns[i].possible_values = [SUCCESS, FAILURE]
|
||||
|
||||
if self.model not in self.fiddler_client.get_model_names(self.project):
|
||||
if self.model not in self.fiddler_client.get_dataset_names(self.project):
|
||||
print( # noqa: T201
|
||||
f"adding dataset {self.model} to project {self.project}."
|
||||
"This only has to be done once."
|
||||
)
|
||||
try:
|
||||
self.fiddler_client.upload_dataset(
|
||||
project_id=self.project,
|
||||
dataset_id=self.model,
|
||||
dataset={"train": self._df},
|
||||
info=dataset_info,
|
||||
)
|
||||
except Exception as e:
|
||||
print( # noqa: T201
|
||||
f"Error adding dataset {self.model}: {e}."
|
||||
"Fiddler integration will not work."
|
||||
)
|
||||
raise e
|
||||
|
||||
model_info = self.fdl.ModelInfo.from_dataset_info(
|
||||
dataset_info=dataset_info,
|
||||
dataset_id="train",
|
||||
model_task=self.fdl.ModelTask.LLM,
|
||||
features=[PROMPT, CONTEXT, RESPONSE],
|
||||
target=FEEDBACK,
|
||||
metadata_cols=[
|
||||
RUN_ID,
|
||||
TOTAL_TOKENS,
|
||||
PROMPT_TOKENS,
|
||||
COMPLETION_TOKENS,
|
||||
MODEL_NAME,
|
||||
DURATION,
|
||||
],
|
||||
custom_features=self.custom_features,
|
||||
)
|
||||
print( # noqa: T201
|
||||
f"adding model {self.model} to project {self.project}."
|
||||
"This only has to be done once."
|
||||
)
|
||||
try:
|
||||
self.fiddler_client.add_model(
|
||||
project_id=self.project,
|
||||
dataset_id=self.model,
|
||||
model_id=self.model,
|
||||
model_info=model_info,
|
||||
)
|
||||
except Exception as e:
|
||||
print( # noqa: T201
|
||||
f"Error adding model {self.model}: {e}."
|
||||
"Fiddler integration will not work."
|
||||
)
|
||||
raise e
|
||||
|
||||
@property
|
||||
def custom_features(self) -> list:
|
||||
"""
|
||||
Define custom features for the model to automatically enrich the data with.
|
||||
Here, we enable the following enrichments:
|
||||
- Automatic Embedding generation for prompt and response
|
||||
- Text Statistics such as:
|
||||
- Automated Readability Index
|
||||
- Coleman Liau Index
|
||||
- Dale Chall Readability Score
|
||||
- Difficult Words
|
||||
- Flesch Reading Ease
|
||||
- Flesch Kincaid Grade
|
||||
- Gunning Fog
|
||||
- Linsear Write Formula
|
||||
- PII - Personal Identifiable Information
|
||||
- Sentiment Analysis
|
||||
|
||||
"""
|
||||
|
||||
return [
|
||||
self.fdl.Enrichment(
|
||||
name="Prompt Embedding",
|
||||
enrichment="embedding",
|
||||
columns=[PROMPT],
|
||||
),
|
||||
self.fdl.TextEmbedding(
|
||||
name="Prompt CF",
|
||||
source_column=PROMPT,
|
||||
column="Prompt Embedding",
|
||||
),
|
||||
self.fdl.Enrichment(
|
||||
name="Response Embedding",
|
||||
enrichment="embedding",
|
||||
columns=[RESPONSE],
|
||||
),
|
||||
self.fdl.TextEmbedding(
|
||||
name="Response CF",
|
||||
source_column=RESPONSE,
|
||||
column="Response Embedding",
|
||||
),
|
||||
self.fdl.Enrichment(
|
||||
name="Text Statistics",
|
||||
enrichment="textstat",
|
||||
columns=[PROMPT, RESPONSE],
|
||||
config={
|
||||
"statistics": [
|
||||
"automated_readability_index",
|
||||
"coleman_liau_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"gunning_fog",
|
||||
"linsear_write_formula",
|
||||
]
|
||||
},
|
||||
),
|
||||
self.fdl.Enrichment(
|
||||
name="PII",
|
||||
enrichment="pii",
|
||||
columns=[PROMPT, RESPONSE],
|
||||
),
|
||||
self.fdl.Enrichment(
|
||||
name="Sentiment",
|
||||
enrichment="sentiment",
|
||||
columns=[PROMPT, RESPONSE],
|
||||
),
|
||||
]
|
||||
|
||||
def _publish_events(
|
||||
self,
|
||||
run_id: UUID,
|
||||
prompt_responses: List[str],
|
||||
duration: int,
|
||||
llm_status: str,
|
||||
model_name: Optional[str] = "",
|
||||
token_usage_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Publish events to fiddler
|
||||
"""
|
||||
|
||||
prompt_count = len(self.run_id_prompts[run_id])
|
||||
df = self.pd.DataFrame(
|
||||
{
|
||||
PROMPT: self.run_id_prompts[run_id],
|
||||
RESPONSE: prompt_responses,
|
||||
RUN_ID: [str(run_id)] * prompt_count,
|
||||
DURATION: [duration] * prompt_count,
|
||||
LLM_STATUS: [llm_status] * prompt_count,
|
||||
MODEL_NAME: [model_name] * prompt_count,
|
||||
}
|
||||
)
|
||||
|
||||
if token_usage_dict:
|
||||
for key, value in token_usage_dict.items():
|
||||
df[key] = [value] * prompt_count if isinstance(value, int) else value
|
||||
|
||||
try:
|
||||
if df.shape[0] > 1:
|
||||
self.fiddler_client.publish_events_batch(self.project, self.model, df)
|
||||
else:
|
||||
df_dict = df.to_dict(orient="records")
|
||||
self.fiddler_client.publish_event(
|
||||
self.project, self.model, event=df_dict[0]
|
||||
)
|
||||
except Exception as e:
|
||||
print( # noqa: T201
|
||||
f"Error publishing events to fiddler: {e}. continuing..."
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
run_id = kwargs[RUN_ID]
|
||||
self.run_id_prompts[run_id] = prompts
|
||||
self.run_id_starttime[run_id] = int(time.time() * 1000)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
flattened_llmresult = response.flatten()
|
||||
run_id = kwargs[RUN_ID]
|
||||
run_duration = int(time.time() * 1000) - self.run_id_starttime[run_id]
|
||||
model_name = ""
|
||||
token_usage_dict = {}
|
||||
|
||||
if isinstance(response.llm_output, dict):
|
||||
token_usage_dict = {
|
||||
k: v
|
||||
for k, v in response.llm_output.items()
|
||||
if k in [TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS]
|
||||
}
|
||||
model_name = response.llm_output.get(MODEL_NAME, "")
|
||||
|
||||
prompt_responses = [
|
||||
llmresult.generations[0][0].text for llmresult in flattened_llmresult
|
||||
]
|
||||
|
||||
self._publish_events(
|
||||
run_id,
|
||||
prompt_responses,
|
||||
run_duration,
|
||||
SUCCESS,
|
||||
model_name,
|
||||
token_usage_dict,
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
run_id = kwargs[RUN_ID]
|
||||
duration = int(time.time() * 1000) - self.run_id_starttime[run_id]
|
||||
|
||||
self._publish_events(
|
||||
run_id, [""] * len(self.run_id_prompts[run_id]), duration, FAILURE
|
||||
)
|
||||
@@ -0,0 +1,364 @@
|
||||
"""FlyteKit callback handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import flytekit
|
||||
from flytekitplugins.deck import renderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_flytekit() -> Tuple[flytekit, renderer]:
|
||||
"""Import flytekit and flytekitplugins-deck-standard."""
|
||||
return (
|
||||
guard_import("flytekit"),
|
||||
guard_import(
|
||||
"flytekitplugins.deck", pip_name="flytekitplugins-deck-standard"
|
||||
).renderer,
|
||||
)
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
textstat: Any = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
|
||||
Returns:
|
||||
`dict` containing the complexity metrics and visualization
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
if textstat is not None:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
"entities": ent_out,
|
||||
}
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback handler that is used within a Flyte task."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize callback handler."""
|
||||
flytekit, renderer = import_flytekit()
|
||||
self.pandas = import_pandas()
|
||||
|
||||
self.textstat = None
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Textstat library is not installed. \
|
||||
It may result in the inability to log \
|
||||
certain metrics that can be captured with Textstat."
|
||||
)
|
||||
|
||||
spacy = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Spacy library is not installed. \
|
||||
It may result in the inability to log \
|
||||
certain metrics that can be captured with Spacy."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.nlp = None
|
||||
if spacy:
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
|
||||
" for certain metrics. To download,"
|
||||
" run the following command in your terminal:"
|
||||
" `python -m spacy download en_core_web_sm`"
|
||||
)
|
||||
|
||||
self.table_renderer = renderer.TableRenderer
|
||||
self.markdown_renderer = renderer.MarkdownRenderer
|
||||
|
||||
self.deck = flytekit.Deck(
|
||||
"LangChain Metrics",
|
||||
self.markdown_renderer().to_html("## LangChain Metrics"),
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
prompt_responses = []
|
||||
for prompt in prompts:
|
||||
prompt_responses.append(prompt)
|
||||
|
||||
resp.update({"prompts": prompt_responses})
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### LLM End"))
|
||||
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
if self.nlp or self.textstat:
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text, nlp=self.nlp, textstat=self.textstat
|
||||
)
|
||||
)
|
||||
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
)
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Text Complexity Metrics")
|
||||
)
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(
|
||||
self.pandas.DataFrame([complexity_metrics])
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Dependency Tree")
|
||||
)
|
||||
self.deck.append(dependency_tree)
|
||||
|
||||
entities = generation_resp["entities"]
|
||||
self.deck.append(self.markdown_renderer().to_html("#### Entities"))
|
||||
self.deck.append(entities)
|
||||
else:
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Generated Response")
|
||||
)
|
||||
self.deck.append(self.markdown_renderer().to_html(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Chain End"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Tool End"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### On Text"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
|
||||
self.deck.append(
|
||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||
)
|
||||
@@ -0,0 +1,88 @@
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
|
||||
def _default_approve(_input: str) -> bool:
|
||||
msg = (
|
||||
"Do you approve of the following input? "
|
||||
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
|
||||
)
|
||||
msg += "\n\n" + _input + "\n"
|
||||
resp = input(msg)
|
||||
return resp.lower() in ("yes", "y")
|
||||
|
||||
|
||||
async def _adefault_approve(_input: str) -> bool:
|
||||
msg = (
|
||||
"Do you approve of the following input? "
|
||||
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
|
||||
)
|
||||
msg += "\n\n" + _input + "\n"
|
||||
resp = input(msg)
|
||||
return resp.lower() in ("yes", "y")
|
||||
|
||||
|
||||
def _default_true(_: Dict[str, Any]) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class HumanRejectedException(Exception):
|
||||
"""Exception to raise when a person manually review and rejects a value."""
|
||||
|
||||
|
||||
class HumanApprovalCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback for manually validating values."""
|
||||
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approve: Callable[[Any], bool] = _default_approve,
|
||||
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
|
||||
):
|
||||
self._approve = approve
|
||||
self._should_check = should_check
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self._should_check(serialized) and not self._approve(input_str):
|
||||
raise HumanRejectedException(
|
||||
f"Inputs {input_str} to tool {serialized} were rejected."
|
||||
)
|
||||
|
||||
|
||||
class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler):
|
||||
"""Asynchronous callback for manually validating values."""
|
||||
|
||||
raise_error: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approve: Callable[[Any], Awaitable[bool]] = _adefault_approve,
|
||||
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
|
||||
):
|
||||
self._approve = approve
|
||||
self._should_check = should_check
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self._should_check(serialized) and not await self._approve(input_str):
|
||||
raise HumanRejectedException(
|
||||
f"Inputs {input_str} to tool {serialized} were rejected."
|
||||
)
|
||||
@@ -0,0 +1,251 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
"""Import the infino client."""
|
||||
return guard_import("infinopy").InfinoClient()
|
||||
|
||||
|
||||
def import_tiktoken() -> Any:
|
||||
"""Import tiktoken for counting tokens for OpenAI models."""
|
||||
return guard_import("tiktoken")
|
||||
|
||||
|
||||
def get_num_tokens(string: str, openai_model_name: str) -> int:
|
||||
"""Calculate num tokens for OpenAI with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main
|
||||
/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
"""
|
||||
tiktoken = import_tiktoken()
|
||||
|
||||
encoding = tiktoken.encoding_for_model(openai_model_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
|
||||
class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Infino."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: Optional[str] = None,
|
||||
model_version: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
# Set Infino client
|
||||
self.client = import_infino()
|
||||
self.model_id = model_id
|
||||
self.model_version = model_version
|
||||
self.verbose = verbose
|
||||
self.is_chat_openai_model = False
|
||||
self.chat_openai_model_name = "gpt-3.5-turbo"
|
||||
|
||||
def _send_to_infino(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
is_ts: bool = True,
|
||||
) -> None:
|
||||
"""Send the key-value to Infino.
|
||||
|
||||
Parameters:
|
||||
key (str): the key to send to Infino.
|
||||
value (Any): the value to send to Infino.
|
||||
is_ts (bool): if True, the value is part of a time series, else it
|
||||
is sent as a log message.
|
||||
"""
|
||||
payload = {
|
||||
"date": int(time.time()),
|
||||
key: value,
|
||||
"labels": {
|
||||
"model_id": self.model_id,
|
||||
"model_version": self.model_version,
|
||||
},
|
||||
}
|
||||
if self.verbose:
|
||||
print(f"Tracking {key} with Infino: {payload}") # noqa: T201
|
||||
|
||||
# Append to Infino time series only if is_ts is True, otherwise
|
||||
# append to Infino log.
|
||||
if is_ts:
|
||||
self.client.append_ts(payload)
|
||||
else:
|
||||
self.client.append_log(payload)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Log the prompts to Infino, and set start time and error flag."""
|
||||
for prompt in prompts:
|
||||
self._send_to_infino("prompt", prompt, is_ts=False)
|
||||
|
||||
# Set the error flag to indicate no error (this will get overridden
|
||||
# in on_llm_error if an error occurs).
|
||||
self.error = 0
|
||||
|
||||
# Set the start time (so that we can calculate the request
|
||||
# duration in on_llm_end).
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Log the latency, error, token usage, and response to Infino."""
|
||||
# Calculate and track the request latency.
|
||||
self.end_time = time.time()
|
||||
duration = self.end_time - self.start_time
|
||||
self._send_to_infino("latency", duration)
|
||||
|
||||
# Track success or error flag.
|
||||
self._send_to_infino("error", self.error)
|
||||
|
||||
# Track prompt response.
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
||||
|
||||
# Track token usage (for non-chat models).
|
||||
if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
if token_usage is not None:
|
||||
prompt_tokens = token_usage["prompt_tokens"]
|
||||
total_tokens = token_usage["total_tokens"]
|
||||
completion_tokens = token_usage["completion_tokens"]
|
||||
self._send_to_infino("prompt_tokens", prompt_tokens)
|
||||
self._send_to_infino("total_tokens", total_tokens)
|
||||
self._send_to_infino("completion_tokens", completion_tokens)
|
||||
|
||||
# Track completion token usage (for openai chat models).
|
||||
if self.is_chat_openai_model:
|
||||
messages = " ".join(
|
||||
cast(str, cast(ChatGeneration, generation).message.content)
|
||||
for generation in generations
|
||||
)
|
||||
completion_tokens = get_num_tokens(
|
||||
messages, openai_model_name=self.chat_openai_model_name
|
||||
)
|
||||
self._send_to_infino("completion_tokens", completion_tokens)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Set the error flag."""
|
||||
self.error = 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain starts."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Need to log the error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
# Currently, for chat models, we only support input prompts for ChatOpenAI.
|
||||
# Check if this model is a ChatOpenAI model.
|
||||
values = serialized.get("id")
|
||||
if values:
|
||||
for value in values:
|
||||
if value == "ChatOpenAI":
|
||||
self.is_chat_openai_model = True
|
||||
break
|
||||
|
||||
# Track prompt tokens for ChatOpenAI model.
|
||||
if self.is_chat_openai_model:
|
||||
invocation_params = kwargs.get("invocation_params")
|
||||
if invocation_params:
|
||||
model_name = invocation_params.get("model_name")
|
||||
if model_name:
|
||||
self.chat_openai_model_name = model_name
|
||||
prompt_tokens = 0
|
||||
for message_list in messages:
|
||||
message_string = " ".join(
|
||||
cast(str, msg.content) for msg in message_list
|
||||
)
|
||||
num_tokens = get_num_tokens(
|
||||
message_string,
|
||||
openai_model_name=self.chat_openai_model_name,
|
||||
)
|
||||
prompt_tokens += num_tokens
|
||||
|
||||
self._send_to_infino("prompt_tokens", prompt_tokens)
|
||||
|
||||
if self.verbose:
|
||||
print( # noqa: T201
|
||||
f"on_chat_model_start: is_chat_openai_model= \
|
||||
{self.is_chat_openai_model}, \
|
||||
chat_openai_model_name={self.chat_openai_model_name}"
|
||||
)
|
||||
|
||||
# Send the prompt to infino
|
||||
prompt = " ".join(
|
||||
cast(str, msg.content) for sublist in messages for msg in sublist
|
||||
)
|
||||
self._send_to_infino("prompt", prompt, is_ts=False)
|
||||
|
||||
# Set the error flag to indicate no error (this will get overridden
|
||||
# in on_llm_error if an error occurs).
|
||||
self.error = 0
|
||||
|
||||
# Set the start time (so that we can calculate the request
|
||||
# duration in on_llm_end).
|
||||
self.start_time = time.time()
|
||||
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, ChatMessage
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
|
||||
class LabelStudioMode(Enum):
|
||||
"""Label Studio mode enumerator."""
|
||||
|
||||
PROMPT = "prompt"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def get_default_label_configs(
|
||||
mode: Union[str, LabelStudioMode],
|
||||
) -> Tuple[str, LabelStudioMode]:
|
||||
"""Get default Label Studio configs for the given mode.
|
||||
|
||||
Parameters:
|
||||
mode: Label Studio mode ("prompt" or "chat")
|
||||
|
||||
Returns: Tuple of Label Studio config and mode
|
||||
"""
|
||||
_default_label_configs = {
|
||||
LabelStudioMode.PROMPT.value: """
|
||||
<View>
|
||||
<Style>
|
||||
.prompt-box {
|
||||
background-color: white;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
}
|
||||
</Style>
|
||||
<View className="root">
|
||||
<View className="prompt-box">
|
||||
<Text name="prompt" value="$prompt"/>
|
||||
</View>
|
||||
<TextArea name="response" toName="prompt"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="prompt"/>
|
||||
</View>""",
|
||||
LabelStudioMode.CHAT.value: """
|
||||
<View>
|
||||
<View className="root">
|
||||
<Paragraphs name="dialogue"
|
||||
value="$prompt"
|
||||
layout="dialogue"
|
||||
textKey="content"
|
||||
nameKey="role"
|
||||
granularity="sentence"/>
|
||||
<Header value="Final response:"/>
|
||||
<TextArea name="response" toName="dialogue"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="dialogue"/>
|
||||
</View>""",
|
||||
}
|
||||
|
||||
if isinstance(mode, str):
|
||||
mode = LabelStudioMode(mode)
|
||||
|
||||
return _default_label_configs[mode.value], mode
|
||||
|
||||
|
||||
class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
"""Label Studio callback handler.
|
||||
Provides the ability to send predictions to Label Studio
|
||||
for human evaluation, feedback and annotation.
|
||||
|
||||
Parameters:
|
||||
api_key: Label Studio API key
|
||||
url: Label Studio URL
|
||||
project_id: Label Studio project ID
|
||||
project_name: Label Studio project name
|
||||
project_config: Label Studio project config (XML)
|
||||
mode: Label Studio mode ("prompt" or "chat")
|
||||
|
||||
Examples:
|
||||
>>> from langchain_community.llms import OpenAI
|
||||
>>> from langchain_community.callbacks import LabelStudioCallbackHandler
|
||||
>>> handler = LabelStudioCallbackHandler(
|
||||
... api_key='<your_key_here>',
|
||||
... url='http://localhost:8080',
|
||||
... project_name='LangChain-%Y-%m-%d',
|
||||
... mode='prompt'
|
||||
... )
|
||||
>>> llm = OpenAI(callbacks=[handler])
|
||||
>>> llm.invoke('Tell me a story about a dog.')
|
||||
"""
|
||||
|
||||
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
project_id: Optional[int] = None,
|
||||
project_name: str = DEFAULT_PROJECT_NAME,
|
||||
project_config: Optional[str] = None,
|
||||
mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Import LabelStudio SDK
|
||||
try:
|
||||
import label_studio_sdk as ls
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" but you don't have the LabelStudio SDK "
|
||||
f"Python package installed or upgraded to the latest version. "
|
||||
f"Please run `pip install -U label-studio-sdk`"
|
||||
f" before using this callback."
|
||||
)
|
||||
|
||||
# Check if Label Studio API key is provided
|
||||
if not api_key:
|
||||
if os.getenv("LABEL_STUDIO_API_KEY"):
|
||||
api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" Label Studio API key is not provided. "
|
||||
f"Please provide Label Studio API key: "
|
||||
f"go to the Label Studio instance, navigate to "
|
||||
f"Account & Settings -> Access Token and copy the key. "
|
||||
f"Use the key as a parameter for the callback: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_api_key='<your_key_here>', ...) or "
|
||||
f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if not url:
|
||||
if os.getenv("LABEL_STUDIO_URL"):
|
||||
url = os.getenv("LABEL_STUDIO_URL")
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Label Studio URL is not provided, "
|
||||
f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
|
||||
f"If you want to provide your own URL, use the parameter: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_url='<your_url_here>', ...) "
|
||||
f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
|
||||
)
|
||||
url = ls.LABEL_STUDIO_DEFAULT_URL
|
||||
self.url = url
|
||||
|
||||
# Maps run_id to prompts
|
||||
self.payload: Dict[str, Dict] = {}
|
||||
|
||||
self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
|
||||
self.project_name = project_name
|
||||
if project_config:
|
||||
self.project_config = project_config
|
||||
self.mode = None
|
||||
else:
|
||||
self.project_config, self.mode = get_default_label_configs(mode)
|
||||
|
||||
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
|
||||
if self.project_id is not None:
|
||||
self.ls_project = self.ls_client.get_project(int(self.project_id))
|
||||
else:
|
||||
project_title = datetime.today().strftime(self.project_name)
|
||||
existing_projects = self.ls_client.get_projects(title=project_title)
|
||||
if existing_projects:
|
||||
self.ls_project = existing_projects[0]
|
||||
self.project_id = self.ls_project.id
|
||||
else:
|
||||
self.ls_project = self.ls_client.create_project(
|
||||
title=project_title, label_config=self.project_config
|
||||
)
|
||||
self.project_id = self.ls_project.id
|
||||
self.parsed_label_config = self.ls_project.parsed_label_config
|
||||
|
||||
# Find the first TextArea tag
|
||||
# "from_name", "to_name", "value" will be used to create predictions
|
||||
self.from_name, self.to_name, self.value, self.input_type = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for tag_name, tag_info in self.parsed_label_config.items():
|
||||
if tag_info["type"] == "TextArea":
|
||||
self.from_name = tag_name
|
||||
self.to_name = tag_info["to_name"][0]
|
||||
self.value = tag_info["inputs"][0]["value"]
|
||||
self.input_type = tag_info["inputs"][0]["type"]
|
||||
break
|
||||
if not self.from_name:
|
||||
error_message = (
|
||||
f'Label Studio project "{self.project_name}" '
|
||||
f"does not have a TextArea tag. "
|
||||
f"Please add a TextArea tag to the project."
|
||||
)
|
||||
if self.mode == LabelStudioMode.PROMPT:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
' and select "Generative AI -> '
|
||||
'Supervised Language Model Fine-tuning" template.'
|
||||
)
|
||||
else:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
" and check available templates under "
|
||||
'"Generative AI" section.'
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
def add_prompts_generations(
|
||||
self, run_id: str, generations: List[List[Generation]]
|
||||
) -> None:
|
||||
# Create tasks in Label Studio
|
||||
tasks = []
|
||||
prompts = self.payload[run_id]["prompts"]
|
||||
model_version = (
|
||||
self.payload[run_id]["kwargs"]
|
||||
.get("invocation_params", {})
|
||||
.get("model_name")
|
||||
)
|
||||
for prompt, generation in zip(prompts, generations):
|
||||
tasks.append(
|
||||
{
|
||||
"data": {
|
||||
self.value: prompt,
|
||||
"run_id": run_id,
|
||||
},
|
||||
"predictions": [
|
||||
{
|
||||
"result": [
|
||||
{
|
||||
"from_name": self.from_name,
|
||||
"to_name": self.to_name,
|
||||
"type": "textarea",
|
||||
"value": {"text": [g.text for g in generation]},
|
||||
}
|
||||
],
|
||||
"model_version": model_version,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self.ls_project.import_tasks(tasks)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Text":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Text>.\n"
|
||||
f"Read more here https://labelstud.io/tags/text"
|
||||
)
|
||||
run_id = str(kwargs["run_id"])
|
||||
self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
|
||||
|
||||
def _get_message_role(self, message: BaseMessage) -> str:
|
||||
"""Get the role of the message."""
|
||||
if isinstance(message, ChatMessage):
|
||||
return message.role
|
||||
else:
|
||||
return message.__class__.__name__
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Paragraphs":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Paragraphs>.\n"
|
||||
f"Read more here https://labelstud.io/tags/paragraphs"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
for message_list in messages:
|
||||
dialog = []
|
||||
for message in message_list:
|
||||
dialog.append(
|
||||
{
|
||||
"role": self._get_message_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
prompts.append(dialog)
|
||||
self.payload[str(run_id)] = {
|
||||
"prompts": prompts,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_id": run_id,
|
||||
"parent_run_id": parent_run_id,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Create a new Label Studio task for each prompt and generation."""
|
||||
run_id = str(kwargs["run_id"])
|
||||
|
||||
# Submit results to Label Studio
|
||||
self.add_prompts_generations(run_id, response.generations)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.payload.pop(run_id)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
@@ -0,0 +1,681 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from packaging.version import parse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_API_URL = "https://app.llmonitor.com"
|
||||
|
||||
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
|
||||
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
|
||||
|
||||
PARAMS_TO_CAPTURE = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"stop",
|
||||
"presence_penalty",
|
||||
"frequence_penalty",
|
||||
"seed",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
]
|
||||
|
||||
|
||||
class UserContextManager:
|
||||
"""Context manager for LLMonitor user context."""
|
||||
|
||||
def __init__(self, user_id: str, user_props: Any = None) -> None:
|
||||
user_ctx.set(user_id)
|
||||
user_props_ctx.set(user_props)
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
|
||||
user_ctx.set(None)
|
||||
user_props_ctx.set(None)
|
||||
|
||||
|
||||
def identify(user_id: str, user_props: Any = None) -> UserContextManager:
|
||||
"""Builds an LLMonitor UserContextManager
|
||||
|
||||
Parameters:
|
||||
- `user_id`: The user id.
|
||||
- `user_props`: The user properties.
|
||||
|
||||
Returns:
|
||||
A context manager that sets the user context.
|
||||
"""
|
||||
return UserContextManager(user_id, user_props)
|
||||
|
||||
|
||||
def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
|
||||
if hasattr(obj, "to_json"):
|
||||
return obj.to_json()
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {key: _serialize(value) for key, value in obj.items()}
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [_serialize(element) for element in obj]
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def _parse_input(raw_input: Any) -> Any:
|
||||
if not raw_input:
|
||||
return None
|
||||
|
||||
# if it's an array of 1, just parse the first element
|
||||
if isinstance(raw_input, list) and len(raw_input) == 1:
|
||||
return _parse_input(raw_input[0])
|
||||
|
||||
if not isinstance(raw_input, dict):
|
||||
return _serialize(raw_input)
|
||||
|
||||
input_value = raw_input.get("input")
|
||||
inputs_value = raw_input.get("inputs")
|
||||
question_value = raw_input.get("question")
|
||||
query_value = raw_input.get("query")
|
||||
|
||||
if input_value:
|
||||
return input_value
|
||||
if inputs_value:
|
||||
return inputs_value
|
||||
if question_value:
|
||||
return question_value
|
||||
if query_value:
|
||||
return query_value
|
||||
|
||||
return _serialize(raw_input)
|
||||
|
||||
|
||||
def _parse_output(raw_output: dict) -> Any:
|
||||
if not raw_output:
|
||||
return None
|
||||
|
||||
if not isinstance(raw_output, dict):
|
||||
return _serialize(raw_output)
|
||||
|
||||
text_value = raw_output.get("text")
|
||||
output_value = raw_output.get("output")
|
||||
output_text_value = raw_output.get("output_text")
|
||||
answer_value = raw_output.get("answer")
|
||||
result_value = raw_output.get("result")
|
||||
|
||||
if text_value:
|
||||
return text_value
|
||||
if answer_value:
|
||||
return answer_value
|
||||
if output_value:
|
||||
return output_value
|
||||
if output_text_value:
|
||||
return output_text_value
|
||||
if result_value:
|
||||
return result_value
|
||||
|
||||
return _serialize(raw_output)
|
||||
|
||||
|
||||
def _parse_lc_role(
|
||||
role: str,
|
||||
) -> str:
|
||||
if role == "human":
|
||||
return "user"
|
||||
else:
|
||||
return role
|
||||
|
||||
|
||||
def _get_user_id(metadata: Any) -> Any:
|
||||
if user_ctx.get() is not None:
|
||||
return user_ctx.get()
|
||||
|
||||
metadata = metadata or {}
|
||||
user_id = metadata.get("user_id")
|
||||
if user_id is None:
|
||||
user_id = metadata.get("userId") # legacy, to delete in the future
|
||||
return user_id
|
||||
|
||||
|
||||
def _get_user_props(metadata: Any) -> Any:
|
||||
if user_props_ctx.get() is not None:
|
||||
return user_props_ctx.get()
|
||||
|
||||
metadata = metadata or {}
|
||||
return metadata.get("user_props", None)
|
||||
|
||||
|
||||
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
|
||||
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
|
||||
parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
|
||||
parsed.update(
|
||||
{
|
||||
key: cast(Any, message.additional_kwargs.get(key))
|
||||
for key in keys
|
||||
if message.additional_kwargs.get(key) is not None
|
||||
}
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _parse_lc_messages(messages: Union[List[BaseMessage], Any]) -> List[Dict[str, Any]]:
|
||||
return [_parse_lc_message(message) for message in messages]
|
||||
|
||||
|
||||
class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler for LLMonitor`.
|
||||
|
||||
#### Parameters:
|
||||
- `app_id`: The app id of the app you want to report to. Defaults to
|
||||
`None`, which means that `LLMONITOR_APP_ID` will be used.
|
||||
- `api_url`: The url of the LLMonitor API. Defaults to `None`,
|
||||
which means that either `LLMONITOR_API_URL` environment variable
|
||||
or `https://app.llmonitor.com` will be used.
|
||||
|
||||
#### Raises:
|
||||
- `ValueError`: if `app_id` is not provided either as an
|
||||
argument or as an environment variable.
|
||||
- `ConnectionError`: if the connection to the API fails.
|
||||
|
||||
|
||||
#### Example:
|
||||
```python
|
||||
from langchain_community.llms import OpenAI
|
||||
from langchain_community.callbacks import LLMonitorCallbackHandler
|
||||
|
||||
llmonitor_callback = LLMonitorCallbackHandler()
|
||||
llm = OpenAI(callbacks=[llmonitor_callback],
|
||||
metadata={"userId": "user-123"})
|
||||
llm.invoke("Hello, how are you?")
|
||||
```
|
||||
"""
|
||||
|
||||
__api_url: str
|
||||
__app_id: str
|
||||
__verbose: bool
|
||||
__llmonitor_version: str
|
||||
__has_valid_config: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: Union[str, None] = None,
|
||||
api_url: Union[str, None] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.__has_valid_config = True
|
||||
|
||||
try:
|
||||
import llmonitor
|
||||
|
||||
self.__llmonitor_version = importlib.metadata.version("llmonitor")
|
||||
self.__track_event = llmonitor.track_event
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"""[LLMonitor] To use the LLMonitor callback handler you need to
|
||||
have the `llmonitor` Python package installed. Please install it
|
||||
with `pip install llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
return
|
||||
|
||||
if parse(self.__llmonitor_version) < parse("0.0.32"):
|
||||
logger.warning(
|
||||
f"""[LLMonitor] The installed `llmonitor` version is
|
||||
{self.__llmonitor_version}
|
||||
but `LLMonitorCallbackHandler` requires at least version 0.0.32
|
||||
upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
|
||||
self.__has_valid_config = True
|
||||
|
||||
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
|
||||
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
|
||||
|
||||
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
|
||||
if _app_id is None:
|
||||
logger.warning(
|
||||
"""[LLMonitor] app_id must be provided either as an argument or
|
||||
as an environment variable"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
else:
|
||||
self.__app_id = _app_id
|
||||
|
||||
if self.__has_valid_config is False:
|
||||
return None
|
||||
|
||||
try:
|
||||
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
|
||||
if not res.ok:
|
||||
raise ConnectionError()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"""[LLMonitor] Could not connect to the LLMonitor API at
|
||||
{self.__api_url}"""
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
|
||||
params = kwargs.get("invocation_params", {})
|
||||
params.update(
|
||||
serialized.get("kwargs", {})
|
||||
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
|
||||
|
||||
name = (
|
||||
params.get("model")
|
||||
or params.get("model_name")
|
||||
or params.get("model_id")
|
||||
)
|
||||
|
||||
if not name and "anthropic" in params.get("_type"):
|
||||
name = "claude-2"
|
||||
|
||||
extra = {
|
||||
param: params.get(param)
|
||||
for param in PARAMS_TO_CAPTURE
|
||||
if params.get(param) is not None
|
||||
}
|
||||
|
||||
input = _parse_input(prompts)
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
|
||||
params = kwargs.get("invocation_params", {})
|
||||
params.update(
|
||||
serialized.get("kwargs", {})
|
||||
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
|
||||
|
||||
name = (
|
||||
params.get("model")
|
||||
or params.get("model_name")
|
||||
or params.get("model_id")
|
||||
)
|
||||
|
||||
if not name and "anthropic" in params.get("_type"):
|
||||
name = "claude-2"
|
||||
|
||||
extra = {
|
||||
param: params.get(param)
|
||||
for param in PARAMS_TO_CAPTURE
|
||||
if params.get(param) is not None
|
||||
}
|
||||
|
||||
input = _parse_lc_messages(messages[0])
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
extra=extra,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
|
||||
try:
|
||||
token_usage = (response.llm_output or {}).get("token_usage", {})
|
||||
|
||||
parsed_output: Any = [
|
||||
_parse_lc_message(generation.message)
|
||||
if hasattr(generation, "message")
|
||||
else generation.text
|
||||
for generation in response.generations[0]
|
||||
]
|
||||
|
||||
# if it's an array of 1, just parse the first element
|
||||
if len(parsed_output) == 1:
|
||||
parsed_output = parsed_output[0]
|
||||
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=parsed_output,
|
||||
token_usage={
|
||||
"prompt": token_usage.get("prompt_tokens"),
|
||||
"completion": token_usage.get("completion_tokens"),
|
||||
},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
name = serialized.get("name")
|
||||
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input_str,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
output = str(output)
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
tags: Union[List[str], None] = None,
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
name = serialized.get("id", [None, None, None, None])[3]
|
||||
type = "chain"
|
||||
metadata = metadata or {}
|
||||
|
||||
agentName = metadata.get("agent_name")
|
||||
if agentName is None:
|
||||
agentName = metadata.get("agentName")
|
||||
|
||||
if name == "AgentExecutor" or name == "PlanAndExecute":
|
||||
type = "agent"
|
||||
if agentName is not None:
|
||||
type = "agent"
|
||||
name = agentName
|
||||
if parent_run_id is not None:
|
||||
type = "chain"
|
||||
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
input = _parse_input(inputs)
|
||||
|
||||
self.__track_event(
|
||||
type,
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
output = _parse_output(outputs)
|
||||
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
name = action.tool
|
||||
input = _parse_input(action.tool_input)
|
||||
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
output = _parse_output(finish.return_values)
|
||||
|
||||
self.__track_event(
|
||||
"agent",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
app_id=self.__app_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
|
||||
|
||||
|
||||
__all__ = ["LLMonitorCallbackHandler", "identify"]
|
||||
104
venv/Lib/site-packages/langchain_community/callbacks/manager.py
Normal file
104
venv/Lib/site-packages/langchain_community/callbacks/manager.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
Generator,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain_core.tracers.context import register_configure_hook
|
||||
|
||||
from langchain_community.callbacks.bedrock_anthropic_callback import (
|
||||
BedrockAnthropicTokenUsageCallbackHandler,
|
||||
)
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain_community.callbacks.tracers.comet import CometTracer
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
bedrock_anthropic_callback_var: (ContextVar)[
|
||||
Optional[BedrockAnthropicTokenUsageCallbackHandler]
|
||||
] = ContextVar("bedrock_anthropic_callback", default=None)
|
||||
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar(
|
||||
"tracing_wandb_callback", default=None
|
||||
)
|
||||
comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar(
|
||||
"tracing_comet_callback", default=None
|
||||
)
|
||||
|
||||
register_configure_hook(openai_callback_var, True)
|
||||
register_configure_hook(bedrock_anthropic_callback_var, True)
|
||||
register_configure_hook(
|
||||
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
|
||||
)
|
||||
register_configure_hook(
|
||||
comet_tracing_callback_var, True, CometTracer, "LANGCHAIN_COMET_TRACING"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
||||
"""Get the OpenAI callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Returns:
|
||||
OpenAICallbackHandler: The OpenAI callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_openai_callback() as cb:
|
||||
... # Use the OpenAI callback handler
|
||||
"""
|
||||
cb = OpenAICallbackHandler()
|
||||
openai_callback_var.set(cb)
|
||||
yield cb
|
||||
openai_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_bedrock_anthropic_callback() -> Generator[
|
||||
BedrockAnthropicTokenUsageCallbackHandler, None, None
|
||||
]:
|
||||
"""Get the Bedrock anthropic callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Returns:
|
||||
BedrockAnthropicTokenUsageCallbackHandler:
|
||||
The Bedrock anthropic callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_bedrock_anthropic_callback() as cb:
|
||||
... # Use the Bedrock anthropic callback handler
|
||||
"""
|
||||
cb = BedrockAnthropicTokenUsageCallbackHandler()
|
||||
bedrock_anthropic_callback_var.set(cb)
|
||||
yield cb
|
||||
bedrock_anthropic_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wandb_tracing_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Get the WandbTracer in a context manager.
|
||||
|
||||
Args:
|
||||
session_name (str, optional): The name of the session.
|
||||
Defaults to "default".
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> with wandb_tracing_enabled() as session:
|
||||
... # Use the WandbTracer session
|
||||
"""
|
||||
cb = WandbTracer()
|
||||
wandb_tracing_callback_var.set(cb)
|
||||
yield None
|
||||
wandb_tracing_callback_var.set(None)
|
||||
@@ -0,0 +1,769 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import get_from_dict_or_env, guard_import
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_mlflow() -> Any:
|
||||
"""Import the mlflow python package and raise an error if it is not installed."""
|
||||
return guard_import("mlflow")
|
||||
|
||||
|
||||
def mlflow_callback_metrics() -> List[str]:
|
||||
"""Get the metrics to log to MLFlow."""
|
||||
return [
|
||||
"step",
|
||||
"starts",
|
||||
"ends",
|
||||
"errors",
|
||||
"text_ctr",
|
||||
"chain_starts",
|
||||
"chain_ends",
|
||||
"llm_starts",
|
||||
"llm_ends",
|
||||
"llm_streams",
|
||||
"tool_starts",
|
||||
"tool_ends",
|
||||
"agent_ends",
|
||||
"retriever_starts",
|
||||
"retriever_ends",
|
||||
]
|
||||
|
||||
|
||||
def get_text_complexity_metrics() -> List[str]:
|
||||
"""Get the text complexity metrics from textstat."""
|
||||
return [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
# "text_standard"
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
nlp: Any = None,
|
||||
textstat: Any = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
textstat: The textstat library to use for complexity metrics calculation.
|
||||
|
||||
Returns:
|
||||
`dict` containing the complexity metrics and visualization
|
||||
files serialized to HTML string.
|
||||
"""
|
||||
resp: Dict[str, Any] = {}
|
||||
if textstat is not None:
|
||||
text_complexity_metrics = {
|
||||
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
|
||||
}
|
||||
resp.update({"text_complexity_metrics": text_complexity_metrics})
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
"entities": ent_out,
|
||||
}
|
||||
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
|
||||
"""Construct an html element from a prompt and a generation.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt.
|
||||
generation (str): The generation.
|
||||
|
||||
Returns:
|
||||
(str): The html string."""
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
return f"""
|
||||
<p style="color:black;">{formatted_prompt}:</p>
|
||||
<blockquote>
|
||||
<p style="color:green;">
|
||||
{formatted_generation}
|
||||
</p>
|
||||
</blockquote>
|
||||
"""
|
||||
|
||||
|
||||
class MlflowLogger:
|
||||
"""Callback Handler that logs metrics and artifacts to mlflow server.
|
||||
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler implements the helper functions to initialize,
|
||||
log metrics and artifacts to the mlflow server.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.mlflow = import_mlflow()
|
||||
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
|
||||
self.mlflow.set_tracking_uri("databricks")
|
||||
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
|
||||
self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
|
||||
else:
|
||||
tracking_uri = get_from_dict_or_env(
|
||||
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
|
||||
)
|
||||
self.mlflow.set_tracking_uri(tracking_uri)
|
||||
|
||||
if run_id := kwargs.get("run_id"):
|
||||
self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id
|
||||
else:
|
||||
# User can set other env variables described here
|
||||
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
|
||||
|
||||
experiment_name = get_from_dict_or_env(
|
||||
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
|
||||
)
|
||||
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
|
||||
if self.mlf_exp is not None:
|
||||
self.mlf_expid = self.mlf_exp.experiment_id
|
||||
else:
|
||||
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
|
||||
|
||||
self.start_run(
|
||||
kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None)
|
||||
)
|
||||
self.dir = kwargs.get("artifacts_dir", "")
|
||||
|
||||
def start_run(
|
||||
self, name: str, tags: Dict[str, str], run_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
If run_id is provided, it will reuse the run with the given run_id.
|
||||
Otherwise, it starts a new run, auto generates the random suffix for name.
|
||||
"""
|
||||
if run_id is None:
|
||||
if name.endswith("-%"):
|
||||
rname = "".join(
|
||||
random.choices(string.ascii_uppercase + string.digits, k=7)
|
||||
)
|
||||
name = name[:-1] + rname
|
||||
run = self.mlflow.MlflowClient().create_run(
|
||||
self.mlf_expid, run_name=name, tags=tags
|
||||
)
|
||||
run_id = run.info.run_id
|
||||
self.run_id = run_id
|
||||
|
||||
def finish_run(self) -> None:
|
||||
"""To finish the run."""
|
||||
self.mlflow.end_run()
|
||||
|
||||
def metric(self, key: str, value: float) -> None:
|
||||
"""To log metric to mlflow server."""
|
||||
self.mlflow.log_metric(key, value, run_id=self.run_id)
|
||||
|
||||
def metrics(
|
||||
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
|
||||
) -> None:
|
||||
"""To log all metrics in the input dict."""
|
||||
self.mlflow.log_metrics(data, run_id=self.run_id)
|
||||
|
||||
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
self.mlflow.log_dict(
|
||||
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def table(self, name: str, dataframe: Any) -> None:
|
||||
"""To log the input pandas dataframe as a html table"""
|
||||
self.html(dataframe.to_html(), f"table_{name}")
|
||||
|
||||
def html(self, html: str, filename: str) -> None:
|
||||
"""To log the input html string as html file artifact."""
|
||||
self.mlflow.log_text(
|
||||
html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def text(self, text: str, filename: str) -> None:
|
||||
"""To log the input text as text file artifact."""
|
||||
self.mlflow.log_text(
|
||||
text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def artifact(self, path: str) -> None:
|
||||
"""To upload the file from given path as artifact."""
|
||||
self.mlflow.log_artifact(path, run_id=self.run_id)
|
||||
|
||||
def langchain_artifact(self, chain: Any) -> None:
|
||||
self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
|
||||
|
||||
|
||||
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs metrics and artifacts to mlflow server.
|
||||
|
||||
Parameters:
|
||||
name (str): Name of the run.
|
||||
experiment (str): Name of the experiment.
|
||||
tags (dict): Tags to be attached for the run.
|
||||
tracking_uri (str): MLflow tracking server uri.
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response to mlflow server.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = "langchainrun-%",
|
||||
experiment: Optional[str] = "langchain",
|
||||
tags: Optional[Dict] = None,
|
||||
tracking_uri: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
artifacts_dir: str = "",
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
import_pandas()
|
||||
import_mlflow()
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
self.experiment = experiment
|
||||
self.tags = tags or {}
|
||||
self.tracking_uri = tracking_uri
|
||||
self.run_id = run_id
|
||||
self.artifacts_dir = artifacts_dir
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
self.mlflg = MlflowLogger(
|
||||
tracking_uri=self.tracking_uri,
|
||||
experiment_name=self.experiment,
|
||||
run_name=self.name,
|
||||
run_tags=self.tags,
|
||||
run_id=self.run_id,
|
||||
artifacts_dir=self.artifacts_dir,
|
||||
)
|
||||
|
||||
self.action_records: list = []
|
||||
self.nlp = None
|
||||
try:
|
||||
spacy = import_spacy()
|
||||
except ImportError as e:
|
||||
logger.warning(e.msg)
|
||||
else:
|
||||
try:
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"Run `python -m spacy download en_core_web_sm` "
|
||||
"to download en_core_web_sm model for text visualization."
|
||||
)
|
||||
|
||||
try:
|
||||
self.textstat = import_textstat()
|
||||
except ImportError as e:
|
||||
logger.warning(e.msg)
|
||||
self.textstat = None
|
||||
|
||||
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
|
||||
|
||||
self.records: Dict[str, Any] = {
|
||||
"on_llm_start_records": [],
|
||||
"on_llm_token_records": [],
|
||||
"on_llm_end_records": [],
|
||||
"on_chain_start_records": [],
|
||||
"on_chain_end_records": [],
|
||||
"on_tool_start_records": [],
|
||||
"on_tool_end_records": [],
|
||||
"on_text_records": [],
|
||||
"on_agent_finish_records": [],
|
||||
"on_agent_action_records": [],
|
||||
"on_retriever_start_records": [],
|
||||
"on_retriever_end_records": [],
|
||||
"action_records": [],
|
||||
}
|
||||
|
||||
def _reset(self) -> None:
|
||||
for k, v in self.metrics.items():
|
||||
self.metrics[k] = 0
|
||||
for k, v in self.records.items():
|
||||
self.records[k] = []
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
llm_starts = self.metrics["llm_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompt"] = prompt
|
||||
self.records["on_llm_start_records"].append(prompt_resp)
|
||||
self.records["action_records"].append(prompt_resp)
|
||||
self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_streams"] += 1
|
||||
|
||||
llm_streams = self.metrics["llm_streams"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_llm_token_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
llm_ends = self.metrics["llm_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
for generations in response.generations:
|
||||
for idx, generation in enumerate(generations):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text,
|
||||
nlp=self.nlp,
|
||||
textstat=self.textstat,
|
||||
)
|
||||
)
|
||||
if "text_complexity_metrics" in generation_resp:
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
)
|
||||
self.mlflg.metrics(
|
||||
complexity_metrics,
|
||||
step=self.metrics["step"],
|
||||
)
|
||||
self.records["on_llm_end_records"].append(generation_resp)
|
||||
self.records["action_records"].append(generation_resp)
|
||||
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
|
||||
if "dependency_tree" in generation_resp:
|
||||
dependency_tree = generation_resp["dependency_tree"]
|
||||
self.mlflg.html(
|
||||
dependency_tree, "dep-" + hash_string(generation.text)
|
||||
)
|
||||
if "entities" in generation_resp:
|
||||
entities = generation_resp["entities"]
|
||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
chain_starts = self.metrics["chain_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
elif isinstance(inputs, list):
|
||||
chain_input = ",".join([str(input) for input in inputs])
|
||||
else:
|
||||
chain_input = str(inputs)
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
self.records["on_chain_start_records"].append(input_resp)
|
||||
self.records["action_records"].append(input_resp)
|
||||
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
if isinstance(outputs, dict):
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
elif isinstance(outputs, list):
|
||||
chain_output = ",".join(map(str, outputs))
|
||||
else:
|
||||
chain_output = str(outputs)
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_chain_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_tool_start_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
tool_ends = self.metrics["tool_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_tool_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when text is received.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
|
||||
text_ctr = self.metrics["text_ctr"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_text_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["agent_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
agent_ends = self.metrics["agent_ends"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_agent_finish_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
self.records["on_agent_action_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["retriever_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
retriever_starts = self.metrics["retriever_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_retriever_start", "query": query})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_retriever_start_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["retriever_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
retriever_ends = self.metrics["retriever_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
retriever_documents = [
|
||||
{
|
||||
"page_content": doc.page_content,
|
||||
"metadata": {
|
||||
k: (
|
||||
str(v)
|
||||
if not isinstance(v, list)
|
||||
else ",".join(str(x) for x in v)
|
||||
)
|
||||
for k, v in doc.metadata.items()
|
||||
},
|
||||
}
|
||||
for doc in documents
|
||||
]
|
||||
resp.update({"action": "on_retriever_end", "documents": retriever_documents})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
|
||||
|
||||
self.records["on_retriever_end_records"].append(resp)
|
||||
self.records["action_records"].append(resp)
|
||||
self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
|
||||
|
||||
def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
|
||||
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
|
||||
|
||||
llm_input_columns = ["step", "prompt"]
|
||||
if "name" in on_llm_start_records_df.columns:
|
||||
llm_input_columns.append("name")
|
||||
elif "id" in on_llm_start_records_df.columns:
|
||||
# id is llm class's full import path. For example:
|
||||
# ["langchain", "llms", "openai", "AzureOpenAI"]
|
||||
on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
|
||||
lambda id_: id_[-1]
|
||||
)
|
||||
llm_input_columns.append("name")
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[llm_input_columns]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = (
|
||||
get_text_complexity_metrics() if self.textstat is not None else []
|
||||
)
|
||||
visualizations_columns = (
|
||||
["dependency_tree", "entities"] if self.nlp is not None else []
|
||||
)
|
||||
|
||||
token_usage_columns = [
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
token_usage_columns = [
|
||||
x for x in token_usage_columns if x in on_llm_end_records_df.columns
|
||||
]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
]
|
||||
+ token_usage_columns
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
session_analysis_df["chat_html"] = session_analysis_df[
|
||||
["prompt", "output"]
|
||||
].apply(
|
||||
lambda row: construct_html_from_prompt_and_generation(
|
||||
row["prompt"], row["output"]
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def _contain_llm_records(self) -> bool:
|
||||
return bool(self.records["on_llm_start_records"])
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
pd = import_pandas()
|
||||
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
|
||||
if self._contain_llm_records():
|
||||
session_analysis_df = self._create_session_analysis_df()
|
||||
chat_html = session_analysis_df.pop("chat_html")
|
||||
chat_html = chat_html.replace("\n", "", regex=True)
|
||||
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
|
||||
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
|
||||
|
||||
if langchain_asset:
|
||||
# To avoid circular import error
|
||||
# mlflow only supports LLMChain asset
|
||||
if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)):
|
||||
self.mlflg.langchain_artifact(langchain_asset)
|
||||
else:
|
||||
langchain_asset_path = str(Path(self.temp_dir.name, "model.json"))
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
self.mlflg.artifact(langchain_asset_path)
|
||||
except ValueError:
|
||||
try:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
self.mlflg.artifact(langchain_asset_path)
|
||||
except AttributeError:
|
||||
print("Could not save model.") # noqa: T201
|
||||
traceback.print_exc()
|
||||
pass
|
||||
except NotImplementedError:
|
||||
print("Could not save model.") # noqa: T201
|
||||
traceback.print_exc()
|
||||
pass
|
||||
except NotImplementedError:
|
||||
print("Could not save model.") # noqa: T201
|
||||
traceback.print_exc()
|
||||
pass
|
||||
if finish:
|
||||
self.mlflg.finish_run()
|
||||
self._reset()
|
||||
@@ -0,0 +1,555 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-5 input
|
||||
"gpt-5": 0.00125,
|
||||
"gpt-5-cached": 0.000125,
|
||||
"gpt-5-2025-08-07": 0.00125,
|
||||
"gpt-5-2025-08-07-cached": 0.000125,
|
||||
# GPT-5 output
|
||||
"gpt-5-completion": 0.01,
|
||||
"gpt-5-2025-08-07-completion": 0.01,
|
||||
# GPT-5-mini input
|
||||
"gpt-5-mini": 0.00025,
|
||||
"gpt-5-mini-cached": 0.000025,
|
||||
"gpt-5-mini-2025-08-07": 0.00025,
|
||||
"gpt-5-mini-2025-08-07-cached": 0.000025,
|
||||
# GPT-5-mini output
|
||||
"gpt-5-mini-completion": 0.002,
|
||||
"gpt-5-mini-2025-08-07-completion": 0.002,
|
||||
# GPT-5-nano input
|
||||
"gpt-5-nano": 0.00005,
|
||||
"gpt-5-nano-cached": 0.000005,
|
||||
"gpt-5-nano-2025-08-07": 0.00005,
|
||||
"gpt-5-nano-2025-08-07-cached": 0.000005,
|
||||
# GPT-5-nano output
|
||||
"gpt-5-nano-completion": 0.0004,
|
||||
"gpt-5-nano-2025-08-07-completion": 0.0004,
|
||||
# GPT-5-chat-latest input
|
||||
"gpt-5-chat-latest": 0.00125,
|
||||
"gpt-5-chat-latest-cached": 0.000125,
|
||||
"gpt-5-chat-latest-2025-08-07": 0.00125,
|
||||
"gpt-5-chat-latest-2025-08-07-cached": 0.000125,
|
||||
# GPT-5-chat-latest output
|
||||
"gpt-5-chat-latest-completion": 0.01,
|
||||
"gpt-5-chat-latest-2025-08-07-completion": 0.01,
|
||||
# GPT-4.1 input
|
||||
"gpt-4.1": 0.002,
|
||||
"gpt-4.1-2025-04-14": 0.002,
|
||||
"gpt-4.1-cached": 0.0005,
|
||||
"gpt-4.1-2025-04-14-cached": 0.0005,
|
||||
# GPT-4.1 output
|
||||
"gpt-4.1-completion": 0.008,
|
||||
"gpt-4.1-2025-04-14-completion": 0.008,
|
||||
# GPT-4.1-mini input
|
||||
"gpt-4.1-mini": 0.0004,
|
||||
"gpt-4.1-mini-2025-04-14": 0.0004,
|
||||
"gpt-4.1-mini-cached": 0.0001,
|
||||
"gpt-4.1-mini-2025-04-14-cached": 0.0001,
|
||||
# GPT-4.1-mini output
|
||||
"gpt-4.1-mini-completion": 0.0016,
|
||||
"gpt-4.1-mini-2025-04-14-completion": 0.0016,
|
||||
# GPT-4.1-nano input
|
||||
"gpt-4.1-nano": 0.0001,
|
||||
"gpt-4.1-nano-2025-04-14": 0.0001,
|
||||
"gpt-4.1-nano-cached": 0.000025,
|
||||
"gpt-4.1-nano-2025-04-14-cached": 0.000025,
|
||||
# GPT-4.1-nano output
|
||||
"gpt-4.1-nano-completion": 0.0004,
|
||||
"gpt-4.1-nano-2025-04-14-completion": 0.0004,
|
||||
# GPT-4.5-preview input
|
||||
"gpt-4.5-preview": 0.075,
|
||||
"gpt-4.5-preview-2025-02-27": 0.075,
|
||||
"gpt-4.5-preview-cached": 0.0375,
|
||||
"gpt-4.5-preview-2025-02-27-cached": 0.0375,
|
||||
# GPT-4.5-preview output
|
||||
"gpt-4.5-preview-completion": 0.15,
|
||||
"gpt-4.5-preview-2025-02-27-completion": 0.15,
|
||||
# OpenAI o1 input
|
||||
"o1": 0.015,
|
||||
"o1-2024-12-17": 0.015,
|
||||
"o1-cached": 0.0075,
|
||||
"o1-2024-12-17-cached": 0.0075,
|
||||
# OpenAI o1 output
|
||||
"o1-completion": 0.06,
|
||||
"o1-2024-12-17-completion": 0.06,
|
||||
# OpenAI o1-pro input
|
||||
"o1-pro": 0.15,
|
||||
"o1-pro-2025-03-19": 0.15,
|
||||
# OpenAI o1-pro output
|
||||
"o1-pro-completion": 0.6,
|
||||
"o1-pro-2025-03-19-completion": 0.6,
|
||||
# OpenAI o3 input
|
||||
"o3": 0.002,
|
||||
"o3-2025-04-16": 0.002,
|
||||
"o3-cached": 0.0005,
|
||||
"o3-2025-04-16-cached": 0.0005,
|
||||
# OpenAI o3 output
|
||||
"o3-completion": 0.008,
|
||||
"o3-2025-04-16-completion": 0.008,
|
||||
# OpenAI o4-mini input
|
||||
"o4-mini": 0.0011,
|
||||
"o4-mini-2025-04-16": 0.0011,
|
||||
"o4-mini-cached": 0.000275,
|
||||
"o4-mini-2025-04-16-cached": 0.000275,
|
||||
# OpenAI o4-mini output
|
||||
"o4-mini-completion": 0.0044,
|
||||
"o4-mini-2025-04-16-completion": 0.0044,
|
||||
# OpenAI o3-mini input
|
||||
"o3-mini": 0.0011,
|
||||
"o3-mini-2025-01-31": 0.0011,
|
||||
"o3-mini-cached": 0.00055,
|
||||
"o3-mini-2025-01-31-cached": 0.00055,
|
||||
# OpenAI o3-mini output
|
||||
"o3-mini-completion": 0.0044,
|
||||
"o3-mini-2025-01-31-completion": 0.0044,
|
||||
# OpenAI o1-mini input (updated pricing)
|
||||
"o1-mini": 0.0011,
|
||||
"o1-mini-cached": 0.00055,
|
||||
"o1-mini-2024-09-12": 0.0011,
|
||||
"o1-mini-2024-09-12-cached": 0.00055,
|
||||
# OpenAI o1-mini output (updated pricing)
|
||||
"o1-mini-completion": 0.0044,
|
||||
"o1-mini-2024-09-12-completion": 0.0044,
|
||||
# OpenAI o1-preview input
|
||||
"o1-preview": 0.015,
|
||||
"o1-preview-cached": 0.0075,
|
||||
"o1-preview-2024-09-12": 0.015,
|
||||
"o1-preview-2024-09-12-cached": 0.0075,
|
||||
# OpenAI o1-preview output
|
||||
"o1-preview-completion": 0.06,
|
||||
"o1-preview-2024-09-12-completion": 0.06,
|
||||
# GPT-4o input
|
||||
"gpt-4o": 0.0025,
|
||||
"gpt-4o-cached": 0.00125,
|
||||
"gpt-4o-2024-05-13": 0.005,
|
||||
"gpt-4o-2024-08-06": 0.0025,
|
||||
"gpt-4o-2024-08-06-cached": 0.00125,
|
||||
"gpt-4o-2024-11-20": 0.0025,
|
||||
"gpt-4o-2024-11-20-cached": 0.00125,
|
||||
# GPT-4o output
|
||||
"gpt-4o-completion": 0.01,
|
||||
"gpt-4o-2024-05-13-completion": 0.015,
|
||||
"gpt-4o-2024-08-06-completion": 0.01,
|
||||
"gpt-4o-2024-11-20-completion": 0.01,
|
||||
# GPT-4o-audio-preview input
|
||||
"gpt-4o-audio-preview": 0.0025,
|
||||
"gpt-4o-audio-preview-2024-12-17": 0.0025,
|
||||
"gpt-4o-audio-preview-2024-10-01": 0.0025,
|
||||
# GPT-4o-audio-preview output
|
||||
"gpt-4o-audio-preview-completion": 0.01,
|
||||
"gpt-4o-audio-preview-2024-12-17-completion": 0.01,
|
||||
"gpt-4o-audio-preview-2024-10-01-completion": 0.01,
|
||||
# GPT-4o-realtime-preview input
|
||||
"gpt-4o-realtime-preview": 0.005,
|
||||
"gpt-4o-realtime-preview-2024-12-17": 0.005,
|
||||
"gpt-4o-realtime-preview-2024-10-01": 0.005,
|
||||
"gpt-4o-realtime-preview-cached": 0.0025,
|
||||
"gpt-4o-realtime-preview-2024-12-17-cached": 0.0025,
|
||||
"gpt-4o-realtime-preview-2024-10-01-cached": 0.0025,
|
||||
# GPT-4o-realtime-preview output
|
||||
"gpt-4o-realtime-preview-completion": 0.02,
|
||||
"gpt-4o-realtime-preview-2024-12-17-completion": 0.02,
|
||||
"gpt-4o-realtime-preview-2024-10-01-completion": 0.02,
|
||||
# GPT-4o-mini input
|
||||
"gpt-4o-mini": 0.00015,
|
||||
"gpt-4o-mini-cached": 0.000075,
|
||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||
"gpt-4o-mini-2024-07-18-cached": 0.000075,
|
||||
# GPT-4o-mini output
|
||||
"gpt-4o-mini-completion": 0.0006,
|
||||
"gpt-4o-mini-2024-07-18-completion": 0.0006,
|
||||
# GPT-4o-mini-audio-preview input
|
||||
"gpt-4o-mini-audio-preview": 0.00015,
|
||||
"gpt-4o-mini-audio-preview-2024-12-17": 0.00015,
|
||||
# GPT-4o-mini-audio-preview output
|
||||
"gpt-4o-mini-audio-preview-completion": 0.0006,
|
||||
"gpt-4o-mini-audio-preview-2024-12-17-completion": 0.0006,
|
||||
# GPT-4o-mini-realtime-preview input
|
||||
"gpt-4o-mini-realtime-preview": 0.0006,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17": 0.0006,
|
||||
"gpt-4o-mini-realtime-preview-cached": 0.0003,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17-cached": 0.0003,
|
||||
# GPT-4o-mini-realtime-preview output
|
||||
"gpt-4o-mini-realtime-preview-completion": 0.0024,
|
||||
"gpt-4o-mini-realtime-preview-2024-12-17-completion": 0.0024,
|
||||
# GPT-4o-mini-search-preview input
|
||||
"gpt-4o-mini-search-preview": 0.00015,
|
||||
"gpt-4o-mini-search-preview-2025-03-11": 0.00015,
|
||||
# GPT-4o-mini-search-preview output
|
||||
"gpt-4o-mini-search-preview-completion": 0.0006,
|
||||
"gpt-4o-mini-search-preview-2025-03-11-completion": 0.0006,
|
||||
# GPT-4o-search-preview input
|
||||
"gpt-4o-search-preview": 0.0025,
|
||||
"gpt-4o-search-preview-2025-03-11": 0.0025,
|
||||
# GPT-4o-search-preview output
|
||||
"gpt-4o-search-preview-completion": 0.01,
|
||||
"gpt-4o-search-preview-2025-03-11-completion": 0.01,
|
||||
# Computer-use-preview input
|
||||
"computer-use-preview": 0.003,
|
||||
"computer-use-preview-2025-03-11": 0.003,
|
||||
# Computer-use-preview output
|
||||
"computer-use-preview-completion": 0.012,
|
||||
"computer-use-preview-2025-03-11-completion": 0.012,
|
||||
# GPT-4 input
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0314": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
"gpt-4-32k-0314": 0.06,
|
||||
"gpt-4-32k-0613": 0.06,
|
||||
"gpt-4-vision-preview": 0.01,
|
||||
"gpt-4-1106-preview": 0.01,
|
||||
"gpt-4-0125-preview": 0.01,
|
||||
"gpt-4-turbo-preview": 0.01,
|
||||
"gpt-4-turbo": 0.01,
|
||||
"gpt-4-turbo-2024-04-09": 0.01,
|
||||
# GPT-4 output
|
||||
"gpt-4-completion": 0.06,
|
||||
"gpt-4-0314-completion": 0.06,
|
||||
"gpt-4-0613-completion": 0.06,
|
||||
"gpt-4-32k-completion": 0.12,
|
||||
"gpt-4-32k-0314-completion": 0.12,
|
||||
"gpt-4-32k-0613-completion": 0.12,
|
||||
"gpt-4-vision-preview-completion": 0.03,
|
||||
"gpt-4-1106-preview-completion": 0.03,
|
||||
"gpt-4-0125-preview-completion": 0.03,
|
||||
"gpt-4-turbo-preview-completion": 0.03,
|
||||
"gpt-4-turbo-completion": 0.03,
|
||||
"gpt-4-turbo-2024-04-09-completion": 0.03,
|
||||
# GPT-3.5 input
|
||||
# gpt-3.5-turbo points at gpt-3.5-turbo-0613 until Feb 16, 2024.
|
||||
# Switches to gpt-3.5-turbo-0125 after.
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-0125": 0.0005,
|
||||
"gpt-3.5-turbo-0301": 0.0015,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-1106": 0.001,
|
||||
"gpt-3.5-turbo-instruct": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
# GPT-3.5 output
|
||||
# gpt-3.5-turbo points at gpt-3.5-turbo-0613 until Feb 16, 2024.
|
||||
# Switches to gpt-3.5-turbo-0125 after.
|
||||
"gpt-3.5-turbo-completion": 0.002,
|
||||
"gpt-3.5-turbo-0125-completion": 0.0015,
|
||||
"gpt-3.5-turbo-0301-completion": 0.002,
|
||||
"gpt-3.5-turbo-0613-completion": 0.002,
|
||||
"gpt-3.5-turbo-1106-completion": 0.002,
|
||||
"gpt-3.5-turbo-instruct-completion": 0.002,
|
||||
"gpt-3.5-turbo-16k-completion": 0.004,
|
||||
"gpt-3.5-turbo-16k-0613-completion": 0.004,
|
||||
# Azure GPT-35 input
|
||||
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0125": 0.0005,
|
||||
"gpt-35-turbo-0301": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613": 0.0015,
|
||||
"gpt-35-turbo-instruct": 0.0015,
|
||||
"gpt-35-turbo-16k": 0.003,
|
||||
"gpt-35-turbo-16k-0613": 0.003,
|
||||
# Azure GPT-35 output
|
||||
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0125-completion": 0.0015,
|
||||
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
|
||||
"gpt-35-turbo-0613-completion": 0.002,
|
||||
"gpt-35-turbo-instruct-completion": 0.002,
|
||||
"gpt-35-turbo-16k-completion": 0.004,
|
||||
"gpt-35-turbo-16k-0613-completion": 0.004,
|
||||
# Others
|
||||
"text-ada-001": 0.0004,
|
||||
"ada": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
"babbage": 0.0005,
|
||||
"text-curie-001": 0.002,
|
||||
"curie": 0.002,
|
||||
"text-davinci-003": 0.02,
|
||||
"text-davinci-002": 0.02,
|
||||
"code-davinci-002": 0.02,
|
||||
# Fine Tuned input
|
||||
"babbage-002-finetuned": 0.0016,
|
||||
"davinci-002-finetuned": 0.012,
|
||||
"gpt-3.5-turbo-0613-finetuned": 0.003,
|
||||
"gpt-3.5-turbo-1106-finetuned": 0.003,
|
||||
"gpt-3.5-turbo-0125-finetuned": 0.003,
|
||||
"gpt-4o-mini-2024-07-18-finetuned": 0.0003,
|
||||
"gpt-4o-mini-2024-07-18-finetuned-cached": 0.00015,
|
||||
# Fine Tuned output
|
||||
"babbage-002-finetuned-completion": 0.0016,
|
||||
"davinci-002-finetuned-completion": 0.012,
|
||||
"gpt-3.5-turbo-0613-finetuned-completion": 0.006,
|
||||
"gpt-3.5-turbo-1106-finetuned-completion": 0.006,
|
||||
"gpt-3.5-turbo-0125-finetuned-completion": 0.006,
|
||||
"gpt-4o-mini-2024-07-18-finetuned-completion": 0.0012,
|
||||
# Azure Fine Tuned input
|
||||
"babbage-002-azure-finetuned": 0.0004,
|
||||
"davinci-002-azure-finetuned": 0.002,
|
||||
"gpt-35-turbo-0613-azure-finetuned": 0.0015,
|
||||
# Azure Fine Tuned output
|
||||
"babbage-002-azure-finetuned-completion": 0.0004,
|
||||
"davinci-002-azure-finetuned-completion": 0.002,
|
||||
"gpt-35-turbo-0613-azure-finetuned-completion": 0.002,
|
||||
# Legacy fine-tuned models
|
||||
"ada-finetuned-legacy": 0.0016,
|
||||
"babbage-finetuned-legacy": 0.0024,
|
||||
"curie-finetuned-legacy": 0.012,
|
||||
"davinci-finetuned-legacy": 0.12,
|
||||
}
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
"""Token type enum."""
|
||||
|
||||
PROMPT = auto()
|
||||
PROMPT_CACHED = auto()
|
||||
COMPLETION = auto()
|
||||
|
||||
|
||||
def standardize_model_name(
|
||||
model_name: str,
|
||||
is_completion: bool = False,
|
||||
*,
|
||||
token_type: TokenType = TokenType.PROMPT,
|
||||
) -> str:
|
||||
"""
|
||||
Standardize the model name to a format that can be used in the OpenAI API.
|
||||
|
||||
Args:
|
||||
model_name: Model name to standardize.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False. Deprecated in favor of ``token_type``.
|
||||
token_type: Token type. Defaults to ``TokenType.PROMPT``.
|
||||
|
||||
Returns:
|
||||
Standardized model name.
|
||||
|
||||
"""
|
||||
if is_completion:
|
||||
warn_deprecated(
|
||||
since="0.3.13",
|
||||
message=(
|
||||
"is_completion is deprecated. Use token_type instead. Example:\n\n"
|
||||
"from langchain_community.callbacks.openai_info import TokenType\n\n"
|
||||
"standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n"
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
token_type = TokenType.COMPLETION
|
||||
model_name = model_name.lower()
|
||||
if ".ft-" in model_name:
|
||||
model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
|
||||
if ":ft-" in model_name:
|
||||
model_name = model_name.split(":")[0] + "-finetuned-legacy"
|
||||
if "ft:" in model_name:
|
||||
model_name = model_name.split(":")[1] + "-finetuned"
|
||||
if token_type == TokenType.COMPLETION and (
|
||||
model_name.startswith("gpt-5")
|
||||
or model_name.startswith("gpt-4")
|
||||
or model_name.startswith("gpt-3.5")
|
||||
or model_name.startswith("gpt-35")
|
||||
or model_name.startswith("o1-")
|
||||
or model_name.startswith("o3-")
|
||||
or model_name.startswith("o4-")
|
||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||
):
|
||||
return model_name + "-completion"
|
||||
if (
|
||||
token_type == TokenType.PROMPT_CACHED
|
||||
and (
|
||||
model_name.startswith("gpt-5")
|
||||
or model_name.startswith("gpt-4o")
|
||||
or model_name.startswith("gpt-4.1")
|
||||
or model_name.startswith("o1")
|
||||
or model_name.startswith("o3")
|
||||
or model_name.startswith("o4")
|
||||
)
|
||||
and not (model_name.startswith("gpt-4o-2024-05-13"))
|
||||
):
|
||||
return model_name + "-cached"
|
||||
else:
|
||||
return model_name
|
||||
|
||||
|
||||
def get_openai_token_cost_for_model(
|
||||
model_name: str,
|
||||
num_tokens: int,
|
||||
is_completion: bool = False,
|
||||
*,
|
||||
token_type: TokenType = TokenType.PROMPT,
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False. Deprecated in favor of ``token_type``.
|
||||
token_type: Token type. Defaults to ``TokenType.PROMPT``.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
if is_completion:
|
||||
warn_deprecated(
|
||||
since="0.3.13",
|
||||
message=(
|
||||
"is_completion is deprecated. Use token_type instead. Example:\n\n"
|
||||
"from langchain_community.callbacks.openai_info import TokenType\n\n"
|
||||
"get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501
|
||||
),
|
||||
removal="1.0",
|
||||
)
|
||||
token_type = TokenType.COMPLETION
|
||||
model_name = standardize_model_name(model_name, token_type=token_type)
|
||||
if model_name not in MODEL_COST_PER_1K_TOKENS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
|
||||
"Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
|
||||
)
|
||||
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
|
||||
|
||||
|
||||
class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks OpenAI info."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
prompt_tokens_cached: int = 0
|
||||
completion_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\t\tPrompt Tokens Cached: {self.prompt_tokens_cached}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"\t\tReasoning Tokens: {self.reasoning_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Print out the token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except IndexError:
|
||||
generation = None
|
||||
if isinstance(generation, ChatGeneration):
|
||||
try:
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage):
|
||||
usage_metadata = message.usage_metadata
|
||||
response_metadata = message.response_metadata
|
||||
else:
|
||||
usage_metadata = None
|
||||
response_metadata = None
|
||||
except AttributeError:
|
||||
usage_metadata = None
|
||||
response_metadata = None
|
||||
else:
|
||||
usage_metadata = None
|
||||
response_metadata = None
|
||||
|
||||
prompt_tokens_cached = 0
|
||||
reasoning_tokens = 0
|
||||
|
||||
if usage_metadata:
|
||||
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
||||
completion_tokens = usage_metadata["output_tokens"]
|
||||
prompt_tokens = usage_metadata["input_tokens"]
|
||||
if response_model_name := (response_metadata or {}).get("model_name"):
|
||||
model_name = standardize_model_name(response_model_name)
|
||||
elif response.llm_output is None:
|
||||
model_name = ""
|
||||
else:
|
||||
model_name = standardize_model_name(
|
||||
response.llm_output.get("model_name", "")
|
||||
)
|
||||
if "cache_read" in usage_metadata.get("input_token_details", {}):
|
||||
prompt_tokens_cached = usage_metadata["input_token_details"][
|
||||
"cache_read"
|
||||
]
|
||||
if "reasoning" in usage_metadata.get("output_token_details", {}):
|
||||
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"]
|
||||
else:
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
|
||||
if "token_usage" not in response.llm_output:
|
||||
with self._lock:
|
||||
self.successful_requests += 1
|
||||
return None
|
||||
|
||||
# compute tokens and cost for this request
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
model_name = standardize_model_name(
|
||||
response.llm_output.get("model_name", "")
|
||||
)
|
||||
|
||||
if model_name in MODEL_COST_PER_1K_TOKENS:
|
||||
uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached
|
||||
uncached_prompt_cost = get_openai_token_cost_for_model(
|
||||
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT
|
||||
)
|
||||
cached_prompt_cost = get_openai_token_cost_for_model(
|
||||
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED
|
||||
)
|
||||
prompt_cost = uncached_prompt_cost + cached_prompt_cost
|
||||
completion_cost = get_openai_token_cost_for_model(
|
||||
model_name, completion_tokens, token_type=TokenType.COMPLETION
|
||||
)
|
||||
else:
|
||||
completion_cost = 0
|
||||
prompt_cost = 0
|
||||
|
||||
# update shared state behind lock
|
||||
with self._lock:
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.prompt_tokens_cached += prompt_tokens_cached
|
||||
self.completion_tokens += completion_tokens
|
||||
self.reasoning_tokens += reasoning_tokens
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
||||
@@ -0,0 +1,163 @@
|
||||
"""Callback handler for promptlayer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import promptlayer
|
||||
|
||||
|
||||
def _lazy_import_promptlayer() -> promptlayer:
|
||||
"""Lazy import promptlayer to avoid circular imports."""
|
||||
try:
|
||||
import promptlayer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The PromptLayerCallbackHandler requires the promptlayer package. "
|
||||
" Please install it with `pip install promptlayer`."
|
||||
)
|
||||
return promptlayer
|
||||
|
||||
|
||||
class PromptLayerCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for promptlayer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pl_id_callback: Optional[Callable[..., Any]] = None,
|
||||
pl_tags: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the PromptLayerCallbackHandler."""
|
||||
_lazy_import_promptlayer()
|
||||
self.pl_id_callback = pl_id_callback
|
||||
self.pl_tags = pl_tags or []
|
||||
self.runs: Dict[UUID, Dict[str, Any]] = {}
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.runs[run_id] = {
|
||||
"messages": [self._create_message_dicts(m)[0] for m in messages],
|
||||
"invocation_params": kwargs.get("invocation_params", {}),
|
||||
"name": ".".join(serialized["id"]),
|
||||
"request_start_time": datetime.datetime.now().timestamp(),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.runs[run_id] = {
|
||||
"prompts": prompts,
|
||||
"invocation_params": kwargs.get("invocation_params", {}),
|
||||
"name": ".".join(serialized["id"]),
|
||||
"request_start_time": datetime.datetime.now().timestamp(),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
from promptlayer.utils import get_api_key, promptlayer_api_request
|
||||
|
||||
run_info = self.runs.get(run_id, {})
|
||||
if not run_info:
|
||||
return
|
||||
run_info["request_end_time"] = datetime.datetime.now().timestamp()
|
||||
for i in range(len(response.generations)):
|
||||
generation = response.generations[i][0]
|
||||
|
||||
resp = {
|
||||
"text": generation.text,
|
||||
"llm_output": response.llm_output,
|
||||
}
|
||||
model_params = run_info.get("invocation_params", {})
|
||||
is_chat_model = run_info.get("messages", None) is not None
|
||||
model_input = (
|
||||
run_info.get("messages", [])[i]
|
||||
if is_chat_model
|
||||
else [run_info.get("prompts", [])[i]]
|
||||
)
|
||||
model_response = (
|
||||
[self._convert_message_to_dict(generation.message)]
|
||||
if is_chat_model and isinstance(generation, ChatGeneration)
|
||||
else resp
|
||||
)
|
||||
|
||||
pl_request_id = promptlayer_api_request(
|
||||
run_info.get("name"),
|
||||
"langchain",
|
||||
model_input,
|
||||
model_params,
|
||||
self.pl_tags,
|
||||
model_response,
|
||||
run_info.get("request_start_time"),
|
||||
run_info.get("request_end_time"),
|
||||
get_api_key(),
|
||||
return_pl_id=bool(self.pl_id_callback is not None),
|
||||
metadata={
|
||||
"_langchain_run_id": str(run_id),
|
||||
"_langchain_parent_run_id": str(parent_run_id),
|
||||
"_langchain_tags": str(run_info.get("tags", [])),
|
||||
},
|
||||
)
|
||||
|
||||
if self.pl_id_callback:
|
||||
self.pl_id_callback(pl_request_id)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {}
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
@@ -0,0 +1,277 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
flatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def save_json(data: dict, file_path: str) -> None:
|
||||
"""Save dict to local file path.
|
||||
|
||||
Parameters:
|
||||
data (dict): The dictionary to be saved.
|
||||
file_path (str): Local file path.
|
||||
"""
|
||||
with open(file_path, "w") as outfile:
|
||||
json.dump(data, outfile)
|
||||
|
||||
|
||||
class SageMakerCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
|
||||
|
||||
Parameters:
|
||||
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
|
||||
"""
|
||||
|
||||
def __init__(self, run: Any) -> None:
|
||||
"""Initialize callback handler."""
|
||||
super().__init__()
|
||||
|
||||
self.run = run
|
||||
|
||||
self.metrics = {
|
||||
"step": 0,
|
||||
"starts": 0,
|
||||
"ends": 0,
|
||||
"errors": 0,
|
||||
"text_ctr": 0,
|
||||
"chain_starts": 0,
|
||||
"chain_ends": 0,
|
||||
"llm_starts": 0,
|
||||
"llm_ends": 0,
|
||||
"llm_streams": 0,
|
||||
"tool_starts": 0,
|
||||
"tool_ends": 0,
|
||||
"agent_ends": 0,
|
||||
}
|
||||
|
||||
# Create a temporary directory
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
def _reset(self) -> None:
|
||||
for k, v in self.metrics.items():
|
||||
self.metrics[k] = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
llm_starts = self.metrics["llm_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompt"] = prompt
|
||||
self.jsonf(
|
||||
prompt_resp,
|
||||
self.temp_dir,
|
||||
f"llm_start_{llm_starts}_prompt_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_streams"] += 1
|
||||
|
||||
llm_streams = self.metrics["llm_streams"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["llm_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
llm_ends = self.metrics["llm_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
|
||||
resp.update(self.metrics)
|
||||
|
||||
for generations in response.generations:
|
||||
for idx, generation in enumerate(generations):
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
|
||||
self.jsonf(
|
||||
resp,
|
||||
self.temp_dir,
|
||||
f"llm_end_{llm_ends}_generation_{idx}",
|
||||
)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
chain_starts = self.metrics["chain_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["inputs"] = chain_input
|
||||
|
||||
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["chain_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
chain_ends = self.metrics["chain_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
|
||||
resp.update({"action": "on_chain_end", "outputs": chain_output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
tool_ends = self.metrics["tool_ends"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["text_ctr"] += 1
|
||||
|
||||
text_ctr = self.metrics["text_ctr"]
|
||||
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["agent_ends"] += 1
|
||||
self.metrics["ends"] += 1
|
||||
|
||||
agent_ends = self.metrics["agent_ends"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
|
||||
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.metrics["step"] += 1
|
||||
self.metrics["tool_starts"] += 1
|
||||
self.metrics["starts"] += 1
|
||||
|
||||
tool_starts = self.metrics["tool_starts"]
|
||||
resp: Dict[str, Any] = {}
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.metrics)
|
||||
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
|
||||
|
||||
def jsonf(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
data_dir: str,
|
||||
filename: str,
|
||||
is_output: Optional[bool] = True,
|
||||
) -> None:
|
||||
"""To log the input data as json file artifact."""
|
||||
file_path = os.path.join(data_dir, f"{filename}.json")
|
||||
save_json(data, file_path)
|
||||
self.run.log_file(file_path, name=filename, is_output=is_output)
|
||||
|
||||
def flush_tracker(self) -> None:
|
||||
"""Reset the steps and delete the temporary local directory."""
|
||||
self._reset()
|
||||
shutil.rmtree(self.temp_dir)
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
|
||||
LLMThoughtLabeler as LLMThoughtLabeler,
|
||||
)
|
||||
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
|
||||
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def StreamlitCallbackHandler(
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
) -> BaseCallbackHandler:
|
||||
"""Callback Handler that writes to a Streamlit app.
|
||||
|
||||
This CallbackHandler is geared towards
|
||||
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
|
||||
inside a series of Streamlit expanders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When this
|
||||
threshold is reached, a new thought will cause the oldest thoughts to be
|
||||
collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether that
|
||||
expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A new StreamlitCallbackHandler instance.
|
||||
|
||||
Note that this is an "auto-updating" API: if the installed version of Streamlit
|
||||
has a more recent StreamlitCallbackHandler implementation, an instance of that class
|
||||
will be used.
|
||||
|
||||
"""
|
||||
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
|
||||
# delegate to it instead of using our built-in handler. The official handler is
|
||||
# guaranteed to support the same set of kwargs.
|
||||
try:
|
||||
from streamlit.external.langchain import (
|
||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler,
|
||||
)
|
||||
|
||||
return OfficialStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
except ImportError:
|
||||
return _InternalStreamlitCallbackHandler(
|
||||
parent_container,
|
||||
max_thought_containers=max_thought_containers,
|
||||
expand_new_thoughts=expand_new_thoughts,
|
||||
collapse_completed_thoughts=collapse_completed_thoughts,
|
||||
thought_labeler=thought_labeler,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.type_util import SupportsStr
|
||||
|
||||
|
||||
class ChildType(Enum):
|
||||
"""Enumerator of the child type."""
|
||||
|
||||
MARKDOWN = "MARKDOWN"
|
||||
EXCEPTION = "EXCEPTION"
|
||||
|
||||
|
||||
class ChildRecord(NamedTuple):
|
||||
"""Child record as a NamedTuple."""
|
||||
|
||||
type: ChildType
|
||||
kwargs: Dict[str, Any]
|
||||
dg: DeltaGenerator
|
||||
|
||||
|
||||
class MutableExpander:
|
||||
"""Streamlit expander that can be renamed and dynamically expanded/collapsed."""
|
||||
|
||||
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
|
||||
"""Create a new MutableExpander.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that the expander will be created inside.
|
||||
|
||||
The expander transparently deletes and recreates its underlying
|
||||
`st.expander` instance when its label changes, and it uses
|
||||
`parent_container` to ensure it recreates this underlying expander in the
|
||||
same location onscreen.
|
||||
label
|
||||
The expander's initial label.
|
||||
expanded
|
||||
The expander's initial `expanded` value.
|
||||
"""
|
||||
self._label = label
|
||||
self._expanded = expanded
|
||||
self._parent_cursor = parent_container.empty()
|
||||
self._container = self._parent_cursor.expander(label, expanded)
|
||||
self._child_records: List[ChildRecord] = []
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
"""Expander's label string."""
|
||||
return self._label
|
||||
|
||||
@property
|
||||
def expanded(self) -> bool:
|
||||
"""True if the expander was created with `expanded=True`."""
|
||||
return self._expanded
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the container and its contents entirely. A cleared container can't
|
||||
be reused.
|
||||
"""
|
||||
self._container = self._parent_cursor.empty()
|
||||
self._child_records.clear()
|
||||
|
||||
def append_copy(self, other: MutableExpander) -> None:
|
||||
"""Append a copy of another MutableExpander's children to this
|
||||
MutableExpander.
|
||||
"""
|
||||
other_records = other._child_records.copy()
|
||||
for record in other_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def update(
|
||||
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
|
||||
) -> None:
|
||||
"""Change the expander's label and expanded state"""
|
||||
if new_label is None:
|
||||
new_label = self._label
|
||||
if new_expanded is None:
|
||||
new_expanded = self._expanded
|
||||
|
||||
if self._label == new_label and self._expanded == new_expanded:
|
||||
# No change!
|
||||
return
|
||||
|
||||
self._label = new_label
|
||||
self._expanded = new_expanded
|
||||
self._container = self._parent_cursor.expander(new_label, new_expanded)
|
||||
|
||||
prev_records = self._child_records
|
||||
self._child_records = []
|
||||
|
||||
# Replay all children into the new container
|
||||
for record in prev_records:
|
||||
self._create_child(record.type, record.kwargs)
|
||||
|
||||
def markdown(
|
||||
self,
|
||||
body: SupportsStr,
|
||||
unsafe_allow_html: bool = False,
|
||||
*,
|
||||
help: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Add a Markdown element to the container and return its index."""
|
||||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||
new_dg = self._get_dg(index).markdown(**kwargs)
|
||||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def exception(
|
||||
self, exception: BaseException, *, index: Optional[int] = None
|
||||
) -> int:
|
||||
"""Add an Exception element to the container and return its index."""
|
||||
kwargs = {"exception": exception}
|
||||
new_dg = self._get_dg(index).exception(**kwargs)
|
||||
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
|
||||
"""Create a new child with the given params"""
|
||||
if type == ChildType.MARKDOWN:
|
||||
self.markdown(**kwargs)
|
||||
elif type == ChildType.EXCEPTION:
|
||||
self.exception(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected child type {type}")
|
||||
|
||||
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
|
||||
"""Add a ChildRecord to self._children. If `index` is specified, replace
|
||||
the existing record at that index. Otherwise, append the record to the
|
||||
end of the list.
|
||||
|
||||
Return the index of the added record.
|
||||
"""
|
||||
if index is not None:
|
||||
# Replace existing child
|
||||
self._child_records[index] = record
|
||||
return index
|
||||
|
||||
# Append new child
|
||||
self._child_records.append(record)
|
||||
return len(self._child_records) - 1
|
||||
|
||||
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
|
||||
if index is not None:
|
||||
# Existing index: reuse child's DeltaGenerator
|
||||
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
|
||||
return self._child_records[index].dg
|
||||
|
||||
# No index: use container's DeltaGenerator
|
||||
return self._container
|
||||
@@ -0,0 +1,419 @@
|
||||
"""Callback Handler that prints to streamlit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def _convert_newlines(text: str) -> str:
|
||||
"""Convert newline characters to markdown newline sequences
|
||||
(space, space, newline).
|
||||
"""
|
||||
return text.replace("\n", " \n")
|
||||
|
||||
|
||||
CHECKMARK_EMOJI = "✅"
|
||||
THINKING_EMOJI = ":thinking_face:"
|
||||
HISTORY_EMOJI = ":books:"
|
||||
EXCEPTION_EMOJI = "⚠️"
|
||||
|
||||
|
||||
class LLMThoughtState(Enum):
|
||||
"""Enumerator of the LLMThought state."""
|
||||
|
||||
# The LLM is thinking about what to do next. We don't know which tool we'll run.
|
||||
THINKING = "THINKING"
|
||||
# The LLM has decided to run a tool. We don't have results from the tool yet.
|
||||
RUNNING_TOOL = "RUNNING_TOOL"
|
||||
# We have results from the tool.
|
||||
COMPLETE = "COMPLETE"
|
||||
|
||||
|
||||
class ToolRecord(NamedTuple):
|
||||
"""Tool record as a NamedTuple."""
|
||||
|
||||
name: str
|
||||
input_str: str
|
||||
|
||||
|
||||
class LLMThoughtLabeler:
|
||||
"""
|
||||
Generates markdown labels for LLMThought containers. Pass a custom
|
||||
subclass of this to StreamlitCallbackHandler to override its default
|
||||
labeling logic.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_initial_label() -> str:
|
||||
"""Return the markdown label for a new LLMThought that doesn't have
|
||||
an associated tool yet.
|
||||
"""
|
||||
return f"{THINKING_EMOJI} **Thinking...**"
|
||||
|
||||
@staticmethod
|
||||
def get_tool_label(tool: ToolRecord, is_complete: bool) -> str:
|
||||
"""Return the label for an LLMThought that has an associated
|
||||
tool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool
|
||||
The tool's ToolRecord
|
||||
|
||||
is_complete
|
||||
True if the thought is complete; False if the thought
|
||||
is still receiving input.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The markdown label for the thought's container.
|
||||
|
||||
"""
|
||||
input = tool.input_str
|
||||
name = tool.name
|
||||
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
|
||||
if name == "_Exception":
|
||||
emoji = EXCEPTION_EMOJI
|
||||
name = "Parsing error"
|
||||
idx = min([60, len(input)])
|
||||
input = input[0:idx]
|
||||
if len(tool.input_str) > idx:
|
||||
input = input + "..."
|
||||
input = input.replace("\n", " ")
|
||||
label = f"{emoji} **{name}:** {input}"
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def get_history_label() -> str:
|
||||
"""Return a markdown label for the special 'history' container
|
||||
that contains overflow thoughts.
|
||||
"""
|
||||
return f"{HISTORY_EMOJI} **History**"
|
||||
|
||||
@staticmethod
|
||||
def get_final_agent_thought_label() -> str:
|
||||
"""Return the markdown label for the agent's final thought -
|
||||
the "Now I have the answer" thought, that doesn't involve
|
||||
a tool.
|
||||
"""
|
||||
return f"{CHECKMARK_EMOJI} **Complete!**"
|
||||
|
||||
|
||||
class LLMThought:
|
||||
"""A thought in the LLM's thought stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
labeler: LLMThoughtLabeler,
|
||||
expanded: bool,
|
||||
collapse_on_complete: bool,
|
||||
):
|
||||
"""Initialize the LLMThought.
|
||||
|
||||
Args:
|
||||
parent_container: The container we're writing into.
|
||||
labeler: The labeler to use for this thought.
|
||||
expanded: Whether the thought should be expanded by default.
|
||||
collapse_on_complete: Whether the thought should be collapsed.
|
||||
"""
|
||||
self._container = MutableExpander(
|
||||
parent_container=parent_container,
|
||||
label=labeler.get_initial_label(),
|
||||
expanded=expanded,
|
||||
)
|
||||
self._state = LLMThoughtState.THINKING
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx: Optional[int] = None
|
||||
self._last_tool: Optional[ToolRecord] = None
|
||||
self._collapse_on_complete = collapse_on_complete
|
||||
self._labeler = labeler
|
||||
|
||||
@property
|
||||
def container(self) -> MutableExpander:
|
||||
"""The container we're writing into."""
|
||||
return self._container
|
||||
|
||||
@property
|
||||
def last_tool(self) -> Optional[ToolRecord]:
|
||||
"""The last tool executed by this thought"""
|
||||
return self._last_tool
|
||||
|
||||
def _reset_llm_token_stream(self) -> None:
|
||||
self._llm_token_stream = ""
|
||||
self._llm_token_writer_idx = None
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# This is only called when the LLM is initialized with `streaming=True`
|
||||
self._llm_token_stream += _convert_newlines(token)
|
||||
self._llm_token_writer_idx = self._container.markdown(
|
||||
self._llm_token_stream, index=self._llm_token_writer_idx
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
# `response` is the concatenation of all the tokens received by the LLM.
|
||||
# If we're receiving streaming tokens from `on_llm_new_token`, this response
|
||||
# data is redundant
|
||||
self._reset_llm_token_stream()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._container.markdown("**LLM encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
# Called with the name of the tool we're about to run (in `serialized[name]`),
|
||||
# and its input. We change our container's label to be the tool name.
|
||||
self._state = LLMThoughtState.RUNNING_TOOL
|
||||
tool_name = serialized["name"]
|
||||
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
||||
self._container.update(
|
||||
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._container.markdown(f"**{str(output)}**")
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._container.markdown("**Tool encountered an error...**")
|
||||
self._container.exception(error)
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
# Called when we're about to kick off a new tool. The `action` data
|
||||
# tells us the tool we're about to use, and the input we'll give it.
|
||||
# We don't output anything here, because we'll receive this same data
|
||||
# when `on_tool_start` is called immediately after.
|
||||
pass
|
||||
|
||||
def complete(self, final_label: Optional[str] = None) -> None:
|
||||
"""Finish the thought."""
|
||||
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
||||
assert self._last_tool is not None, (
|
||||
"_last_tool should never be null when _state == RUNNING_TOOL"
|
||||
)
|
||||
final_label = self._labeler.get_tool_label(
|
||||
self._last_tool, is_complete=True
|
||||
)
|
||||
self._state = LLMThoughtState.COMPLETE
|
||||
if self._collapse_on_complete:
|
||||
self._container.update(new_label=final_label, new_expanded=False)
|
||||
else:
|
||||
self._container.update(new_label=final_label)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove the thought from the screen. A cleared thought can't be reused."""
|
||||
self._container.clear()
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler that writes to a Streamlit app."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_container: DeltaGenerator,
|
||||
*,
|
||||
max_thought_containers: int = 4,
|
||||
expand_new_thoughts: bool = True,
|
||||
collapse_completed_thoughts: bool = True,
|
||||
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
||||
):
|
||||
"""Create a StreamlitCallbackHandler instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent_container
|
||||
The `st.container` that will contain all the Streamlit elements that the
|
||||
Handler creates.
|
||||
max_thought_containers
|
||||
The max number of completed LLM thought containers to show at once. When
|
||||
this threshold is reached, a new thought will cause the oldest thoughts to
|
||||
be collapsed into a "History" expander. Defaults to 4.
|
||||
expand_new_thoughts
|
||||
Each LLM "thought" gets its own `st.expander`. This param controls whether
|
||||
that expander is expanded by default. Defaults to True.
|
||||
collapse_completed_thoughts
|
||||
If True, LLM thought expanders will be collapsed when completed.
|
||||
Defaults to True.
|
||||
thought_labeler
|
||||
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
|
||||
will use the default thought labeling logic. Defaults to None.
|
||||
"""
|
||||
self._parent_container = parent_container
|
||||
self._history_parent = parent_container.container()
|
||||
self._history_container: Optional[MutableExpander] = None
|
||||
self._current_thought: Optional[LLMThought] = None
|
||||
self._completed_thoughts: List[LLMThought] = []
|
||||
self._max_thought_containers = max(max_thought_containers, 1)
|
||||
self._expand_new_thoughts = expand_new_thoughts
|
||||
self._collapse_completed_thoughts = collapse_completed_thoughts
|
||||
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
|
||||
|
||||
def _require_current_thought(self) -> LLMThought:
|
||||
"""Return our current LLMThought. Raise an error if we have no current
|
||||
thought.
|
||||
"""
|
||||
if self._current_thought is None:
|
||||
raise RuntimeError("Current LLMThought is unexpectedly None!")
|
||||
return self._current_thought
|
||||
|
||||
def _get_last_completed_thought(self) -> Optional[LLMThought]:
|
||||
"""Return our most recent completed LLMThought, or None if we don't have one."""
|
||||
if len(self._completed_thoughts) > 0:
|
||||
return self._completed_thoughts[len(self._completed_thoughts) - 1]
|
||||
return None
|
||||
|
||||
@property
|
||||
def _num_thought_containers(self) -> int:
|
||||
"""The number of 'thought containers' we're currently showing: the
|
||||
number of completed thought containers, the history container (if it exists),
|
||||
and the current thought container (if it exists).
|
||||
"""
|
||||
count = len(self._completed_thoughts)
|
||||
if self._history_container is not None:
|
||||
count += 1
|
||||
if self._current_thought is not None:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
|
||||
"""Complete the current thought, optionally assigning it a new label.
|
||||
Add it to our _completed_thoughts list.
|
||||
"""
|
||||
thought = self._require_current_thought()
|
||||
thought.complete(final_label)
|
||||
self._completed_thoughts.append(thought)
|
||||
self._current_thought = None
|
||||
|
||||
def _prune_old_thought_containers(self) -> None:
|
||||
"""If we have too many thoughts onscreen, move older thoughts to the
|
||||
'history container.'
|
||||
"""
|
||||
while (
|
||||
self._num_thought_containers > self._max_thought_containers
|
||||
and len(self._completed_thoughts) > 0
|
||||
):
|
||||
# Create our history container if it doesn't exist, and if
|
||||
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
|
||||
# have room to show history.)
|
||||
if self._history_container is None and self._max_thought_containers > 1:
|
||||
self._history_container = MutableExpander(
|
||||
self._history_parent,
|
||||
label=self._thought_labeler.get_history_label(),
|
||||
expanded=False,
|
||||
)
|
||||
|
||||
oldest_thought = self._completed_thoughts.pop(0)
|
||||
if self._history_container is not None:
|
||||
self._history_container.markdown(oldest_thought.container.label)
|
||||
self._history_container.append_copy(oldest_thought.container)
|
||||
oldest_thought.clear()
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is None:
|
||||
self._current_thought = LLMThought(
|
||||
parent_container=self._parent_container,
|
||||
expanded=self._expand_new_thoughts,
|
||||
collapse_on_complete=self._collapse_completed_thoughts,
|
||||
labeler=self._thought_labeler,
|
||||
)
|
||||
|
||||
self._current_thought.on_llm_start(serialized, prompts)
|
||||
|
||||
# We don't prune_old_thought_containers here, because our container won't
|
||||
# be visible until it has a child.
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_new_token(token, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
output = str(output)
|
||||
self._require_current_thought().on_tool_end(
|
||||
output, color, observation_prefix, llm_prefix, **kwargs
|
||||
)
|
||||
self._complete_current_thought()
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
self._require_current_thought().on_agent_action(action, color, **kwargs)
|
||||
self._prune_old_thought_containers()
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
if self._current_thought is not None:
|
||||
self._current_thought.complete(
|
||||
self._thought_labeler.get_final_agent_thought_label()
|
||||
)
|
||||
self._current_thought = None
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Tracers that record execution of LangChain runs."""
|
||||
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.stdout import (
|
||||
ConsoleCallbackHandler,
|
||||
FunctionCallbackHandler,
|
||||
)
|
||||
|
||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||
|
||||
__all__ = [
|
||||
"ConsoleCallbackHandler",
|
||||
"FunctionCallbackHandler",
|
||||
"LangChainTracer",
|
||||
"WandbTracer",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,135 @@
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
|
||||
from langchain_core.tracers import BaseTracer
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
from comet_llm import Span
|
||||
from comet_llm.chains.chain import Chain
|
||||
|
||||
from langchain_community.callbacks.tracers.schemas import Run
|
||||
|
||||
|
||||
def _get_run_type(run: "Run") -> str:
|
||||
if isinstance(run.run_type, str):
|
||||
return run.run_type
|
||||
elif hasattr(run.run_type, "value"):
|
||||
return run.run_type.value
|
||||
else:
|
||||
return str(run.run_type)
|
||||
|
||||
|
||||
def import_comet_llm_api() -> SimpleNamespace:
|
||||
"""Import comet_llm api and raise an error if it is not installed."""
|
||||
comet_llm = guard_import("comet_llm")
|
||||
comet_llm_chains = guard_import("comet_llm.chains")
|
||||
|
||||
return SimpleNamespace(
|
||||
chain=comet_llm_chains.chain,
|
||||
span=comet_llm_chains.span,
|
||||
chain_api=comet_llm_chains.api,
|
||||
experiment_info=comet_llm.experiment_info,
|
||||
flush=comet_llm.flush,
|
||||
)
|
||||
|
||||
|
||||
class CometTracer(BaseTracer):
|
||||
"""Comet Tracer."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the Comet Tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._span_map: Dict["UUID", "Span"] = {}
|
||||
"""Map from run id to span."""
|
||||
self._chains_map: Dict["UUID", "Chain"] = {}
|
||||
"""Map from run id to chain."""
|
||||
self._initialize_comet_modules()
|
||||
|
||||
def _initialize_comet_modules(self) -> None:
|
||||
comet_llm_api = import_comet_llm_api()
|
||||
self._chain: ModuleType = comet_llm_api.chain
|
||||
self._span: ModuleType = comet_llm_api.span
|
||||
self._chain_api: ModuleType = comet_llm_api.chain_api
|
||||
self._experiment_info: ModuleType = comet_llm_api.experiment_info
|
||||
self._flush: Callable[[], None] = comet_llm_api.flush
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
run_dict: Dict[str, Any] = run.dict()
|
||||
chain_ = self._chains_map[run.id]
|
||||
chain_.set_outputs(outputs=run_dict["outputs"])
|
||||
self._chain_api.log_chain(chain_)
|
||||
|
||||
def _process_start_trace(self, run: "Run") -> None:
|
||||
run_dict: Dict[str, Any] = run.dict()
|
||||
if not run.parent_run_id:
|
||||
# This is the first run, which maps to a chain
|
||||
metadata = run_dict["extra"].get("metadata", None)
|
||||
|
||||
chain_: "Chain" = self._chain.Chain(
|
||||
inputs=run_dict["inputs"],
|
||||
metadata=metadata,
|
||||
experiment_info=self._experiment_info.get(),
|
||||
)
|
||||
self._chains_map[run.id] = chain_
|
||||
else:
|
||||
span: "Span" = self._span.Span(
|
||||
inputs=run_dict["inputs"],
|
||||
category=_get_run_type(run),
|
||||
metadata=run_dict["extra"],
|
||||
name=run.name,
|
||||
)
|
||||
span.__api__start__(self._chains_map[run.parent_run_id])
|
||||
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
|
||||
self._span_map[run.id] = span
|
||||
|
||||
def _process_end_trace(self, run: "Run") -> None:
|
||||
run_dict: Dict[str, Any] = run.dict()
|
||||
if not run.parent_run_id:
|
||||
pass
|
||||
# Langchain will call _persist_run for us
|
||||
else:
|
||||
span = self._span_map[run.id]
|
||||
span.set_outputs(outputs=run_dict["outputs"])
|
||||
span.__api__end__()
|
||||
|
||||
def flush(self) -> None:
|
||||
self._flush()
|
||||
|
||||
def _on_llm_start(self, run: "Run") -> None:
|
||||
"""Process the LLM Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_llm_end(self, run: "Run") -> None:
|
||||
"""Process the LLM Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_llm_error(self, run: "Run") -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_chain_start(self, run: "Run") -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_chain_end(self, run: "Run") -> None:
|
||||
"""Process the Chain Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_chain_error(self, run: "Run") -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_tool_start(self, run: "Run") -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
self._process_start_trace(run)
|
||||
|
||||
def _on_tool_end(self, run: "Run") -> None:
|
||||
"""Process the Tool Run."""
|
||||
self._process_end_trace(run)
|
||||
|
||||
def _on_tool_error(self, run: "Run") -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self._process_end_trace(run)
|
||||
@@ -0,0 +1,507 @@
|
||||
"""A Tracer Implementation that records activity to Weights & Biases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.output_parsers.pydantic import PydanticBaseModel
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from wandb import Settings as WBSettings
|
||||
from wandb.sdk.data_types.trace_tree import Trace
|
||||
from wandb.sdk.lib.paths import StrPath
|
||||
from wandb.wandb_run import Run as WBRun
|
||||
|
||||
PRINT_WARNINGS = True
|
||||
|
||||
|
||||
def _serialize_io(run_io: Optional[dict]) -> dict:
|
||||
"""Utility to serialize the input and output of a run to store in wandb.
|
||||
Currently, supports serializing pydantic models and protobuf messages.
|
||||
|
||||
:param run_io: The inputs and outputs of the run.
|
||||
:return: The serialized inputs and outputs.
|
||||
|
||||
|
||||
"""
|
||||
if not run_io:
|
||||
return {}
|
||||
from google.protobuf.json_format import MessageToJson
|
||||
from google.protobuf.message import Message
|
||||
|
||||
serialized_inputs = {}
|
||||
for key, value in run_io.items():
|
||||
if isinstance(value, Message):
|
||||
serialized_inputs[key] = MessageToJson(value)
|
||||
|
||||
elif isinstance(value, PydanticBaseModel):
|
||||
serialized_inputs[key] = (
|
||||
value.model_dump_json()
|
||||
if hasattr(value, "model_dump_json")
|
||||
else value.json()
|
||||
)
|
||||
|
||||
elif key == "input_documents":
|
||||
serialized_inputs.update(
|
||||
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
|
||||
)
|
||||
else:
|
||||
serialized_inputs[key] = value
|
||||
return serialized_inputs
|
||||
|
||||
|
||||
def flatten_run(run: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Utility to flatten a nest run object into a list of runs.
|
||||
:param run: The base run to flatten.
|
||||
:return: The flattened list of runs.
|
||||
"""
|
||||
|
||||
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Utility to recursively flatten a list of child runs in a run.
|
||||
:param child_runs: The list of child runs to flatten.
|
||||
:return: The flattened list of runs.
|
||||
"""
|
||||
if child_runs is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for item in child_runs:
|
||||
child_runs = item.pop("child_runs", [])
|
||||
result.append(item)
|
||||
result.extend(flatten(child_runs))
|
||||
|
||||
return result
|
||||
|
||||
return flatten([run])
|
||||
|
||||
|
||||
def truncate_run_iterative(
|
||||
runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Utility to truncate a list of runs dictionaries to only keep the specified
|
||||
keys in each run.
|
||||
:param runs: The list of runs to truncate.
|
||||
:param keep_keys: The keys to keep in each run.
|
||||
:return: The truncated list of runs.
|
||||
"""
|
||||
|
||||
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Utility to truncate a single run dictionary to only keep the specified
|
||||
keys.
|
||||
:param run: The run dictionary to truncate.
|
||||
:return: The truncated run dictionary
|
||||
"""
|
||||
new_dict = {}
|
||||
for key in run:
|
||||
if key in keep_keys:
|
||||
new_dict[key] = run.get(key)
|
||||
return new_dict
|
||||
|
||||
return list(map(truncate_single, runs))
|
||||
|
||||
|
||||
def modify_serialized_iterative(
|
||||
runs: List[Dict[str, Any]],
|
||||
exact_keys: Tuple[str, ...] = (),
|
||||
partial_keys: Tuple[str, ...] = (),
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Utility to modify the serialized field of a list of runs dictionaries.
|
||||
removes any keys that match the exact_keys and any keys that contain any of the
|
||||
partial_keys.
|
||||
recursively moves the dictionaries under the kwargs key to the top level.
|
||||
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
|
||||
visualize the run. promotes the "serialized" field to the top level.
|
||||
:param runs: The list of runs to modify.
|
||||
:param exact_keys: A tuple of keys to remove from the serialized field.
|
||||
:param partial_keys: A tuple of partial keys to remove from the serialized
|
||||
field.
|
||||
:return: The modified list of runs.
|
||||
"""
|
||||
|
||||
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively removes exact and partial keys from a dictionary.
|
||||
:param obj: The dictionary to remove keys from.
|
||||
:return: The modified dictionary.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
obj = {
|
||||
k: v
|
||||
for k, v in obj.items()
|
||||
if k not in exact_keys
|
||||
and not any(partial in k for partial in partial_keys)
|
||||
}
|
||||
for k, v in obj.items():
|
||||
obj[k] = remove_exact_and_partial_keys(v)
|
||||
elif isinstance(obj, list):
|
||||
obj = [remove_exact_and_partial_keys(x) for x in obj]
|
||||
return obj
|
||||
|
||||
def handle_id_and_kwargs(obj: Dict[str, Any], root: bool = False) -> Dict[str, Any]:
|
||||
"""Recursively handles the id and kwargs fields of a dictionary.
|
||||
changes the id field to a string "_kind" field that tells WBTraceTree how
|
||||
to visualize the run. recursively moves the dictionaries under the kwargs
|
||||
key to the top level.
|
||||
:param obj: a run dictionary with id and kwargs fields.
|
||||
:param root: whether this is the root dictionary or the serialized
|
||||
dictionary.
|
||||
:return: The modified dictionary.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
if "data" in obj and isinstance(obj["data"], dict):
|
||||
obj = obj["data"]
|
||||
if ("id" in obj or "name" in obj) and not root:
|
||||
_kind = obj.get("id")
|
||||
if not _kind:
|
||||
_kind = [obj.get("name")]
|
||||
if isinstance(_kind, list):
|
||||
obj["_kind"] = _kind[-1]
|
||||
obj.pop("id", None)
|
||||
obj.pop("name", None)
|
||||
if "kwargs" in obj:
|
||||
kwargs = obj.pop("kwargs")
|
||||
for k, v in kwargs.items():
|
||||
obj[k] = v
|
||||
for k, v in obj.items():
|
||||
obj[k] = handle_id_and_kwargs(v)
|
||||
elif isinstance(obj, list):
|
||||
obj = [handle_id_and_kwargs(x) for x in obj]
|
||||
return obj
|
||||
|
||||
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transforms the serialized field of a run dictionary to be compatible
|
||||
with WBTraceTree.
|
||||
:param serialized: The serialized field of a run dictionary.
|
||||
:return: The transformed serialized field.
|
||||
"""
|
||||
serialized = handle_id_and_kwargs(serialized, root=True)
|
||||
serialized = remove_exact_and_partial_keys(serialized)
|
||||
return serialized
|
||||
|
||||
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transforms a run dictionary to be compatible with WBTraceTree.
|
||||
:param run: The run dictionary to transform.
|
||||
:return: The transformed run dictionary.
|
||||
"""
|
||||
transformed_dict = transform_serialized(run)
|
||||
|
||||
serialized = transformed_dict.pop("serialized")
|
||||
for k, v in serialized.items():
|
||||
transformed_dict[k] = v
|
||||
|
||||
_kind = transformed_dict.get("_kind", None)
|
||||
name = transformed_dict.pop("name", None)
|
||||
|
||||
if not name:
|
||||
name = _kind
|
||||
|
||||
output_dict = {
|
||||
f"{name}": transformed_dict,
|
||||
}
|
||||
return output_dict
|
||||
|
||||
return list(map(transform_run, runs))
|
||||
|
||||
|
||||
def build_tree(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Builds a nested dictionary from a list of runs.
|
||||
:param runs: The list of runs to build the tree from.
|
||||
:return: The nested dictionary representing the langchain Run in a tree
|
||||
structure compatible with WBTraceTree.
|
||||
"""
|
||||
id_to_data = {}
|
||||
child_to_parent = {}
|
||||
|
||||
for entity in runs:
|
||||
for key, data in entity.items():
|
||||
id_val = data.pop("id", None)
|
||||
parent_run_id = data.pop("parent_run_id", None)
|
||||
id_to_data[id_val] = {key: data}
|
||||
if parent_run_id:
|
||||
child_to_parent[id_val] = parent_run_id
|
||||
|
||||
for child_id, parent_id in child_to_parent.items():
|
||||
parent_dict = id_to_data[parent_id]
|
||||
parent_dict[next(iter(parent_dict))][next(iter(id_to_data[child_id]))] = (
|
||||
id_to_data[child_id][next(iter(id_to_data[child_id]))]
|
||||
)
|
||||
|
||||
root_dict = next(
|
||||
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
|
||||
)
|
||||
|
||||
return root_dict
|
||||
|
||||
|
||||
class WandbRunArgs(TypedDict):
|
||||
"""Arguments for the WandbTracer."""
|
||||
|
||||
job_type: Optional[str]
|
||||
dir: Optional[StrPath]
|
||||
config: Union[Dict, str, None]
|
||||
project: Optional[str]
|
||||
entity: Optional[str]
|
||||
reinit: Optional[bool]
|
||||
tags: Optional[Sequence]
|
||||
group: Optional[str]
|
||||
name: Optional[str]
|
||||
notes: Optional[str]
|
||||
magic: Optional[Union[dict, str, bool]]
|
||||
config_exclude_keys: Optional[List[str]]
|
||||
config_include_keys: Optional[List[str]]
|
||||
anonymous: Optional[str]
|
||||
mode: Optional[str]
|
||||
allow_val_change: Optional[bool]
|
||||
resume: Optional[Union[bool, str]]
|
||||
force: Optional[bool]
|
||||
tensorboard: Optional[bool]
|
||||
sync_tensorboard: Optional[bool]
|
||||
monitor_gym: Optional[bool]
|
||||
save_code: Optional[bool]
|
||||
id: Optional[str]
|
||||
settings: Union[WBSettings, Dict[str, Any], None]
|
||||
|
||||
|
||||
class WandbTracer(BaseTracer):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
This handler will log the model architecture and run traces to Weights and Biases.
|
||||
This will ensure that all LangChain activity is logged to W&B.
|
||||
"""
|
||||
|
||||
_run: Optional[WBRun] = None
|
||||
_run_args: Optional[WandbRunArgs] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_args: Optional[WandbRunArgs] = None,
|
||||
io_serializer: Callable = _serialize_io,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initializes the WandbTracer.
|
||||
|
||||
Parameters:
|
||||
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
||||
provided, `wandb.init()` will be called with no arguments. Please
|
||||
refer to the `wandb.init` for more details.
|
||||
io_serializer: callable A function that serializes the input and outputs
|
||||
of a run to store in wandb. Defaults to "_serialize_io"
|
||||
|
||||
To use W&B to monitor all LangChain activity, add this tracer like any other
|
||||
LangChain callback:
|
||||
```
|
||||
from wandb.integration.langchain import WandbTracer
|
||||
|
||||
tracer = WandbTracer()
|
||||
chain = LLMChain(llm, callbacks=[tracer])
|
||||
# ...end of notebook / script:
|
||||
tracer.finish()
|
||||
```
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import wandb
|
||||
from wandb.sdk.data_types import trace_tree
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import wandb python package."
|
||||
"Please install it with `pip install -U wandb`."
|
||||
) from e
|
||||
self._wandb = wandb
|
||||
self._trace_tree = trace_tree
|
||||
self._run_args = run_args
|
||||
self._ensure_run(should_print_url=(wandb.run is None))
|
||||
self._io_serializer = io_serializer
|
||||
warn_deprecated(
|
||||
"0.3.8",
|
||||
pending=False,
|
||||
message=(
|
||||
"Please use the `WeaveTracer` from the `weave` package instead of this."
|
||||
"The `WeaveTracer` is a more flexible and powerful tool for logging "
|
||||
"and tracing your LangChain callables."
|
||||
"Find more information at https://weave-docs.wandb.ai/guides/integrations/langchain"
|
||||
),
|
||||
alternative=(
|
||||
"Please instantiate the WeaveTracer from "
|
||||
"`weave.integrations.langchain import WeaveTracer` ."
|
||||
"For autologging simply use `weave.init()` and log all traces "
|
||||
"from your LangChain callables."
|
||||
),
|
||||
)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Waits for all asynchronous processes to finish and data to upload.
|
||||
|
||||
Proxy for `wandb.finish()`.
|
||||
"""
|
||||
self._wandb.finish()
|
||||
|
||||
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||
"""Ensures an active W&B run exists.
|
||||
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
run_args: Dict = {**(self._run_args or {})}
|
||||
|
||||
if "settings" not in run_args:
|
||||
run_args["settings"] = {"silent": True}
|
||||
|
||||
self._wandb.init(**run_args)
|
||||
if self._wandb.run is not None:
|
||||
if should_print_url:
|
||||
run_url = self._wandb.run.settings.run_url
|
||||
self._wandb.termlog(
|
||||
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||
"`WandbTracer` is currently in beta.\n"
|
||||
"Please report any issues to "
|
||||
"https://github.com/wandb/wandb/issues with the tag "
|
||||
"`langchain`."
|
||||
)
|
||||
|
||||
self._wandb.run._label(repo="langchain")
|
||||
|
||||
def process_model_dict(self, run: Run) -> Optional[Dict[str, Any]]:
|
||||
"""Utility to process a run for wandb model_dict serialization.
|
||||
:param run: The run to process.
|
||||
:return: The convert model_dict to pass to WBTraceTree.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(run.json())
|
||||
processed = flatten_run(data)
|
||||
keep_keys = (
|
||||
"id",
|
||||
"name",
|
||||
"serialized",
|
||||
"parent_run_id",
|
||||
)
|
||||
processed = truncate_run_iterative(processed, keep_keys=keep_keys)
|
||||
exact_keys, partial_keys = (
|
||||
("lc", "type", "graph"),
|
||||
(
|
||||
"api_key",
|
||||
"input",
|
||||
"output",
|
||||
),
|
||||
)
|
||||
processed = modify_serialized_iterative(
|
||||
processed, exact_keys=exact_keys, partial_keys=partial_keys
|
||||
)
|
||||
output = build_tree(processed)
|
||||
return output
|
||||
except Exception as e:
|
||||
if PRINT_WARNINGS:
|
||||
self._wandb.termerror(f"WARNING: Failed to serialize model: {e}")
|
||||
return None
|
||||
|
||||
def _log_trace_from_run(self, run: Run) -> None:
|
||||
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||
self._ensure_run()
|
||||
|
||||
def create_trace(
|
||||
run: "Run", parent: Optional["Trace"] = None
|
||||
) -> Optional["Trace"]:
|
||||
"""
|
||||
Create a trace for a given run and its child runs.
|
||||
|
||||
Args:
|
||||
run (Run): The run for which to create a trace.
|
||||
parent (Optional[Trace]): The parent trace.
|
||||
If provided, the created trace is added as a child to the parent trace.
|
||||
|
||||
Returns:
|
||||
The created trace. If an error occurs during the creation of the trace,
|
||||
None is returned.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the creation of the trace,
|
||||
no exception is raised and a warning is printed.
|
||||
"""
|
||||
|
||||
def get_metadata_dict(r: "Run") -> Dict[str, Any]:
|
||||
"""
|
||||
Extract metadata from a given run.
|
||||
|
||||
This function extracts metadata from a given run
|
||||
and returns it as a dictionary.
|
||||
|
||||
Args:
|
||||
r (Run): The run from which to extract metadata.
|
||||
|
||||
Returns:
|
||||
`dict` containing the extracted metadata.
|
||||
"""
|
||||
run_dict = json.loads(r.json())
|
||||
metadata_dict = run_dict.get("metadata", {})
|
||||
metadata_dict["run_id"] = run_dict.get("id")
|
||||
metadata_dict["parent_run_id"] = run_dict.get("parent_run_id")
|
||||
metadata_dict["tags"] = run_dict.get("tags")
|
||||
metadata_dict["execution_order"] = run_dict.get(
|
||||
"dotted_order", ""
|
||||
).count(".")
|
||||
return metadata_dict
|
||||
|
||||
try:
|
||||
if run.run_type in ["llm", "tool"]:
|
||||
run_type = run.run_type
|
||||
elif run.run_type == "chain":
|
||||
run_type = "agent" if "agent" in run.name.lower() else "chain"
|
||||
else:
|
||||
run_type = None
|
||||
|
||||
metadata = get_metadata_dict(run)
|
||||
trace_tree = self._trace_tree.Trace(
|
||||
name=run.name,
|
||||
kind=run_type,
|
||||
status_code="error" if run.error else "success",
|
||||
start_time_ms=int(run.start_time.timestamp() * 1000)
|
||||
if run.start_time is not None
|
||||
else None,
|
||||
end_time_ms=int(run.end_time.timestamp() * 1000)
|
||||
if run.end_time is not None
|
||||
else None,
|
||||
metadata=metadata,
|
||||
inputs=self._io_serializer(run.inputs),
|
||||
outputs=self._io_serializer(run.outputs),
|
||||
)
|
||||
|
||||
# If the run has child runs, recursively create traces for them
|
||||
for child_run in run.child_runs:
|
||||
create_trace(child_run, trace_tree)
|
||||
|
||||
if parent is None:
|
||||
return trace_tree
|
||||
else:
|
||||
parent.add_child(trace_tree)
|
||||
return parent
|
||||
except Exception as e:
|
||||
if PRINT_WARNINGS:
|
||||
self._wandb.termwarn(
|
||||
f"WARNING: Failed to serialize trace for run due to: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
run_trace = create_trace(run)
|
||||
model_dict = self.process_model_dict(run)
|
||||
if model_dict is not None and run_trace is not None:
|
||||
run_trace._model_dict = model_dict
|
||||
if self._wandb.run is not None and run_trace is not None:
|
||||
run_trace.log("langchain_trace")
|
||||
|
||||
def _persist_run(self, run: "Run") -> None:
|
||||
"""Persist a run."""
|
||||
self._log_trace_from_run(run)
|
||||
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback handler for Trubrics.
|
||||
|
||||
Args:
|
||||
project: a trubrics project, default project is "default"
|
||||
email: a trubrics account email, can equally be set in env variables
|
||||
password: a trubrics account password, can equally be set in env variables
|
||||
**kwargs: all other kwargs are parsed and set to trubrics prompt variables,
|
||||
or added to the `metadata` dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str = "default",
|
||||
email: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from trubrics import Trubrics
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The TrubricsCallbackHandler requires installation of "
|
||||
"the trubrics package. "
|
||||
"Please install it with `pip install trubrics`."
|
||||
)
|
||||
|
||||
self.trubrics = Trubrics(
|
||||
project=project,
|
||||
email=email or os.environ["TRUBRICS_EMAIL"],
|
||||
password=password or os.environ["TRUBRICS_PASSWORD"],
|
||||
)
|
||||
self.config_model: dict = {}
|
||||
self.prompt: Optional[str] = None
|
||||
self.messages: Optional[list] = None
|
||||
self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
self.prompt = prompts[0]
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
|
||||
self.prompt = self.messages[-1]["content"]
|
||||
|
||||
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
|
||||
tags = ["langchain"]
|
||||
user_id = None
|
||||
session_id = None
|
||||
metadata: dict = {"langchain_run_id": run_id}
|
||||
if self.messages:
|
||||
metadata["messages"] = self.messages
|
||||
if self.trubrics_kwargs:
|
||||
if self.trubrics_kwargs.get("tags"):
|
||||
tags.append(*self.trubrics_kwargs.pop("tags"))
|
||||
user_id = self.trubrics_kwargs.pop("user_id", None)
|
||||
session_id = self.trubrics_kwargs.pop("session_id", None)
|
||||
metadata.update(self.trubrics_kwargs)
|
||||
|
||||
for generation in response.generations:
|
||||
self.trubrics.log_prompt(
|
||||
config_model={
|
||||
"model": response.llm_output.get("model_name")
|
||||
if response.llm_output
|
||||
else "NA"
|
||||
},
|
||||
prompt=self.prompt,
|
||||
generation=generation[0].text,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Ratelimiting Handler to limit requests or tokens"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from upstash_ratelimit import Ratelimit
|
||||
except ImportError:
|
||||
Ratelimit = None
|
||||
|
||||
|
||||
class UpstashRatelimitError(Exception):
|
||||
"""
|
||||
Upstash Ratelimit Error
|
||||
|
||||
Raised when the rate limit is reached in `UpstashRatelimitHandler`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
type: Literal["token", "request"],
|
||||
limit: Optional[int] = None,
|
||||
reset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
message (str): error message
|
||||
type (str): The kind of the limit which was reached. One of
|
||||
"token" or "request"
|
||||
limit (Optional[int]): The limit which was reached. Passed when type
|
||||
is request
|
||||
reset (Optional[int]): unix timestamp in milliseconds when the limits
|
||||
are reset. Passed when type is request
|
||||
"""
|
||||
# Call the base class constructor with the parameters it needs
|
||||
super().__init__(message)
|
||||
self.type = type
|
||||
self.limit = limit
|
||||
self.reset = reset
|
||||
|
||||
|
||||
class UpstashRatelimitHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback to handle rate limiting based on the number of requests
|
||||
or the number of tokens in the input.
|
||||
|
||||
It uses Upstash Ratelimit to track the ratelimit which utilizes
|
||||
Upstash Redis to track the state.
|
||||
|
||||
Should not be passed to the chain when initialising the chain.
|
||||
This is because the handler has a state which should be fresh
|
||||
every time invoke is called. Instead, initialise and pass a handler
|
||||
every time you invoke.
|
||||
"""
|
||||
|
||||
raise_error: bool = True
|
||||
_checked: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identifier: str,
|
||||
*,
|
||||
token_ratelimit: Optional[Ratelimit] = None,
|
||||
request_ratelimit: Optional[Ratelimit] = None,
|
||||
include_output_tokens: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates UpstashRatelimitHandler. Must be passed an identifier to
|
||||
ratelimit like a user id or an ip address.
|
||||
|
||||
Additionally, it must be passed at least one of token_ratelimit
|
||||
or request_ratelimit parameters.
|
||||
|
||||
Args:
|
||||
identifier Union[int, str]: the identifier
|
||||
token_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||
number of tokens. Only works with OpenAI models since only
|
||||
these models provide the number of tokens as information
|
||||
in their output.
|
||||
request_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||
number of requests
|
||||
include_output_tokens bool: Whether to count output tokens when
|
||||
rate limiting based on number of tokens. Only used when
|
||||
`token_ratelimit` is passed. False by default.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from upstash_redis import Redis
|
||||
from upstash_ratelimit import Ratelimit, FixedWindow
|
||||
|
||||
redis = Redis.from_env()
|
||||
ratelimit = Ratelimit(
|
||||
redis=redis,
|
||||
# fixed window to allow 10 requests every 10 seconds:
|
||||
limiter=FixedWindow(max_requests=10, window=10),
|
||||
)
|
||||
|
||||
user_id = "foo"
|
||||
handler = UpstashRatelimitHandler(
|
||||
identifier=user_id,
|
||||
request_ratelimit=ratelimit
|
||||
)
|
||||
|
||||
# Initialize a simple runnable to test
|
||||
chain = RunnableLambda(str)
|
||||
|
||||
# pass handler as callback:
|
||||
output = chain.invoke(
|
||||
"input",
|
||||
config={
|
||||
"callbacks": [handler]
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
if not any([token_ratelimit, request_ratelimit]):
|
||||
raise ValueError(
|
||||
"You must pass at least one of input_token_ratelimit or"
|
||||
" request_ratelimit parameters for handler to work."
|
||||
)
|
||||
|
||||
self.identifier = identifier
|
||||
self.token_ratelimit = token_ratelimit
|
||||
self.request_ratelimit = request_ratelimit
|
||||
self.include_output_tokens = include_output_tokens
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Run when chain starts running.
|
||||
|
||||
on_chain_start runs multiple times during a chain execution. To make
|
||||
sure that it's only called once, we keep a bool state `_checked`. If
|
||||
not `self._checked`, we call limit with `request_ratelimit` and raise
|
||||
`UpstashRatelimitError` if the identifier is rate limited.
|
||||
"""
|
||||
if self.request_ratelimit and not self._checked:
|
||||
response = self.request_ratelimit.limit(self.identifier)
|
||||
if not response.allowed:
|
||||
raise UpstashRatelimitError(
|
||||
"Request limit reached!", "request", response.limit, response.reset
|
||||
)
|
||||
self._checked = True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""
|
||||
Run when LLM starts running
|
||||
"""
|
||||
if self.token_ratelimit:
|
||||
remaining = self.token_ratelimit.get_remaining(self.identifier)
|
||||
if remaining <= 0:
|
||||
raise UpstashRatelimitError("Token limit reached!", "token")
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when LLM ends running
|
||||
|
||||
If the `include_output_tokens` is set to True, number of tokens
|
||||
in LLM completion are counted for rate limiting
|
||||
"""
|
||||
if self.token_ratelimit:
|
||||
try:
|
||||
llm_output = response.llm_output or {}
|
||||
token_usage = llm_output["token_usage"]
|
||||
token_count = (
|
||||
token_usage["total_tokens"]
|
||||
if self.include_output_tokens
|
||||
else token_usage["prompt_tokens"]
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"LLM response doesn't include"
|
||||
" `token_usage: {total_tokens: int, prompt_tokens: int}`"
|
||||
" field. To use UpstashRatelimitHandler with token_ratelimit,"
|
||||
" either use a model which returns token_usage (like "
|
||||
" OpenAI models) or rate limit only with request_ratelimit."
|
||||
)
|
||||
|
||||
# call limit to add the completion tokens to rate limit
|
||||
# but don't raise exception since we already generated
|
||||
# the tokens and would rather continue execution.
|
||||
self.token_ratelimit.limit(self.identifier, rate=token_count)
|
||||
|
||||
def reset(self, identifier: Optional[str] = None) -> "UpstashRatelimitHandler":
|
||||
"""
|
||||
Creates a new UpstashRatelimitHandler object with the same
|
||||
ratelimit configurations but with a new identifier if it's
|
||||
provided.
|
||||
|
||||
Also resets the state of the handler.
|
||||
"""
|
||||
return UpstashRatelimitHandler(
|
||||
identifier=identifier or self.identifier,
|
||||
token_ratelimit=self.token_ratelimit,
|
||||
request_ratelimit=self.request_ratelimit,
|
||||
include_output_tokens=self.include_output_tokens,
|
||||
)
|
||||
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
UpTrain Callback Handler
|
||||
|
||||
UpTrain is an open-source platform to evaluate and improve LLM applications. It provides
|
||||
grades for 20+ preconfigured checks (covering language, code, embedding use cases),
|
||||
performs root cause analyses on instances of failure cases and provides guidance for
|
||||
resolving them.
|
||||
|
||||
This module contains a callback handler for integrating UpTrain seamlessly into your
|
||||
pipeline and facilitating diverse evaluations. The callback handler automates various
|
||||
evaluations to assess the performance and effectiveness of the components within the
|
||||
pipeline.
|
||||
|
||||
The evaluations conducted include:
|
||||
|
||||
1. RAG:
|
||||
- Context Relevance: Determines the relevance of the context extracted from the query
|
||||
to the response.
|
||||
- Factual Accuracy: Assesses if the Language Model (LLM) is providing accurate
|
||||
information or hallucinating.
|
||||
- Response Completeness: Checks if the response contains all the information
|
||||
requested by the query.
|
||||
|
||||
2. Multi Query Generation:
|
||||
MultiQueryRetriever generates multiple variants of a question with similar meanings
|
||||
to the original question. This evaluation includes previous assessments and adds:
|
||||
- Multi Query Accuracy: Ensures that the multi-queries generated convey the same
|
||||
meaning as the original query.
|
||||
|
||||
3. Context Compression and Reranking:
|
||||
Re-ranking involves reordering nodes based on relevance to the query and selecting
|
||||
top n nodes.
|
||||
Due to the potential reduction in the number of nodes after re-ranking, the following
|
||||
evaluations
|
||||
are performed in addition to the RAG evaluations:
|
||||
- Context Reranking: Determines if the order of re-ranked nodes is more relevant to
|
||||
the query than the original order.
|
||||
- Context Conciseness: Examines whether the reduced number of nodes still provides
|
||||
all the required information.
|
||||
|
||||
These evaluations collectively ensure the robustness and effectiveness of the RAG query
|
||||
engine, MultiQueryRetriever, and the re-ranking process within the pipeline.
|
||||
|
||||
Useful links:
|
||||
Github: https://github.com/uptrain-ai/uptrain
|
||||
Website: https://uptrain.ai/
|
||||
Docs: https://docs.uptrain.ai/getting-started/introduction
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def import_uptrain() -> Any:
|
||||
"""Import the `uptrain` package."""
|
||||
return guard_import("uptrain")
|
||||
|
||||
|
||||
class UpTrainDataSchema:
|
||||
"""The UpTrain data schema for tracking evaluation results.
|
||||
|
||||
Args:
|
||||
project_name (str): The project name to be shown in UpTrain dashboard.
|
||||
|
||||
Attributes:
|
||||
project_name (str): The project name to be shown in UpTrain dashboard.
|
||||
uptrain_results (DefaultDict[str, Any]): Dictionary to store evaluation results.
|
||||
eval_types (Set[str]): Set to store the types of evaluations.
|
||||
query (str): Query for the RAG evaluation.
|
||||
context (str): Context for the RAG evaluation.
|
||||
response (str): Response for the RAG evaluation.
|
||||
old_context (List[str]): Old context nodes for Context Conciseness evaluation.
|
||||
new_context (List[str]): New context nodes for Context Conciseness evaluation.
|
||||
context_conciseness_run_id (str): Run ID for Context Conciseness evaluation.
|
||||
multi_queries (List[str]): List of multi queries for Multi Query evaluation.
|
||||
multi_query_run_id (str): Run ID for Multi Query evaluation.
|
||||
multi_query_daugher_run_id (str): Run ID for Multi Query daughter evaluation.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, project_name: str) -> None:
|
||||
"""Initialize the UpTrain data schema."""
|
||||
# For tracking project name and results
|
||||
self.project_name: str = project_name
|
||||
self.uptrain_results: DefaultDict[str, Any] = defaultdict(list)
|
||||
|
||||
# For tracking event types
|
||||
self.eval_types: Set[str] = set()
|
||||
|
||||
## RAG
|
||||
self.query: str = ""
|
||||
self.context: str = ""
|
||||
self.response: str = ""
|
||||
|
||||
## CONTEXT CONCISENESS
|
||||
self.old_context: List[str] = []
|
||||
self.new_context: List[str] = []
|
||||
self.context_conciseness_run_id: UUID = UUID(int=0)
|
||||
|
||||
# MULTI QUERY
|
||||
self.multi_queries: List[str] = []
|
||||
self.multi_query_run_id: UUID = UUID(int=0)
|
||||
self.multi_query_daugher_run_id: UUID = UUID(int=0)
|
||||
|
||||
|
||||
class UpTrainCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that logs evaluation results to uptrain and the console.
|
||||
|
||||
Args:
|
||||
project_name (str): The project name to be shown in UpTrain dashboard.
|
||||
key_type (str): Type of key to use. Must be 'uptrain' or 'openai'.
|
||||
api_key (str): API key for the UpTrain or OpenAI API.
|
||||
(This key is required to perform evaluations using GPT.)
|
||||
|
||||
Raises:
|
||||
ValueError: If the key type is invalid.
|
||||
ImportError: If the `uptrain` package is not installed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
project_name: str = "langchain",
|
||||
key_type: str = "openai",
|
||||
api_key: str = "sk-****************", # The API key to use for evaluation
|
||||
model: str = "gpt-3.5-turbo", # The model to use for evaluation
|
||||
log_results: bool = True,
|
||||
) -> None:
|
||||
"""Initializes the `UpTrainCallbackHandler`."""
|
||||
super().__init__()
|
||||
|
||||
uptrain = import_uptrain()
|
||||
|
||||
self.log_results = log_results
|
||||
|
||||
# Set uptrain variables
|
||||
self.schema = UpTrainDataSchema(project_name=project_name)
|
||||
self.first_score_printed_flag = False
|
||||
|
||||
if key_type == "uptrain":
|
||||
settings = uptrain.Settings(uptrain_access_token=api_key, model=model)
|
||||
self.uptrain_client = uptrain.APIClient(settings=settings)
|
||||
elif key_type == "openai":
|
||||
settings = uptrain.Settings(
|
||||
openai_api_key=api_key, evaluate_locally=True, model=model
|
||||
)
|
||||
self.uptrain_client = uptrain.EvalLLM(settings=settings)
|
||||
else:
|
||||
raise ValueError("Invalid key type: Must be 'uptrain' or 'openai'")
|
||||
|
||||
def uptrain_evaluate(
|
||||
self,
|
||||
evaluation_name: str,
|
||||
data: List[Dict[str, Any]],
|
||||
checks: List[str],
|
||||
) -> None:
|
||||
"""Run an evaluation on the UpTrain server using UpTrain client."""
|
||||
if self.uptrain_client.__class__.__name__ == "APIClient":
|
||||
uptrain_result = self.uptrain_client.log_and_evaluate(
|
||||
project_name=self.schema.project_name,
|
||||
evaluation_name=evaluation_name,
|
||||
data=data,
|
||||
checks=checks,
|
||||
)
|
||||
else:
|
||||
uptrain_result = self.uptrain_client.evaluate(
|
||||
project_name=self.schema.project_name,
|
||||
evaluation_name=evaluation_name,
|
||||
data=data,
|
||||
checks=checks,
|
||||
)
|
||||
self.schema.uptrain_results[self.schema.project_name].append(uptrain_result)
|
||||
|
||||
score_name_map = {
|
||||
"score_context_relevance": "Context Relevance Score",
|
||||
"score_factual_accuracy": "Factual Accuracy Score",
|
||||
"score_response_completeness": "Response Completeness Score",
|
||||
"score_sub_query_completeness": "Sub Query Completeness Score",
|
||||
"score_context_reranking": "Context Reranking Score",
|
||||
"score_context_conciseness": "Context Conciseness Score",
|
||||
"score_multi_query_accuracy": "Multi Query Accuracy Score",
|
||||
}
|
||||
|
||||
if self.log_results:
|
||||
# Set logger level to INFO to print the evaluation results
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
for row in uptrain_result:
|
||||
columns = list(row.keys())
|
||||
for column in columns:
|
||||
if column == "question":
|
||||
logger.info(f"\nQuestion: {row[column]}")
|
||||
self.first_score_printed_flag = False
|
||||
elif column == "response":
|
||||
logger.info(f"Response: {row[column]}")
|
||||
self.first_score_printed_flag = False
|
||||
elif column == "variants":
|
||||
logger.info("Multi Queries:")
|
||||
for variant in row[column]:
|
||||
logger.info(f" - {variant}")
|
||||
self.first_score_printed_flag = False
|
||||
elif column.startswith("score"):
|
||||
if not self.first_score_printed_flag:
|
||||
logger.info("")
|
||||
self.first_score_printed_flag = True
|
||||
if column in score_name_map:
|
||||
logger.info(f"{score_name_map[column]}: {row[column]}")
|
||||
else:
|
||||
logger.info(f"{column}: {row[column]}")
|
||||
|
||||
if self.log_results:
|
||||
# Set logger level back to WARNING
|
||||
# (We are doing this to avoid printing the logs from HTTP requests)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Log records to uptrain when an LLM ends."""
|
||||
uptrain = import_uptrain()
|
||||
self.schema.response = response.generations[0][0].text
|
||||
if (
|
||||
"qa_rag" in self.schema.eval_types
|
||||
and parent_run_id != self.schema.multi_query_daugher_run_id
|
||||
):
|
||||
data = [
|
||||
{
|
||||
"question": self.schema.query,
|
||||
"context": self.schema.context,
|
||||
"response": self.schema.response,
|
||||
}
|
||||
]
|
||||
|
||||
self.uptrain_evaluate(
|
||||
evaluation_name="rag",
|
||||
data=data,
|
||||
checks=[
|
||||
uptrain.Evals.CONTEXT_RELEVANCE,
|
||||
uptrain.Evals.FACTUAL_ACCURACY,
|
||||
uptrain.Evals.RESPONSE_COMPLETENESS,
|
||||
],
|
||||
)
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when chain starts"""
|
||||
if parent_run_id == self.schema.multi_query_run_id:
|
||||
self.schema.multi_query_daugher_run_id = run_id
|
||||
if isinstance(inputs, dict) and set(inputs.keys()) == {"context", "question"}:
|
||||
self.schema.eval_types.add("qa_rag")
|
||||
|
||||
context = ""
|
||||
if isinstance(inputs["context"], Document):
|
||||
context = inputs["context"].page_content
|
||||
elif isinstance(inputs["context"], list):
|
||||
for doc in inputs["context"]:
|
||||
context += doc.page_content + "\n"
|
||||
elif isinstance(inputs["context"], str):
|
||||
context = inputs["context"]
|
||||
self.schema.context = context
|
||||
self.schema.query = inputs["question"]
|
||||
pass
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if "contextual_compression" in serialized["id"]:
|
||||
self.schema.eval_types.add("contextual_compression")
|
||||
self.schema.query = query
|
||||
self.schema.context_conciseness_run_id = run_id
|
||||
|
||||
if "multi_query" in serialized["id"]:
|
||||
self.schema.eval_types.add("multi_query")
|
||||
self.schema.multi_query_run_id = run_id
|
||||
self.schema.query = query
|
||||
elif "multi_query" in self.schema.eval_types:
|
||||
self.schema.multi_queries.append(query)
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever ends running."""
|
||||
uptrain = import_uptrain()
|
||||
if run_id == self.schema.multi_query_run_id:
|
||||
data = [
|
||||
{
|
||||
"question": self.schema.query,
|
||||
"variants": self.schema.multi_queries,
|
||||
}
|
||||
]
|
||||
|
||||
self.uptrain_evaluate(
|
||||
evaluation_name="multi_query",
|
||||
data=data,
|
||||
checks=[uptrain.Evals.MULTI_QUERY_ACCURACY],
|
||||
)
|
||||
if "contextual_compression" in self.schema.eval_types:
|
||||
if parent_run_id == self.schema.context_conciseness_run_id:
|
||||
for doc in documents:
|
||||
self.schema.old_context.append(doc.page_content)
|
||||
elif run_id == self.schema.context_conciseness_run_id:
|
||||
for doc in documents:
|
||||
self.schema.new_context.append(doc.page_content)
|
||||
context = "\n".join(
|
||||
[
|
||||
f"{index}. {string}"
|
||||
for index, string in enumerate(self.schema.old_context, start=1)
|
||||
]
|
||||
)
|
||||
reranked_context = "\n".join(
|
||||
[
|
||||
f"{index}. {string}"
|
||||
for index, string in enumerate(self.schema.new_context, start=1)
|
||||
]
|
||||
)
|
||||
data = [
|
||||
{
|
||||
"question": self.schema.query,
|
||||
"context": context,
|
||||
"concise_context": reranked_context,
|
||||
"reranked_context": reranked_context,
|
||||
}
|
||||
]
|
||||
self.uptrain_evaluate(
|
||||
evaluation_name="context_reranking",
|
||||
data=data,
|
||||
checks=[
|
||||
uptrain.Evals.CONTEXT_CONCISENESS,
|
||||
uptrain.Evals.CONTEXT_RERANKING,
|
||||
],
|
||||
)
|
||||
239
venv/Lib/site-packages/langchain_community/callbacks/utils.py
Normal file
239
venv/Lib/site-packages/langchain_community/callbacks/utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Tuple, Union
|
||||
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
|
||||
def import_spacy() -> Any:
|
||||
"""Import the spacy python package and raise an error if it is not installed."""
|
||||
return guard_import("spacy")
|
||||
|
||||
|
||||
def import_pandas() -> Any:
|
||||
"""Import the pandas python package and raise an error if it is not installed."""
|
||||
return guard_import("pandas")
|
||||
|
||||
|
||||
def import_textstat() -> Any:
|
||||
"""Import the textstat python package and raise an error if it is not installed."""
|
||||
return guard_import("textstat")
|
||||
|
||||
|
||||
def _flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Iterable[Tuple[str, Any]]:
|
||||
"""
|
||||
Generator that yields flattened items from a nested dictionary for a flat dict.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Yields:
|
||||
(str, any): A key-value pair from the flattened dictionary.
|
||||
"""
|
||||
for key, value in nested_dict.items():
|
||||
new_key = parent_key + sep + key if parent_key else key
|
||||
if isinstance(value, dict):
|
||||
yield from _flatten_dict(value, new_key, sep)
|
||||
else:
|
||||
yield new_key, value
|
||||
|
||||
|
||||
def flatten_dict(
|
||||
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
|
||||
) -> Dict[str, Any]:
|
||||
"""Flatten a nested dictionary into a flat dictionary.
|
||||
|
||||
Parameters:
|
||||
nested_dict (dict): The nested dictionary to flatten.
|
||||
parent_key (str): The prefix to prepend to the keys of the flattened dict.
|
||||
sep (str): The separator to use between the parent key and the key of the
|
||||
flattened dictionary.
|
||||
|
||||
Returns:
|
||||
(dict): A flat dictionary.
|
||||
|
||||
"""
|
||||
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
|
||||
return flat_dict
|
||||
|
||||
|
||||
def hash_string(s: str) -> str:
|
||||
"""Hash a string using sha1.
|
||||
|
||||
Parameters:
|
||||
s (str): The string to hash.
|
||||
|
||||
Returns:
|
||||
(str): The hashed string.
|
||||
"""
|
||||
return hashlib.sha1(s.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def load_json(json_path: Union[str, Path]) -> str:
|
||||
"""Load json file to a string.
|
||||
|
||||
Parameters:
|
||||
json_path (str): The path to the json file.
|
||||
|
||||
Returns:
|
||||
(str): The string representation of the json file.
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
data = f.read()
|
||||
return data
|
||||
|
||||
|
||||
class BaseMetadataCallbackHandler:
|
||||
"""Handle the metadata and associated function states for callbacks.
|
||||
|
||||
Attributes:
|
||||
step (int): The current step.
|
||||
starts (int): The number of times the start method has been called.
|
||||
ends (int): The number of times the end method has been called.
|
||||
errors (int): The number of times the error method has been called.
|
||||
text_ctr (int): The number of times the text method has been called.
|
||||
ignore_llm_ (bool): Whether to ignore llm callbacks.
|
||||
ignore_chain_ (bool): Whether to ignore chain callbacks.
|
||||
ignore_agent_ (bool): Whether to ignore agent callbacks.
|
||||
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
|
||||
always_verbose_ (bool): Whether to always be verbose.
|
||||
chain_starts (int): The number of times the chain start method has been called.
|
||||
chain_ends (int): The number of times the chain end method has been called.
|
||||
llm_starts (int): The number of times the llm start method has been called.
|
||||
llm_ends (int): The number of times the llm end method has been called.
|
||||
llm_streams (int): The number of times the text method has been called.
|
||||
tool_starts (int): The number of times the tool start method has been called.
|
||||
tool_ends (int): The number of times the tool end method has been called.
|
||||
agent_ends (int): The number of times the agent end method has been called.
|
||||
on_llm_start_records (list): A list of records of the on_llm_start method.
|
||||
on_llm_token_records (list): A list of records of the on_llm_token method.
|
||||
on_llm_end_records (list): A list of records of the on_llm_end method.
|
||||
on_chain_start_records (list): A list of records of the on_chain_start method.
|
||||
on_chain_end_records (list): A list of records of the on_chain_end method.
|
||||
on_tool_start_records (list): A list of records of the on_tool_start method.
|
||||
on_tool_end_records (list): A list of records of the on_tool_end method.
|
||||
on_agent_finish_records (list): A list of records of the on_agent_end method.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.ignore_retriever_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records: list = []
|
||||
self.on_llm_token_records: list = []
|
||||
self.on_llm_end_records: list = []
|
||||
|
||||
self.on_chain_start_records: list = []
|
||||
self.on_chain_end_records: list = []
|
||||
|
||||
self.on_tool_start_records: list = []
|
||||
self.on_tool_end_records: list = []
|
||||
|
||||
self.on_text_records: list = []
|
||||
self.on_agent_finish_records: list = []
|
||||
self.on_agent_action_records: list = []
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
def get_custom_callback_meta(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step": self.step,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"errors": self.errors,
|
||||
"text_ctr": self.text_ctr,
|
||||
"chain_starts": self.chain_starts,
|
||||
"chain_ends": self.chain_ends,
|
||||
"llm_starts": self.llm_starts,
|
||||
"llm_ends": self.llm_ends,
|
||||
"llm_streams": self.llm_streams,
|
||||
"tool_starts": self.tool_starts,
|
||||
"tool_ends": self.tool_ends,
|
||||
"agent_ends": self.agent_ends,
|
||||
}
|
||||
|
||||
def reset_callback_meta(self) -> None:
|
||||
"""Reset the callback metadata."""
|
||||
self.step = 0
|
||||
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
self.text_ctr = 0
|
||||
|
||||
self.ignore_llm_ = False
|
||||
self.ignore_chain_ = False
|
||||
self.ignore_agent_ = False
|
||||
self.always_verbose_ = False
|
||||
|
||||
self.chain_starts = 0
|
||||
self.chain_ends = 0
|
||||
|
||||
self.llm_starts = 0
|
||||
self.llm_ends = 0
|
||||
self.llm_streams = 0
|
||||
|
||||
self.tool_starts = 0
|
||||
self.tool_ends = 0
|
||||
|
||||
self.agent_ends = 0
|
||||
|
||||
self.on_llm_start_records = []
|
||||
self.on_llm_token_records = []
|
||||
self.on_llm_end_records = []
|
||||
|
||||
self.on_chain_start_records = []
|
||||
self.on_chain_end_records = []
|
||||
|
||||
self.on_tool_start_records = []
|
||||
self.on_tool_end_records = []
|
||||
|
||||
self.on_text_records = []
|
||||
self.on_agent_finish_records = []
|
||||
self.on_agent_action_records = []
|
||||
return None
|
||||
@@ -0,0 +1,597 @@
|
||||
import json
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core._api import warn_deprecated
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.utils import guard_import
|
||||
|
||||
from langchain_community.callbacks.utils import (
|
||||
BaseMetadataCallbackHandler,
|
||||
flatten_dict,
|
||||
hash_string,
|
||||
import_pandas,
|
||||
import_spacy,
|
||||
import_textstat,
|
||||
)
|
||||
|
||||
|
||||
def import_wandb() -> Any:
|
||||
"""Import the wandb python package and raise an error if it is not installed."""
|
||||
return guard_import("wandb")
|
||||
|
||||
|
||||
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
|
||||
"""Load json file to a dictionary.
|
||||
|
||||
Parameters:
|
||||
json_path (str): The path to the json file.
|
||||
|
||||
Returns:
|
||||
(dict): The dictionary representation of the json file.
|
||||
"""
|
||||
with open(json_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def analyze_text(
|
||||
text: str,
|
||||
complexity_metrics: bool = True,
|
||||
visualize: bool = True,
|
||||
nlp: Any = None,
|
||||
output_dir: Optional[Union[str, Path]] = None,
|
||||
) -> dict:
|
||||
"""Analyze text using textstat and spacy.
|
||||
|
||||
Parameters:
|
||||
text (str): The text to analyze.
|
||||
complexity_metrics (bool): Whether to compute complexity metrics.
|
||||
visualize (bool): Whether to visualize the text.
|
||||
nlp (spacy.lang): The spacy language model to use for visualization.
|
||||
output_dir (str): The directory to save the visualization files to.
|
||||
|
||||
Returns:
|
||||
`dict` containing the complexity metrics and visualization
|
||||
files serialized in a wandb.Html element.
|
||||
"""
|
||||
resp = {}
|
||||
textstat = import_textstat()
|
||||
wandb = import_wandb()
|
||||
spacy = import_spacy()
|
||||
if complexity_metrics:
|
||||
text_complexity_metrics = {
|
||||
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
||||
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
||||
"smog_index": textstat.smog_index(text),
|
||||
"coleman_liau_index": textstat.coleman_liau_index(text),
|
||||
"automated_readability_index": textstat.automated_readability_index(text),
|
||||
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
|
||||
"difficult_words": textstat.difficult_words(text),
|
||||
"linsear_write_formula": textstat.linsear_write_formula(text),
|
||||
"gunning_fog": textstat.gunning_fog(text),
|
||||
"text_standard": textstat.text_standard(text),
|
||||
"fernandez_huerta": textstat.fernandez_huerta(text),
|
||||
"szigriszt_pazos": textstat.szigriszt_pazos(text),
|
||||
"gutierrez_polini": textstat.gutierrez_polini(text),
|
||||
"crawford": textstat.crawford(text),
|
||||
"gulpease_index": textstat.gulpease_index(text),
|
||||
"osman": textstat.osman(text),
|
||||
}
|
||||
resp.update(text_complexity_metrics)
|
||||
|
||||
if visualize and nlp and output_dir is not None:
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": wandb.Html(str(dep_output_path)),
|
||||
"entities": wandb.Html(str(ent_output_path)),
|
||||
}
|
||||
resp.update(text_visualizations)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
|
||||
"""Construct an html element from a prompt and a generation.
|
||||
|
||||
Parameters:
|
||||
prompt (str): The prompt.
|
||||
generation (str): The generation.
|
||||
|
||||
Returns:
|
||||
(wandb.Html): The html element."""
|
||||
wandb = import_wandb()
|
||||
formatted_prompt = prompt.replace("\n", "<br>")
|
||||
formatted_generation = generation.replace("\n", "<br>")
|
||||
|
||||
return wandb.Html(
|
||||
f"""
|
||||
<p style="color:black;">{formatted_prompt}:</p>
|
||||
<blockquote>
|
||||
<p style="color:green;">
|
||||
{formatted_generation}
|
||||
</p>
|
||||
</blockquote>
|
||||
""",
|
||||
inject=False,
|
||||
)
|
||||
|
||||
|
||||
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
"""Callback Handler that logs to Weights and Biases.
|
||||
|
||||
Parameters:
|
||||
job_type (str): The type of job.
|
||||
project (str): The project to log to.
|
||||
entity (str): The entity to log to.
|
||||
tags (list): The tags to log.
|
||||
group (str): The group to log to.
|
||||
name (str): The name of the run.
|
||||
notes (str): The notes to log.
|
||||
visualize (bool): Whether to visualize the run.
|
||||
complexity_metrics (bool): Whether to log complexity metrics.
|
||||
stream_logs (bool): Whether to stream callback actions to W&B
|
||||
|
||||
This handler will utilize the associated callback method called and formats
|
||||
the input of each callback function with metadata regarding the state of LLM run,
|
||||
and adds the response to the list of records for both the {method}_records and
|
||||
action. It then logs the response using the run.log() method to Weights and Biases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = "langchain_callback_demo",
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: bool = False,
|
||||
complexity_metrics: bool = False,
|
||||
stream_logs: bool = False,
|
||||
) -> None:
|
||||
"""Initialize callback handler."""
|
||||
|
||||
wandb = import_wandb()
|
||||
import_pandas()
|
||||
import_textstat()
|
||||
spacy = import_spacy()
|
||||
super().__init__()
|
||||
|
||||
self.job_type = job_type
|
||||
self.project = project
|
||||
self.entity = entity
|
||||
self.tags = tags
|
||||
self.group = group
|
||||
self.name = name
|
||||
self.notes = notes
|
||||
self.visualize = visualize
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.run = wandb.init(
|
||||
job_type=self.job_type,
|
||||
project=self.project,
|
||||
entity=self.entity,
|
||||
tags=self.tags,
|
||||
group=self.group,
|
||||
name=self.name,
|
||||
notes=self.notes,
|
||||
)
|
||||
warning = (
|
||||
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
|
||||
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
|
||||
"instead."
|
||||
)
|
||||
wandb.termwarn(
|
||||
warning,
|
||||
repeat=False,
|
||||
)
|
||||
self.callback_columns: list = []
|
||||
self.action_records: list = []
|
||||
self.complexity_metrics = complexity_metrics
|
||||
self.visualize = visualize
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
warn_deprecated(
|
||||
"0.3.8",
|
||||
pending=False,
|
||||
message=(
|
||||
"Please use the WeaveTracer instead of the WandbCallbackHandler. "
|
||||
"The WeaveTracer is a more flexible and powerful tool for logging "
|
||||
"and tracing your LangChain callables."
|
||||
"Find more information at https://weave-docs.wandb.ai/guides/integrations/langchain"
|
||||
),
|
||||
alternative=(
|
||||
"Please instantiate the WeaveTracer from "
|
||||
"weave.integrations.langchain import WeaveTracer ."
|
||||
"For autologging simply use weave.init() and log all traces "
|
||||
"from your LangChain callables."
|
||||
),
|
||||
)
|
||||
|
||||
def _init_resp(self) -> Dict:
|
||||
return {k: None for k in self.callback_columns}
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts."""
|
||||
self.step += 1
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_resp = deepcopy(resp)
|
||||
prompt_resp["prompts"] = prompt
|
||||
self.on_llm_start_records.append(prompt_resp)
|
||||
self.action_records.append(prompt_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(prompt_resp)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.step += 1
|
||||
self.llm_streams += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_new_token", "token": token})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_llm_token_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.step += 1
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_llm_end"})
|
||||
resp.update(flatten_dict(response.llm_output or {}))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
generation_resp = deepcopy(resp)
|
||||
generation_resp.update(flatten_dict(generation.dict()))
|
||||
generation_resp.update(
|
||||
analyze_text(
|
||||
generation.text,
|
||||
complexity_metrics=self.complexity_metrics,
|
||||
visualize=self.visualize,
|
||||
nlp=self.nlp,
|
||||
output_dir=self.temp_dir.name,
|
||||
)
|
||||
)
|
||||
self.on_llm_end_records.append(generation_resp)
|
||||
self.action_records.append(generation_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(generation_resp)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.step += 1
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_start"})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
chain_input = inputs["input"]
|
||||
|
||||
if isinstance(chain_input, str):
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp["input"] = chain_input
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
elif isinstance(chain_input, list):
|
||||
for inp in chain_input:
|
||||
input_resp = deepcopy(resp)
|
||||
input_resp.update(inp)
|
||||
self.on_chain_start_records.append(input_resp)
|
||||
self.action_records.append(input_resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(input_resp)
|
||||
else:
|
||||
raise ValueError("Unexpected data format provided!")
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.step += 1
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_chain_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_start", "input_str": input_str})
|
||||
resp.update(flatten_dict(serialized))
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_start_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
output = str(output)
|
||||
self.step += 1
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_tool_end", "output": output})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_tool_end_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.step += 1
|
||||
self.errors += 1
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Run when agent is ending.
|
||||
"""
|
||||
self.step += 1
|
||||
self.text_ctr += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update({"action": "on_text", "text": text})
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_text_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.step += 1
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_finish",
|
||||
"output": finish.return_values["output"],
|
||||
"log": finish.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
|
||||
self.on_agent_finish_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.step += 1
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
resp = self._init_resp()
|
||||
resp.update(
|
||||
{
|
||||
"action": "on_agent_action",
|
||||
"tool": action.tool,
|
||||
"tool_input": action.tool_input,
|
||||
"log": action.log,
|
||||
}
|
||||
)
|
||||
resp.update(self.get_custom_callback_meta())
|
||||
self.on_agent_action_records.append(resp)
|
||||
self.action_records.append(resp)
|
||||
if self.stream_logs:
|
||||
self.run.log(resp)
|
||||
|
||||
def _create_session_analysis_df(self) -> Any:
|
||||
"""Create a dataframe with all the information from the session."""
|
||||
pd = import_pandas()
|
||||
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
|
||||
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
|
||||
|
||||
llm_input_prompts_df = (
|
||||
on_llm_start_records_df[["step", "prompts", "name"]]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "prompt_step"}, axis=1)
|
||||
)
|
||||
complexity_metrics_columns = []
|
||||
visualizations_columns = []
|
||||
|
||||
if self.complexity_metrics:
|
||||
complexity_metrics_columns = [
|
||||
"flesch_reading_ease",
|
||||
"flesch_kincaid_grade",
|
||||
"smog_index",
|
||||
"coleman_liau_index",
|
||||
"automated_readability_index",
|
||||
"dale_chall_readability_score",
|
||||
"difficult_words",
|
||||
"linsear_write_formula",
|
||||
"gunning_fog",
|
||||
"text_standard",
|
||||
"fernandez_huerta",
|
||||
"szigriszt_pazos",
|
||||
"gutierrez_polini",
|
||||
"crawford",
|
||||
"gulpease_index",
|
||||
"osman",
|
||||
]
|
||||
|
||||
if self.visualize:
|
||||
visualizations_columns = ["dependency_tree", "entities"]
|
||||
|
||||
llm_outputs_df = (
|
||||
on_llm_end_records_df[
|
||||
[
|
||||
"step",
|
||||
"text",
|
||||
"token_usage_total_tokens",
|
||||
"token_usage_prompt_tokens",
|
||||
"token_usage_completion_tokens",
|
||||
]
|
||||
+ complexity_metrics_columns
|
||||
+ visualizations_columns
|
||||
]
|
||||
.dropna(axis=1)
|
||||
.rename({"step": "output_step", "text": "output"}, axis=1)
|
||||
)
|
||||
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
|
||||
session_analysis_df["chat_html"] = session_analysis_df[
|
||||
["prompts", "output"]
|
||||
].apply(
|
||||
lambda row: construct_html_from_prompt_and_generation(
|
||||
row["prompts"], row["output"]
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def flush_tracker(
|
||||
self,
|
||||
langchain_asset: Any = None,
|
||||
reset: bool = True,
|
||||
finish: bool = False,
|
||||
job_type: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
entity: Optional[str] = None,
|
||||
tags: Optional[Sequence] = None,
|
||||
group: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
notes: Optional[str] = None,
|
||||
visualize: Optional[bool] = None,
|
||||
complexity_metrics: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Flush the tracker and reset the session.
|
||||
|
||||
Args:
|
||||
langchain_asset: The langchain asset to save.
|
||||
reset: Whether to reset the session.
|
||||
finish: Whether to finish the run.
|
||||
job_type: The job type.
|
||||
project: The project.
|
||||
entity: The entity.
|
||||
tags: The tags.
|
||||
group: The group.
|
||||
name: The name.
|
||||
notes: The notes.
|
||||
visualize: Whether to visualize.
|
||||
complexity_metrics: Whether to compute complexity metrics.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pd = import_pandas()
|
||||
wandb = import_wandb()
|
||||
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
|
||||
session_analysis_table = wandb.Table(
|
||||
dataframe=self._create_session_analysis_df()
|
||||
)
|
||||
self.run.log(
|
||||
{
|
||||
"action_records": action_records_table,
|
||||
"session_analysis": session_analysis_table,
|
||||
}
|
||||
)
|
||||
|
||||
if langchain_asset:
|
||||
langchain_asset_path = Path(self.temp_dir.name, "model.json")
|
||||
model_artifact = wandb.Artifact(name="model", type="model")
|
||||
model_artifact.add(action_records_table, name="action_records")
|
||||
model_artifact.add(session_analysis_table, name="session_analysis")
|
||||
try:
|
||||
langchain_asset.save(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except ValueError:
|
||||
langchain_asset.save_agent(langchain_asset_path)
|
||||
model_artifact.add_file(str(langchain_asset_path))
|
||||
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
|
||||
except NotImplementedError as e:
|
||||
print("Could not save model.") # noqa: T201
|
||||
print(repr(e)) # noqa: T201
|
||||
pass
|
||||
self.run.log_artifact(model_artifact)
|
||||
|
||||
if finish or reset:
|
||||
self.run.finish()
|
||||
self.temp_dir.cleanup()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore[misc]
|
||||
job_type=job_type if job_type else self.job_type,
|
||||
project=project if project else self.project,
|
||||
entity=entity if entity else self.entity,
|
||||
tags=tags if tags else self.tags,
|
||||
group=group if group else self.group,
|
||||
name=name if name else self.name,
|
||||
notes=notes if notes else self.notes,
|
||||
visualize=visualize if visualize else self.visualize,
|
||||
complexity_metrics=(
|
||||
complexity_metrics
|
||||
if complexity_metrics
|
||||
else self.complexity_metrics
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.utils import get_from_env, guard_import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from whylogs.api.logger.logger import Logger
|
||||
|
||||
diagnostic_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_langkit(
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
) -> Any:
|
||||
"""Import the langkit python package and raise an error if it is not installed.
|
||||
|
||||
Args:
|
||||
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
|
||||
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
|
||||
themes: Whether to import the langkit.themes module. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The imported langkit module.
|
||||
"""
|
||||
langkit = guard_import("langkit")
|
||||
guard_import("langkit.regexes")
|
||||
guard_import("langkit.textstat")
|
||||
if sentiment:
|
||||
guard_import("langkit.sentiment")
|
||||
if toxicity:
|
||||
guard_import("langkit.toxicity")
|
||||
if themes:
|
||||
guard_import("langkit.themes")
|
||||
return langkit
|
||||
|
||||
|
||||
class WhyLabsCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Callback Handler for logging to WhyLabs. This callback handler utilizes
|
||||
`langkit` to extract features from the prompts & responses when interacting with
|
||||
an LLM. These features can be used to guardrail, evaluate, and observe interactions
|
||||
over time to detect issues relating to hallucinations, prompt engineering,
|
||||
or output validation. LangKit is an LLM monitoring toolkit developed by WhyLabs.
|
||||
|
||||
Here are some examples of what can be monitored with LangKit:
|
||||
* Text Quality
|
||||
- readability score
|
||||
- complexity and grade scores
|
||||
* Text Relevance
|
||||
- Similarity scores between prompt/responses
|
||||
- Similarity scores against user-defined themes
|
||||
- Topic classification
|
||||
* Security and Privacy
|
||||
- patterns - count of strings matching a user-defined regex pattern group
|
||||
- jailbreaks - similarity scores with respect to known jailbreak attempts
|
||||
- prompt injection - similarity scores with respect to known prompt attacks
|
||||
- refusals - similarity scores with respect to known LLM refusal responses
|
||||
* Sentiment and Toxicity
|
||||
- sentiment analysis
|
||||
- toxicity analysis
|
||||
|
||||
For more information, see https://docs.whylabs.ai/docs/language-model-monitoring
|
||||
or check out the LangKit repo here: https://github.com/whylabs/langkit
|
||||
|
||||
---
|
||||
Args:
|
||||
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
|
||||
way to specify the API key is with environment variable
|
||||
WHYLABS_API_KEY.
|
||||
org_id (Optional[str]): WhyLabs organization id to write profiles to.
|
||||
Optional because the preferred way to specify the organization id is
|
||||
with environment variable WHYLABS_DEFAULT_ORG_ID.
|
||||
dataset_id (Optional[str]): WhyLabs dataset id to write profiles to.
|
||||
Optional because the preferred way to specify the dataset id is
|
||||
with environment variable WHYLABS_DEFAULT_DATASET_ID.
|
||||
sentiment (bool): Whether to enable sentiment analysis. Defaults to False.
|
||||
toxicity (bool): Whether to enable toxicity analysis. Defaults to False.
|
||||
themes (bool): Whether to enable theme analysis. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger, handler: Any):
|
||||
"""Initiate the rolling logger."""
|
||||
super().__init__()
|
||||
if hasattr(handler, "init"):
|
||||
handler.init(self)
|
||||
if hasattr(handler, "_get_callbacks"):
|
||||
self._callbacks = handler._get_callbacks()
|
||||
else:
|
||||
self._callbacks = dict()
|
||||
diagnostic_logger.warning("initialized handler without callbacks.")
|
||||
self._logger = logger
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Explicitly write current profile if using a rolling logger."""
|
||||
if self._logger and hasattr(self._logger, "_do_rollover"):
|
||||
self._logger._do_rollover()
|
||||
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close any loggers to allow writing out of any profiles before exiting."""
|
||||
if self._logger and hasattr(self._logger, "close"):
|
||||
self._logger.close()
|
||||
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
|
||||
|
||||
def __enter__(self) -> WhyLabsCallbackHandler:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exception_type: Any, exception_value: Any, traceback: Any
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
) -> WhyLabsCallbackHandler:
|
||||
"""Instantiate whylogs Logger from params.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
|
||||
way to specify the API key is with environment variable
|
||||
WHYLABS_API_KEY.
|
||||
org_id (Optional[str]): WhyLabs organization id to write profiles to.
|
||||
If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_ORG_ID.
|
||||
dataset_id (Optional[str]): The model or dataset this callback is gathering
|
||||
telemetry for. If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_DATASET_ID.
|
||||
sentiment (bool): If True will initialize a model to perform
|
||||
sentiment analysis compound score. Defaults to False and will not gather
|
||||
this metric.
|
||||
toxicity (bool): If True will initialize a model to score
|
||||
toxicity. Defaults to False and will not gather this metric.
|
||||
themes (bool): If True will initialize a model to calculate
|
||||
distance to configured themes. Defaults to None and will not gather this
|
||||
metric.
|
||||
logger (Optional[Logger]): If specified will bind the configured logger as
|
||||
the telemetry gathering agent. Defaults to LangKit schema with periodic
|
||||
WhyLabs writer.
|
||||
"""
|
||||
# langkit library will import necessary whylogs libraries
|
||||
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
|
||||
why = guard_import("whylogs")
|
||||
get_callback_instance = guard_import(
|
||||
"langkit.callback_handler"
|
||||
).get_callback_instance
|
||||
WhyLabsWriter = guard_import("whylogs.api.writer.whylabs").WhyLabsWriter
|
||||
udf_schema = guard_import("whylogs.experimental.core.udf_schema").udf_schema
|
||||
|
||||
if logger is None:
|
||||
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
|
||||
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
|
||||
dataset_id = dataset_id or get_from_env(
|
||||
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
|
||||
)
|
||||
whylabs_writer = WhyLabsWriter(
|
||||
api_key=api_key, org_id=org_id, dataset_id=dataset_id
|
||||
)
|
||||
|
||||
whylabs_logger = why.logger(
|
||||
mode="rolling", interval=5, when="M", schema=udf_schema()
|
||||
)
|
||||
|
||||
whylabs_logger.append_writer(writer=whylabs_writer)
|
||||
else:
|
||||
diagnostic_logger.info("Using passed in whylogs logger {logger}")
|
||||
whylabs_logger = logger
|
||||
|
||||
callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
|
||||
diagnostic_logger.info(
|
||||
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
|
||||
)
|
||||
return callback_handler_cls
|
||||
Reference in New Issue
Block a user