initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
"""**Callback handlers** allow listening to events in LangChain.
**Class hierarchy:**
.. code-block::
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.callbacks.aim_callback import (
AimCallbackHandler,
)
from langchain_community.callbacks.argilla_callback import (
ArgillaCallbackHandler,
)
from langchain_community.callbacks.arize_callback import (
ArizeCallbackHandler,
)
from langchain_community.callbacks.arthur_callback import (
ArthurCallbackHandler,
)
from langchain_community.callbacks.clearml_callback import (
ClearMLCallbackHandler,
)
from langchain_community.callbacks.comet_ml_callback import (
CometCallbackHandler,
)
from langchain_community.callbacks.context_callback import (
ContextCallbackHandler,
)
from langchain_community.callbacks.fiddler_callback import (
FiddlerCallbackHandler,
)
from langchain_community.callbacks.flyte_callback import (
FlyteCallbackHandler,
)
from langchain_community.callbacks.human import (
HumanApprovalCallbackHandler,
)
from langchain_community.callbacks.infino_callback import (
InfinoCallbackHandler,
)
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler,
)
from langchain_community.callbacks.llmonitor_callback import (
LLMonitorCallbackHandler,
)
from langchain_community.callbacks.manager import (
get_openai_callback,
wandb_tracing_enabled,
)
from langchain_community.callbacks.mlflow_callback import (
MlflowCallbackHandler,
)
from langchain_community.callbacks.openai_info import (
OpenAICallbackHandler,
)
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler,
)
from langchain_community.callbacks.sagemaker_callback import (
SageMakerCallbackHandler,
)
from langchain_community.callbacks.streamlit import (
LLMThoughtLabeler,
StreamlitCallbackHandler,
)
from langchain_community.callbacks.trubrics_callback import (
TrubricsCallbackHandler,
)
from langchain_community.callbacks.upstash_ratelimit_callback import (
UpstashRatelimitError,
UpstashRatelimitHandler, # noqa: F401
)
from langchain_community.callbacks.uptrain_callback import (
UpTrainCallbackHandler,
)
from langchain_community.callbacks.wandb_callback import (
WandbCallbackHandler,
)
from langchain_community.callbacks.whylabs_callback import (
WhyLabsCallbackHandler,
)
_module_lookup = {
"AimCallbackHandler": "langchain_community.callbacks.aim_callback",
"ArgillaCallbackHandler": "langchain_community.callbacks.argilla_callback",
"ArizeCallbackHandler": "langchain_community.callbacks.arize_callback",
"ArthurCallbackHandler": "langchain_community.callbacks.arthur_callback",
"ClearMLCallbackHandler": "langchain_community.callbacks.clearml_callback",
"CometCallbackHandler": "langchain_community.callbacks.comet_ml_callback",
"ContextCallbackHandler": "langchain_community.callbacks.context_callback",
"FiddlerCallbackHandler": "langchain_community.callbacks.fiddler_callback",
"FlyteCallbackHandler": "langchain_community.callbacks.flyte_callback",
"HumanApprovalCallbackHandler": "langchain_community.callbacks.human",
"InfinoCallbackHandler": "langchain_community.callbacks.infino_callback",
"LLMThoughtLabeler": "langchain_community.callbacks.streamlit",
"LLMonitorCallbackHandler": "langchain_community.callbacks.llmonitor_callback",
"LabelStudioCallbackHandler": "langchain_community.callbacks.labelstudio_callback",
"MlflowCallbackHandler": "langchain_community.callbacks.mlflow_callback",
"OpenAICallbackHandler": "langchain_community.callbacks.openai_info",
"PromptLayerCallbackHandler": "langchain_community.callbacks.promptlayer_callback",
"SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback",
"StreamlitCallbackHandler": "langchain_community.callbacks.streamlit",
"TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback",
"UpstashRatelimitError": "langchain_community.callbacks.upstash_ratelimit_callback",
"UpstashRatelimitHandler": "langchain_community.callbacks.upstash_ratelimit_callback", # noqa
"UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback",
"WandbCallbackHandler": "langchain_community.callbacks.wandb_callback",
"WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback",
"get_openai_callback": "langchain_community.callbacks.manager",
"wandb_tracing_enabled": "langchain_community.callbacks.manager",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"AimCallbackHandler",
"ArgillaCallbackHandler",
"ArizeCallbackHandler",
"ArthurCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"ContextCallbackHandler",
"FiddlerCallbackHandler",
"FlyteCallbackHandler",
"HumanApprovalCallbackHandler",
"InfinoCallbackHandler",
"LLMThoughtLabeler",
"LLMonitorCallbackHandler",
"LabelStudioCallbackHandler",
"MlflowCallbackHandler",
"OpenAICallbackHandler",
"PromptLayerCallbackHandler",
"SageMakerCallbackHandler",
"StreamlitCallbackHandler",
"TrubricsCallbackHandler",
"UpstashRatelimitError",
"UpstashRatelimitHandler",
"UpTrainCallbackHandler",
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"wandb_tracing_enabled",
]

View File

@@ -0,0 +1,434 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
def import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
return guard_import("aim")
class BaseMetadataCallbackHandler:
"""Callback handler for the metadata and associated function states for callbacks.
Attributes:
step (int): The current step.
starts (int): The number of times the start method has been called.
ends (int): The number of times the end method has been called.
errors (int): The number of times the error method has been called.
text_ctr (int): The number of times the text method has been called.
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
llm_starts (int): The number of times the llm start method has been called.
llm_ends (int): The number of times the llm end method has been called.
llm_streams (int): The number of times the text method has been called.
tool_starts (int): The number of times the tool start method has been called.
tool_ends (int): The number of times the tool end method has been called.
agent_ends (int): The number of times the agent end method has been called.
"""
def __init__(self) -> None:
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,
"starts": self.starts,
"ends": self.ends,
"errors": self.errors,
"text_ctr": self.text_ctr,
"chain_starts": self.chain_starts,
"chain_ends": self.chain_ends,
"llm_starts": self.llm_starts,
"llm_ends": self.llm_ends,
"llm_streams": self.llm_streams,
"tool_starts": self.tool_starts,
"tool_ends": self.tool_ends,
"agent_ends": self.agent_ends,
}
def reset_callback_meta(self) -> None:
"""Reset the callback metadata."""
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
return None
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Aim.
Parameters:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run
and then logs the response to Aim.
"""
def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
) -> None:
"""Initialize callback handler."""
super().__init__()
aim = import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records: list = []
def setup(self, **kwargs: Any) -> None:
aim = import_aim()
if not self._run:
if self._run_hash:
self._run = aim.Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
)
else:
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
if kwargs:
for key, value in kwargs.items():
self._run.set(key, value, strict=False)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
aim = import_aim()
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = {"action": "on_llm_start"}
resp.update(self.get_custom_callback_meta())
prompts_res = deepcopy(prompts)
self._run.track(
[aim.Text(prompt) for prompt in prompts_res],
name="on_llm_start",
context=resp,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
aim = import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = {"action": "on_llm_end"}
resp.update(self.get_custom_callback_meta())
response_res = deepcopy(response)
generated = [
aim.Text(generation.text)
for generations in response_res.generations
for generation in generations
]
self._run.track(
generated,
name="on_llm_end",
context=resp,
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
aim = import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = {"action": "on_chain_start"}
resp.update(self.get_custom_callback_meta())
inputs_res = deepcopy(inputs)
self._run.track(
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
aim = import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = {"action": "on_chain_end"}
resp.update(self.get_custom_callback_meta())
outputs_res = deepcopy(outputs)
self._run.track(
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {"action": "on_tool_start"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
aim = import_aim()
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = {"action": "on_tool_end"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
aim = import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = {"action": "on_agent_finish"}
resp.update(self.get_custom_callback_meta())
finish_res = deepcopy(finish)
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
finish_res.return_values["output"], finish_res.log
)
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {
"action": "on_agent_action",
"tool": action.tool,
}
resp.update(self.get_custom_callback_meta())
action_res = deepcopy(action)
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
action_res.tool_input, action_res.log
)
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
def flush_tracker(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
) -> None:
"""Flush the tracker and reset the session.
Args:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
Returns:
None
"""
if langchain_asset:
try:
for key, value in langchain_asset.dict().items():
self._run.set(key, value, strict=False)
except Exception:
pass
if finish or reset:
self._run.close()
self.reset_callback_meta()
if reset:
aim = import_aim()
self.repo = repo if repo else self.repo
self.experiment_name = (
experiment_name if experiment_name else self.experiment_name
)
self.system_tracking_interval = (
system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval
)
self.log_system_params = (
log_system_params if log_system_params else self.log_system_params
)
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records = []

View File

@@ -0,0 +1,349 @@
import os
import warnings
from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from packaging.version import parse
class ArgillaCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into Argilla.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset` in
Argilla, please visit
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
Examples:
>>> from langchain_community.llms import OpenAI
>>> from langchain_community.callbacks import ArgillaCallbackHandler
>>> argilla_callback = ArgillaCallbackHandler(
... dataset_name="my-dataset",
... workspace_name="my-workspace",
... api_url="http://localhost:6900",
... api_key="argilla.apikey",
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[argilla_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best NLP-annotation tool out there? (no bias at all)",
... ])
"Argilla, no doubt about it."
"""
REPO_URL: str = "https://github.com/argilla-io/argilla"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html"
DEFAULT_API_URL: str = "http://localhost:6900"
def __init__(
self,
dataset_name: str,
workspace_name: Optional[str] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""Initializes the `ArgillaCallbackHandler`.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
"""
super().__init__()
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg
self.ARGILLA_VERSION = rg.__version__
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
" default API URL in Argilla Quickstart."
),
)
api_url = self.DEFAULT_API_URL
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
self.DEFAULT_API_KEY = (
"admin.apikey"
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
else "owner.apikey"
)
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
" default API key in Argilla Quickstart."
),
)
api_key = self.DEFAULT_API_KEY
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(api_key=api_key, api_url=api_url)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
f"please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
# Set the Argilla variables
self.dataset_name = dataset_name
self.workspace_name = workspace_name or rg.get_workspace()
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
**extra_args,
)
except Exception as e:
raise FileNotFoundError(
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
f" please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f"{self.workspace_name} had fields that are not supported yet for the"
f"`langchain` integration. Supported fields are: {supported_fields},"
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
" For more information on how to create a `langchain`-compatible"
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Save the prompts in memory when an LLM starts."""
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to Argilla when an LLM ends."""
# Do nothing if there's a parent_run_id, since we will log the records when
# the chain ends
if kwargs["parent_run_id"]:
return
# Creates the records and adds them to the `FeedbackDataset`
prompts = self.prompts[str(kwargs["run_id"])]
for prompt, generations in zip(prompts, response.generations):
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": generation.text.strip(),
},
}
for generation in generations
]
)
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs:
self.prompts.update(
{
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
inputs["input"]
if isinstance(inputs["input"], list)
else [inputs["input"]]
)
}
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
differs if the output is a list or not.
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
List, self.prompts.get(str(kwargs["run_id"]), [])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, chain_output_val)
]
)
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": chain_output_val.strip(),
},
}
]
)
# Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass

View File

@@ -0,0 +1,213 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import import_pandas
class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
SPACE_KEY: Optional[str] = None,
API_KEY: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
super().__init__()
self.model_id = model_id
self.model_version = model_version
self.space_key = SPACE_KEY
self.api_key = API_KEY
self.prompt_records: List[str] = []
self.response_records: List[str] = []
self.prediction_ids: List[str] = []
self.pred_timestamps: List[int] = []
self.response_embeddings: List[float] = []
self.prompt_embeddings: List[float] = []
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.step = 0
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased",
tokenizer_max_length=512,
batch_size=256,
)
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS")
else:
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
# Safe check if 'llm_output' and 'token_usage' exist
if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = self.total_tokens = self.completion_tokens = (
0 # assign default value
)
for generations in response.generations:
for generation in generations:
prompt = self.prompt_records[self.step]
self.step = self.step + 1
prompt_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True)
)
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True)
)
pred_timestamp = datetime.now().timestamp()
# Define the columns and data
columns = [
"prediction_ts",
"response",
"prompt",
"response_vector",
"prompt_vector",
"prompt_token",
"completion_token",
"total_token",
]
data = [
[
pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Create the DataFrame
df = pd.DataFrame(data, columns=columns)
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames(
vector_column_name="prompt_vector", data_column_name="prompt"
)
response_columns = EmbeddingColumnNames(
vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
)
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!") # noqa: T201
else:
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_text(self, text: str, **kwargs: Any) -> None:
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
pass

View File

@@ -0,0 +1,297 @@
"""ArthurAI's Callback Handler."""
from __future__ import annotations
import os
import uuid
from collections import defaultdict
from datetime import datetime
from time import time
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
import numpy as np
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
if TYPE_CHECKING:
import arthurai
from arthurai.core.models import ArthurModel
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOKEN_USAGE = "token_usage"
FINISH_REASON = "finish_reason"
DURATION = "duration"
def _lazy_load_arthur() -> arthurai:
"""Lazy load Arthur."""
try:
import arthurai
except ImportError as e:
raise ImportError(
"To use the ArthurCallbackHandler you need the"
" `arthurai` package. Please install it with"
" `pip install arthurai`.",
e,
)
return arthurai
class ArthurCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arthur platform.
Arthur helps enterprise teams optimize model operations
and performance at scale. The Arthur API tracks model
performance, explainability, and fairness across tabular,
NLP, and CV models. Our API is model- and platform-agnostic,
and continuously scales with complex and dynamic enterprise needs.
To learn more about Arthur, visit our website at
https://www.arthur.ai/ or read the Arthur docs at
https://docs.arthur.ai/
"""
def __init__(
self,
arthur_model: ArthurModel,
) -> None:
"""Initialize callback handler."""
super().__init__()
arthurai = _lazy_load_arthur()
Stage = arthurai.common.constants.Stage
ValueType = arthurai.common.constants.ValueType
self.arthur_model = arthur_model
# save the attributes of this model to be used when preparing
# inferences to log to Arthur in on_llm_end()
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
self.input_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.ModelPipelineInput
and x.value_type == ValueType.Unstructured_Text
][0].name
self.output_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.PredictedValue
and x.value_type == ValueType.Unstructured_Text
][0].name
self.token_likelihood_attr = None
if (
len(
[
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
]
)
> 0
):
self.token_likelihood_attr = [
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
][0].name
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
@classmethod
def from_credentials(
cls,
model_id: str,
arthur_url: Optional[str] = "https://app.arthur.ai",
arthur_login: Optional[str] = None,
arthur_password: Optional[str] = None,
) -> ArthurCallbackHandler:
"""Initialize callback handler from Arthur credentials.
Args:
model_id (str): The ID of the arthur model to log to.
arthur_url (str, optional): The URL of the Arthur instance to log to.
Defaults to "https://app.arthur.ai".
arthur_login (str, optional): The login to use to connect to Arthur.
Defaults to None.
arthur_password (str, optional): The password to use to connect to
Arthur. Defaults to None.
Returns:
ArthurCallbackHandler: The initialized callback handler.
"""
arthurai = _lazy_load_arthur()
ArthurAI = arthurai.ArthurAI
ResponseClientError = arthurai.common.exceptions.ResponseClientError
# connect to Arthur
if arthur_login is None:
try:
arthur_api_key = os.environ["ARTHUR_API_KEY"]
except KeyError:
raise ValueError(
"No Arthur authentication provided. Either give"
" a login to the ArthurCallbackHandler"
" or set an ARTHUR_API_KEY as an environment variable."
)
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
else:
if arthur_password is None:
arthur = ArthurAI(url=arthur_url, login=arthur_login)
else:
arthur = ArthurAI(
url=arthur_url, login=arthur_login, password=arthur_password
)
# get model from Arthur by the provided model ID
try:
arthur_model = arthur.get_model(model_id)
except ResponseClientError:
raise ValueError(
f"Was unable to retrieve model with id {model_id} from Arthur."
" Make sure the ID corresponds to a model that is currently"
" registered with your Arthur account."
)
return cls(arthur_model)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""On LLM start, save the input prompts"""
run_id = kwargs["run_id"]
self.run_map[run_id]["input_texts"] = prompts
self.run_map[run_id]["start_time"] = time()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""On LLM end, send data to Arthur."""
try:
import pytz
except ImportError as e:
raise ImportError(
"Could not import pytz. Please install it with 'pip install pytz'."
) from e
run_id = kwargs["run_id"]
# get the run params from this run ID,
# or raise an error if this run ID has no corresponding metadata in self.run_map
try:
run_map_data = self.run_map[run_id]
except KeyError as e:
raise KeyError(
"This function has been called with a run_id"
" that was never registered in on_llm_start()."
" Restart and try running the LLM again"
) from e
# mark the duration time between on_llm_start() and on_llm_end()
time_from_start_to_end = time() - run_map_data["start_time"]
# create inferences to log to Arthur
inferences = []
for i, generations in enumerate(response.generations):
for generation in generations:
inference = {
"partner_inference_id": str(uuid.uuid4()),
"inference_timestamp": datetime.now(tz=pytz.UTC),
self.input_attr: run_map_data["input_texts"][i],
self.output_attr: generation.text,
}
if generation.generation_info is not None:
# add finish reason to the inference
# if generation info contains a finish reason and
# if the ArthurModel was registered to monitor finish_reason
if (
FINISH_REASON in generation.generation_info
and FINISH_REASON in self.attr_names
):
inference[FINISH_REASON] = generation.generation_info[
FINISH_REASON
]
# add token likelihoods data to the inference if the ArthurModel
# was registered to monitor token likelihoods
logprobs_data = generation.generation_info["logprobs"]
if (
logprobs_data is not None
and self.token_likelihood_attr is not None
):
logprobs = logprobs_data["top_logprobs"]
likelihoods = [
{k: np.exp(v) for k, v in logprobs[i].items()}
for i in range(len(logprobs))
]
inference[self.token_likelihood_attr] = likelihoods
# add token usage counts to the inference if the
# ArthurModel was registered to monitor token usage
if (
isinstance(response.llm_output, dict)
and TOKEN_USAGE in response.llm_output
):
token_usage = response.llm_output[TOKEN_USAGE]
if (
PROMPT_TOKENS in token_usage
and PROMPT_TOKENS in self.attr_names
):
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
if (
COMPLETION_TOKENS in token_usage
and COMPLETION_TOKENS in self.attr_names
):
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
# add inference duration to the inference if the ArthurModel
# was registered to monitor inference duration
if DURATION in self.attr_names:
inference[DURATION] = time_from_start_to_end
inferences.append(inference)
# send inferences to arthur
self.arthur_model.send_inferences(inferences)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""On chain start, do nothing."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""On chain end, do nothing."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""On new token, pass."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""

View File

@@ -0,0 +1,135 @@
import threading
from typing import Any, Dict, List, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
MODEL_COST_PER_1K_INPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0008,
"anthropic.claude-v2": 0.008,
"anthropic.claude-v2:1": 0.008,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.003,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.003,
"anthropic.claude-sonnet-4-20250514-v1:0": 0.003,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
"anthropic.claude-3-opus-20240229-v1:0": 0.015,
"anthropic.claude-opus-4-20250514-v1:0": 0.015,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.0008,
}
MODEL_COST_PER_1K_OUTPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0024,
"anthropic.claude-v2": 0.024,
"anthropic.claude-v2:1": 0.024,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.015,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.015,
"anthropic.claude-sonnet-4-20250514-v1:0": 0.015,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
"anthropic.claude-3-opus-20240229-v1:0": 0.075,
"anthropic.claude-opus-4-20250514-v1:0": 0.075,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.004,
}
def _get_anthropic_claude_token_cost(
prompt_tokens: int, completion_tokens: int, model_id: Union[str, None]
) -> float:
if model_id:
# The model ID can be a cross-region (system-defined) inference profile ID,
# which has a prefix indicating the region (e.g., 'us', 'eu') but
# shares the same token costs as the "base model".
# By extracting the "base model ID", by taking the last two segments
# of the model ID, we can map cross-region inference profile IDs to
# their corresponding cost entries.
base_model_id = model_id.split(".")[-2] + "." + model_id.split(".")[-1]
else:
base_model_id = None
"""Get the cost of tokens for the Claude model."""
if base_model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
raise ValueError(
f"Unknown model: {model_id}. Please provide a valid Anthropic model name."
"Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())
)
return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[base_model_id] + (
completion_tokens / 1000
) * MODEL_COST_PER_1K_OUTPUT_TOKENS[base_model_id]
class BedrockAnthropicTokenUsageCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks bedrock anthropic info."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
if "usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None
# compute tokens and cost for this request
token_usage = response.llm_output["usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
total_tokens = token_usage.get("total_tokens", 0)
model_id = response.llm_output.get("model_id", None)
total_cost = _get_anthropic_claude_token_cost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
model_id=model_id,
)
# update shared state behind lock
with self._lock:
self.total_cost += total_cost
self.total_tokens += total_tokens
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1
def __copy__(self) -> "BedrockAnthropicTokenUsageCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "BedrockAnthropicTokenUsageCallbackHandler":
"""Return a deep copy of the callback handler."""
return self

View File

@@ -0,0 +1,518 @@
from __future__ import annotations
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
return guard_import("clearml")
class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to ClearML.
Parameters:
job_type (str): The type of clearml task such as "inference", "testing" or "qc"
project_name (str): The clearml project name
tags (list): Tags to add to the task
task_name (str): Name of the clearml task
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics
stream_logs (bool): Whether to stream callback actions to ClearML
This handler will utilize the associated callback method and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to the ClearML console.
"""
def __init__(
self,
task_type: Optional[str] = "inference",
project_name: Optional[str] = "langchain_callback_demo",
tags: Optional[Sequence] = None,
task_name: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""Initialize callback handler."""
clearml = import_clearml()
spacy = import_spacy()
super().__init__()
self.task_type = task_type
self.project_name = project_name
self.tags = tags
self.task_name = task_name
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
# Check if ClearML task already exists (e.g. in pipeline)
if clearml.Task.current_task():
self.task = clearml.Task.current_task()
else:
self.task = clearml.Task.init(
task_type=self.task_type,
project_name=self.project_name,
tags=self.tags,
task_name=self.task_name,
output_uri=True,
)
self.logger = self.task.get_logger()
warning = (
"The clearml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
)
self.logger.report_text(warning, level=30, print_console=True)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.logger.report_text(prompt_resp)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(self.analyze_text(generation.text))
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.logger.report_text(generation_resp)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
else:
raise ValueError("Unexpected data format provided!")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.on_tool_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.on_tool_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def analyze_text(self, text: str) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
Returns:
`dict` containing the complexity metrics.
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(
text
),
"dale_chall_readability_score": textstat.dale_chall_readability_score(
text
),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if self.visualize and self.nlp and self.temp_dir.name is not None:
doc = self.nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
dep_output_path = Path(
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
)
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
ent_output_path = Path(
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
)
ent_output_path.open("w", encoding="utf-8").write(ent_out)
self.logger.report_media(
"Dependencies Plot", text, local_path=dep_output_path
)
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
base_df=on_llm_end_records_df,
base_df_fields=["step", "prompts"]
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
)
complexity_metrics_columns = []
visualizations_columns: List = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df,
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns,
{"step": "output_step", "text": "output"},
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
return session_analysis_df
def flush_tracker(
self,
name: Optional[str] = None,
langchain_asset: Any = None,
finish: bool = False,
) -> None:
"""Flush the tracker and setup the session.
Everything after this will be a new table.
Args:
name: Name of the performed session so far so it is identifiable
langchain_asset: The langchain asset to save.
finish: Whether to finish the run.
Returns:
None
"""
pd = import_pandas()
clearml = import_clearml()
# Log the action records
self.logger.report_table(
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
)
# Session analysis
session_analysis_df = self._create_session_analysis_df()
self.logger.report_table(
"Session Analysis", name, table_plot=session_analysis_df
)
if self.stream_logs:
self.logger.report_text(
{
"action_records": pd.DataFrame(self.action_records),
"session_analysis": session_analysis_df,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
try:
langchain_asset.save(langchain_asset_path)
# Create output model and connect it to the task
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except NotImplementedError as e:
print("Could not save model.") # noqa: T201
print(repr(e)) # noqa: T201
pass
# Cleanup after adding everything to ClearML
self.task.flush(wait_for_uploads=True)
self.temp_dir.cleanup()
self.temp_dir = tempfile.TemporaryDirectory()
self.reset_callback_meta()
if finish:
self.task.close()

View File

@@ -0,0 +1,639 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import guard_import
import langchain_community
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
LANGCHAIN_MODEL_NAME = "langchain-model"
def import_comet_ml() -> Any:
"""Import comet_ml and raise an error if it is not installed."""
return guard_import("comet_ml")
def _get_experiment(
workspace: Optional[str] = None, project_name: Optional[str] = None
) -> Any:
comet_ml = import_comet_ml()
experiment = comet_ml.Experiment(
workspace=workspace,
project_name=project_name,
)
return experiment
def _fetch_text_complexity_metrics(text: str) -> dict:
textstat = import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
return text_complexity_metrics
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
pd = import_pandas()
metrics_df = pd.DataFrame(metrics)
metrics_summary = metrics_df.describe()
return metrics_summary.to_dict()
class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Comet.
Parameters:
job_type (str): The type of comet_ml task such as "inference",
"testing" or "qc"
project_name (str): The comet_ml project name
tags (list): Tags to add to the task
task_name (str): Name of the comet_ml task
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics
stream_logs (bool): Whether to stream callback actions to Comet
This handler will utilize the associated callback method and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to Comet.
"""
def __init__(
self,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
stream_logs: bool = True,
) -> None:
"""Initialize callback handler."""
self.comet_ml = import_comet_ml()
super().__init__()
self.task_type = task_type
self.workspace = workspace
self.project_name = project_name
self.tags = tags
self.visualizations = visualizations
self.complexity_metrics = complexity_metrics
self.custom_metrics = custom_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
self.experiment = _get_experiment(workspace, project_name)
self.experiment.log_other("Created from", "langchain")
if tags:
self.experiment.add_tags(tags)
self.name = name
if self.name:
self.experiment.set_name(self.name)
warning = (
"The comet_ml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/comet-ml/issue-tracking/issues with the tag "
"`langchain`."
)
self.comet_ml.LOGGER.warning(warning)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
if self.visualizations:
spacy = import_spacy()
self.nlp = spacy.load("en_core_web_sm")
else:
self.nlp = None
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_start"})
metadata.update(flatten_dict(serialized))
metadata.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(metadata)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self._log_stream(prompt, metadata, self.step)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.action_records.append(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_end"})
metadata.update(flatten_dict(response.llm_output or {}))
metadata.update(self.get_custom_callback_meta())
output_complexity_metrics = []
output_custom_metrics = []
for prompt_idx, generations in enumerate(response.generations):
for gen_idx, generation in enumerate(generations):
text = generation.text
generation_resp = deepcopy(metadata)
generation_resp.update(flatten_dict(generation.dict()))
complexity_metrics = self._get_complexity_metrics(text)
if complexity_metrics:
output_complexity_metrics.append(complexity_metrics)
generation_resp.update(complexity_metrics)
custom_metrics = self._get_custom_metrics(
generation, prompt_idx, gen_idx
)
if custom_metrics:
output_custom_metrics.append(custom_metrics)
generation_resp.update(custom_metrics)
if self.stream_logs:
self._log_stream(text, metadata, self.step)
self.action_records.append(generation_resp)
self.on_llm_end_records.append(generation_resp)
self._log_text_metrics(output_complexity_metrics, step=self.step)
self._log_text_metrics(output_custom_metrics, step=self.step)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for chain_input_key, chain_input_val in inputs.items():
if isinstance(chain_input_val, str):
input_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_input_val, resp, self.step)
input_resp.update({chain_input_key: chain_input_val})
self.action_records.append(input_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Input Value for {chain_input_key} will not be logged"
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end"})
resp.update(self.get_custom_callback_meta())
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, str):
output_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_output_val, resp, self.step)
output_resp.update({chain_output_key: chain_output_val})
self.action_records.append(output_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Output Value for {chain_output_key} will not be logged"
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(input_str, resp, self.step)
resp.update({"input_str": input_str})
self.action_records.append(resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end"})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(output, resp, self.step)
resp.update({"output": output})
self.action_records.append(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text"})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(text, resp, self.step)
resp.update({"text": text})
self.action_records.append(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
output = finish.return_values["output"]
log = finish.log
resp.update({"action": "on_agent_finish", "log": log})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(output, resp, self.step)
resp.update({"output": output})
self.action_records.append(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
tool = action.tool
tool_input = str(action.tool_input)
log = action.log
resp = self._init_resp()
resp.update({"action": "on_agent_action", "log": log, "tool": tool})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(tool_input, resp, self.step)
resp.update({"tool_input": tool_input})
self.action_records.append(resp)
def _get_complexity_metrics(self, text: str) -> dict:
"""Compute text complexity metrics using textstat.
Parameters:
text (str): The text to analyze.
Returns:
`dict` containing the complexity metrics.
"""
resp = {}
if self.complexity_metrics:
text_complexity_metrics = _fetch_text_complexity_metrics(text)
resp.update(text_complexity_metrics)
return resp
def _get_custom_metrics(
self, generation: Generation, prompt_idx: int, gen_idx: int
) -> dict:
"""Compute Custom Metrics for an LLM Generated Output
Args:
generation (LLMResult): Output generation from an LLM
prompt_idx (int): List index of the input prompt
gen_idx (int): List index of the generated output
Returns:
dict: `dict` containing the custom metrics.
"""
resp = {}
if self.custom_metrics:
custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
resp.update(custom_metrics)
return resp
def flush_tracker(
self,
langchain_asset: Any = None,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = "comet-langchain-demo",
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
finish: bool = False,
reset: bool = False,
) -> None:
"""Flush the tracker and setup the session.
Everything after this will be a new table.
Args:
name: Name of the performed session so far so it is identifiable
langchain_asset: The langchain asset to save.
finish: Whether to finish the run.
Returns:
None
"""
self._log_session(langchain_asset)
if langchain_asset:
try:
self._log_model(langchain_asset)
except Exception:
self.comet_ml.LOGGER.error(
"Failed to export agent or LLM to Comet",
exc_info=True,
extra={"show_traceback": True},
)
if finish:
self.experiment.end()
if reset:
self._reset(
task_type,
workspace,
project_name,
tags,
name,
visualizations,
complexity_metrics,
custom_metrics,
)
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
self.experiment.log_text(prompt, metadata=metadata, step=step)
def _log_model(self, langchain_asset: Any) -> None:
model_parameters = self._get_llm_parameters(langchain_asset)
self.experiment.log_parameters(model_parameters, prefix="model")
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
try:
if hasattr(langchain_asset, "save"):
langchain_asset.save(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
except (ValueError, AttributeError, NotImplementedError) as e:
if hasattr(langchain_asset, "save_agent"):
langchain_asset.save_agent(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
else:
self.comet_ml.LOGGER.error(
f"{e}"
" Could not save Langchain Asset "
f"for {langchain_asset.__class__.__name__}"
)
def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
try:
llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
# Log the cleaned dataframe as a table
self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
metadata = {"langchain_version": str(langchain_community.__version__)}
# Log the langchain low-level records as a JSON file directly
self.experiment.log_asset_data(
self.action_records, "langchain-action_records.json", metadata=metadata
)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
self._log_visualizations(llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log visualizations to Comet",
exc_info=True,
extra={"show_traceback": True},
)
def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
if not metrics:
return
metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
for key, value in metrics_summary.items():
self.experiment.log_metrics(value, prefix=key, step=step)
def _log_visualizations(self, session_df: Any) -> None:
if not (self.visualizations and self.nlp):
return
spacy = import_spacy()
prompts = session_df["prompts"].tolist()
outputs = session_df["text"].tolist()
for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
doc = self.nlp(output)
sentence_spans = list(doc.sents)
for visualization in self.visualizations:
try:
html = spacy.displacy.render(
sentence_spans,
style=visualization,
options={"compact": True},
jupyter=False,
page=True,
)
self.experiment.log_asset_data(
html,
name=f"langchain-viz-{visualization}-{idx}.html",
metadata={"prompt": prompt},
step=idx,
)
except Exception as e:
self.comet_ml.LOGGER.warning(
e, exc_info=True, extra={"show_traceback": True}
)
return
def _reset(
self,
task_type: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
) -> None:
_task_type = task_type if task_type else self.task_type
_workspace = workspace if workspace else self.workspace
_project_name = project_name if project_name else self.project_name
_tags = tags if tags else self.tags
_name = name if name else self.name
_visualizations = visualizations if visualizations else self.visualizations
_complexity_metrics = (
complexity_metrics if complexity_metrics else self.complexity_metrics
)
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
self.__init__( # type: ignore[misc]
task_type=_task_type,
workspace=_workspace,
project_name=_project_name,
tags=_tags,
name=_name,
visualizations=_visualizations,
complexity_metrics=_complexity_metrics,
custom_metrics=_custom_metrics,
)
self.reset_callback_meta()
self.temp_dir = tempfile.TemporaryDirectory()
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
pd = import_pandas()
llm_parameters = self._get_llm_parameters(langchain_asset)
num_generations_per_prompt = llm_parameters.get("n", 1)
llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
# Repeat each input row based on the number of outputs generated per prompt
llm_start_records_df = llm_start_records_df.loc[
llm_start_records_df.index.repeat(num_generations_per_prompt)
].reset_index(drop=True)
llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_session_df = pd.merge(
llm_start_records_df,
llm_end_records_df,
left_index=True,
right_index=True,
suffixes=["_llm_start", "_llm_end"],
)
return llm_session_df
def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
if not langchain_asset:
return {}
try:
if hasattr(langchain_asset, "agent"):
llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm_chain"):
llm_parameters = langchain_asset.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm"):
llm_parameters = langchain_asset.llm.dict()
else:
llm_parameters = langchain_asset.dict()
except Exception:
return {}
return llm_parameters

View File

@@ -0,0 +1,183 @@
# flake8: noqa
import os
import warnings
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into deepeval.
Args:
implementation_name: name of the `implementation` in deepeval
metrics: A list of metrics
Raises:
ImportError: if the `deepeval` package is not installed.
Examples:
>>> from langchain_community.llms import OpenAI
>>> from langchain_community.callbacks import DeepEvalCallbackHandler
>>> from deepeval.metrics import AnswerRelevancy
>>> metric = AnswerRelevancy(minimum_score=0.3)
>>> deepeval_callback = DeepEvalCallbackHandler(
... implementation_name="exampleImplementation",
... metrics=[metric],
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[deepeval_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best evaluation tool out there? (no bias at all)",
... ])
"Deepeval, no doubt about it."
"""
REPO_URL: str = "https://github.com/confident-ai/deepeval"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.confident-ai.com" # noqa: E501
def __init__(
self,
metrics: List[Any],
implementation_name: Optional[str] = None,
) -> None:
"""Initializes the `deepevalCallbackHandler`.
Args:
implementation_name: Name of the implementation you want.
metrics: What metrics do you want to track?
Raises:
ImportError: if the `deepeval` package is not installed.
ConnectionError: if the connection to deepeval fails.
"""
super().__init__()
# Import deepeval (not via `import_deepeval` to keep hints in IDEs)
try:
import deepeval # ignore: F401,I001
except ImportError:
raise ImportError(
"""To use the deepeval callback manager you need to have the
`deepeval` Python package installed. Please install it with
`pip install deepeval`"""
)
if os.path.exists(".deepeval"):
warnings.warn(
"""You are currently not logging anything to the dashboard, we
recommend using `deepeval login`."""
)
# Set the deepeval variables
self.implementation_name = implementation_name
self.metrics = metrics
warnings.warn(
(
"The `DeepEvalCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Store the prompts"""
self.prompts = prompts
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to deepeval when an LLM ends."""
from deepeval.metrics.answer_relevancy import AnswerRelevancy
from deepeval.metrics.bias_classifier import UnBiasedMetric
from deepeval.metrics.metric import Metric
from deepeval.metrics.toxic_classifier import NonToxicMetric
for metric in self.metrics:
for i, generation in enumerate(response.generations):
# Here, we only measure the first generation's output
output = generation[0].text
query = self.prompts[i]
if isinstance(metric, AnswerRelevancy):
result = metric.measure(
output=output,
query=query,
)
print(f"Answer Relevancy: {result}") # noqa: T201
elif isinstance(metric, UnBiasedMetric):
score = metric.measure(output)
print(f"Bias Score: {score}") # noqa: T201
elif isinstance(metric, NonToxicMetric):
score = metric.measure(output)
print(f"Toxic Score: {score}") # noqa: T201
else:
raise ValueError(
f"""Metric {metric.__name__} is not supported by deepeval
callbacks."""
)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when chain starts"""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass

View File

@@ -0,0 +1,192 @@
"""Callback handler for Context AI"""
import os
from typing import Any, Dict, List
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
def import_context() -> Any:
"""Import the `getcontext` package."""
return (
guard_import("getcontext", pip_name="python-context"),
guard_import("getcontext.token", pip_name="python-context").Credential,
guard_import(
"getcontext.generated.models", pip_name="python-context"
).Conversation,
guard_import("getcontext.generated.models", pip_name="python-context").Message,
guard_import(
"getcontext.generated.models", pip_name="python-context"
).MessageRole,
guard_import("getcontext.generated.models", pip_name="python-context").Rating,
)
class ContextCallbackHandler(BaseCallbackHandler):
"""Callback Handler that records transcripts to the Context service.
(https://context.ai).
Keyword Args:
token (optional): The token with which to authenticate requests to Context.
Visit https://with.context.ai/settings to generate a token.
If not provided, the value of the `CONTEXT_TOKEN` environment
variable will be used.
Raises:
ImportError: if the `context-python` package is not installed.
Chat Example:
>>> from langchain_community.llms import ChatOpenAI
>>> from langchain_community.callbacks import ContextCallbackHandler
>>> context_callback = ContextCallbackHandler(
... token="<CONTEXT_TOKEN_HERE>",
... )
>>> chat = ChatOpenAI(
... temperature=0,
... headers={"user_id": "123"},
... callbacks=[context_callback],
... openai_api_key="API_KEY_HERE",
... )
>>> messages = [
... SystemMessage(content="You translate English to French."),
... HumanMessage(content="I love programming with LangChain."),
... ]
>>> chat.invoke(messages)
Chain Example:
>>> from langchain_classic.chains import LLMChain
>>> from langchain_community.chat_models import ChatOpenAI
>>> from langchain_community.callbacks import ContextCallbackHandler
>>> context_callback = ContextCallbackHandler(
... token="<CONTEXT_TOKEN_HERE>",
... )
>>> human_message_prompt = HumanMessagePromptTemplate(
... prompt=PromptTemplate(
... template="What is a good name for a company that makes {product}?",
... input_variables=["product"],
... ),
... )
>>> chat_prompt_template = ChatPromptTemplate.from_messages(
... [human_message_prompt]
... )
>>> callback = ContextCallbackHandler(token)
>>> # Note: the same callback object must be shared between the
... LLM and the chain.
>>> chat = ChatOpenAI(temperature=0.9, callbacks=[callback])
>>> chain = LLMChain(
... llm=chat,
... prompt=chat_prompt_template,
... callbacks=[callback]
... )
>>> chain.run("colorful socks")
"""
def __init__(self, token: str = "", verbose: bool = False, **kwargs: Any) -> None:
(
self.context,
self.credential,
self.conversation_model,
self.message_model,
self.message_role_model,
self.rating_model,
) = import_context()
token = token or os.environ.get("CONTEXT_TOKEN") or ""
self.client = self.context.ContextAPI(credential=self.credential(token))
self.chain_run_id = None
self.llm_model = None
self.messages: List[Any] = []
self.metadata: Dict[str, str] = {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
**kwargs: Any,
) -> Any:
"""Run when the chat model is started."""
llm_model = kwargs.get("invocation_params", {}).get("model", None)
if llm_model is not None:
self.metadata["model"] = llm_model
if len(messages) == 0:
return
for message in messages[0]:
role = self.message_role_model.SYSTEM
if message.type == "human":
role = self.message_role_model.USER
elif message.type == "system":
role = self.message_role_model.SYSTEM
elif message.type == "ai":
role = self.message_role_model.ASSISTANT
self.messages.append(
self.message_model(
message=message.content,
role=role,
)
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends."""
if len(response.generations) == 0 or len(response.generations[0]) == 0:
return
if not self.chain_run_id:
generation = response.generations[0][0]
self.messages.append(
self.message_model(
message=generation.text,
role=self.message_role_model.ASSISTANT,
)
)
self._log_conversation()
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts."""
self.chain_run_id = kwargs.get("run_id", None)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends."""
self.messages.append(
self.message_model(
message=outputs["text"],
role=self.message_role_model.ASSISTANT,
)
)
self._log_conversation()
self.chain_run_id = None
def _log_conversation(self) -> None:
"""Log the conversation to the context API."""
if len(self.messages) == 0:
return
self.client.log.conversation_upsert(
body={
"conversation": self.conversation_model(
messages=self.messages,
metadata=self.metadata,
)
}
)
self.messages = []
self.metadata = {}

View File

@@ -0,0 +1,335 @@
import time
from typing import Any, Dict, List, Optional
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import import_pandas
# Define constants
# LLMResult keys
TOKEN_USAGE = "token_usage"
TOTAL_TOKENS = "total_tokens"
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
RUN_ID = "run_id"
MODEL_NAME = "model_name"
GOOD = "good"
BAD = "bad"
NEUTRAL = "neutral"
SUCCESS = "success"
FAILURE = "failure"
# Default values
DEFAULT_MAX_TOKEN = 65536
DEFAULT_MAX_DURATION = 120000
# Fiddler specific constants
PROMPT = "prompt"
RESPONSE = "response"
CONTEXT = "context"
DURATION = "duration"
FEEDBACK = "feedback"
LLM_STATUS = "llm_status"
FEEDBACK_POSSIBLE_VALUES = [GOOD, BAD, NEUTRAL]
# Define a dataset dictionary
_dataset_dict = {
PROMPT: ["fiddler"] * 10,
RESPONSE: ["fiddler"] * 10,
CONTEXT: ["fiddler"] * 10,
FEEDBACK: ["good"] * 10,
LLM_STATUS: ["success"] * 10,
MODEL_NAME: ["fiddler"] * 10,
RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10,
TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
PROMPT_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
COMPLETION_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
DURATION: [1, DEFAULT_MAX_DURATION] * 5,
}
def import_fiddler() -> Any:
"""Import the fiddler python package and raise an error if it is not installed."""
return guard_import("fiddler", pip_name="fiddler-client")
# First, define custom callback handler implementations
class FiddlerCallbackHandler(BaseCallbackHandler):
def __init__(
self,
url: str,
org: str,
project: str,
model: str,
api_key: str,
) -> None:
"""
Initialize Fiddler callback handler.
Args:
url: Fiddler URL (e.g. https://demo.fiddler.ai).
Make sure to include the protocol (http/https).
org: Fiddler organization id
project: Fiddler project name to publish events to
model: Fiddler model name to publish events to
api_key: Fiddler authentication token
"""
super().__init__()
# Initialize Fiddler client and other necessary properties
self.fdl = import_fiddler()
self.pd = import_pandas()
self.url = url
self.org = org
self.project = project
self.model = model
self.api_key = api_key
self._df = self.pd.DataFrame(_dataset_dict)
self.run_id_prompts: Dict[UUID, List[str]] = {}
self.run_id_response: Dict[UUID, List[str]] = {}
self.run_id_starttime: Dict[UUID, int] = {}
# Initialize Fiddler client here
self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key)
if self.project not in self.fiddler_client.get_project_names():
print( # noqa: T201
f"adding project {self.project}.This only has to be done once."
)
try:
self.fiddler_client.add_project(self.project)
except Exception as e:
print( # noqa: T201
f"Error adding project {self.project}:"
"{e}. Fiddler integration will not work."
)
raise e
dataset_info = self.fdl.DatasetInfo.from_dataframe(
self._df, max_inferred_cardinality=0
)
# Set feedback column to categorical
for i in range(len(dataset_info.columns)):
if dataset_info.columns[i].name == FEEDBACK:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = FEEDBACK_POSSIBLE_VALUES
elif dataset_info.columns[i].name == LLM_STATUS:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = [SUCCESS, FAILURE]
if self.model not in self.fiddler_client.get_model_names(self.project):
if self.model not in self.fiddler_client.get_dataset_names(self.project):
print( # noqa: T201
f"adding dataset {self.model} to project {self.project}."
"This only has to be done once."
)
try:
self.fiddler_client.upload_dataset(
project_id=self.project,
dataset_id=self.model,
dataset={"train": self._df},
info=dataset_info,
)
except Exception as e:
print( # noqa: T201
f"Error adding dataset {self.model}: {e}."
"Fiddler integration will not work."
)
raise e
model_info = self.fdl.ModelInfo.from_dataset_info(
dataset_info=dataset_info,
dataset_id="train",
model_task=self.fdl.ModelTask.LLM,
features=[PROMPT, CONTEXT, RESPONSE],
target=FEEDBACK,
metadata_cols=[
RUN_ID,
TOTAL_TOKENS,
PROMPT_TOKENS,
COMPLETION_TOKENS,
MODEL_NAME,
DURATION,
],
custom_features=self.custom_features,
)
print( # noqa: T201
f"adding model {self.model} to project {self.project}."
"This only has to be done once."
)
try:
self.fiddler_client.add_model(
project_id=self.project,
dataset_id=self.model,
model_id=self.model,
model_info=model_info,
)
except Exception as e:
print( # noqa: T201
f"Error adding model {self.model}: {e}."
"Fiddler integration will not work."
)
raise e
@property
def custom_features(self) -> list:
"""
Define custom features for the model to automatically enrich the data with.
Here, we enable the following enrichments:
- Automatic Embedding generation for prompt and response
- Text Statistics such as:
- Automated Readability Index
- Coleman Liau Index
- Dale Chall Readability Score
- Difficult Words
- Flesch Reading Ease
- Flesch Kincaid Grade
- Gunning Fog
- Linsear Write Formula
- PII - Personal Identifiable Information
- Sentiment Analysis
"""
return [
self.fdl.Enrichment(
name="Prompt Embedding",
enrichment="embedding",
columns=[PROMPT],
),
self.fdl.TextEmbedding(
name="Prompt CF",
source_column=PROMPT,
column="Prompt Embedding",
),
self.fdl.Enrichment(
name="Response Embedding",
enrichment="embedding",
columns=[RESPONSE],
),
self.fdl.TextEmbedding(
name="Response CF",
source_column=RESPONSE,
column="Response Embedding",
),
self.fdl.Enrichment(
name="Text Statistics",
enrichment="textstat",
columns=[PROMPT, RESPONSE],
config={
"statistics": [
"automated_readability_index",
"coleman_liau_index",
"dale_chall_readability_score",
"difficult_words",
"flesch_reading_ease",
"flesch_kincaid_grade",
"gunning_fog",
"linsear_write_formula",
]
},
),
self.fdl.Enrichment(
name="PII",
enrichment="pii",
columns=[PROMPT, RESPONSE],
),
self.fdl.Enrichment(
name="Sentiment",
enrichment="sentiment",
columns=[PROMPT, RESPONSE],
),
]
def _publish_events(
self,
run_id: UUID,
prompt_responses: List[str],
duration: int,
llm_status: str,
model_name: Optional[str] = "",
token_usage_dict: Optional[Dict[str, Any]] = None,
) -> None:
"""
Publish events to fiddler
"""
prompt_count = len(self.run_id_prompts[run_id])
df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
RUN_ID: [str(run_id)] * prompt_count,
DURATION: [duration] * prompt_count,
LLM_STATUS: [llm_status] * prompt_count,
MODEL_NAME: [model_name] * prompt_count,
}
)
if token_usage_dict:
for key, value in token_usage_dict.items():
df[key] = [value] * prompt_count if isinstance(value, int) else value
try:
if df.shape[0] > 1:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
else:
df_dict = df.to_dict(orient="records")
self.fiddler_client.publish_event(
self.project, self.model, event=df_dict[0]
)
except Exception as e:
print( # noqa: T201
f"Error publishing events to fiddler: {e}. continuing..."
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
run_id = kwargs[RUN_ID]
self.run_id_prompts[run_id] = prompts
self.run_id_starttime[run_id] = int(time.time() * 1000)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
flattened_llmresult = response.flatten()
run_id = kwargs[RUN_ID]
run_duration = int(time.time() * 1000) - self.run_id_starttime[run_id]
model_name = ""
token_usage_dict = {}
if isinstance(response.llm_output, dict):
token_usage_dict = {
k: v
for k, v in response.llm_output.items()
if k in [TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS]
}
model_name = response.llm_output.get(MODEL_NAME, "")
prompt_responses = [
llmresult.generations[0][0].text for llmresult in flattened_llmresult
]
self._publish_events(
run_id,
prompt_responses,
run_duration,
SUCCESS,
model_name,
token_usage_dict,
)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
run_id = kwargs[RUN_ID]
duration = int(time.time() * 1000) - self.run_id_starttime[run_id]
self._publish_events(
run_id, [""] * len(self.run_id_prompts[run_id]), duration, FAILURE
)

View File

@@ -0,0 +1,364 @@
"""FlyteKit callback handler."""
from __future__ import annotations
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
if TYPE_CHECKING:
import flytekit
from flytekitplugins.deck import renderer
logger = logging.getLogger(__name__)
def import_flytekit() -> Tuple[flytekit, renderer]:
"""Import flytekit and flytekitplugins-deck-standard."""
return (
guard_import("flytekit"),
guard_import(
"flytekitplugins.deck", pip_name="flytekitplugins-deck-standard"
).renderer,
)
def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization.
Returns:
`dict` containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
if textstat is not None:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback handler that is used within a Flyte task."""
def __init__(self) -> None:
"""Initialize callback handler."""
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Textstat."
)
spacy = None
try:
spacy = import_spacy()
except ImportError:
logger.warning(
"Spacy library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Spacy."
)
super().__init__()
self.nlp = None
if spacy:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
" for certain metrics. To download,"
" run the following command in your terminal:"
" `python -m spacy download en_core_web_sm`"
)
self.table_renderer = renderer.TableRenderer
self.markdown_renderer = renderer.MarkdownRenderer
self.deck = flytekit.Deck(
"LangChain Metrics",
self.markdown_renderer().to_html("## LangChain Metrics"),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
prompt_responses = []
for prompt in prompts:
prompt_responses.append(prompt)
resp.update({"prompts": prompt_responses})
self.deck.append(self.markdown_renderer().to_html("### LLM Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### LLM End"))
self.deck.append(self.table_renderer().to_html(self.pandas.DataFrame([resp])))
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
if self.nlp or self.textstat:
generation_resp.update(
analyze_text(
generation.text, nlp=self.nlp, textstat=self.textstat
)
)
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
)
self.deck.append(
self.markdown_renderer().to_html("#### Text Complexity Metrics")
)
self.deck.append(
self.table_renderer().to_html(
self.pandas.DataFrame([complexity_metrics])
)
+ "\n"
)
dependency_tree = generation_resp["dependency_tree"]
self.deck.append(
self.markdown_renderer().to_html("#### Dependency Tree")
)
self.deck.append(dependency_tree)
entities = generation_resp["entities"]
self.deck.append(self.markdown_renderer().to_html("#### Entities"))
self.deck.append(entities)
else:
self.deck.append(
self.markdown_renderer().to_html("#### Generated Response")
)
self.deck.append(self.markdown_renderer().to_html(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.deck.append(self.markdown_renderer().to_html("### Chain Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([input_resp])) + "\n"
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Chain End"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Tool Start"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.step += 1
self.tool_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Tool End"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### On Text"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Finish"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.deck.append(self.markdown_renderer().to_html("### Agent Action"))
self.deck.append(
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)

View File

@@ -0,0 +1,88 @@
from typing import Any, Awaitable, Callable, Dict, Optional
from uuid import UUID
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
def _default_approve(_input: str) -> bool:
msg = (
"Do you approve of the following input? "
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
)
msg += "\n\n" + _input + "\n"
resp = input(msg)
return resp.lower() in ("yes", "y")
async def _adefault_approve(_input: str) -> bool:
msg = (
"Do you approve of the following input? "
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no."
)
msg += "\n\n" + _input + "\n"
resp = input(msg)
return resp.lower() in ("yes", "y")
def _default_true(_: Dict[str, Any]) -> bool:
return True
class HumanRejectedException(Exception):
"""Exception to raise when a person manually review and rejects a value."""
class HumanApprovalCallbackHandler(BaseCallbackHandler):
"""Callback for manually validating values."""
raise_error: bool = True
def __init__(
self,
approve: Callable[[Any], bool] = _default_approve,
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
):
self._approve = approve
self._should_check = should_check
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if self._should_check(serialized) and not self._approve(input_str):
raise HumanRejectedException(
f"Inputs {input_str} to tool {serialized} were rejected."
)
class AsyncHumanApprovalCallbackHandler(AsyncCallbackHandler):
"""Asynchronous callback for manually validating values."""
raise_error: bool = True
def __init__(
self,
approve: Callable[[Any], Awaitable[bool]] = _adefault_approve,
should_check: Callable[[Dict[str, Any]], bool] = _default_true,
):
self._approve = approve
self._should_check = should_check
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if self._should_check(serialized) and not await self._approve(input_str):
raise HumanRejectedException(
f"Inputs {input_str} to tool {serialized} were rejected."
)

View File

@@ -0,0 +1,251 @@
import time
from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.utils import guard_import
def import_infino() -> Any:
"""Import the infino client."""
return guard_import("infinopy").InfinoClient()
def import_tiktoken() -> Any:
"""Import tiktoken for counting tokens for OpenAI models."""
return guard_import("tiktoken")
def get_num_tokens(string: str, openai_model_name: str) -> int:
"""Calculate num tokens for OpenAI with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/main
/examples/How_to_count_tokens_with_tiktoken.ipynb
"""
tiktoken = import_tiktoken()
encoding = tiktoken.encoding_for_model(openai_model_name)
num_tokens = len(encoding.encode(string))
return num_tokens
class InfinoCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Infino."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
verbose: bool = False,
) -> None:
# Set Infino client
self.client = import_infino()
self.model_id = model_id
self.model_version = model_version
self.verbose = verbose
self.is_chat_openai_model = False
self.chat_openai_model_name = "gpt-3.5-turbo"
def _send_to_infino(
self,
key: str,
value: Any,
is_ts: bool = True,
) -> None:
"""Send the key-value to Infino.
Parameters:
key (str): the key to send to Infino.
value (Any): the value to send to Infino.
is_ts (bool): if True, the value is part of a time series, else it
is sent as a log message.
"""
payload = {
"date": int(time.time()),
key: value,
"labels": {
"model_id": self.model_id,
"model_version": self.model_version,
},
}
if self.verbose:
print(f"Tracking {key} with Infino: {payload}") # noqa: T201
# Append to Infino time series only if is_ts is True, otherwise
# append to Infino log.
if is_ts:
self.client.append_ts(payload)
else:
self.client.append_log(payload)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Log the prompts to Infino, and set start time and error flag."""
for prompt in prompts:
self._send_to_infino("prompt", prompt, is_ts=False)
# Set the error flag to indicate no error (this will get overridden
# in on_llm_error if an error occurs).
self.error = 0
# Set the start time (so that we can calculate the request
# duration in on_llm_end).
self.start_time = time.time()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log the latency, error, token usage, and response to Infino."""
# Calculate and track the request latency.
self.end_time = time.time()
duration = self.end_time - self.start_time
self._send_to_infino("latency", duration)
# Track success or error flag.
self._send_to_infino("error", self.error)
# Track prompt response.
for generations in response.generations:
for generation in generations:
self._send_to_infino("prompt_response", generation.text, is_ts=False)
# Track token usage (for non-chat models).
if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
token_usage = response.llm_output["token_usage"]
if token_usage is not None:
prompt_tokens = token_usage["prompt_tokens"]
total_tokens = token_usage["total_tokens"]
completion_tokens = token_usage["completion_tokens"]
self._send_to_infino("prompt_tokens", prompt_tokens)
self._send_to_infino("total_tokens", total_tokens)
self._send_to_infino("completion_tokens", completion_tokens)
# Track completion token usage (for openai chat models).
if self.is_chat_openai_model:
messages = " ".join(
cast(str, cast(ChatGeneration, generation).message.content)
for generation in generations
)
completion_tokens = get_num_tokens(
messages, openai_model_name=self.chat_openai_model_name
)
self._send_to_infino("completion_tokens", completion_tokens)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Set the error flag."""
self.error = 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Need to log the error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
# Currently, for chat models, we only support input prompts for ChatOpenAI.
# Check if this model is a ChatOpenAI model.
values = serialized.get("id")
if values:
for value in values:
if value == "ChatOpenAI":
self.is_chat_openai_model = True
break
# Track prompt tokens for ChatOpenAI model.
if self.is_chat_openai_model:
invocation_params = kwargs.get("invocation_params")
if invocation_params:
model_name = invocation_params.get("model_name")
if model_name:
self.chat_openai_model_name = model_name
prompt_tokens = 0
for message_list in messages:
message_string = " ".join(
cast(str, msg.content) for msg in message_list
)
num_tokens = get_num_tokens(
message_string,
openai_model_name=self.chat_openai_model_name,
)
prompt_tokens += num_tokens
self._send_to_infino("prompt_tokens", prompt_tokens)
if self.verbose:
print( # noqa: T201
f"on_chat_model_start: is_chat_openai_model= \
{self.is_chat_openai_model}, \
chat_openai_model_name={self.chat_openai_model_name}"
)
# Send the prompt to infino
prompt = " ".join(
cast(str, msg.content) for sublist in messages for msg in sublist
)
self._send_to_infino("prompt", prompt, is_ts=False)
# Set the error flag to indicate no error (this will get overridden
# in on_llm_error if an error occurs).
self.error = 0
# Set the start time (so that we can calculate the request
# duration in on_llm_end).
self.start_time = time.time()

View File

@@ -0,0 +1,390 @@
import os
import warnings
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage, ChatMessage
from langchain_core.outputs import Generation, LLMResult
class LabelStudioMode(Enum):
"""Label Studio mode enumerator."""
PROMPT = "prompt"
CHAT = "chat"
def get_default_label_configs(
mode: Union[str, LabelStudioMode],
) -> Tuple[str, LabelStudioMode]:
"""Get default Label Studio configs for the given mode.
Parameters:
mode: Label Studio mode ("prompt" or "chat")
Returns: Tuple of Label Studio config and mode
"""
_default_label_configs = {
LabelStudioMode.PROMPT.value: """
<View>
<Style>
.prompt-box {
background-color: white;
border-radius: 10px;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
padding: 20px;
}
</Style>
<View className="root">
<View className="prompt-box">
<Text name="prompt" value="$prompt"/>
</View>
<TextArea name="response" toName="prompt"
maxSubmissions="1" editable="true"
required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="prompt"/>
</View>""",
LabelStudioMode.CHAT.value: """
<View>
<View className="root">
<Paragraphs name="dialogue"
value="$prompt"
layout="dialogue"
textKey="content"
nameKey="role"
granularity="sentence"/>
<Header value="Final response:"/>
<TextArea name="response" toName="dialogue"
maxSubmissions="1" editable="true"
required="true"/>
</View>
<Header value="Rate the response:"/>
<Rating name="rating" toName="dialogue"/>
</View>""",
}
if isinstance(mode, str):
mode = LabelStudioMode(mode)
return _default_label_configs[mode.value], mode
class LabelStudioCallbackHandler(BaseCallbackHandler):
"""Label Studio callback handler.
Provides the ability to send predictions to Label Studio
for human evaluation, feedback and annotation.
Parameters:
api_key: Label Studio API key
url: Label Studio URL
project_id: Label Studio project ID
project_name: Label Studio project name
project_config: Label Studio project config (XML)
mode: Label Studio mode ("prompt" or "chat")
Examples:
>>> from langchain_community.llms import OpenAI
>>> from langchain_community.callbacks import LabelStudioCallbackHandler
>>> handler = LabelStudioCallbackHandler(
... api_key='<your_key_here>',
... url='http://localhost:8080',
... project_name='LangChain-%Y-%m-%d',
... mode='prompt'
... )
>>> llm = OpenAI(callbacks=[handler])
>>> llm.invoke('Tell me a story about a dog.')
"""
DEFAULT_PROJECT_NAME: str = "LangChain-%Y-%m-%d"
def __init__(
self,
api_key: Optional[str] = None,
url: Optional[str] = None,
project_id: Optional[int] = None,
project_name: str = DEFAULT_PROJECT_NAME,
project_config: Optional[str] = None,
mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
):
super().__init__()
# Import LabelStudio SDK
try:
import label_studio_sdk as ls
except ImportError:
raise ImportError(
f"You're using {self.__class__.__name__} in your code,"
f" but you don't have the LabelStudio SDK "
f"Python package installed or upgraded to the latest version. "
f"Please run `pip install -U label-studio-sdk`"
f" before using this callback."
)
# Check if Label Studio API key is provided
if not api_key:
if os.getenv("LABEL_STUDIO_API_KEY"):
api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
else:
raise ValueError(
f"You're using {self.__class__.__name__} in your code,"
f" Label Studio API key is not provided. "
f"Please provide Label Studio API key: "
f"go to the Label Studio instance, navigate to "
f"Account & Settings -> Access Token and copy the key. "
f"Use the key as a parameter for the callback: "
f"{self.__class__.__name__}"
f"(label_studio_api_key='<your_key_here>', ...) or "
f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
)
self.api_key = api_key
if not url:
if os.getenv("LABEL_STUDIO_URL"):
url = os.getenv("LABEL_STUDIO_URL")
else:
warnings.warn(
f"Label Studio URL is not provided, "
f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
f"If you want to provide your own URL, use the parameter: "
f"{self.__class__.__name__}"
f"(label_studio_url='<your_url_here>', ...) "
f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
)
url = ls.LABEL_STUDIO_DEFAULT_URL
self.url = url
# Maps run_id to prompts
self.payload: Dict[str, Dict] = {}
self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
self.project_name = project_name
if project_config:
self.project_config = project_config
self.mode = None
else:
self.project_config, self.mode = get_default_label_configs(mode)
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
if self.project_id is not None:
self.ls_project = self.ls_client.get_project(int(self.project_id))
else:
project_title = datetime.today().strftime(self.project_name)
existing_projects = self.ls_client.get_projects(title=project_title)
if existing_projects:
self.ls_project = existing_projects[0]
self.project_id = self.ls_project.id
else:
self.ls_project = self.ls_client.create_project(
title=project_title, label_config=self.project_config
)
self.project_id = self.ls_project.id
self.parsed_label_config = self.ls_project.parsed_label_config
# Find the first TextArea tag
# "from_name", "to_name", "value" will be used to create predictions
self.from_name, self.to_name, self.value, self.input_type = (
None,
None,
None,
None,
)
for tag_name, tag_info in self.parsed_label_config.items():
if tag_info["type"] == "TextArea":
self.from_name = tag_name
self.to_name = tag_info["to_name"][0]
self.value = tag_info["inputs"][0]["value"]
self.input_type = tag_info["inputs"][0]["type"]
break
if not self.from_name:
error_message = (
f'Label Studio project "{self.project_name}" '
f"does not have a TextArea tag. "
f"Please add a TextArea tag to the project."
)
if self.mode == LabelStudioMode.PROMPT:
error_message += (
"\nHINT: go to project Settings -> "
"Labeling Interface -> Browse Templates"
' and select "Generative AI -> '
'Supervised Language Model Fine-tuning" template.'
)
else:
error_message += (
"\nHINT: go to project Settings -> "
"Labeling Interface -> Browse Templates"
" and check available templates under "
'"Generative AI" section.'
)
raise ValueError(error_message)
def add_prompts_generations(
self, run_id: str, generations: List[List[Generation]]
) -> None:
# Create tasks in Label Studio
tasks = []
prompts = self.payload[run_id]["prompts"]
model_version = (
self.payload[run_id]["kwargs"]
.get("invocation_params", {})
.get("model_name")
)
for prompt, generation in zip(prompts, generations):
tasks.append(
{
"data": {
self.value: prompt,
"run_id": run_id,
},
"predictions": [
{
"result": [
{
"from_name": self.from_name,
"to_name": self.to_name,
"type": "textarea",
"value": {"text": [g.text for g in generation]},
}
],
"model_version": model_version,
}
],
}
)
self.ls_project.import_tasks(tasks)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
**kwargs: Any,
) -> None:
"""Save the prompts in memory when an LLM starts."""
if self.input_type != "Text":
raise ValueError(
f'\nLabel Studio project "{self.project_name}" '
f"has an input type <{self.input_type}>. "
f'To make it work with the mode="chat", '
f"the input type should be <Text>.\n"
f"Read more here https://labelstud.io/tags/text"
)
run_id = str(kwargs["run_id"])
self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
def _get_message_role(self, message: BaseMessage) -> str:
"""Get the role of the message."""
if isinstance(message, ChatMessage):
return message.role
else:
return message.__class__.__name__
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Save the prompts in memory when an LLM starts."""
if self.input_type != "Paragraphs":
raise ValueError(
f'\nLabel Studio project "{self.project_name}" '
f"has an input type <{self.input_type}>. "
f'To make it work with the mode="chat", '
f"the input type should be <Paragraphs>.\n"
f"Read more here https://labelstud.io/tags/paragraphs"
)
prompts = []
for message_list in messages:
dialog = []
for message in message_list:
dialog.append(
{
"role": self._get_message_role(message),
"content": message.content,
}
)
prompts.append(dialog)
self.payload[str(run_id)] = {
"prompts": prompts,
"tags": tags,
"metadata": metadata,
"run_id": run_id,
"parent_run_id": parent_run_id,
"kwargs": kwargs,
}
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Create a new Label Studio task for each prompt and generation."""
run_id = str(kwargs["run_id"])
# Submit results to Label Studio
self.add_prompts_generations(run_id, response.generations)
# Pop current run from `self.runs`
self.payload.pop(run_id)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass

View File

@@ -0,0 +1,681 @@
import importlib.metadata
import logging
import os
import traceback
import warnings
from contextvars import ContextVar
from typing import Any, Dict, List, Union, cast
from uuid import UUID
import requests
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from packaging.version import parse
logger = logging.getLogger(__name__)
DEFAULT_API_URL = "https://app.llmonitor.com"
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
PARAMS_TO_CAPTURE = [
"temperature",
"top_p",
"top_k",
"stop",
"presence_penalty",
"frequence_penalty",
"seed",
"function_call",
"functions",
"tools",
"tool_choice",
"response_format",
"max_tokens",
"logit_bias",
]
class UserContextManager:
"""Context manager for LLMonitor user context."""
def __init__(self, user_id: str, user_props: Any = None) -> None:
user_ctx.set(user_id)
user_props_ctx.set(user_props)
def __enter__(self) -> Any:
pass
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
user_ctx.set(None)
user_props_ctx.set(None)
def identify(user_id: str, user_props: Any = None) -> UserContextManager:
"""Builds an LLMonitor UserContextManager
Parameters:
- `user_id`: The user id.
- `user_props`: The user properties.
Returns:
A context manager that sets the user context.
"""
return UserContextManager(user_id, user_props)
def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
if hasattr(obj, "to_json"):
return obj.to_json()
if isinstance(obj, dict):
return {key: _serialize(value) for key, value in obj.items()}
if isinstance(obj, list):
return [_serialize(element) for element in obj]
return obj
def _parse_input(raw_input: Any) -> Any:
if not raw_input:
return None
# if it's an array of 1, just parse the first element
if isinstance(raw_input, list) and len(raw_input) == 1:
return _parse_input(raw_input[0])
if not isinstance(raw_input, dict):
return _serialize(raw_input)
input_value = raw_input.get("input")
inputs_value = raw_input.get("inputs")
question_value = raw_input.get("question")
query_value = raw_input.get("query")
if input_value:
return input_value
if inputs_value:
return inputs_value
if question_value:
return question_value
if query_value:
return query_value
return _serialize(raw_input)
def _parse_output(raw_output: dict) -> Any:
if not raw_output:
return None
if not isinstance(raw_output, dict):
return _serialize(raw_output)
text_value = raw_output.get("text")
output_value = raw_output.get("output")
output_text_value = raw_output.get("output_text")
answer_value = raw_output.get("answer")
result_value = raw_output.get("result")
if text_value:
return text_value
if answer_value:
return answer_value
if output_value:
return output_value
if output_text_value:
return output_text_value
if result_value:
return result_value
return _serialize(raw_output)
def _parse_lc_role(
role: str,
) -> str:
if role == "human":
return "user"
else:
return role
def _get_user_id(metadata: Any) -> Any:
if user_ctx.get() is not None:
return user_ctx.get()
metadata = metadata or {}
user_id = metadata.get("user_id")
if user_id is None:
user_id = metadata.get("userId") # legacy, to delete in the future
return user_id
def _get_user_props(metadata: Any) -> Any:
if user_props_ctx.get() is not None:
return user_props_ctx.get()
metadata = metadata or {}
return metadata.get("user_props", None)
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
parsed.update(
{
key: cast(Any, message.additional_kwargs.get(key))
for key in keys
if message.additional_kwargs.get(key) is not None
}
)
return parsed
def _parse_lc_messages(messages: Union[List[BaseMessage], Any]) -> List[Dict[str, Any]]:
return [_parse_lc_message(message) for message in messages]
class LLMonitorCallbackHandler(BaseCallbackHandler):
"""Callback Handler for LLMonitor`.
#### Parameters:
- `app_id`: The app id of the app you want to report to. Defaults to
`None`, which means that `LLMONITOR_APP_ID` will be used.
- `api_url`: The url of the LLMonitor API. Defaults to `None`,
which means that either `LLMONITOR_API_URL` environment variable
or `https://app.llmonitor.com` will be used.
#### Raises:
- `ValueError`: if `app_id` is not provided either as an
argument or as an environment variable.
- `ConnectionError`: if the connection to the API fails.
#### Example:
```python
from langchain_community.llms import OpenAI
from langchain_community.callbacks import LLMonitorCallbackHandler
llmonitor_callback = LLMonitorCallbackHandler()
llm = OpenAI(callbacks=[llmonitor_callback],
metadata={"userId": "user-123"})
llm.invoke("Hello, how are you?")
```
"""
__api_url: str
__app_id: str
__verbose: bool
__llmonitor_version: str
__has_valid_config: bool
def __init__(
self,
app_id: Union[str, None] = None,
api_url: Union[str, None] = None,
verbose: bool = False,
) -> None:
super().__init__()
self.__has_valid_config = True
try:
import llmonitor
self.__llmonitor_version = importlib.metadata.version("llmonitor")
self.__track_event = llmonitor.track_event
except ImportError:
logger.warning(
"""[LLMonitor] To use the LLMonitor callback handler you need to
have the `llmonitor` Python package installed. Please install it
with `pip install llmonitor`"""
)
self.__has_valid_config = False
return
if parse(self.__llmonitor_version) < parse("0.0.32"):
logger.warning(
f"""[LLMonitor] The installed `llmonitor` version is
{self.__llmonitor_version}
but `LLMonitorCallbackHandler` requires at least version 0.0.32
upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
)
self.__has_valid_config = False
self.__has_valid_config = True
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
if _app_id is None:
logger.warning(
"""[LLMonitor] app_id must be provided either as an argument or
as an environment variable"""
)
self.__has_valid_config = False
else:
self.__app_id = _app_id
if self.__has_valid_config is False:
return None
try:
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
if not res.ok:
raise ConnectionError()
except Exception:
logger.warning(
f"""[LLMonitor] Could not connect to the LLMonitor API at
{self.__api_url}"""
)
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_input(prompts)
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_lc_messages(messages[0])
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output: Any = [
_parse_lc_message(generation.message)
if hasattr(generation, "message")
else generation.text
for generation in response.generations[0]
]
# if it's an array of 1, just parse the first element
if len(parsed_output) == 1:
parsed_output = parsed_output[0]
self.__track_event(
"llm",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=parsed_output,
token_usage={
"prompt": token_usage.get("prompt_tokens"),
"completion": token_usage.get("completion_tokens"),
},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = serialized.get("name")
self.__track_event(
"tool",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input_str,
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
**kwargs: Any,
) -> None:
output = str(output)
if self.__has_valid_config is False:
return
try:
self.__track_event(
"tool",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
tags: Union[List[str], None] = None,
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
name = serialized.get("id", [None, None, None, None])[3]
type = "chain"
metadata = metadata or {}
agentName = metadata.get("agent_name")
if agentName is None:
agentName = metadata.get("agentName")
if name == "AgentExecutor" or name == "PlanAndExecute":
type = "agent"
if agentName is not None:
type = "agent"
name = agentName
if parent_run_id is not None:
type = "chain"
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
input = _parse_input(inputs)
self.__track_event(
type,
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
output = _parse_output(outputs)
self.__track_event(
"chain",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
name = action.tool
input = _parse_input(action.tool_input)
self.__track_event(
"tool",
"start",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
output = _parse_output(finish.return_values)
self.__track_event(
"agent",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"chain",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"tool",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
self.__track_event(
"llm",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
__all__ = ["LLMonitorCallbackHandler", "identify"]

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import logging
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
Generator,
Optional,
)
from langchain_core.tracers.context import register_configure_hook
from langchain_community.callbacks.bedrock_anthropic_callback import (
BedrockAnthropicTokenUsageCallbackHandler,
)
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain_community.callbacks.tracers.comet import CometTracer
from langchain_community.callbacks.tracers.wandb import WandbTracer
logger = logging.getLogger(__name__)
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
bedrock_anthropic_callback_var: (ContextVar)[
Optional[BedrockAnthropicTokenUsageCallbackHandler]
] = ContextVar("bedrock_anthropic_callback", default=None)
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar(
"tracing_wandb_callback", default=None
)
comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar(
"tracing_comet_callback", default=None
)
register_configure_hook(openai_callback_var, True)
register_configure_hook(bedrock_anthropic_callback_var, True)
register_configure_hook(
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
)
register_configure_hook(
comet_tracing_callback_var, True, CometTracer, "LANGCHAIN_COMET_TRACING"
)
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get the OpenAI callback handler in a context manager.
which conveniently exposes token and cost information.
Returns:
OpenAICallbackHandler: The OpenAI callback handler.
Example:
>>> with get_openai_callback() as cb:
... # Use the OpenAI callback handler
"""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
openai_callback_var.set(None)
@contextmanager
def get_bedrock_anthropic_callback() -> Generator[
BedrockAnthropicTokenUsageCallbackHandler, None, None
]:
"""Get the Bedrock anthropic callback handler in a context manager.
which conveniently exposes token and cost information.
Returns:
BedrockAnthropicTokenUsageCallbackHandler:
The Bedrock anthropic callback handler.
Example:
>>> with get_bedrock_anthropic_callback() as cb:
... # Use the Bedrock anthropic callback handler
"""
cb = BedrockAnthropicTokenUsageCallbackHandler()
bedrock_anthropic_callback_var.set(cb)
yield cb
bedrock_anthropic_callback_var.set(None)
@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
) -> Generator[None, None, None]:
"""Get the WandbTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
None
Example:
>>> with wandb_tracing_enabled() as session:
... # Use the WandbTracer session
"""
cb = WandbTracer()
wandb_tracing_callback_var.set(cb)
yield None
wandb_tracing_callback_var.set(None)

View File

@@ -0,0 +1,769 @@
import logging
import os
import random
import string
import tempfile
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env, guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
logger = logging.getLogger(__name__)
def import_mlflow() -> Any:
"""Import the mlflow python package and raise an error if it is not installed."""
return guard_import("mlflow")
def mlflow_callback_metrics() -> List[str]:
"""Get the metrics to log to MLFlow."""
return [
"step",
"starts",
"ends",
"errors",
"text_ctr",
"chain_starts",
"chain_ends",
"llm_starts",
"llm_ends",
"llm_streams",
"tool_starts",
"tool_ends",
"agent_ends",
"retriever_starts",
"retriever_ends",
]
def get_text_complexity_metrics() -> List[str]:
"""Get the text complexity metrics from textstat."""
return [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
# "text_standard"
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
nlp (spacy.lang): The spacy language model to use for visualization.
textstat: The textstat library to use for complexity metrics calculation.
Returns:
`dict` containing the complexity metrics and visualization
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
if textstat is not None:
text_complexity_metrics = {
key: getattr(textstat, key)(text) for key in get_text_complexity_metrics()
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
doc = nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
"""Construct an html element from a prompt and a generation.
Parameters:
prompt (str): The prompt.
generation (str): The generation.
Returns:
(str): The html string."""
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
return f"""
<p style="color:black;">{formatted_prompt}:</p>
<blockquote>
<p style="color:green;">
{formatted_generation}
</p>
</blockquote>
"""
class MlflowLogger:
"""Callback Handler that logs metrics and artifacts to mlflow server.
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler implements the helper functions to initialize,
log metrics and artifacts to the mlflow server.
"""
def __init__(self, **kwargs: Any):
self.mlflow = import_mlflow()
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
self.mlflow.set_tracking_uri("databricks")
self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id()
self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid)
else:
tracking_uri = get_from_dict_or_env(
kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", ""
)
self.mlflow.set_tracking_uri(tracking_uri)
if run_id := kwargs.get("run_id"):
self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id
else:
# User can set other env variables described here
# > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server
experiment_name = get_from_dict_or_env(
kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME"
)
self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name)
if self.mlf_exp is not None:
self.mlf_expid = self.mlf_exp.experiment_id
else:
self.mlf_expid = self.mlflow.create_experiment(experiment_name)
self.start_run(
kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None)
)
self.dir = kwargs.get("artifacts_dir", "")
def start_run(
self, name: str, tags: Dict[str, str], run_id: Optional[str] = None
) -> None:
"""
If run_id is provided, it will reuse the run with the given run_id.
Otherwise, it starts a new run, auto generates the random suffix for name.
"""
if run_id is None:
if name.endswith("-%"):
rname = "".join(
random.choices(string.ascii_uppercase + string.digits, k=7)
)
name = name[:-1] + rname
run = self.mlflow.MlflowClient().create_run(
self.mlf_expid, run_name=name, tags=tags
)
run_id = run.info.run_id
self.run_id = run_id
def finish_run(self) -> None:
"""To finish the run."""
self.mlflow.end_run()
def metric(self, key: str, value: float) -> None:
"""To log metric to mlflow server."""
self.mlflow.log_metric(key, value, run_id=self.run_id)
def metrics(
self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0
) -> None:
"""To log all metrics in the input dict."""
self.mlflow.log_metrics(data, run_id=self.run_id)
def jsonf(self, data: Dict[str, Any], filename: str) -> None:
"""To log the input data as json file artifact."""
self.mlflow.log_dict(
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
)
def table(self, name: str, dataframe: Any) -> None:
"""To log the input pandas dataframe as a html table"""
self.html(dataframe.to_html(), f"table_{name}")
def html(self, html: str, filename: str) -> None:
"""To log the input html string as html file artifact."""
self.mlflow.log_text(
html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id
)
def text(self, text: str, filename: str) -> None:
"""To log the input text as text file artifact."""
self.mlflow.log_text(
text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id
)
def artifact(self, path: str) -> None:
"""To upload the file from given path as artifact."""
self.mlflow.log_artifact(path, run_id=self.run_id)
def langchain_artifact(self, chain: Any) -> None:
self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs metrics and artifacts to mlflow server.
Parameters:
name (str): Name of the run.
experiment (str): Name of the experiment.
tags (dict): Tags to be attached for the run.
tracking_uri (str): MLflow tracking server uri.
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to mlflow server.
"""
def __init__(
self,
name: Optional[str] = "langchainrun-%",
experiment: Optional[str] = "langchain",
tags: Optional[Dict] = None,
tracking_uri: Optional[str] = None,
run_id: Optional[str] = None,
artifacts_dir: str = "",
) -> None:
"""Initialize callback handler."""
import_pandas()
import_mlflow()
super().__init__()
self.name = name
self.experiment = experiment
self.tags = tags or {}
self.tracking_uri = tracking_uri
self.run_id = run_id
self.artifacts_dir = artifacts_dir
self.temp_dir = tempfile.TemporaryDirectory()
self.mlflg = MlflowLogger(
tracking_uri=self.tracking_uri,
experiment_name=self.experiment,
run_name=self.name,
run_tags=self.tags,
run_id=self.run_id,
artifacts_dir=self.artifacts_dir,
)
self.action_records: list = []
self.nlp = None
try:
spacy = import_spacy()
except ImportError as e:
logger.warning(e.msg)
else:
try:
self.nlp = spacy.load("en_core_web_sm")
except OSError:
logger.warning(
"Run `python -m spacy download en_core_web_sm` "
"to download en_core_web_sm model for text visualization."
)
try:
self.textstat = import_textstat()
except ImportError as e:
logger.warning(e.msg)
self.textstat = None
self.metrics = {key: 0 for key in mlflow_callback_metrics()}
self.records: Dict[str, Any] = {
"on_llm_start_records": [],
"on_llm_token_records": [],
"on_llm_end_records": [],
"on_chain_start_records": [],
"on_chain_end_records": [],
"on_tool_start_records": [],
"on_tool_end_records": [],
"on_text_records": [],
"on_agent_finish_records": [],
"on_agent_action_records": [],
"on_retriever_start_records": [],
"on_retriever_end_records": [],
"action_records": [],
}
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
for k, v in self.records.items():
self.records[k] = []
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.records["on_llm_start_records"].append(prompt_resp)
self.records["action_records"].append(prompt_resp)
self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_llm_token_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
generation.text,
nlp=self.nlp,
textstat=self.textstat,
)
)
if "text_complexity_metrics" in generation_resp:
complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics"
)
self.mlflg.metrics(
complexity_metrics,
step=self.metrics["step"],
)
self.records["on_llm_end_records"].append(generation_resp)
self.records["action_records"].append(generation_resp)
self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}")
if "dependency_tree" in generation_resp:
dependency_tree = generation_resp["dependency_tree"]
self.mlflg.html(
dependency_tree, "dep-" + hash_string(generation.text)
)
if "entities" in generation_resp:
entities = generation_resp["entities"]
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
if isinstance(inputs, dict):
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
elif isinstance(inputs, list):
chain_input = ",".join([str(input) for input in inputs])
else:
chain_input = str(inputs)
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.records["on_chain_start_records"].append(input_resp)
self.records["action_records"].append(input_resp)
self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
def on_chain_end(
self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any
) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
if isinstance(outputs, dict):
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
elif isinstance(outputs, list):
chain_output = ",".join(map(str, outputs))
else:
chain_output = str(outputs)
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_chain_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_tool_start_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
tool_ends = self.metrics["tool_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_tool_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when text is received.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_text_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_agent_finish_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_agent_action_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
self.metrics["step"] += 1
self.metrics["retriever_starts"] += 1
self.metrics["starts"] += 1
retriever_starts = self.metrics["retriever_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_retriever_start", "query": query})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_start_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
self.metrics["step"] += 1
self.metrics["retriever_ends"] += 1
self.metrics["ends"] += 1
retriever_ends = self.metrics["retriever_ends"]
resp: Dict[str, Any] = {}
retriever_documents = [
{
"page_content": doc.page_content,
"metadata": {
k: (
str(v)
if not isinstance(v, list)
else ",".join(str(x) for x in v)
)
for k, v in doc.metadata.items()
},
}
for doc in documents
]
resp.update({"action": "on_retriever_end", "documents": retriever_documents})
resp.update(self.metrics)
self.mlflg.metrics(self.metrics, step=self.metrics["step"])
self.records["on_retriever_end_records"].append(resp)
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any:
"""Run when Retriever errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"])
on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"])
llm_input_columns = ["step", "prompt"]
if "name" in on_llm_start_records_df.columns:
llm_input_columns.append("name")
elif "id" in on_llm_start_records_df.columns:
# id is llm class's full import path. For example:
# ["langchain", "llms", "openai", "AzureOpenAI"]
on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply(
lambda id_: id_[-1]
)
llm_input_columns.append("name")
llm_input_prompts_df = (
on_llm_start_records_df[llm_input_columns]
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = (
get_text_complexity_metrics() if self.textstat is not None else []
)
visualizations_columns = (
["dependency_tree", "entities"] if self.nlp is not None else []
)
token_usage_columns = [
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
token_usage_columns = [
x for x in token_usage_columns if x in on_llm_end_records_df.columns
]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
]
+ token_usage_columns
+ complexity_metrics_columns
+ visualizations_columns
]
.dropna(axis=1)
.rename({"step": "output_step", "text": "output"}, axis=1)
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
session_analysis_df["chat_html"] = session_analysis_df[
["prompt", "output"]
].apply(
lambda row: construct_html_from_prompt_and_generation(
row["prompt"], row["output"]
),
axis=1,
)
return session_analysis_df
def _contain_llm_records(self) -> bool:
return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
pd = import_pandas()
self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"]))
if self._contain_llm_records():
session_analysis_df = self._create_session_analysis_df()
chat_html = session_analysis_df.pop("chat_html")
chat_html = chat_html.replace("\n", "", regex=True)
self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df))
self.mlflg.html("".join(chat_html.tolist()), "chat_html")
if langchain_asset:
# To avoid circular import error
# mlflow only supports LLMChain asset
if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)):
self.mlflg.langchain_artifact(langchain_asset)
else:
langchain_asset_path = str(Path(self.temp_dir.name, "model.json"))
try:
langchain_asset.save(langchain_asset_path)
self.mlflg.artifact(langchain_asset_path)
except ValueError:
try:
langchain_asset.save_agent(langchain_asset_path)
self.mlflg.artifact(langchain_asset_path)
except AttributeError:
print("Could not save model.") # noqa: T201
traceback.print_exc()
pass
except NotImplementedError:
print("Could not save model.") # noqa: T201
traceback.print_exc()
pass
except NotImplementedError:
print("Could not save model.") # noqa: T201
traceback.print_exc()
pass
if finish:
self.mlflg.finish_run()
self._reset()

View File

@@ -0,0 +1,555 @@
"""Callback Handler that prints to std out."""
import threading
from enum import Enum, auto
from typing import Any, Dict, List
from langchain_core._api import warn_deprecated
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
MODEL_COST_PER_1K_TOKENS = {
# GPT-5 input
"gpt-5": 0.00125,
"gpt-5-cached": 0.000125,
"gpt-5-2025-08-07": 0.00125,
"gpt-5-2025-08-07-cached": 0.000125,
# GPT-5 output
"gpt-5-completion": 0.01,
"gpt-5-2025-08-07-completion": 0.01,
# GPT-5-mini input
"gpt-5-mini": 0.00025,
"gpt-5-mini-cached": 0.000025,
"gpt-5-mini-2025-08-07": 0.00025,
"gpt-5-mini-2025-08-07-cached": 0.000025,
# GPT-5-mini output
"gpt-5-mini-completion": 0.002,
"gpt-5-mini-2025-08-07-completion": 0.002,
# GPT-5-nano input
"gpt-5-nano": 0.00005,
"gpt-5-nano-cached": 0.000005,
"gpt-5-nano-2025-08-07": 0.00005,
"gpt-5-nano-2025-08-07-cached": 0.000005,
# GPT-5-nano output
"gpt-5-nano-completion": 0.0004,
"gpt-5-nano-2025-08-07-completion": 0.0004,
# GPT-5-chat-latest input
"gpt-5-chat-latest": 0.00125,
"gpt-5-chat-latest-cached": 0.000125,
"gpt-5-chat-latest-2025-08-07": 0.00125,
"gpt-5-chat-latest-2025-08-07-cached": 0.000125,
# GPT-5-chat-latest output
"gpt-5-chat-latest-completion": 0.01,
"gpt-5-chat-latest-2025-08-07-completion": 0.01,
# GPT-4.1 input
"gpt-4.1": 0.002,
"gpt-4.1-2025-04-14": 0.002,
"gpt-4.1-cached": 0.0005,
"gpt-4.1-2025-04-14-cached": 0.0005,
# GPT-4.1 output
"gpt-4.1-completion": 0.008,
"gpt-4.1-2025-04-14-completion": 0.008,
# GPT-4.1-mini input
"gpt-4.1-mini": 0.0004,
"gpt-4.1-mini-2025-04-14": 0.0004,
"gpt-4.1-mini-cached": 0.0001,
"gpt-4.1-mini-2025-04-14-cached": 0.0001,
# GPT-4.1-mini output
"gpt-4.1-mini-completion": 0.0016,
"gpt-4.1-mini-2025-04-14-completion": 0.0016,
# GPT-4.1-nano input
"gpt-4.1-nano": 0.0001,
"gpt-4.1-nano-2025-04-14": 0.0001,
"gpt-4.1-nano-cached": 0.000025,
"gpt-4.1-nano-2025-04-14-cached": 0.000025,
# GPT-4.1-nano output
"gpt-4.1-nano-completion": 0.0004,
"gpt-4.1-nano-2025-04-14-completion": 0.0004,
# GPT-4.5-preview input
"gpt-4.5-preview": 0.075,
"gpt-4.5-preview-2025-02-27": 0.075,
"gpt-4.5-preview-cached": 0.0375,
"gpt-4.5-preview-2025-02-27-cached": 0.0375,
# GPT-4.5-preview output
"gpt-4.5-preview-completion": 0.15,
"gpt-4.5-preview-2025-02-27-completion": 0.15,
# OpenAI o1 input
"o1": 0.015,
"o1-2024-12-17": 0.015,
"o1-cached": 0.0075,
"o1-2024-12-17-cached": 0.0075,
# OpenAI o1 output
"o1-completion": 0.06,
"o1-2024-12-17-completion": 0.06,
# OpenAI o1-pro input
"o1-pro": 0.15,
"o1-pro-2025-03-19": 0.15,
# OpenAI o1-pro output
"o1-pro-completion": 0.6,
"o1-pro-2025-03-19-completion": 0.6,
# OpenAI o3 input
"o3": 0.002,
"o3-2025-04-16": 0.002,
"o3-cached": 0.0005,
"o3-2025-04-16-cached": 0.0005,
# OpenAI o3 output
"o3-completion": 0.008,
"o3-2025-04-16-completion": 0.008,
# OpenAI o4-mini input
"o4-mini": 0.0011,
"o4-mini-2025-04-16": 0.0011,
"o4-mini-cached": 0.000275,
"o4-mini-2025-04-16-cached": 0.000275,
# OpenAI o4-mini output
"o4-mini-completion": 0.0044,
"o4-mini-2025-04-16-completion": 0.0044,
# OpenAI o3-mini input
"o3-mini": 0.0011,
"o3-mini-2025-01-31": 0.0011,
"o3-mini-cached": 0.00055,
"o3-mini-2025-01-31-cached": 0.00055,
# OpenAI o3-mini output
"o3-mini-completion": 0.0044,
"o3-mini-2025-01-31-completion": 0.0044,
# OpenAI o1-mini input (updated pricing)
"o1-mini": 0.0011,
"o1-mini-cached": 0.00055,
"o1-mini-2024-09-12": 0.0011,
"o1-mini-2024-09-12-cached": 0.00055,
# OpenAI o1-mini output (updated pricing)
"o1-mini-completion": 0.0044,
"o1-mini-2024-09-12-completion": 0.0044,
# OpenAI o1-preview input
"o1-preview": 0.015,
"o1-preview-cached": 0.0075,
"o1-preview-2024-09-12": 0.015,
"o1-preview-2024-09-12-cached": 0.0075,
# OpenAI o1-preview output
"o1-preview-completion": 0.06,
"o1-preview-2024-09-12-completion": 0.06,
# GPT-4o input
"gpt-4o": 0.0025,
"gpt-4o-cached": 0.00125,
"gpt-4o-2024-05-13": 0.005,
"gpt-4o-2024-08-06": 0.0025,
"gpt-4o-2024-08-06-cached": 0.00125,
"gpt-4o-2024-11-20": 0.0025,
"gpt-4o-2024-11-20-cached": 0.00125,
# GPT-4o output
"gpt-4o-completion": 0.01,
"gpt-4o-2024-05-13-completion": 0.015,
"gpt-4o-2024-08-06-completion": 0.01,
"gpt-4o-2024-11-20-completion": 0.01,
# GPT-4o-audio-preview input
"gpt-4o-audio-preview": 0.0025,
"gpt-4o-audio-preview-2024-12-17": 0.0025,
"gpt-4o-audio-preview-2024-10-01": 0.0025,
# GPT-4o-audio-preview output
"gpt-4o-audio-preview-completion": 0.01,
"gpt-4o-audio-preview-2024-12-17-completion": 0.01,
"gpt-4o-audio-preview-2024-10-01-completion": 0.01,
# GPT-4o-realtime-preview input
"gpt-4o-realtime-preview": 0.005,
"gpt-4o-realtime-preview-2024-12-17": 0.005,
"gpt-4o-realtime-preview-2024-10-01": 0.005,
"gpt-4o-realtime-preview-cached": 0.0025,
"gpt-4o-realtime-preview-2024-12-17-cached": 0.0025,
"gpt-4o-realtime-preview-2024-10-01-cached": 0.0025,
# GPT-4o-realtime-preview output
"gpt-4o-realtime-preview-completion": 0.02,
"gpt-4o-realtime-preview-2024-12-17-completion": 0.02,
"gpt-4o-realtime-preview-2024-10-01-completion": 0.02,
# GPT-4o-mini input
"gpt-4o-mini": 0.00015,
"gpt-4o-mini-cached": 0.000075,
"gpt-4o-mini-2024-07-18": 0.00015,
"gpt-4o-mini-2024-07-18-cached": 0.000075,
# GPT-4o-mini output
"gpt-4o-mini-completion": 0.0006,
"gpt-4o-mini-2024-07-18-completion": 0.0006,
# GPT-4o-mini-audio-preview input
"gpt-4o-mini-audio-preview": 0.00015,
"gpt-4o-mini-audio-preview-2024-12-17": 0.00015,
# GPT-4o-mini-audio-preview output
"gpt-4o-mini-audio-preview-completion": 0.0006,
"gpt-4o-mini-audio-preview-2024-12-17-completion": 0.0006,
# GPT-4o-mini-realtime-preview input
"gpt-4o-mini-realtime-preview": 0.0006,
"gpt-4o-mini-realtime-preview-2024-12-17": 0.0006,
"gpt-4o-mini-realtime-preview-cached": 0.0003,
"gpt-4o-mini-realtime-preview-2024-12-17-cached": 0.0003,
# GPT-4o-mini-realtime-preview output
"gpt-4o-mini-realtime-preview-completion": 0.0024,
"gpt-4o-mini-realtime-preview-2024-12-17-completion": 0.0024,
# GPT-4o-mini-search-preview input
"gpt-4o-mini-search-preview": 0.00015,
"gpt-4o-mini-search-preview-2025-03-11": 0.00015,
# GPT-4o-mini-search-preview output
"gpt-4o-mini-search-preview-completion": 0.0006,
"gpt-4o-mini-search-preview-2025-03-11-completion": 0.0006,
# GPT-4o-search-preview input
"gpt-4o-search-preview": 0.0025,
"gpt-4o-search-preview-2025-03-11": 0.0025,
# GPT-4o-search-preview output
"gpt-4o-search-preview-completion": 0.01,
"gpt-4o-search-preview-2025-03-11-completion": 0.01,
# Computer-use-preview input
"computer-use-preview": 0.003,
"computer-use-preview-2025-03-11": 0.003,
# Computer-use-preview output
"computer-use-preview-completion": 0.012,
"computer-use-preview-2025-03-11-completion": 0.012,
# GPT-4 input
"gpt-4": 0.03,
"gpt-4-0314": 0.03,
"gpt-4-0613": 0.03,
"gpt-4-32k": 0.06,
"gpt-4-32k-0314": 0.06,
"gpt-4-32k-0613": 0.06,
"gpt-4-vision-preview": 0.01,
"gpt-4-1106-preview": 0.01,
"gpt-4-0125-preview": 0.01,
"gpt-4-turbo-preview": 0.01,
"gpt-4-turbo": 0.01,
"gpt-4-turbo-2024-04-09": 0.01,
# GPT-4 output
"gpt-4-completion": 0.06,
"gpt-4-0314-completion": 0.06,
"gpt-4-0613-completion": 0.06,
"gpt-4-32k-completion": 0.12,
"gpt-4-32k-0314-completion": 0.12,
"gpt-4-32k-0613-completion": 0.12,
"gpt-4-vision-preview-completion": 0.03,
"gpt-4-1106-preview-completion": 0.03,
"gpt-4-0125-preview-completion": 0.03,
"gpt-4-turbo-preview-completion": 0.03,
"gpt-4-turbo-completion": 0.03,
"gpt-4-turbo-2024-04-09-completion": 0.03,
# GPT-3.5 input
# gpt-3.5-turbo points at gpt-3.5-turbo-0613 until Feb 16, 2024.
# Switches to gpt-3.5-turbo-0125 after.
"gpt-3.5-turbo": 0.0015,
"gpt-3.5-turbo-0125": 0.0005,
"gpt-3.5-turbo-0301": 0.0015,
"gpt-3.5-turbo-0613": 0.0015,
"gpt-3.5-turbo-1106": 0.001,
"gpt-3.5-turbo-instruct": 0.0015,
"gpt-3.5-turbo-16k": 0.003,
"gpt-3.5-turbo-16k-0613": 0.003,
# GPT-3.5 output
# gpt-3.5-turbo points at gpt-3.5-turbo-0613 until Feb 16, 2024.
# Switches to gpt-3.5-turbo-0125 after.
"gpt-3.5-turbo-completion": 0.002,
"gpt-3.5-turbo-0125-completion": 0.0015,
"gpt-3.5-turbo-0301-completion": 0.002,
"gpt-3.5-turbo-0613-completion": 0.002,
"gpt-3.5-turbo-1106-completion": 0.002,
"gpt-3.5-turbo-instruct-completion": 0.002,
"gpt-3.5-turbo-16k-completion": 0.004,
"gpt-3.5-turbo-16k-0613-completion": 0.004,
# Azure GPT-35 input
"gpt-35-turbo": 0.0015, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0125": 0.0005,
"gpt-35-turbo-0301": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613": 0.0015,
"gpt-35-turbo-instruct": 0.0015,
"gpt-35-turbo-16k": 0.003,
"gpt-35-turbo-16k-0613": 0.003,
# Azure GPT-35 output
"gpt-35-turbo-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0125-completion": 0.0015,
"gpt-35-turbo-0301-completion": 0.002, # Azure OpenAI version of ChatGPT
"gpt-35-turbo-0613-completion": 0.002,
"gpt-35-turbo-instruct-completion": 0.002,
"gpt-35-turbo-16k-completion": 0.004,
"gpt-35-turbo-16k-0613-completion": 0.004,
# Others
"text-ada-001": 0.0004,
"ada": 0.0004,
"text-babbage-001": 0.0005,
"babbage": 0.0005,
"text-curie-001": 0.002,
"curie": 0.002,
"text-davinci-003": 0.02,
"text-davinci-002": 0.02,
"code-davinci-002": 0.02,
# Fine Tuned input
"babbage-002-finetuned": 0.0016,
"davinci-002-finetuned": 0.012,
"gpt-3.5-turbo-0613-finetuned": 0.003,
"gpt-3.5-turbo-1106-finetuned": 0.003,
"gpt-3.5-turbo-0125-finetuned": 0.003,
"gpt-4o-mini-2024-07-18-finetuned": 0.0003,
"gpt-4o-mini-2024-07-18-finetuned-cached": 0.00015,
# Fine Tuned output
"babbage-002-finetuned-completion": 0.0016,
"davinci-002-finetuned-completion": 0.012,
"gpt-3.5-turbo-0613-finetuned-completion": 0.006,
"gpt-3.5-turbo-1106-finetuned-completion": 0.006,
"gpt-3.5-turbo-0125-finetuned-completion": 0.006,
"gpt-4o-mini-2024-07-18-finetuned-completion": 0.0012,
# Azure Fine Tuned input
"babbage-002-azure-finetuned": 0.0004,
"davinci-002-azure-finetuned": 0.002,
"gpt-35-turbo-0613-azure-finetuned": 0.0015,
# Azure Fine Tuned output
"babbage-002-azure-finetuned-completion": 0.0004,
"davinci-002-azure-finetuned-completion": 0.002,
"gpt-35-turbo-0613-azure-finetuned-completion": 0.002,
# Legacy fine-tuned models
"ada-finetuned-legacy": 0.0016,
"babbage-finetuned-legacy": 0.0024,
"curie-finetuned-legacy": 0.012,
"davinci-finetuned-legacy": 0.12,
}
class TokenType(Enum):
"""Token type enum."""
PROMPT = auto()
PROMPT_CACHED = auto()
COMPLETION = auto()
def standardize_model_name(
model_name: str,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> str:
"""
Standardize the model name to a format that can be used in the OpenAI API.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.
Returns:
Standardized model name.
"""
if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n"
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = model_name.lower()
if ".ft-" in model_name:
model_name = model_name.split(".ft-")[0] + "-azure-finetuned"
if ":ft-" in model_name:
model_name = model_name.split(":")[0] + "-finetuned-legacy"
if "ft:" in model_name:
model_name = model_name.split(":")[1] + "-finetuned"
if token_type == TokenType.COMPLETION and (
model_name.startswith("gpt-5")
or model_name.startswith("gpt-4")
or model_name.startswith("gpt-3.5")
or model_name.startswith("gpt-35")
or model_name.startswith("o1-")
or model_name.startswith("o3-")
or model_name.startswith("o4-")
or ("finetuned" in model_name and "legacy" not in model_name)
):
return model_name + "-completion"
if (
token_type == TokenType.PROMPT_CACHED
and (
model_name.startswith("gpt-5")
or model_name.startswith("gpt-4o")
or model_name.startswith("gpt-4.1")
or model_name.startswith("o1")
or model_name.startswith("o3")
or model_name.startswith("o4")
)
and not (model_name.startswith("gpt-4o-2024-05-13"))
):
return model_name + "-cached"
else:
return model_name
def get_openai_token_cost_for_model(
model_name: str,
num_tokens: int,
is_completion: bool = False,
*,
token_type: TokenType = TokenType.PROMPT,
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False. Deprecated in favor of ``token_type``.
token_type: Token type. Defaults to ``TokenType.PROMPT``.
Returns:
Cost in USD.
"""
if is_completion:
warn_deprecated(
since="0.3.13",
message=(
"is_completion is deprecated. Use token_type instead. Example:\n\n"
"from langchain_community.callbacks.openai_info import TokenType\n\n"
"get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501
),
removal="1.0",
)
token_type = TokenType.COMPLETION
model_name = standardize_model_name(model_name, token_type=token_type)
if model_name not in MODEL_COST_PER_1K_TOKENS:
raise ValueError(
f"Unknown model: {model_name}. Please provide a valid OpenAI model name."
"Known models are: " + ", ".join(MODEL_COST_PER_1K_TOKENS.keys())
)
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000)
class OpenAICallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks OpenAI info."""
total_tokens: int = 0
prompt_tokens: int = 0
prompt_tokens_cached: int = 0
completion_tokens: int = 0
reasoning_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\t\tPrompt Tokens Cached: {self.prompt_tokens_cached}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"\t\tReasoning Tokens: {self.reasoning_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
try:
generation = response.generations[0][0]
except IndexError:
generation = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
response_metadata = message.response_metadata
else:
usage_metadata = None
response_metadata = None
except AttributeError:
usage_metadata = None
response_metadata = None
else:
usage_metadata = None
response_metadata = None
prompt_tokens_cached = 0
reasoning_tokens = 0
if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"]
prompt_tokens = usage_metadata["input_tokens"]
if response_model_name := (response_metadata or {}).get("model_name"):
model_name = standardize_model_name(response_model_name)
elif response.llm_output is None:
model_name = ""
else:
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)
if "cache_read" in usage_metadata.get("input_token_details", {}):
prompt_tokens_cached = usage_metadata["input_token_details"][
"cache_read"
]
if "reasoning" in usage_metadata.get("output_token_details", {}):
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"]
else:
if response.llm_output is None:
return None
if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None
# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(
response.llm_output.get("model_name", "")
)
if model_name in MODEL_COST_PER_1K_TOKENS:
uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached
uncached_prompt_cost = get_openai_token_cost_for_model(
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT
)
cached_prompt_cost = get_openai_token_cost_for_model(
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED
)
prompt_cost = uncached_prompt_cost + cached_prompt_cost
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, token_type=TokenType.COMPLETION
)
else:
completion_cost = 0
prompt_cost = 0
# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.prompt_tokens_cached += prompt_tokens_cached
self.completion_tokens += completion_tokens
self.reasoning_tokens += reasoning_tokens
self.successful_requests += 1
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
"""Return a deep copy of the callback handler."""
return self

View File

@@ -0,0 +1,163 @@
"""Callback handler for promptlayer."""
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
)
if TYPE_CHECKING:
import promptlayer
def _lazy_import_promptlayer() -> promptlayer:
"""Lazy import promptlayer to avoid circular imports."""
try:
import promptlayer
except ImportError:
raise ImportError(
"The PromptLayerCallbackHandler requires the promptlayer package. "
" Please install it with `pip install promptlayer`."
)
return promptlayer
class PromptLayerCallbackHandler(BaseCallbackHandler):
"""Callback handler for promptlayer."""
def __init__(
self,
pl_id_callback: Optional[Callable[..., Any]] = None,
pl_tags: Optional[List[str]] = None,
) -> None:
"""Initialize the PromptLayerCallbackHandler."""
_lazy_import_promptlayer()
self.pl_id_callback = pl_id_callback
self.pl_tags = pl_tags or []
self.runs: Dict[UUID, Dict[str, Any]] = {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"messages": [self._create_message_dicts(m)[0] for m in messages],
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"prompts": prompts,
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
from promptlayer.utils import get_api_key, promptlayer_api_request
run_info = self.runs.get(run_id, {})
if not run_info:
return
run_info["request_end_time"] = datetime.datetime.now().timestamp()
for i in range(len(response.generations)):
generation = response.generations[i][0]
resp = {
"text": generation.text,
"llm_output": response.llm_output,
}
model_params = run_info.get("invocation_params", {})
is_chat_model = run_info.get("messages", None) is not None
model_input = (
run_info.get("messages", [])[i]
if is_chat_model
else [run_info.get("prompts", [])[i]]
)
model_response = (
[self._convert_message_to_dict(generation.message)]
if is_chat_model and isinstance(generation, ChatGeneration)
else resp
)
pl_request_id = promptlayer_api_request(
run_info.get("name"),
"langchain",
model_input,
model_params,
self.pl_tags,
model_response,
run_info.get("request_start_time"),
run_info.get("request_end_time"),
get_api_key(),
return_pl_id=bool(self.pl_id_callback is not None),
metadata={
"_langchain_run_id": str(run_id),
"_langchain_parent_run_id": str(parent_run_id),
"_langchain_tags": str(run_info.get("tags", [])),
},
)
if self.pl_id_callback:
self.pl_id_callback(pl_request_id)
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {}
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params

View File

@@ -0,0 +1,277 @@
import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import (
flatten_dict,
)
def save_json(data: dict, file_path: str) -> None:
"""Save dict to local file path.
Parameters:
data (dict): The dictionary to be saved.
file_path (str): Local file path.
"""
with open(file_path, "w") as outfile:
json.dump(data, outfile)
class SageMakerCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments.
Parameters:
run (sagemaker.experiments.run.Run): Run object where the experiment is logged.
"""
def __init__(self, run: Any) -> None:
"""Initialize callback handler."""
super().__init__()
self.run = run
self.metrics = {
"step": 0,
"starts": 0,
"ends": 0,
"errors": 0,
"text_ctr": 0,
"chain_starts": 0,
"chain_ends": 0,
"llm_starts": 0,
"llm_ends": 0,
"llm_streams": 0,
"tool_starts": 0,
"tool_ends": 0,
"agent_ends": 0,
}
# Create a temporary directory
self.temp_dir = tempfile.mkdtemp()
def _reset(self) -> None:
for k, v in self.metrics.items():
self.metrics[k] = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.metrics["step"] += 1
self.metrics["llm_starts"] += 1
self.metrics["starts"] += 1
llm_starts = self.metrics["llm_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
for idx, prompt in enumerate(prompts):
prompt_resp = deepcopy(resp)
prompt_resp["prompt"] = prompt
self.jsonf(
prompt_resp,
self.temp_dir,
f"llm_start_{llm_starts}_prompt_{idx}",
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.metrics["step"] += 1
self.metrics["llm_streams"] += 1
llm_streams = self.metrics["llm_streams"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.metrics["step"] += 1
self.metrics["llm_ends"] += 1
self.metrics["ends"] += 1
llm_ends = self.metrics["llm_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.metrics)
for generations in response.generations:
for idx, generation in enumerate(generations):
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
self.jsonf(
resp,
self.temp_dir,
f"llm_end_{llm_ends}_generation_{idx}",
)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.metrics["step"] += 1
self.metrics["chain_starts"] += 1
self.metrics["starts"] += 1
chain_starts = self.metrics["chain_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()])
input_resp = deepcopy(resp)
input_resp["inputs"] = chain_input
self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.metrics["step"] += 1
self.metrics["chain_ends"] += 1
self.metrics["ends"] += 1
chain_ends = self.metrics["chain_ends"]
resp: Dict[str, Any] = {}
chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()])
resp.update({"action": "on_chain_end", "outputs": chain_output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.metrics["step"] += 1
self.metrics["tool_ends"] += 1
self.metrics["ends"] += 1
tool_ends = self.metrics["tool_ends"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.metrics["step"] += 1
self.metrics["text_ctr"] += 1
text_ctr = self.metrics["text_ctr"]
resp: Dict[str, Any] = {}
resp.update({"action": "on_text", "text": text})
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.metrics["step"] += 1
self.metrics["agent_ends"] += 1
self.metrics["ends"] += 1
agent_ends = self.metrics["agent_ends"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.metrics["step"] += 1
self.metrics["tool_starts"] += 1
self.metrics["starts"] += 1
tool_starts = self.metrics["tool_starts"]
resp: Dict[str, Any] = {}
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.metrics)
self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
def jsonf(
self,
data: Dict[str, Any],
data_dir: str,
filename: str,
is_output: Optional[bool] = True,
) -> None:
"""To log the input data as json file artifact."""
file_path = os.path.join(data_dir, f"{filename}.json")
save_json(data, file_path)
self.run.log_file(file_path, name=filename, is_output=is_output)
def flush_tracker(self) -> None:
"""Reset the steps and delete the temporary local directory."""
self._reset()
shutil.rmtree(self.temp_dir)

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
LLMThoughtLabeler as LLMThoughtLabeler,
)
from langchain_community.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
)
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def StreamlitCallbackHandler(
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
) -> BaseCallbackHandler:
"""Callback Handler that writes to a Streamlit app.
This CallbackHandler is geared towards
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
inside a series of Streamlit expanders.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When this
threshold is reached, a new thought will cause the oldest thoughts to be
collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether that
expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
Returns
-------
A new StreamlitCallbackHandler instance.
Note that this is an "auto-updating" API: if the installed version of Streamlit
has a more recent StreamlitCallbackHandler implementation, an instance of that class
will be used.
"""
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
# delegate to it instead of using our built-in handler. The official handler is
# guaranteed to support the same set of kwargs.
try:
from streamlit.external.langchain import (
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler,
)
return OfficialStreamlitCallbackHandler(
parent_container,
max_thought_containers=max_thought_containers,
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
thought_labeler=thought_labeler,
)
except ImportError:
return _InternalStreamlitCallbackHandler(
parent_container,
max_thought_containers=max_thought_containers,
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
thought_labeler=thought_labeler,
)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
from streamlit.type_util import SupportsStr
class ChildType(Enum):
"""Enumerator of the child type."""
MARKDOWN = "MARKDOWN"
EXCEPTION = "EXCEPTION"
class ChildRecord(NamedTuple):
"""Child record as a NamedTuple."""
type: ChildType
kwargs: Dict[str, Any]
dg: DeltaGenerator
class MutableExpander:
"""Streamlit expander that can be renamed and dynamically expanded/collapsed."""
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
"""Create a new MutableExpander.
Parameters
----------
parent_container
The `st.container` that the expander will be created inside.
The expander transparently deletes and recreates its underlying
`st.expander` instance when its label changes, and it uses
`parent_container` to ensure it recreates this underlying expander in the
same location onscreen.
label
The expander's initial label.
expanded
The expander's initial `expanded` value.
"""
self._label = label
self._expanded = expanded
self._parent_cursor = parent_container.empty()
self._container = self._parent_cursor.expander(label, expanded)
self._child_records: List[ChildRecord] = []
@property
def label(self) -> str:
"""Expander's label string."""
return self._label
@property
def expanded(self) -> bool:
"""True if the expander was created with `expanded=True`."""
return self._expanded
def clear(self) -> None:
"""Remove the container and its contents entirely. A cleared container can't
be reused.
"""
self._container = self._parent_cursor.empty()
self._child_records.clear()
def append_copy(self, other: MutableExpander) -> None:
"""Append a copy of another MutableExpander's children to this
MutableExpander.
"""
other_records = other._child_records.copy()
for record in other_records:
self._create_child(record.type, record.kwargs)
def update(
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
) -> None:
"""Change the expander's label and expanded state"""
if new_label is None:
new_label = self._label
if new_expanded is None:
new_expanded = self._expanded
if self._label == new_label and self._expanded == new_expanded:
# No change!
return
self._label = new_label
self._expanded = new_expanded
self._container = self._parent_cursor.expander(new_label, new_expanded)
prev_records = self._child_records
self._child_records = []
# Replay all children into the new container
for record in prev_records:
self._create_child(record.type, record.kwargs)
def markdown(
self,
body: SupportsStr,
unsafe_allow_html: bool = False,
*,
help: Optional[str] = None,
index: Optional[int] = None,
) -> int:
"""Add a Markdown element to the container and return its index."""
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
new_dg = self._get_dg(index).markdown(**kwargs)
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
return self._add_record(record, index)
def exception(
self, exception: BaseException, *, index: Optional[int] = None
) -> int:
"""Add an Exception element to the container and return its index."""
kwargs = {"exception": exception}
new_dg = self._get_dg(index).exception(**kwargs)
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
return self._add_record(record, index)
def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
"""Create a new child with the given params"""
if type == ChildType.MARKDOWN:
self.markdown(**kwargs)
elif type == ChildType.EXCEPTION:
self.exception(**kwargs)
else:
raise RuntimeError(f"Unexpected child type {type}")
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
"""Add a ChildRecord to self._children. If `index` is specified, replace
the existing record at that index. Otherwise, append the record to the
end of the list.
Return the index of the added record.
"""
if index is not None:
# Replace existing child
self._child_records[index] = record
return index
# Append new child
self._child_records.append(record)
return len(self._child_records) - 1
def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
if index is not None:
# Existing index: reuse child's DeltaGenerator
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
return self._child_records[index].dg
# No index: use container's DeltaGenerator
return self._container

View File

@@ -0,0 +1,419 @@
"""Callback Handler that prints to streamlit."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
(space, space, newline).
"""
return text.replace("\n", " \n")
CHECKMARK_EMOJI = ""
THINKING_EMOJI = ":thinking_face:"
HISTORY_EMOJI = ":books:"
EXCEPTION_EMOJI = "⚠️"
class LLMThoughtState(Enum):
"""Enumerator of the LLMThought state."""
# The LLM is thinking about what to do next. We don't know which tool we'll run.
THINKING = "THINKING"
# The LLM has decided to run a tool. We don't have results from the tool yet.
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
class ToolRecord(NamedTuple):
"""Tool record as a NamedTuple."""
name: str
input_str: str
class LLMThoughtLabeler:
"""
Generates markdown labels for LLMThought containers. Pass a custom
subclass of this to StreamlitCallbackHandler to override its default
labeling logic.
"""
@staticmethod
def get_initial_label() -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
return f"{THINKING_EMOJI} **Thinking...**"
@staticmethod
def get_tool_label(tool: ToolRecord, is_complete: bool) -> str:
"""Return the label for an LLMThought that has an associated
tool.
Parameters
----------
tool
The tool's ToolRecord
is_complete
True if the thought is complete; False if the thought
is still receiving input.
Returns
-------
The markdown label for the thought's container.
"""
input = tool.input_str
name = tool.name
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
if name == "_Exception":
emoji = EXCEPTION_EMOJI
name = "Parsing error"
idx = min([60, len(input)])
input = input[0:idx]
if len(tool.input_str) > idx:
input = input + "..."
input = input.replace("\n", " ")
label = f"{emoji} **{name}:** {input}"
return label
@staticmethod
def get_history_label() -> str:
"""Return a markdown label for the special 'history' container
that contains overflow thoughts.
"""
return f"{HISTORY_EMOJI} **History**"
@staticmethod
def get_final_agent_thought_label() -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
return f"{CHECKMARK_EMOJI} **Complete!**"
class LLMThought:
"""A thought in the LLM's thought stream."""
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
):
"""Initialize the LLMThought.
Args:
parent_container: The container we're writing into.
labeler: The labeler to use for this thought.
expanded: Whether the thought should be expanded by default.
collapse_on_complete: Whether the thought should be collapsed.
"""
self._container = MutableExpander(
parent_container=parent_container,
label=labeler.get_initial_label(),
expanded=expanded,
)
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
self._llm_token_writer_idx: Optional[int] = None
self._last_tool: Optional[ToolRecord] = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
def container(self) -> MutableExpander:
"""The container we're writing into."""
return self._container
@property
def last_tool(self) -> Optional[ToolRecord]:
"""The last tool executed by this thought"""
return self._last_tool
def _reset_llm_token_stream(self) -> None:
self._llm_token_stream = ""
self._llm_token_writer_idx = None
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
self._reset_llm_token_stream()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
self._llm_token_writer_idx = self._container.markdown(
self._llm_token_stream, index=self._llm_token_writer_idx
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
# If we're receiving streaming tokens from `on_llm_new_token`, this response
# data is redundant
self._reset_llm_token_stream()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
# Called with the name of the tool we're about to run (in `serialized[name]`),
# and its input. We change our container's label to be the tool name.
self._state = LLMThoughtState.RUNNING_TOOL
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
)
def on_tool_end(
self,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._container.markdown(f"**{str(output)}**")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
# Called when we're about to kick off a new tool. The `action` data
# tells us the tool we're about to use, and the input we'll give it.
# We don't output anything here, because we'll receive this same data
# when `on_tool_start` is called immediately after.
pass
def complete(self, final_label: Optional[str] = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
assert self._last_tool is not None, (
"_last_tool should never be null when _state == RUNNING_TOOL"
)
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
self._state = LLMThoughtState.COMPLETE
if self._collapse_on_complete:
self._container.update(new_label=final_label, new_expanded=False)
else:
self._container.update(new_label=final_label)
def clear(self) -> None:
"""Remove the thought from the screen. A cleared thought can't be reused."""
self._container.clear()
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback handler that writes to a Streamlit app."""
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
):
"""Create a StreamlitCallbackHandler instance.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
that expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
self._history_container: Optional[MutableExpander] = None
self._current_thought: Optional[LLMThought] = None
self._completed_thoughts: List[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
self._expand_new_thoughts = expand_new_thoughts
self._collapse_completed_thoughts = collapse_completed_thoughts
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
def _require_current_thought(self) -> LLMThought:
"""Return our current LLMThought. Raise an error if we have no current
thought.
"""
if self._current_thought is None:
raise RuntimeError("Current LLMThought is unexpectedly None!")
return self._current_thought
def _get_last_completed_thought(self) -> Optional[LLMThought]:
"""Return our most recent completed LLMThought, or None if we don't have one."""
if len(self._completed_thoughts) > 0:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
@property
def _num_thought_containers(self) -> int:
"""The number of 'thought containers' we're currently showing: the
number of completed thought containers, the history container (if it exists),
and the current thought container (if it exists).
"""
count = len(self._completed_thoughts)
if self._history_container is not None:
count += 1
if self._current_thought is not None:
count += 1
return count
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
"""
thought = self._require_current_thought()
thought.complete(final_label)
self._completed_thoughts.append(thought)
self._current_thought = None
def _prune_old_thought_containers(self) -> None:
"""If we have too many thoughts onscreen, move older thoughts to the
'history container.'
"""
while (
self._num_thought_containers > self._max_thought_containers
and len(self._completed_thoughts) > 0
):
# Create our history container if it doesn't exist, and if
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
# have room to show history.)
if self._history_container is None and self._max_thought_containers > 1:
self._history_container = MutableExpander(
self._history_parent,
label=self._thought_labeler.get_history_label(),
expanded=False,
)
oldest_thought = self._completed_thoughts.pop(0)
if self._history_container is not None:
self._history_container.markdown(oldest_thought.container.label)
self._history_container.append_copy(oldest_thought.container)
oldest_thought.clear()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThought(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
self._prune_old_thought_containers()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers()
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
self._prune_old_thought_containers()
def on_tool_end(
self,
output: Any,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
output = str(output)
self._require_current_thought().on_tool_end(
output, color, observation_prefix, llm_prefix, **kwargs
)
self._complete_current_thought()
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
self._prune_old_thought_containers()
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
self._prune_old_thought_containers()
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
)
self._current_thought = None

View File

@@ -0,0 +1,16 @@
"""Tracers that record execution of LangChain runs."""
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
)
from langchain_community.callbacks.tracers.wandb import WandbTracer
__all__ = [
"ConsoleCallbackHandler",
"FunctionCallbackHandler",
"LangChainTracer",
"WandbTracer",
]

View File

@@ -0,0 +1,135 @@
from types import ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Dict
from langchain_core.tracers import BaseTracer
from langchain_core.utils import guard_import
if TYPE_CHECKING:
from uuid import UUID
from comet_llm import Span
from comet_llm.chains.chain import Chain
from langchain_community.callbacks.tracers.schemas import Run
def _get_run_type(run: "Run") -> str:
if isinstance(run.run_type, str):
return run.run_type
elif hasattr(run.run_type, "value"):
return run.run_type.value
else:
return str(run.run_type)
def import_comet_llm_api() -> SimpleNamespace:
"""Import comet_llm api and raise an error if it is not installed."""
comet_llm = guard_import("comet_llm")
comet_llm_chains = guard_import("comet_llm.chains")
return SimpleNamespace(
chain=comet_llm_chains.chain,
span=comet_llm_chains.span,
chain_api=comet_llm_chains.api,
experiment_info=comet_llm.experiment_info,
flush=comet_llm.flush,
)
class CometTracer(BaseTracer):
"""Comet Tracer."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the Comet Tracer."""
super().__init__(**kwargs)
self._span_map: Dict["UUID", "Span"] = {}
"""Map from run id to span."""
self._chains_map: Dict["UUID", "Chain"] = {}
"""Map from run id to chain."""
self._initialize_comet_modules()
def _initialize_comet_modules(self) -> None:
comet_llm_api = import_comet_llm_api()
self._chain: ModuleType = comet_llm_api.chain
self._span: ModuleType = comet_llm_api.span
self._chain_api: ModuleType = comet_llm_api.chain_api
self._experiment_info: ModuleType = comet_llm_api.experiment_info
self._flush: Callable[[], None] = comet_llm_api.flush
def _persist_run(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
chain_ = self._chains_map[run.id]
chain_.set_outputs(outputs=run_dict["outputs"])
self._chain_api.log_chain(chain_)
def _process_start_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id:
# This is the first run, which maps to a chain
metadata = run_dict["extra"].get("metadata", None)
chain_: "Chain" = self._chain.Chain(
inputs=run_dict["inputs"],
metadata=metadata,
experiment_info=self._experiment_info.get(),
)
self._chains_map[run.id] = chain_
else:
span: "Span" = self._span.Span(
inputs=run_dict["inputs"],
category=_get_run_type(run),
metadata=run_dict["extra"],
name=run.name,
)
span.__api__start__(self._chains_map[run.parent_run_id])
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
self._span_map[run.id] = span
def _process_end_trace(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
if not run.parent_run_id:
pass
# Langchain will call _persist_run for us
else:
span = self._span_map[run.id]
span.set_outputs(outputs=run_dict["outputs"])
span.__api__end__()
def flush(self) -> None:
self._flush()
def _on_llm_start(self, run: "Run") -> None:
"""Process the LLM Run upon start."""
self._process_start_trace(run)
def _on_llm_end(self, run: "Run") -> None:
"""Process the LLM Run."""
self._process_end_trace(run)
def _on_llm_error(self, run: "Run") -> None:
"""Process the LLM Run upon error."""
self._process_end_trace(run)
def _on_chain_start(self, run: "Run") -> None:
"""Process the Chain Run upon start."""
self._process_start_trace(run)
def _on_chain_end(self, run: "Run") -> None:
"""Process the Chain Run."""
self._process_end_trace(run)
def _on_chain_error(self, run: "Run") -> None:
"""Process the Chain Run upon error."""
self._process_end_trace(run)
def _on_tool_start(self, run: "Run") -> None:
"""Process the Tool Run upon start."""
self._process_start_trace(run)
def _on_tool_end(self, run: "Run") -> None:
"""Process the Tool Run."""
self._process_end_trace(run)
def _on_tool_error(self, run: "Run") -> None:
"""Process the Tool Run upon error."""
self._process_end_trace(run)

View File

@@ -0,0 +1,507 @@
"""A Tracer Implementation that records activity to Weights & Biases."""
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)
from langchain_core._api import warn_deprecated
from langchain_core.output_parsers.pydantic import PydanticBaseModel
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from wandb import Settings as WBSettings
from wandb.sdk.data_types.trace_tree import Trace
from wandb.sdk.lib.paths import StrPath
from wandb.wandb_run import Run as WBRun
PRINT_WARNINGS = True
def _serialize_io(run_io: Optional[dict]) -> dict:
"""Utility to serialize the input and output of a run to store in wandb.
Currently, supports serializing pydantic models and protobuf messages.
:param run_io: The inputs and outputs of the run.
:return: The serialized inputs and outputs.
"""
if not run_io:
return {}
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
serialized_inputs = {}
for key, value in run_io.items():
if isinstance(value, Message):
serialized_inputs[key] = MessageToJson(value)
elif isinstance(value, PydanticBaseModel):
serialized_inputs[key] = (
value.model_dump_json()
if hasattr(value, "model_dump_json")
else value.json()
)
elif key == "input_documents":
serialized_inputs.update(
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
)
else:
serialized_inputs[key] = value
return serialized_inputs
def flatten_run(run: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Utility to flatten a nest run object into a list of runs.
:param run: The base run to flatten.
:return: The flattened list of runs.
"""
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Utility to recursively flatten a list of child runs in a run.
:param child_runs: The list of child runs to flatten.
:return: The flattened list of runs.
"""
if child_runs is None:
return []
result = []
for item in child_runs:
child_runs = item.pop("child_runs", [])
result.append(item)
result.extend(flatten(child_runs))
return result
return flatten([run])
def truncate_run_iterative(
runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
) -> List[Dict[str, Any]]:
"""Utility to truncate a list of runs dictionaries to only keep the specified
keys in each run.
:param runs: The list of runs to truncate.
:param keep_keys: The keys to keep in each run.
:return: The truncated list of runs.
"""
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
"""Utility to truncate a single run dictionary to only keep the specified
keys.
:param run: The run dictionary to truncate.
:return: The truncated run dictionary
"""
new_dict = {}
for key in run:
if key in keep_keys:
new_dict[key] = run.get(key)
return new_dict
return list(map(truncate_single, runs))
def modify_serialized_iterative(
runs: List[Dict[str, Any]],
exact_keys: Tuple[str, ...] = (),
partial_keys: Tuple[str, ...] = (),
) -> List[Dict[str, Any]]:
"""Utility to modify the serialized field of a list of runs dictionaries.
removes any keys that match the exact_keys and any keys that contain any of the
partial_keys.
recursively moves the dictionaries under the kwargs key to the top level.
changes the "id" field to a string "_kind" field that tells WBTraceTree how to
visualize the run. promotes the "serialized" field to the top level.
:param runs: The list of runs to modify.
:param exact_keys: A tuple of keys to remove from the serialized field.
:param partial_keys: A tuple of partial keys to remove from the serialized
field.
:return: The modified list of runs.
"""
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively removes exact and partial keys from a dictionary.
:param obj: The dictionary to remove keys from.
:return: The modified dictionary.
"""
if isinstance(obj, dict):
obj = {
k: v
for k, v in obj.items()
if k not in exact_keys
and not any(partial in k for partial in partial_keys)
}
for k, v in obj.items():
obj[k] = remove_exact_and_partial_keys(v)
elif isinstance(obj, list):
obj = [remove_exact_and_partial_keys(x) for x in obj]
return obj
def handle_id_and_kwargs(obj: Dict[str, Any], root: bool = False) -> Dict[str, Any]:
"""Recursively handles the id and kwargs fields of a dictionary.
changes the id field to a string "_kind" field that tells WBTraceTree how
to visualize the run. recursively moves the dictionaries under the kwargs
key to the top level.
:param obj: a run dictionary with id and kwargs fields.
:param root: whether this is the root dictionary or the serialized
dictionary.
:return: The modified dictionary.
"""
if isinstance(obj, dict):
if "data" in obj and isinstance(obj["data"], dict):
obj = obj["data"]
if ("id" in obj or "name" in obj) and not root:
_kind = obj.get("id")
if not _kind:
_kind = [obj.get("name")]
if isinstance(_kind, list):
obj["_kind"] = _kind[-1]
obj.pop("id", None)
obj.pop("name", None)
if "kwargs" in obj:
kwargs = obj.pop("kwargs")
for k, v in kwargs.items():
obj[k] = v
for k, v in obj.items():
obj[k] = handle_id_and_kwargs(v)
elif isinstance(obj, list):
obj = [handle_id_and_kwargs(x) for x in obj]
return obj
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""Transforms the serialized field of a run dictionary to be compatible
with WBTraceTree.
:param serialized: The serialized field of a run dictionary.
:return: The transformed serialized field.
"""
serialized = handle_id_and_kwargs(serialized, root=True)
serialized = remove_exact_and_partial_keys(serialized)
return serialized
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
"""Transforms a run dictionary to be compatible with WBTraceTree.
:param run: The run dictionary to transform.
:return: The transformed run dictionary.
"""
transformed_dict = transform_serialized(run)
serialized = transformed_dict.pop("serialized")
for k, v in serialized.items():
transformed_dict[k] = v
_kind = transformed_dict.get("_kind", None)
name = transformed_dict.pop("name", None)
if not name:
name = _kind
output_dict = {
f"{name}": transformed_dict,
}
return output_dict
return list(map(transform_run, runs))
def build_tree(runs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Builds a nested dictionary from a list of runs.
:param runs: The list of runs to build the tree from.
:return: The nested dictionary representing the langchain Run in a tree
structure compatible with WBTraceTree.
"""
id_to_data = {}
child_to_parent = {}
for entity in runs:
for key, data in entity.items():
id_val = data.pop("id", None)
parent_run_id = data.pop("parent_run_id", None)
id_to_data[id_val] = {key: data}
if parent_run_id:
child_to_parent[id_val] = parent_run_id
for child_id, parent_id in child_to_parent.items():
parent_dict = id_to_data[parent_id]
parent_dict[next(iter(parent_dict))][next(iter(id_to_data[child_id]))] = (
id_to_data[child_id][next(iter(id_to_data[child_id]))]
)
root_dict = next(
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
)
return root_dict
class WandbRunArgs(TypedDict):
"""Arguments for the WandbTracer."""
job_type: Optional[str]
dir: Optional[StrPath]
config: Union[Dict, str, None]
project: Optional[str]
entity: Optional[str]
reinit: Optional[bool]
tags: Optional[Sequence]
group: Optional[str]
name: Optional[str]
notes: Optional[str]
magic: Optional[Union[dict, str, bool]]
config_exclude_keys: Optional[List[str]]
config_include_keys: Optional[List[str]]
anonymous: Optional[str]
mode: Optional[str]
allow_val_change: Optional[bool]
resume: Optional[Union[bool, str]]
force: Optional[bool]
tensorboard: Optional[bool]
sync_tensorboard: Optional[bool]
monitor_gym: Optional[bool]
save_code: Optional[bool]
id: Optional[str]
settings: Union[WBSettings, Dict[str, Any], None]
class WandbTracer(BaseTracer):
"""Callback Handler that logs to Weights and Biases.
This handler will log the model architecture and run traces to Weights and Biases.
This will ensure that all LangChain activity is logged to W&B.
"""
_run: Optional[WBRun] = None
_run_args: Optional[WandbRunArgs] = None
def __init__(
self,
run_args: Optional[WandbRunArgs] = None,
io_serializer: Callable = _serialize_io,
**kwargs: Any,
) -> None:
"""Initializes the WandbTracer.
Parameters:
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
provided, `wandb.init()` will be called with no arguments. Please
refer to the `wandb.init` for more details.
io_serializer: callable A function that serializes the input and outputs
of a run to store in wandb. Defaults to "_serialize_io"
To use W&B to monitor all LangChain activity, add this tracer like any other
LangChain callback:
```
from wandb.integration.langchain import WandbTracer
tracer = WandbTracer()
chain = LLMChain(llm, callbacks=[tracer])
# ...end of notebook / script:
tracer.finish()
```
"""
super().__init__(**kwargs)
try:
import wandb
from wandb.sdk.data_types import trace_tree
except ImportError as e:
raise ImportError(
"Could not import wandb python package."
"Please install it with `pip install -U wandb`."
) from e
self._wandb = wandb
self._trace_tree = trace_tree
self._run_args = run_args
self._ensure_run(should_print_url=(wandb.run is None))
self._io_serializer = io_serializer
warn_deprecated(
"0.3.8",
pending=False,
message=(
"Please use the `WeaveTracer` from the `weave` package instead of this."
"The `WeaveTracer` is a more flexible and powerful tool for logging "
"and tracing your LangChain callables."
"Find more information at https://weave-docs.wandb.ai/guides/integrations/langchain"
),
alternative=(
"Please instantiate the WeaveTracer from "
"`weave.integrations.langchain import WeaveTracer` ."
"For autologging simply use `weave.init()` and log all traces "
"from your LangChain callables."
),
)
def finish(self) -> None:
"""Waits for all asynchronous processes to finish and data to upload.
Proxy for `wandb.finish()`.
"""
self._wandb.finish()
def _ensure_run(self, should_print_url: bool = False) -> None:
"""Ensures an active W&B run exists.
If not, will start a new run with the provided run_args.
"""
if self._wandb.run is None:
run_args: Dict = {**(self._run_args or {})}
if "settings" not in run_args:
run_args["settings"] = {"silent": True}
self._wandb.init(**run_args)
if self._wandb.run is not None:
if should_print_url:
run_url = self._wandb.run.settings.run_url
self._wandb.termlog(
f"Streaming LangChain activity to W&B at {run_url}\n"
"`WandbTracer` is currently in beta.\n"
"Please report any issues to "
"https://github.com/wandb/wandb/issues with the tag "
"`langchain`."
)
self._wandb.run._label(repo="langchain")
def process_model_dict(self, run: Run) -> Optional[Dict[str, Any]]:
"""Utility to process a run for wandb model_dict serialization.
:param run: The run to process.
:return: The convert model_dict to pass to WBTraceTree.
"""
try:
data = json.loads(run.json())
processed = flatten_run(data)
keep_keys = (
"id",
"name",
"serialized",
"parent_run_id",
)
processed = truncate_run_iterative(processed, keep_keys=keep_keys)
exact_keys, partial_keys = (
("lc", "type", "graph"),
(
"api_key",
"input",
"output",
),
)
processed = modify_serialized_iterative(
processed, exact_keys=exact_keys, partial_keys=partial_keys
)
output = build_tree(processed)
return output
except Exception as e:
if PRINT_WARNINGS:
self._wandb.termerror(f"WARNING: Failed to serialize model: {e}")
return None
def _log_trace_from_run(self, run: Run) -> None:
"""Logs a LangChain Run to W*B as a W&B Trace."""
self._ensure_run()
def create_trace(
run: "Run", parent: Optional["Trace"] = None
) -> Optional["Trace"]:
"""
Create a trace for a given run and its child runs.
Args:
run (Run): The run for which to create a trace.
parent (Optional[Trace]): The parent trace.
If provided, the created trace is added as a child to the parent trace.
Returns:
The created trace. If an error occurs during the creation of the trace,
None is returned.
Raises:
Exception: If an error occurs during the creation of the trace,
no exception is raised and a warning is printed.
"""
def get_metadata_dict(r: "Run") -> Dict[str, Any]:
"""
Extract metadata from a given run.
This function extracts metadata from a given run
and returns it as a dictionary.
Args:
r (Run): The run from which to extract metadata.
Returns:
`dict` containing the extracted metadata.
"""
run_dict = json.loads(r.json())
metadata_dict = run_dict.get("metadata", {})
metadata_dict["run_id"] = run_dict.get("id")
metadata_dict["parent_run_id"] = run_dict.get("parent_run_id")
metadata_dict["tags"] = run_dict.get("tags")
metadata_dict["execution_order"] = run_dict.get(
"dotted_order", ""
).count(".")
return metadata_dict
try:
if run.run_type in ["llm", "tool"]:
run_type = run.run_type
elif run.run_type == "chain":
run_type = "agent" if "agent" in run.name.lower() else "chain"
else:
run_type = None
metadata = get_metadata_dict(run)
trace_tree = self._trace_tree.Trace(
name=run.name,
kind=run_type,
status_code="error" if run.error else "success",
start_time_ms=int(run.start_time.timestamp() * 1000)
if run.start_time is not None
else None,
end_time_ms=int(run.end_time.timestamp() * 1000)
if run.end_time is not None
else None,
metadata=metadata,
inputs=self._io_serializer(run.inputs),
outputs=self._io_serializer(run.outputs),
)
# If the run has child runs, recursively create traces for them
for child_run in run.child_runs:
create_trace(child_run, trace_tree)
if parent is None:
return trace_tree
else:
parent.add_child(trace_tree)
return parent
except Exception as e:
if PRINT_WARNINGS:
self._wandb.termwarn(
f"WARNING: Failed to serialize trace for run due to: {e}"
)
return None
run_trace = create_trace(run)
model_dict = self.process_model_dict(run)
if model_dict is not None and run_trace is not None:
run_trace._model_dict = model_dict
if self._wandb.run is not None and run_trace is not None:
run_trace.log("langchain_trace")
def _persist_run(self, run: "Run") -> None:
"""Persist a run."""
self._log_trace_from_run(run)

View File

@@ -0,0 +1,125 @@
import os
from typing import Any, Dict, List, Optional
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import LLMResult
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class TrubricsCallbackHandler(BaseCallbackHandler):
"""
Callback handler for Trubrics.
Args:
project: a trubrics project, default project is "default"
email: a trubrics account email, can equally be set in env variables
password: a trubrics account password, can equally be set in env variables
**kwargs: all other kwargs are parsed and set to trubrics prompt variables,
or added to the `metadata` dict
"""
def __init__(
self,
project: str = "default",
email: Optional[str] = None,
password: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__()
try:
from trubrics import Trubrics
except ImportError:
raise ImportError(
"The TrubricsCallbackHandler requires installation of "
"the trubrics package. "
"Please install it with `pip install trubrics`."
)
self.trubrics = Trubrics(
project=project,
email=email or os.environ["TRUBRICS_EMAIL"],
password=password or os.environ["TRUBRICS_PASSWORD"],
)
self.config_model: dict = {}
self.prompt: Optional[str] = None
self.messages: Optional[list] = None
self.trubrics_kwargs: Optional[dict] = kwargs if kwargs else None
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.prompt = prompts[0]
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
self.prompt = self.messages[-1]["content"]
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
tags = ["langchain"]
user_id = None
session_id = None
metadata: dict = {"langchain_run_id": run_id}
if self.messages:
metadata["messages"] = self.messages
if self.trubrics_kwargs:
if self.trubrics_kwargs.get("tags"):
tags.append(*self.trubrics_kwargs.pop("tags"))
user_id = self.trubrics_kwargs.pop("user_id", None)
session_id = self.trubrics_kwargs.pop("session_id", None)
metadata.update(self.trubrics_kwargs)
for generation in response.generations:
self.trubrics.log_prompt(
config_model={
"model": response.llm_output.get("model_name")
if response.llm_output
else "NA"
},
prompt=self.prompt,
generation=generation[0].text,
user_id=user_id,
session_id=session_id,
tags=tags,
metadata=metadata,
)

View File

@@ -0,0 +1,206 @@
"""Ratelimiting Handler to limit requests or tokens"""
import logging
from typing import Any, Dict, List, Literal, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
logger = logging.getLogger(__name__)
try:
from upstash_ratelimit import Ratelimit
except ImportError:
Ratelimit = None
class UpstashRatelimitError(Exception):
"""
Upstash Ratelimit Error
Raised when the rate limit is reached in `UpstashRatelimitHandler`
"""
def __init__(
self,
message: str,
type: Literal["token", "request"],
limit: Optional[int] = None,
reset: Optional[float] = None,
):
"""
Args:
message (str): error message
type (str): The kind of the limit which was reached. One of
"token" or "request"
limit (Optional[int]): The limit which was reached. Passed when type
is request
reset (Optional[int]): unix timestamp in milliseconds when the limits
are reset. Passed when type is request
"""
# Call the base class constructor with the parameters it needs
super().__init__(message)
self.type = type
self.limit = limit
self.reset = reset
class UpstashRatelimitHandler(BaseCallbackHandler):
"""
Callback to handle rate limiting based on the number of requests
or the number of tokens in the input.
It uses Upstash Ratelimit to track the ratelimit which utilizes
Upstash Redis to track the state.
Should not be passed to the chain when initialising the chain.
This is because the handler has a state which should be fresh
every time invoke is called. Instead, initialise and pass a handler
every time you invoke.
"""
raise_error: bool = True
_checked: bool = False
def __init__(
self,
identifier: str,
*,
token_ratelimit: Optional[Ratelimit] = None,
request_ratelimit: Optional[Ratelimit] = None,
include_output_tokens: bool = False,
):
"""
Creates UpstashRatelimitHandler. Must be passed an identifier to
ratelimit like a user id or an ip address.
Additionally, it must be passed at least one of token_ratelimit
or request_ratelimit parameters.
Args:
identifier Union[int, str]: the identifier
token_ratelimit Optional[Ratelimit]: Ratelimit to limit the
number of tokens. Only works with OpenAI models since only
these models provide the number of tokens as information
in their output.
request_ratelimit Optional[Ratelimit]: Ratelimit to limit the
number of requests
include_output_tokens bool: Whether to count output tokens when
rate limiting based on number of tokens. Only used when
`token_ratelimit` is passed. False by default.
Example:
.. code-block:: python
from upstash_redis import Redis
from upstash_ratelimit import Ratelimit, FixedWindow
redis = Redis.from_env()
ratelimit = Ratelimit(
redis=redis,
# fixed window to allow 10 requests every 10 seconds:
limiter=FixedWindow(max_requests=10, window=10),
)
user_id = "foo"
handler = UpstashRatelimitHandler(
identifier=user_id,
request_ratelimit=ratelimit
)
# Initialize a simple runnable to test
chain = RunnableLambda(str)
# pass handler as callback:
output = chain.invoke(
"input",
config={
"callbacks": [handler]
}
)
"""
if not any([token_ratelimit, request_ratelimit]):
raise ValueError(
"You must pass at least one of input_token_ratelimit or"
" request_ratelimit parameters for handler to work."
)
self.identifier = identifier
self.token_ratelimit = token_ratelimit
self.request_ratelimit = request_ratelimit
self.include_output_tokens = include_output_tokens
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
"""
Run when chain starts running.
on_chain_start runs multiple times during a chain execution. To make
sure that it's only called once, we keep a bool state `_checked`. If
not `self._checked`, we call limit with `request_ratelimit` and raise
`UpstashRatelimitError` if the identifier is rate limited.
"""
if self.request_ratelimit and not self._checked:
response = self.request_ratelimit.limit(self.identifier)
if not response.allowed:
raise UpstashRatelimitError(
"Request limit reached!", "request", response.limit, response.reset
)
self._checked = True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""
Run when LLM starts running
"""
if self.token_ratelimit:
remaining = self.token_ratelimit.get_remaining(self.identifier)
if remaining <= 0:
raise UpstashRatelimitError("Token limit reached!", "token")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""
Run when LLM ends running
If the `include_output_tokens` is set to True, number of tokens
in LLM completion are counted for rate limiting
"""
if self.token_ratelimit:
try:
llm_output = response.llm_output or {}
token_usage = llm_output["token_usage"]
token_count = (
token_usage["total_tokens"]
if self.include_output_tokens
else token_usage["prompt_tokens"]
)
except KeyError:
raise ValueError(
"LLM response doesn't include"
" `token_usage: {total_tokens: int, prompt_tokens: int}`"
" field. To use UpstashRatelimitHandler with token_ratelimit,"
" either use a model which returns token_usage (like "
" OpenAI models) or rate limit only with request_ratelimit."
)
# call limit to add the completion tokens to rate limit
# but don't raise exception since we already generated
# the tokens and would rather continue execution.
self.token_ratelimit.limit(self.identifier, rate=token_count)
def reset(self, identifier: Optional[str] = None) -> "UpstashRatelimitHandler":
"""
Creates a new UpstashRatelimitHandler object with the same
ratelimit configurations but with a new identifier if it's
provided.
Also resets the state of the handler.
"""
return UpstashRatelimitHandler(
identifier=identifier or self.identifier,
token_ratelimit=self.token_ratelimit,
request_ratelimit=self.request_ratelimit,
include_output_tokens=self.include_output_tokens,
)

View File

@@ -0,0 +1,384 @@
"""
UpTrain Callback Handler
UpTrain is an open-source platform to evaluate and improve LLM applications. It provides
grades for 20+ preconfigured checks (covering language, code, embedding use cases),
performs root cause analyses on instances of failure cases and provides guidance for
resolving them.
This module contains a callback handler for integrating UpTrain seamlessly into your
pipeline and facilitating diverse evaluations. The callback handler automates various
evaluations to assess the performance and effectiveness of the components within the
pipeline.
The evaluations conducted include:
1. RAG:
- Context Relevance: Determines the relevance of the context extracted from the query
to the response.
- Factual Accuracy: Assesses if the Language Model (LLM) is providing accurate
information or hallucinating.
- Response Completeness: Checks if the response contains all the information
requested by the query.
2. Multi Query Generation:
MultiQueryRetriever generates multiple variants of a question with similar meanings
to the original question. This evaluation includes previous assessments and adds:
- Multi Query Accuracy: Ensures that the multi-queries generated convey the same
meaning as the original query.
3. Context Compression and Reranking:
Re-ranking involves reordering nodes based on relevance to the query and selecting
top n nodes.
Due to the potential reduction in the number of nodes after re-ranking, the following
evaluations
are performed in addition to the RAG evaluations:
- Context Reranking: Determines if the order of re-ranked nodes is more relevant to
the query than the original order.
- Context Conciseness: Examines whether the reduced number of nodes still provides
all the required information.
These evaluations collectively ensure the robustness and effectiveness of the RAG query
engine, MultiQueryRetriever, and the re-ranking process within the pipeline.
Useful links:
Github: https://github.com/uptrain-ai/uptrain
Website: https://uptrain.ai/
Docs: https://docs.uptrain.ai/getting-started/introduction
"""
import logging
import sys
from collections import defaultdict
from typing import (
Any,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Set,
)
from uuid import UUID
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def import_uptrain() -> Any:
"""Import the `uptrain` package."""
return guard_import("uptrain")
class UpTrainDataSchema:
"""The UpTrain data schema for tracking evaluation results.
Args:
project_name (str): The project name to be shown in UpTrain dashboard.
Attributes:
project_name (str): The project name to be shown in UpTrain dashboard.
uptrain_results (DefaultDict[str, Any]): Dictionary to store evaluation results.
eval_types (Set[str]): Set to store the types of evaluations.
query (str): Query for the RAG evaluation.
context (str): Context for the RAG evaluation.
response (str): Response for the RAG evaluation.
old_context (List[str]): Old context nodes for Context Conciseness evaluation.
new_context (List[str]): New context nodes for Context Conciseness evaluation.
context_conciseness_run_id (str): Run ID for Context Conciseness evaluation.
multi_queries (List[str]): List of multi queries for Multi Query evaluation.
multi_query_run_id (str): Run ID for Multi Query evaluation.
multi_query_daugher_run_id (str): Run ID for Multi Query daughter evaluation.
"""
def __init__(self, project_name: str) -> None:
"""Initialize the UpTrain data schema."""
# For tracking project name and results
self.project_name: str = project_name
self.uptrain_results: DefaultDict[str, Any] = defaultdict(list)
# For tracking event types
self.eval_types: Set[str] = set()
## RAG
self.query: str = ""
self.context: str = ""
self.response: str = ""
## CONTEXT CONCISENESS
self.old_context: List[str] = []
self.new_context: List[str] = []
self.context_conciseness_run_id: UUID = UUID(int=0)
# MULTI QUERY
self.multi_queries: List[str] = []
self.multi_query_run_id: UUID = UUID(int=0)
self.multi_query_daugher_run_id: UUID = UUID(int=0)
class UpTrainCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs evaluation results to uptrain and the console.
Args:
project_name (str): The project name to be shown in UpTrain dashboard.
key_type (str): Type of key to use. Must be 'uptrain' or 'openai'.
api_key (str): API key for the UpTrain or OpenAI API.
(This key is required to perform evaluations using GPT.)
Raises:
ValueError: If the key type is invalid.
ImportError: If the `uptrain` package is not installed.
"""
def __init__(
self,
*,
project_name: str = "langchain",
key_type: str = "openai",
api_key: str = "sk-****************", # The API key to use for evaluation
model: str = "gpt-3.5-turbo", # The model to use for evaluation
log_results: bool = True,
) -> None:
"""Initializes the `UpTrainCallbackHandler`."""
super().__init__()
uptrain = import_uptrain()
self.log_results = log_results
# Set uptrain variables
self.schema = UpTrainDataSchema(project_name=project_name)
self.first_score_printed_flag = False
if key_type == "uptrain":
settings = uptrain.Settings(uptrain_access_token=api_key, model=model)
self.uptrain_client = uptrain.APIClient(settings=settings)
elif key_type == "openai":
settings = uptrain.Settings(
openai_api_key=api_key, evaluate_locally=True, model=model
)
self.uptrain_client = uptrain.EvalLLM(settings=settings)
else:
raise ValueError("Invalid key type: Must be 'uptrain' or 'openai'")
def uptrain_evaluate(
self,
evaluation_name: str,
data: List[Dict[str, Any]],
checks: List[str],
) -> None:
"""Run an evaluation on the UpTrain server using UpTrain client."""
if self.uptrain_client.__class__.__name__ == "APIClient":
uptrain_result = self.uptrain_client.log_and_evaluate(
project_name=self.schema.project_name,
evaluation_name=evaluation_name,
data=data,
checks=checks,
)
else:
uptrain_result = self.uptrain_client.evaluate(
project_name=self.schema.project_name,
evaluation_name=evaluation_name,
data=data,
checks=checks,
)
self.schema.uptrain_results[self.schema.project_name].append(uptrain_result)
score_name_map = {
"score_context_relevance": "Context Relevance Score",
"score_factual_accuracy": "Factual Accuracy Score",
"score_response_completeness": "Response Completeness Score",
"score_sub_query_completeness": "Sub Query Completeness Score",
"score_context_reranking": "Context Reranking Score",
"score_context_conciseness": "Context Conciseness Score",
"score_multi_query_accuracy": "Multi Query Accuracy Score",
}
if self.log_results:
# Set logger level to INFO to print the evaluation results
logger.setLevel(logging.INFO)
for row in uptrain_result:
columns = list(row.keys())
for column in columns:
if column == "question":
logger.info(f"\nQuestion: {row[column]}")
self.first_score_printed_flag = False
elif column == "response":
logger.info(f"Response: {row[column]}")
self.first_score_printed_flag = False
elif column == "variants":
logger.info("Multi Queries:")
for variant in row[column]:
logger.info(f" - {variant}")
self.first_score_printed_flag = False
elif column.startswith("score"):
if not self.first_score_printed_flag:
logger.info("")
self.first_score_printed_flag = True
if column in score_name_map:
logger.info(f"{score_name_map[column]}: {row[column]}")
else:
logger.info(f"{column}: {row[column]}")
if self.log_results:
# Set logger level back to WARNING
# (We are doing this to avoid printing the logs from HTTP requests)
logger.setLevel(logging.WARNING)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Log records to uptrain when an LLM ends."""
uptrain = import_uptrain()
self.schema.response = response.generations[0][0].text
if (
"qa_rag" in self.schema.eval_types
and parent_run_id != self.schema.multi_query_daugher_run_id
):
data = [
{
"question": self.schema.query,
"context": self.schema.context,
"response": self.schema.response,
}
]
self.uptrain_evaluate(
evaluation_name="rag",
data=data,
checks=[
uptrain.Evals.CONTEXT_RELEVANCE,
uptrain.Evals.FACTUAL_ACCURACY,
uptrain.Evals.RESPONSE_COMPLETENESS,
],
)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when chain starts"""
if parent_run_id == self.schema.multi_query_run_id:
self.schema.multi_query_daugher_run_id = run_id
if isinstance(inputs, dict) and set(inputs.keys()) == {"context", "question"}:
self.schema.eval_types.add("qa_rag")
context = ""
if isinstance(inputs["context"], Document):
context = inputs["context"].page_content
elif isinstance(inputs["context"], list):
for doc in inputs["context"]:
context += doc.page_content + "\n"
elif isinstance(inputs["context"], str):
context = inputs["context"]
self.schema.context = context
self.schema.query = inputs["question"]
pass
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
if "contextual_compression" in serialized["id"]:
self.schema.eval_types.add("contextual_compression")
self.schema.query = query
self.schema.context_conciseness_run_id = run_id
if "multi_query" in serialized["id"]:
self.schema.eval_types.add("multi_query")
self.schema.multi_query_run_id = run_id
self.schema.query = query
elif "multi_query" in self.schema.eval_types:
self.schema.multi_queries.append(query)
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
uptrain = import_uptrain()
if run_id == self.schema.multi_query_run_id:
data = [
{
"question": self.schema.query,
"variants": self.schema.multi_queries,
}
]
self.uptrain_evaluate(
evaluation_name="multi_query",
data=data,
checks=[uptrain.Evals.MULTI_QUERY_ACCURACY],
)
if "contextual_compression" in self.schema.eval_types:
if parent_run_id == self.schema.context_conciseness_run_id:
for doc in documents:
self.schema.old_context.append(doc.page_content)
elif run_id == self.schema.context_conciseness_run_id:
for doc in documents:
self.schema.new_context.append(doc.page_content)
context = "\n".join(
[
f"{index}. {string}"
for index, string in enumerate(self.schema.old_context, start=1)
]
)
reranked_context = "\n".join(
[
f"{index}. {string}"
for index, string in enumerate(self.schema.new_context, start=1)
]
)
data = [
{
"question": self.schema.query,
"context": context,
"concise_context": reranked_context,
"reranked_context": reranked_context,
}
]
self.uptrain_evaluate(
evaluation_name="context_reranking",
data=data,
checks=[
uptrain.Evals.CONTEXT_CONCISENESS,
uptrain.Evals.CONTEXT_RERANKING,
],
)

View File

@@ -0,0 +1,239 @@
import hashlib
from pathlib import Path
from typing import Any, Dict, Iterable, Tuple, Union
from langchain_core.utils import guard_import
def import_spacy() -> Any:
"""Import the spacy python package and raise an error if it is not installed."""
return guard_import("spacy")
def import_pandas() -> Any:
"""Import the pandas python package and raise an error if it is not installed."""
return guard_import("pandas")
def import_textstat() -> Any:
"""Import the textstat python package and raise an error if it is not installed."""
return guard_import("textstat")
def _flatten_dict(
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Iterable[Tuple[str, Any]]:
"""
Generator that yields flattened items from a nested dictionary for a flat dict.
Parameters:
nested_dict (dict): The nested dictionary to flatten.
parent_key (str): The prefix to prepend to the keys of the flattened dict.
sep (str): The separator to use between the parent key and the key of the
flattened dictionary.
Yields:
(str, any): A key-value pair from the flattened dictionary.
"""
for key, value in nested_dict.items():
new_key = parent_key + sep + key if parent_key else key
if isinstance(value, dict):
yield from _flatten_dict(value, new_key, sep)
else:
yield new_key, value
def flatten_dict(
nested_dict: Dict[str, Any], parent_key: str = "", sep: str = "_"
) -> Dict[str, Any]:
"""Flatten a nested dictionary into a flat dictionary.
Parameters:
nested_dict (dict): The nested dictionary to flatten.
parent_key (str): The prefix to prepend to the keys of the flattened dict.
sep (str): The separator to use between the parent key and the key of the
flattened dictionary.
Returns:
(dict): A flat dictionary.
"""
flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
return flat_dict
def hash_string(s: str) -> str:
"""Hash a string using sha1.
Parameters:
s (str): The string to hash.
Returns:
(str): The hashed string.
"""
return hashlib.sha1(s.encode("utf-8")).hexdigest()
def load_json(json_path: Union[str, Path]) -> str:
"""Load json file to a string.
Parameters:
json_path (str): The path to the json file.
Returns:
(str): The string representation of the json file.
"""
with open(json_path, "r") as f:
data = f.read()
return data
class BaseMetadataCallbackHandler:
"""Handle the metadata and associated function states for callbacks.
Attributes:
step (int): The current step.
starts (int): The number of times the start method has been called.
ends (int): The number of times the end method has been called.
errors (int): The number of times the error method has been called.
text_ctr (int): The number of times the text method has been called.
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
llm_starts (int): The number of times the llm start method has been called.
llm_ends (int): The number of times the llm end method has been called.
llm_streams (int): The number of times the text method has been called.
tool_starts (int): The number of times the tool start method has been called.
tool_ends (int): The number of times the tool end method has been called.
agent_ends (int): The number of times the agent end method has been called.
on_llm_start_records (list): A list of records of the on_llm_start method.
on_llm_token_records (list): A list of records of the on_llm_token method.
on_llm_end_records (list): A list of records of the on_llm_end method.
on_chain_start_records (list): A list of records of the on_chain_start method.
on_chain_end_records (list): A list of records of the on_chain_end method.
on_tool_start_records (list): A list of records of the on_tool_start method.
on_tool_end_records (list): A list of records of the on_tool_end method.
on_agent_finish_records (list): A list of records of the on_agent_end method.
"""
def __init__(self) -> None:
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
self.on_llm_start_records: list = []
self.on_llm_token_records: list = []
self.on_llm_end_records: list = []
self.on_chain_start_records: list = []
self.on_chain_end_records: list = []
self.on_tool_start_records: list = []
self.on_tool_end_records: list = []
self.on_text_records: list = []
self.on_agent_finish_records: list = []
self.on_agent_action_records: list = []
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,
"starts": self.starts,
"ends": self.ends,
"errors": self.errors,
"text_ctr": self.text_ctr,
"chain_starts": self.chain_starts,
"chain_ends": self.chain_ends,
"llm_starts": self.llm_starts,
"llm_ends": self.llm_ends,
"llm_streams": self.llm_streams,
"tool_starts": self.tool_starts,
"tool_ends": self.tool_ends,
"agent_ends": self.agent_ends,
}
def reset_callback_meta(self) -> None:
"""Reset the callback metadata."""
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
self.on_llm_start_records = []
self.on_llm_token_records = []
self.on_llm_end_records = []
self.on_chain_start_records = []
self.on_chain_end_records = []
self.on_tool_start_records = []
self.on_tool_end_records = []
self.on_text_records = []
self.on_agent_finish_records = []
self.on_agent_action_records = []
return None

View File

@@ -0,0 +1,597 @@
import json
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core._api import warn_deprecated
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
)
def import_wandb() -> Any:
"""Import the wandb python package and raise an error if it is not installed."""
return guard_import("wandb")
def load_json_to_dict(json_path: Union[str, Path]) -> dict:
"""Load json file to a dictionary.
Parameters:
json_path (str): The path to the json file.
Returns:
(dict): The dictionary representation of the json file.
"""
with open(json_path, "r") as f:
data = json.load(f)
return data
def analyze_text(
text: str,
complexity_metrics: bool = True,
visualize: bool = True,
nlp: Any = None,
output_dir: Optional[Union[str, Path]] = None,
) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
complexity_metrics (bool): Whether to compute complexity metrics.
visualize (bool): Whether to visualize the text.
nlp (spacy.lang): The spacy language model to use for visualization.
output_dir (str): The directory to save the visualization files to.
Returns:
`dict` containing the complexity metrics and visualization
files serialized in a wandb.Html element.
"""
resp = {}
textstat = import_textstat()
wandb = import_wandb()
spacy = import_spacy()
if complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if visualize and nlp and output_dir is not None:
doc = nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
ent_output_path.open("w", encoding="utf-8").write(ent_out)
text_visualizations = {
"dependency_tree": wandb.Html(str(dep_output_path)),
"entities": wandb.Html(str(ent_output_path)),
}
resp.update(text_visualizations)
return resp
def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any:
"""Construct an html element from a prompt and a generation.
Parameters:
prompt (str): The prompt.
generation (str): The generation.
Returns:
(wandb.Html): The html element."""
wandb = import_wandb()
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
return wandb.Html(
f"""
<p style="color:black;">{formatted_prompt}:</p>
<blockquote>
<p style="color:green;">
{formatted_generation}
</p>
</blockquote>
""",
inject=False,
)
class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Weights and Biases.
Parameters:
job_type (str): The type of job.
project (str): The project to log to.
entity (str): The entity to log to.
tags (list): The tags to log.
group (str): The group to log to.
name (str): The name of the run.
notes (str): The notes to log.
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics.
stream_logs (bool): Whether to stream callback actions to W&B
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response using the run.log() method to Weights and Biases.
"""
def __init__(
self,
job_type: Optional[str] = None,
project: Optional[str] = "langchain_callback_demo",
entity: Optional[str] = None,
tags: Optional[Sequence] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""Initialize callback handler."""
wandb = import_wandb()
import_pandas()
import_textstat()
spacy = import_spacy()
super().__init__()
self.job_type = job_type
self.project = project
self.entity = entity
self.tags = tags
self.group = group
self.name = name
self.notes = notes
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
self.run = wandb.init(
job_type=self.job_type,
project=self.project,
entity=self.entity,
tags=self.tags,
group=self.group,
name=self.name,
notes=self.notes,
)
warning = (
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
"instead."
)
wandb.termwarn(
warning,
repeat=False,
)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
warn_deprecated(
"0.3.8",
pending=False,
message=(
"Please use the WeaveTracer instead of the WandbCallbackHandler. "
"The WeaveTracer is a more flexible and powerful tool for logging "
"and tracing your LangChain callables."
"Find more information at https://weave-docs.wandb.ai/guides/integrations/langchain"
),
alternative=(
"Please instantiate the WeaveTracer from "
"weave.integrations.langchain import WeaveTracer ."
"For autologging simply use weave.init() and log all traces "
"from your LangChain callables."
),
)
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.run.log(prompt_resp)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(
analyze_text(
generation.text,
complexity_metrics=self.complexity_metrics,
visualize=self.visualize,
nlp=self.nlp,
output_dir=self.temp_dir.name,
)
)
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.run.log(generation_resp)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs["input"]
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
else:
raise ValueError("Unexpected data format provided!")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.on_tool_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.on_tool_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = (
on_llm_start_records_df[["step", "prompts", "name"]]
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
)
complexity_metrics_columns = []
visualizations_columns = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
if self.visualize:
visualizations_columns = ["dependency_tree", "entities"]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
]
.dropna(axis=1)
.rename({"step": "output_step", "text": "output"}, axis=1)
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
session_analysis_df["chat_html"] = session_analysis_df[
["prompts", "output"]
].apply(
lambda row: construct_html_from_prompt_and_generation(
row["prompts"], row["output"]
),
axis=1,
)
return session_analysis_df
def flush_tracker(
self,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
job_type: Optional[str] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
tags: Optional[Sequence] = None,
group: Optional[str] = None,
name: Optional[str] = None,
notes: Optional[str] = None,
visualize: Optional[bool] = None,
complexity_metrics: Optional[bool] = None,
) -> None:
"""Flush the tracker and reset the session.
Args:
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
job_type: The job type.
project: The project.
entity: The entity.
tags: The tags.
group: The group.
name: The name.
notes: The notes.
visualize: Whether to visualize.
complexity_metrics: Whether to compute complexity metrics.
Returns:
None
"""
pd = import_pandas()
wandb = import_wandb()
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
)
self.run.log(
{
"action_records": action_records_table,
"session_analysis": session_analysis_table,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_artifact = wandb.Artifact(name="model", type="model")
model_artifact.add(action_records_table, name="action_records")
model_artifact.add(session_analysis_table, name="session_analysis")
try:
langchain_asset.save(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
except NotImplementedError as e:
print("Could not save model.") # noqa: T201
print(repr(e)) # noqa: T201
pass
self.run.log_artifact(model_artifact)
if finish or reset:
self.run.finish()
self.temp_dir.cleanup()
self.reset_callback_meta()
if reset:
self.__init__( # type: ignore[misc]
job_type=job_type if job_type else self.job_type,
project=project if project else self.project,
entity=entity if entity else self.entity,
tags=tags if tags else self.tags,
group=group if group else self.group,
name=name if name else self.name,
notes=notes if notes else self.notes,
visualize=visualize if visualize else self.visualize,
complexity_metrics=(
complexity_metrics
if complexity_metrics
else self.complexity_metrics
),
)

View File

@@ -0,0 +1,187 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.utils import get_from_env, guard_import
if TYPE_CHECKING:
from whylogs.api.logger.logger import Logger
diagnostic_logger = logging.getLogger(__name__)
def import_langkit(
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
) -> Any:
"""Import the langkit python package and raise an error if it is not installed.
Args:
sentiment: Whether to import the langkit.sentiment module. Defaults to False.
toxicity: Whether to import the langkit.toxicity module. Defaults to False.
themes: Whether to import the langkit.themes module. Defaults to False.
Returns:
The imported langkit module.
"""
langkit = guard_import("langkit")
guard_import("langkit.regexes")
guard_import("langkit.textstat")
if sentiment:
guard_import("langkit.sentiment")
if toxicity:
guard_import("langkit.toxicity")
if themes:
guard_import("langkit.themes")
return langkit
class WhyLabsCallbackHandler(BaseCallbackHandler):
"""
Callback Handler for logging to WhyLabs. This callback handler utilizes
`langkit` to extract features from the prompts & responses when interacting with
an LLM. These features can be used to guardrail, evaluate, and observe interactions
over time to detect issues relating to hallucinations, prompt engineering,
or output validation. LangKit is an LLM monitoring toolkit developed by WhyLabs.
Here are some examples of what can be monitored with LangKit:
* Text Quality
- readability score
- complexity and grade scores
* Text Relevance
- Similarity scores between prompt/responses
- Similarity scores against user-defined themes
- Topic classification
* Security and Privacy
- patterns - count of strings matching a user-defined regex pattern group
- jailbreaks - similarity scores with respect to known jailbreak attempts
- prompt injection - similarity scores with respect to known prompt attacks
- refusals - similarity scores with respect to known LLM refusal responses
* Sentiment and Toxicity
- sentiment analysis
- toxicity analysis
For more information, see https://docs.whylabs.ai/docs/language-model-monitoring
or check out the LangKit repo here: https://github.com/whylabs/langkit
---
Args:
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
way to specify the API key is with environment variable
WHYLABS_API_KEY.
org_id (Optional[str]): WhyLabs organization id to write profiles to.
Optional because the preferred way to specify the organization id is
with environment variable WHYLABS_DEFAULT_ORG_ID.
dataset_id (Optional[str]): WhyLabs dataset id to write profiles to.
Optional because the preferred way to specify the dataset id is
with environment variable WHYLABS_DEFAULT_DATASET_ID.
sentiment (bool): Whether to enable sentiment analysis. Defaults to False.
toxicity (bool): Whether to enable toxicity analysis. Defaults to False.
themes (bool): Whether to enable theme analysis. Defaults to False.
"""
def __init__(self, logger: Logger, handler: Any):
"""Initiate the rolling logger."""
super().__init__()
if hasattr(handler, "init"):
handler.init(self)
if hasattr(handler, "_get_callbacks"):
self._callbacks = handler._get_callbacks()
else:
self._callbacks = dict()
diagnostic_logger.warning("initialized handler without callbacks.")
self._logger = logger
def flush(self) -> None:
"""Explicitly write current profile if using a rolling logger."""
if self._logger and hasattr(self._logger, "_do_rollover"):
self._logger._do_rollover()
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
def close(self) -> None:
"""Close any loggers to allow writing out of any profiles before exiting."""
if self._logger and hasattr(self._logger, "close"):
self._logger.close()
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
def __enter__(self) -> WhyLabsCallbackHandler:
return self
def __exit__(
self, exception_type: Any, exception_value: Any, traceback: Any
) -> None:
self.close()
@classmethod
def from_params(
cls,
*,
api_key: Optional[str] = None,
org_id: Optional[str] = None,
dataset_id: Optional[str] = None,
sentiment: bool = False,
toxicity: bool = False,
themes: bool = False,
logger: Optional[Logger] = None,
) -> WhyLabsCallbackHandler:
"""Instantiate whylogs Logger from params.
Args:
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
way to specify the API key is with environment variable
WHYLABS_API_KEY.
org_id (Optional[str]): WhyLabs organization id to write profiles to.
If not set must be specified in environment variable
WHYLABS_DEFAULT_ORG_ID.
dataset_id (Optional[str]): The model or dataset this callback is gathering
telemetry for. If not set must be specified in environment variable
WHYLABS_DEFAULT_DATASET_ID.
sentiment (bool): If True will initialize a model to perform
sentiment analysis compound score. Defaults to False and will not gather
this metric.
toxicity (bool): If True will initialize a model to score
toxicity. Defaults to False and will not gather this metric.
themes (bool): If True will initialize a model to calculate
distance to configured themes. Defaults to None and will not gather this
metric.
logger (Optional[Logger]): If specified will bind the configured logger as
the telemetry gathering agent. Defaults to LangKit schema with periodic
WhyLabs writer.
"""
# langkit library will import necessary whylogs libraries
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
why = guard_import("whylogs")
get_callback_instance = guard_import(
"langkit.callback_handler"
).get_callback_instance
WhyLabsWriter = guard_import("whylogs.api.writer.whylabs").WhyLabsWriter
udf_schema = guard_import("whylogs.experimental.core.udf_schema").udf_schema
if logger is None:
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
dataset_id = dataset_id or get_from_env(
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
)
whylabs_writer = WhyLabsWriter(
api_key=api_key, org_id=org_id, dataset_id=dataset_id
)
whylabs_logger = why.logger(
mode="rolling", interval=5, when="M", schema=udf_schema()
)
whylabs_logger.append_writer(writer=whylabs_writer)
else:
diagnostic_logger.info("Using passed in whylogs logger {logger}")
whylabs_logger = logger
callback_handler_cls = get_callback_instance(logger=whylabs_logger, impl=cls)
diagnostic_logger.info(
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
)
return callback_handler_cls