initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,193 @@
"""LangSmith Client."""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langsmith._expect import expect
from langsmith.async_client import AsyncClient
from langsmith.client import Client
from langsmith.evaluation import (
aevaluate,
aevaluate_existing,
evaluate,
evaluate_existing,
)
from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
from langsmith.prompt_cache import AsyncPromptCache, PromptCache
from langsmith.run_helpers import (
get_current_run_tree,
get_tracing_context,
set_run_metadata,
trace,
traceable,
tracing_context,
)
from langsmith.run_trees import RunTree, configure
from langsmith.testing._internal import test, unit
from langsmith.utils import ContextThreadPoolExecutor
from langsmith.uuid import uuid7, uuid7_from_datetime
# Avoid calling into importlib on every call to __version__
__version__ = "0.7.9"
version = __version__ # for backwards compatibility
def __getattr__(name: str) -> Any:
if name == "__version__":
return version
elif name == "Client":
from langsmith.client import Client
return Client
elif name == "AsyncClient":
from langsmith.async_client import AsyncClient
return AsyncClient
elif name == "RunTree":
from langsmith.run_trees import RunTree
return RunTree
elif name == "EvaluationResult":
from langsmith.evaluation.evaluator import EvaluationResult
return EvaluationResult
elif name == "RunEvaluator":
from langsmith.evaluation.evaluator import RunEvaluator
return RunEvaluator
elif name == "trace":
from langsmith.run_helpers import trace
return trace
elif name == "traceable":
from langsmith.run_helpers import traceable
return traceable
elif name == "test":
from langsmith.testing._internal import test
return test
elif name == "expect":
from langsmith._expect import expect
return expect
elif name == "evaluate":
from langsmith.evaluation import evaluate
return evaluate
elif name == "evaluate_existing":
from langsmith.evaluation import evaluate_existing
return evaluate_existing
elif name == "aevaluate":
from langsmith.evaluation import aevaluate
return aevaluate
elif name == "aevaluate_existing":
from langsmith.evaluation import aevaluate_existing
return aevaluate_existing
elif name == "tracing_context":
from langsmith.run_helpers import tracing_context
return tracing_context
elif name == "get_tracing_context":
from langsmith.run_helpers import get_tracing_context
return get_tracing_context
elif name == "get_current_run_tree":
from langsmith.run_helpers import get_current_run_tree
return get_current_run_tree
elif name == "set_run_metadata":
from langsmith.run_helpers import set_run_metadata
return set_run_metadata
elif name == "unit":
from langsmith.testing._internal import unit
return unit
elif name == "ContextThreadPoolExecutor":
from langsmith.utils import ContextThreadPoolExecutor
return ContextThreadPoolExecutor
elif name == "configure":
from langsmith.run_trees import configure
return configure
elif name == "uuid7":
from langsmith.uuid import uuid7
return uuid7
elif name == "uuid7_from_datetime":
from langsmith.uuid import uuid7_from_datetime
return uuid7_from_datetime
elif name == "PromptCache":
from langsmith.prompt_cache import PromptCache
return PromptCache
elif name == "AsyncPromptCache":
from langsmith.prompt_cache import AsyncPromptCache
return AsyncPromptCache
elif name == "Cache":
from langsmith.prompt_cache import Cache
return Cache
elif name == "AsyncCache":
from langsmith.prompt_cache import AsyncCache
return AsyncCache
elif name == "configure_global_prompt_cache":
from langsmith.prompt_cache import configure_global_prompt_cache
return configure_global_prompt_cache
elif name == "configure_global_async_prompt_cache":
from langsmith.prompt_cache import configure_global_async_prompt_cache
return configure_global_async_prompt_cache
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
"Client",
"AsyncClient",
"PromptCache",
"AsyncPromptCache",
"Cache",
"AsyncCache",
"configure_global_prompt_cache",
"configure_global_async_prompt_cache",
"RunTree",
"configure",
"__version__",
"EvaluationResult",
"RunEvaluator",
"anonymizer",
"traceable",
"trace",
"unit",
"test",
"expect",
"evaluate",
"evaluate_existing",
"aevaluate_existing",
"aevaluate",
"tracing_context",
"get_tracing_context",
"get_current_run_tree",
"set_run_metadata",
"ContextThreadPoolExecutor",
"uuid7",
"uuid7_from_datetime",
]

View File

@@ -0,0 +1,465 @@
"""Make approximate assertions as "expectations" on test results.
This module is designed to be used within test cases decorated with the
`@pytest.mark.decorator` decorator
It allows you to log scores about a test case and optionally make assertions that log as
"expectation" feedback to LangSmith.
Example:
```python
import pytest
from langsmith import expect
@pytest.mark.langsmith
def test_output_semantically_close():
response = oai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello!"},
],
)
response_txt = response.choices[0].message.content
# Intended usage
expect.embedding_distance(
prediction=response_txt,
reference="Hello!",
).to_be_less_than(0.9)
# Score the test case
matcher = expect.edit_distance(
prediction=response_txt,
reference="Hello!",
)
# Apply an assertion and log 'expectation' feedback to LangSmith
matcher.to_be_less_than(1)
# You can also directly make assertions on values directly
expect.value(response_txt).to_contain("Hello!")
# Or using a custom check
expect.value(response_txt).against(lambda x: "Hello" in x)
# You can even use this for basic metric logging within tests
expect.score(0.8)
expect.score(0.7, key="similarity").to_be_greater_than(0.7)
```
""" # noqa: E501
from __future__ import annotations
import atexit
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
overload,
)
from langsmith import client as ls_client
from langsmith import run_helpers as rh
from langsmith import run_trees as rt
from langsmith import utils as ls_utils
if TYPE_CHECKING:
from langsmith._internal._edit_distance import EditDistanceConfig
from langsmith._internal._embedding_distance import EmbeddingConfig
# Sentinel class used until PEP 0661 is accepted
class _NULL_SENTRY:
"""A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
""" # noqa: D205
def __bool__(self) -> Literal[False]:
return False
def __repr__(self) -> str:
return "NOT_GIVEN"
NOT_GIVEN = _NULL_SENTRY()
class _Matcher:
"""A class for making assertions on expectation values."""
def __init__(
self,
client: Optional[ls_client.Client],
key: str,
value: Any,
_executor: Optional[ls_utils.ContextThreadPoolExecutor] = None,
run_id: Optional[str] = None,
):
self._client = client
self.key = key
self.value = value
self._executor = _executor or ls_utils.ContextThreadPoolExecutor(max_workers=3)
self._rt = rh.get_current_run_tree()
self._run_id = self._rt.trace_id if self._rt else run_id
def _submit_feedback(self, score: int, message: Optional[str] = None) -> None:
if not ls_utils.test_tracking_is_disabled():
if not self._client:
self._client = rt.get_cached_client()
self._executor.submit(
self._client.create_feedback,
run_id=self._run_id,
key="expectation",
score=score,
comment=message,
session_id=self._rt.session_id if self._rt else None,
start_time=self._rt.start_time if self._rt else None,
)
def _assert(self, condition: bool, message: str, method_name: str) -> None:
try:
assert condition, message
self._submit_feedback(1, message=f"Success: {self.key}.{method_name}")
except AssertionError as e:
self._submit_feedback(0, repr(e))
raise e from None
def to_be_less_than(self, value: float) -> None:
"""Assert that the expectation value is less than the given value.
Args:
value: The value to compare against.
Raises:
AssertionError: If the expectation value is not less than the given value.
"""
self._assert(
self.value < value,
f"Expected {self.key} to be less than {value}, but got {self.value}",
"to_be_less_than",
)
def to_be_greater_than(self, value: float) -> None:
"""Assert that the expectation value is greater than the given value.
Args:
value: The value to compare against.
Raises:
AssertionError: If the expectation value is not
greater than the given value.
"""
self._assert(
self.value > value,
f"Expected {self.key} to be greater than {value}, but got {self.value}",
"to_be_greater_than",
)
def to_be_between(self, min_value: float, max_value: float) -> None:
"""Assert that the expectation value is between the given min and max values.
Args:
min_value: The minimum value (exclusive).
max_value: The maximum value (exclusive).
Raises:
AssertionError: If the expectation value is not between the min and max.
"""
self._assert(
min_value < self.value < max_value,
f"Expected {self.key} to be between {min_value} and {max_value},"
f" but got {self.value}",
"to_be_between",
)
def to_be_approximately(self, value: float, precision: int = 2) -> None:
"""Assert that the expectation value is approximately equal to the given value.
Args:
value: The value to compare against.
precision: The number of decimal places to round to for comparison.
Raises:
AssertionError: If the rounded expectation value
does not equal the rounded given value.
"""
self._assert(
round(self.value, precision) == round(value, precision),
f"Expected {self.key} to be approximately {value}, but got {self.value}",
"to_be_approximately",
)
def to_equal(self, value: float) -> None:
"""Assert that the expectation value equals the given value.
Args:
value: The value to compare against.
Raises:
AssertionError: If the expectation value does
not exactly equal the given value.
"""
self._assert(
self.value == value,
f"Expected {self.key} to be equal to {value}, but got {self.value}",
"to_equal",
)
def to_be_none(self) -> None:
"""Assert that the expectation value is `None`.
Raises:
AssertionError: If the expectation value is not `None`.
"""
self._assert(
self.value is None,
f"Expected {self.key} to be None, but got {self.value}",
"to_be_none",
)
def to_contain(self, value: Any) -> None:
"""Assert that the expectation value contains the given value.
Args:
value: The value to check for containment.
Raises:
AssertionError: If the expectation value does not contain the given value.
"""
self._assert(
value in self.value,
f"Expected {self.key} to contain {value}, but it does not",
"to_contain",
)
# Custom assertions
def against(self, func: Callable, /) -> None:
"""Assert the expectation value against a custom function.
Args:
func: A custom function that takes the expectation value as input.
Raises:
AssertionError: If the custom function returns False.
"""
func_signature = inspect.signature(func)
self._assert(
func(self.value),
f"Assertion {func_signature} failed for {self.key}",
"against",
)
class _Expect:
"""A class for setting expectations on test results."""
def __init__(self, *, client: Optional[ls_client.Client] = None):
self._client = client
self.executor = ls_utils.ContextThreadPoolExecutor(max_workers=3)
atexit.register(self.executor.shutdown, wait=True)
def embedding_distance(
self,
prediction: str,
reference: str,
*,
config: Optional[EmbeddingConfig] = None,
) -> _Matcher:
"""Compute the embedding distance between the prediction and reference.
This logs the embedding distance to LangSmith and returns a `_Matcher` instance
for making assertions on the distance value.
By default, this uses the OpenAI API for computing embeddings.
Args:
prediction: The predicted string to compare.
reference: The reference string to compare against.
config: Optional configuration for the embedding distance evaluator.
Supported options:
- `encoder`: A custom encoder function to encode the list of input
strings to embeddings.
Defaults to the OpenAI API.
- `metric`: The distance metric to use for comparison.
Supported values: `'cosine'`, `'euclidean'`, `'manhattan'`,
`'chebyshev'`, `'hamming'`.
Returns:
A `_Matcher` instance for the embedding distance value.
Example:
```python
expect.embedding_distance(
prediction="hello",
reference="hi",
).to_be_less_than(1.0)
```
""" # noqa: E501
from langsmith._internal._embedding_distance import EmbeddingDistance
config = config or {}
encoder_func = "custom" if config.get("encoder") else "openai"
evaluator = EmbeddingDistance(config=config)
score = evaluator.evaluate(prediction=prediction, reference=reference)
src_info = {"encoder": encoder_func, "metric": evaluator.distance}
self._submit_feedback(
"embedding_distance",
{
"score": score,
"source_info": src_info,
"comment": f"Using {encoder_func}, Metric: {evaluator.distance}",
},
)
return _Matcher(
self._client, "embedding_distance", score, _executor=self.executor
)
def edit_distance(
self,
prediction: str,
reference: str,
*,
config: Optional[EditDistanceConfig] = None,
) -> _Matcher:
"""Compute the string distance between the prediction and reference.
This logs the string distance (Damerau-Levenshtein) to LangSmith and returns
a `_Matcher` instance for making assertions on the distance value.
This depends on the `rapidfuzz` package for string distance computation.
Args:
prediction: The predicted string to compare.
reference: The reference string to compare against.
config: Optional configuration for the string distance evaluator.
Supported options:
- `metric`: The distance metric to use for comparison.
Supported values: `'damerau_levenshtein'`, `'levenshtein'`,
`'jaro'`, `'jaro_winkler'`, `'hamming'`, `'indel'`.
- `normalize_score`: Whether to normalize the score between `0` and `1`.
Returns:
A `_Matcher` instance for the string distance value.
Examples:
```python
expect.edit_distance("hello", "helo").to_be_less_than(1)
```
"""
from langsmith._internal._edit_distance import EditDistance
config = config or {}
metric = config.get("metric") or "damerau_levenshtein"
normalize = config.get("normalize_score", True)
evaluator = EditDistance(config=config)
score = evaluator.evaluate(prediction=prediction, reference=reference)
src_info = {"metric": metric, "normalize": normalize}
self._submit_feedback(
"edit_distance",
{
"score": score,
"source_info": src_info,
"comment": f"Using {metric}, Normalize: {normalize}",
},
)
return _Matcher(
self._client,
"edit_distance",
score,
_executor=self.executor,
)
def value(self, value: Any) -> _Matcher:
"""Create a `_Matcher` instance for making assertions on the given value.
Args:
value: The value to make assertions on.
Returns:
A `_Matcher` instance for the given value.
Example:
```python
expect.value(10).to_be_less_than(20)
```
"""
return _Matcher(self._client, "value", value, _executor=self.executor)
def score(
self,
score: Union[float, int, bool],
*,
key: str = "score",
source_run_id: Optional[ls_client.ID_TYPE] = None,
comment: Optional[str] = None,
) -> _Matcher:
"""Log a numeric score to LangSmith.
Args:
score: The score value to log.
key: The key to use for logging the score. Defaults to `'score'`.
Example:
```python
expect.score(0.8) # doctest: +ELLIPSIS
<langsmith._expect._Matcher object at ...>
expect.score(0.8, key="similarity").to_be_greater_than(0.7)
```
"""
self._submit_feedback(
key,
{
"score": score,
"source_info": {"method": "expect.score"},
"source_run_id": source_run_id,
"comment": comment,
},
)
return _Matcher(self._client, key, score, _executor=self.executor)
## Private Methods
@overload
def __call__(self, value: Any, /) -> _Matcher: ...
@overload
def __call__(self, /, *, client: ls_client.Client) -> _Expect: ...
def __call__(
self,
value: Optional[Any] = NOT_GIVEN,
/,
client: Optional[ls_client.Client] = None,
) -> Union[_Expect, _Matcher]:
expected = _Expect(client=client)
if value is not NOT_GIVEN:
return expected.value(value)
return expected
def _submit_feedback(self, key: str, results: dict):
current_run = rh.get_current_run_tree()
run_id = current_run.trace_id if current_run else None
if not ls_utils.test_tracking_is_disabled():
if not self._client:
self._client = rt.get_cached_client()
self.executor.submit(
self._client.create_feedback, run_id=run_id, key=key, **results
)
expect = _Expect()
__all__ = ["expect"]

View File

@@ -0,0 +1,358 @@
"""Adapted.
Original source:
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
"""
import asyncio
import contextvars
import functools
import inspect
from collections import deque
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Coroutine,
Iterable,
Iterator,
)
from contextlib import AbstractAsyncContextManager
from typing import (
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
cast,
overload,
)
T = TypeVar("T")
_no_default = object()
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
# before 3.10, the builtin anext() was not available
def py_anext(
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
) -> Awaitable[Union[T, None, Any]]:
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
if default is _no_default:
return __anext__(iterator)
async def anext_impl() -> Union[T, Any]:
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
class NoLock:
"""Dummy lock that provides the proper interface but no protection."""
async def __aenter__(self) -> None:
pass
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
return False
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: deque[T],
# the buffers of all peers, including our own
peers: list[deque[T]],
lock: AbstractAsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""Iterate over :py:func:`~.tee`."""
try:
while True:
if not buffer:
async with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
async with lock:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "aclose"):
await iterator.aclose()
class Tee(Generic[T]):
"""Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but pare the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
```python
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
```
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: AsyncIterator[T],
n: int = 2,
*,
lock: Optional[AbstractAsyncContextManager[Any]] = None,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
@overload
def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
await self.aclose()
return False
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
atee = Tee
async def async_zip(*async_iterables):
"""Async version of zip."""
# Before Python 3.10, aiter() was not available
iterators = [iterable.__aiter__() for iterable in async_iterables]
while True:
try:
items = await asyncio.gather(
*(py_anext(iterator) for iterator in iterators)
)
yield tuple(items)
except StopAsyncIteration:
break
def ensure_async_iterator(
iterable: Union[Iterable, AsyncIterable],
) -> AsyncIterator:
if hasattr(iterable, "__anext__"):
return cast(AsyncIterator, iterable)
elif hasattr(iterable, "__aiter__"):
return cast(AsyncIterator, iterable.__aiter__())
else:
class AsyncIteratorWrapper:
def __init__(self, iterable: Iterable):
self._iterator = iter(iterable)
async def __anext__(self):
try:
return next(self._iterator)
except StopIteration:
raise StopAsyncIteration
def __aiter__(self):
return self
return AsyncIteratorWrapper(iterable)
def aiter_with_concurrency(
n: Optional[int],
generator: AsyncIterator[Coroutine[None, None, T]],
*,
_eager_consumption_timeout: float = 0,
) -> AsyncGenerator[T, None]:
"""Process async generator with max parallelism.
Args:
n: The number of tasks to run concurrently.
generator: The async generator to process.
_eager_consumption_timeout: If set, check for completed tasks after
each iteration and yield their results. This can be used to
consume the generator eagerly while still respecting the concurrency
limit.
Yields:
The processed items yielded by the async generator.
"""
if n == 0:
async def consume():
async for item in generator:
yield await item
return consume()
semaphore = cast(
asyncio.Semaphore, asyncio.Semaphore(n) if n is not None else NoLock()
)
async def process_item(ix: int, item):
async with semaphore:
res = await item
return (ix, res)
async def process_generator():
tasks = {}
accepts_context = asyncio_accepts_context()
ix = 0
async for item in generator:
if accepts_context:
context = contextvars.copy_context()
task = asyncio.create_task(process_item(ix, item), context=context)
else:
task = asyncio.create_task(process_item(ix, item))
tasks[ix] = task
ix += 1
if _eager_consumption_timeout > 0:
try:
for _fut in asyncio.as_completed(
tasks.values(),
timeout=_eager_consumption_timeout,
):
task_idx, res = await _fut
yield res
del tasks[task_idx]
except asyncio.TimeoutError:
pass
if n is not None and len(tasks) >= n:
done, _ = await asyncio.wait(
tasks.values(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
task_idx, res = task.result()
yield res
del tasks[task_idx]
for task in asyncio.as_completed(tasks.values()):
_, res = await task
yield res
return process_generator()
def accepts_context(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a context argument."""
try:
return inspect.signature(callable).parameters.get("context") is not None
except ValueError:
return False
# Ported from Python 3.9+ to support Python 3.8
async def aio_to_thread(
func, /, *args, __ctx: Optional[contextvars.Context] = None, **kwargs
):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = asyncio.get_running_loop()
ctx = __ctx or contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
@functools.lru_cache(maxsize=1)
def asyncio_accepts_context():
"""Check if the current asyncio event loop accepts a context argument."""
return accepts_context(asyncio.create_task)

View File

@@ -0,0 +1,964 @@
from __future__ import annotations
import concurrent.futures as cf
import copy
import functools
import io
import logging
import sys
import threading
import time
import weakref
from multiprocessing import cpu_count
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal._compressed_traces import ZSTD_AVAILABLE, CompressedTraces
from langsmith._internal._constants import (
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
_AUTO_SCALE_UP_NTHREADS_LIMIT,
_AUTO_SCALE_UP_QSIZE_TRIGGER,
_BOUNDARY,
)
from langsmith._internal._operations import (
SerializedFeedbackOperation,
SerializedRunOperation,
combine_serialized_queue_operations,
)
if TYPE_CHECKING:
from opentelemetry.context.context import Context # type: ignore[import]
from langsmith.client import Client
logger = logging.getLogger("langsmith.client")
LANGSMITH_CLIENT_THREAD_POOL = cf.ThreadPoolExecutor(max_workers=cpu_count())
def _group_batch_by_api_endpoint(
batch: list[TracingQueueItem],
) -> dict[
tuple[
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
],
list[TracingQueueItem],
]:
"""Group batch items by endpoint and auth combination."""
from collections import defaultdict
grouped = defaultdict(list)
for item in batch:
key = (
item.api_url,
item.api_key,
item.service_key,
item.tenant_id,
item.authorization,
item.cookie,
)
grouped[key].append(item)
return grouped
@functools.total_ordering
class TracingQueueItem:
"""An item in the tracing queue.
Attributes:
priority (str): The priority of the item.
item (Any): The item itself.
otel_context (Optional[Context]): The OTEL context of the item.
"""
priority: str
item: Union[SerializedRunOperation, SerializedFeedbackOperation]
api_url: Optional[str]
api_key: Optional[str]
service_key: Optional[str]
tenant_id: Optional[str]
authorization: Optional[str]
cookie: Optional[str]
otel_context: Optional[Context]
__slots__ = (
"priority",
"item",
"api_key",
"api_url",
"service_key",
"tenant_id",
"authorization",
"cookie",
"otel_context",
)
def __init__(
self,
priority: str,
item: Union[SerializedRunOperation, SerializedFeedbackOperation],
api_key: Optional[str] = None,
api_url: Optional[str] = None,
service_key: Optional[str] = None,
tenant_id: Optional[str] = None,
authorization: Optional[str] = None,
cookie: Optional[str] = None,
otel_context: Optional[Context] = None,
) -> None:
self.priority = priority
self.item = item
self.api_key = api_key
self.api_url = api_url
self.service_key = service_key
self.tenant_id = tenant_id
self.authorization = authorization
self.cookie = cookie
self.otel_context = otel_context
def __lt__(self, other: TracingQueueItem) -> bool:
return (self.priority, self.item.__class__) < (
other.priority,
other.item.__class__,
)
def __eq__(self, other: object) -> bool:
return isinstance(other, TracingQueueItem) and (
self.priority,
self.item.__class__,
) == (other.priority, other.item.__class__)
def _tracing_thread_drain_queue(
tracing_queue: Queue, limit: int = 100, block: bool = True, max_size_bytes: int = 0
) -> list[TracingQueueItem]:
next_batch: list[TracingQueueItem] = []
current_size = 0
try:
# wait 250ms for the first item, then
# - drain the queue with a 50ms block timeout
# - stop draining if we hit either count or size limit
# shorter drain timeout is used instead of non-blocking calls to
# avoid creating too many small batches
if item := tracing_queue.get(block=block, timeout=0.25):
next_batch.append(item)
if max_size_bytes > 0:
current_size += item.item.calculate_serialized_size()
# If first item already exceeds limit, return just this item
if current_size > max_size_bytes:
return next_batch
# Continue draining until we hit count limit OR size limit
while True:
try:
item = tracing_queue.get(block=block, timeout=0.05)
except Empty:
break
# Add the item first
next_batch.append(item)
# Then check size limit AFTER adding the item
if max_size_bytes > 0:
current_size += item.item.calculate_serialized_size()
# If we've exceeded size limit, stop here
# (item is included in this batch)
if current_size > max_size_bytes:
break
# Check count limit AFTER adding the item
if limit and len(next_batch) >= limit:
break
except Empty:
pass
return next_batch
def _tracing_thread_drain_compressed_buffer(
client: Client, size_limit: int = 100, size_limit_bytes: int | None = 20_971_520
) -> tuple[Optional[io.BytesIO], Optional[tuple[int, int]]]:
try:
if client.compressed_traces is None:
return None, None
with client.compressed_traces.lock:
pre_compressed_size = client.compressed_traces.uncompressed_size
size_limit_bytes = client._max_batch_size_bytes or size_limit_bytes
if size_limit is not None and size_limit <= 0:
raise ValueError(f"size_limit must be positive; got {size_limit}")
if size_limit_bytes is not None and size_limit_bytes < 0:
raise ValueError(
f"size_limit_bytes must be nonnegative; got {size_limit_bytes}"
)
if (
size_limit_bytes is None or pre_compressed_size < size_limit_bytes
) and (
size_limit is None or client.compressed_traces.trace_count < size_limit
):
return None, None
# Write final boundary and close compression stream
client.compressed_traces.compressor_writer.write(
f"--{_BOUNDARY}--\r\n".encode()
)
client.compressed_traces.compressor_writer.close()
current_size = client.compressed_traces.buffer.tell()
filled_buffer = client.compressed_traces.buffer
setattr(
cast(Any, filled_buffer),
"context",
client.compressed_traces._context,
)
compressed_traces_info = (pre_compressed_size, current_size)
client.compressed_traces.reset()
filled_buffer.seek(0)
return (filled_buffer, compressed_traces_info)
except Exception:
logger.error(
"LangSmith tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
# exceptions are logged elsewhere, but we need to make sure the
# background thread continues to run
return None, None
def _process_buffered_run_ops_batch(
client: Client,
batch_to_process: list[tuple[str, dict, dict[str, Optional[str]]]],
) -> None:
"""Process a batch of run operations asynchronously."""
try:
# Extract just the run dictionaries for process_buffered_run_ops
run_dicts = [run_data for _, run_data, _ in batch_to_process]
original_ids = [run.get("id") for run in run_dicts]
# Apply process_buffered_run_ops transformation
if client._process_buffered_run_ops is None:
raise RuntimeError(
"process_buffered_run_ops should not be None when processing batch"
)
processed_runs = list(client._process_buffered_run_ops(run_dicts))
# Validate that the transformation preserves run count and IDs
if len(processed_runs) != len(run_dicts):
raise ValueError(
f"process_buffered_run_ops must return the same number of runs. "
f"Expected {len(run_dicts)}, got {len(processed_runs)}"
)
processed_ids = [run.get("id") for run in processed_runs]
if processed_ids != original_ids:
raise ValueError(
f"process_buffered_run_ops must preserve run IDs in the same order. "
f"Expected {original_ids}, got {processed_ids}"
)
# Process each run and add to compressed traces
for (operation, _, write_ctx), processed_run in zip(
batch_to_process, processed_runs
):
if operation == "post":
client._create_run(processed_run, **write_ctx)
elif operation == "patch":
client._update_run(processed_run, **write_ctx)
# Trigger data available event
if client._data_available_event:
client._data_available_event.set()
except Exception:
# Log errors but don't crash the background thread
logger.error(
"LangSmith buffered run ops processing error: Failed to process batch.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
def _tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
use_multipart: bool,
mark_task_done: bool = True,
ops: Optional[
list[Union[SerializedRunOperation, SerializedFeedbackOperation]]
] = None,
) -> None:
"""Handle a batch of tracing queue items by sending them to LangSmith.
Args:
client: The LangSmith client to use for sending data.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
use_multipart: Whether to use multipart endpoint for sending data.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False when called from parallel execution to avoid double counting.
ops: Pre-combined serialized operations to use instead of combining from batch.
If None, operations will be combined from the batch items.
"""
try:
# Group batch items by (api_url, auth) combination
grouped_batches = _group_batch_by_api_endpoint(batch)
for (
api_url,
api_key,
service_key,
tenant_id,
authorization,
cookie,
), group_batch in grouped_batches.items():
if not ops:
group_ops = combine_serialized_queue_operations(
[item.item for item in group_batch]
)
else:
group_ids = {item.item.id for item in group_batch}
group_ops = [op for op in ops if op.id in group_ids]
if use_multipart:
client._multipart_ingest_ops(
group_ops,
api_url=api_url,
api_key=api_key,
service_key=service_key,
tenant_id=tenant_id,
authorization=authorization,
cookie=cookie,
)
else:
if any(isinstance(op, SerializedFeedbackOperation) for op in group_ops):
logger.warning(
"Feedback operations are not supported in non-multipart mode"
)
group_ops = [
op
for op in group_ops
if not isinstance(op, SerializedFeedbackOperation)
]
client._batch_ingest_run_ops(
cast(list[SerializedRunOperation], group_ops),
api_url=api_url,
api_key=api_key,
service_key=service_key,
tenant_id=tenant_id,
authorization=authorization,
cookie=cookie,
)
except Exception as e:
logger.error(
"LangSmith tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
client._invoke_tracing_error_callback(e)
finally:
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def _otel_tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
mark_task_done: bool = True,
ops: Optional[
list[Union[SerializedRunOperation, SerializedFeedbackOperation]]
] = None,
) -> None:
"""Handle a batch of tracing queue items by exporting them to OTEL.
Args:
client: The LangSmith client containing the OTEL exporter.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False when called from parallel execution to avoid double counting.
ops: Pre-combined serialized operations to use instead of combining from batch.
If None, operations will be combined from the batch items.
"""
try:
if ops is None:
ops = combine_serialized_queue_operations([item.item for item in batch])
run_ops = [op for op in ops if isinstance(op, SerializedRunOperation)]
otel_context_map = {
item.item.id: item.otel_context
for item in batch
if isinstance(item.item, SerializedRunOperation)
}
if run_ops:
if client.otel_exporter is not None:
client.otel_exporter.export_batch(run_ops, otel_context_map)
else:
logger.error(
"LangSmith tracing error: Failed to submit OTEL trace data.\n"
"This does not affect your application's runtime.\n"
"Error details: client.otel_exporter is None"
)
except Exception as e:
logger.error(
"OTEL tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
client._invoke_tracing_error_callback(e)
finally:
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def _hybrid_tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
use_multipart: bool,
mark_task_done: bool = True,
) -> None:
"""Handle a batch of tracing queue items by sending to both both LangSmith and OTEL.
Args:
client: The LangSmith client to use for sending data.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
use_multipart: Whether to use multipart endpoint for LangSmith.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False primarily for testing when items weren't actually queued.
"""
# Combine operations once to avoid race conditions
ops = combine_serialized_queue_operations([item.item for item in batch])
# Create copies for each thread to avoid shared mutation
langsmith_ops = copy.deepcopy(ops)
otel_ops = copy.deepcopy(ops)
try:
# Use ThreadPoolExecutor for parallel execution
with cf.ThreadPoolExecutor(max_workers=2) as executor:
# Submit both tasks
future_langsmith = executor.submit(
_tracing_thread_handle_batch,
client,
tracing_queue,
batch,
use_multipart,
False, # Don't mark tasks done - we'll do it once at the end
langsmith_ops,
)
future_otel = executor.submit(
_otel_tracing_thread_handle_batch,
client,
tracing_queue,
batch,
False, # Don't mark tasks done - we'll do it once at the end
otel_ops,
)
# Wait for both to complete
future_langsmith.result()
future_otel.result()
except RuntimeError as e:
if "cannot schedule new futures after interpreter shutdown" in str(e):
# During interpreter shutdown, ThreadPoolExecutor is blocked,
# fall back to sequential processing
logger.debug(
"Interpreter shutting down, falling back to sequential processing"
)
_tracing_thread_handle_batch(
client, tracing_queue, batch, use_multipart, False, langsmith_ops
)
_otel_tracing_thread_handle_batch(
client, tracing_queue, batch, False, otel_ops
)
else:
raise
# Mark all tasks as done once, only if requested
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def get_size_limit_from_env() -> Optional[int]:
size_limit_str = ls_utils.get_env_var(
"BATCH_INGEST_SIZE_LIMIT",
)
if size_limit_str is not None:
try:
return int(size_limit_str)
except ValueError:
logger.warning(
f"Invalid value for BATCH_INGEST_SIZE_LIMIT: {size_limit_str}, "
"continuing with default"
)
return None
def _ensure_ingest_config(
info: ls_schemas.LangSmithInfo,
) -> ls_schemas.BatchIngestConfig:
default_config = ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit_bytes=None, # Note this field is not used here
size_limit=100,
scale_up_nthreads_limit=_AUTO_SCALE_UP_NTHREADS_LIMIT,
scale_up_qsize_trigger=_AUTO_SCALE_UP_QSIZE_TRIGGER,
scale_down_nempty_trigger=_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
)
if not info:
return default_config
try:
if not info.batch_ingest_config:
return default_config
env_size_limit = get_size_limit_from_env()
if env_size_limit is not None:
info.batch_ingest_config["size_limit"] = env_size_limit
return info.batch_ingest_config
except BaseException:
return default_config
def get_tracing_mode() -> tuple[bool, bool]:
"""Get the current tracing mode configuration.
Returns:
tuple[bool, bool]:
- hybrid_otel_and_langsmith: True if both OTEL and LangSmith tracing
are enabled, which is default behavior if OTEL_ENABLED is set to
true and OTEL_ONLY is not set to true
- is_otel_only: True if only OTEL tracing is enabled
"""
otel_enabled = ls_utils.is_env_var_truish("OTEL_ENABLED")
otel_only = ls_utils.is_env_var_truish("OTEL_ONLY")
# If OTEL is not enabled, neither mode should be active
if not otel_enabled:
return False, False
hybrid_otel_and_langsmith = not otel_only
is_otel_only = otel_only
return hybrid_otel_and_langsmith, is_otel_only
def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"]
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]
use_multipart = not client._multipart_disabled and batch_ingest_config.get(
"use_multipart_endpoint", True
)
sub_threads: list[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3
# Disable compression if explicitly set, using OpenTelemetry, or zstd unavailable
if not ZSTD_AVAILABLE:
logger.debug(
"zstandard package is not installed. "
"Falling back to uncompressed multipart ingestion."
)
disable_compression = (
ls_utils.is_env_var_truish("DISABLE_RUN_COMPRESSION")
or client.otel_exporter is not None
or not ZSTD_AVAILABLE
)
if not disable_compression and use_multipart:
if not (client.info.instance_flags or {}).get(
"zstd_compression_enabled", False
):
logger.warning(
"Run compression is not enabled. Please update to the latest "
"version of LangSmith. Falling back to regular multipart ingestion."
)
else:
client._futures = weakref.WeakSet()
client.compressed_traces = CompressedTraces()
client._data_available_event = threading.Event()
threading.Thread(
target=tracing_control_thread_func_compress_parallel,
args=(weakref.ref(client),),
daemon=client._use_daemon_threads,
).start()
num_known_refs += 1
def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
hasattr(client, "_manual_cleanup") and client._manual_cleanup
):
logger.debug("Client is being cleaned up, stopping tracing thread")
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
logger.debug("Main thread is dead, stopping tracing thread")
return False
if hasattr(sys, "getrefcount"):
# check if client refs count indicates we're the only remaining
# reference to the client
should_keep_thread = sys.getrefcount(client) > num_known_refs + len(
sub_threads
)
if not should_keep_thread:
logger.debug(
"Client refs count indicates we're the only remaining reference "
"to the client, stopping tracing thread",
)
return should_keep_thread
else:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True
# loop until
while keep_thread_active():
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
if (
len(sub_threads) < scale_up_nthreads_limit
and tracing_queue.qsize() > scale_up_qsize_trigger
):
new_thread = threading.Thread(
target=_tracing_sub_thread_func,
args=(weakref.ref(client), use_multipart),
daemon=client._use_daemon_threads,
)
sub_threads.append(new_thread)
new_thread.start()
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes
or batch_ingest_config.get("size_limit_bytes")
or 0
)
if next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode: both OTEL and LangSmith
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only mode
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only mode
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
# drain the queue on exit - apply same logic
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes or batch_ingest_config.get("size_limit_bytes") or 0
)
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode cleanup
logger.debug("Hybrid mode cleanup")
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only cleanup
logger.debug("OTEL-only cleanup")
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only cleanup
logger.debug("LangSmith-only cleanup")
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
logger.debug("Tracing control thread is shutting down")
def tracing_control_thread_func_compress_parallel(
client_ref: weakref.ref[Client], flush_interval: float = 0.5
) -> None:
client = client_ref()
if client is None:
return
logger.debug("Tracing control thread func compress parallel called")
if (
client.compressed_traces is None
or client._data_available_event is None
or client._futures is None
):
logger.error(
"LangSmith tracing error: Required compression attributes not "
"initialized.\nThis may affect trace submission but does not "
"impact your application's runtime."
)
return
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
size_limit_bytes = client._max_batch_size_bytes or batch_ingest_config.get(
"size_limit_bytes", 20_971_520
)
# One for this func, one for the parent thread, one for getrefcount,
# one for _get_data_type_cached
num_known_refs = 4
def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
hasattr(client, "_manual_cleanup") and client._manual_cleanup
):
logger.debug("Client is being cleaned up, stopping compression thread")
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
logger.debug("Main thread is dead, stopping compression thread")
return False
if hasattr(sys, "getrefcount"):
# check if client refs count indicates we're the only remaining
# reference to the client
should_keep_thread = sys.getrefcount(client) > num_known_refs
if not should_keep_thread:
logger.debug(
"Client refs count indicates we're the only remaining reference "
"to the client, stopping compression thread",
)
return should_keep_thread
else:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True
last_flush_time = time.monotonic()
while True:
triggered = client._data_available_event.wait(timeout=0.05)
if not keep_thread_active():
break
# If data arrived, clear the event and attempt a drain
if triggered:
client._data_available_event.clear()
data_stream, compressed_traces_info = (
_tracing_thread_drain_compressed_buffer
)(client, size_limit, size_limit_bytes)
# If we have data, submit the send request
if data_stream is not None:
try:
future = LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_traces_info,
)
client._futures.add(future)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_traces_info,
)
last_flush_time = time.monotonic()
else:
if (time.monotonic() - last_flush_time) >= flush_interval:
(
data_stream,
compressed_traces_info,
) = _tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
if data_stream is not None:
try:
cf.wait(
[
LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_traces_info,
)
last_flush_time = time.monotonic()
# Drain the buffer on exit (final flush)
try:
(
final_data_stream,
compressed_traces_info,
) = _tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
if final_data_stream is not None:
try:
cf.wait(
[
LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
final_data_stream,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
final_data_stream,
compressed_traces_info,
)
except Exception:
logger.error(
"LangSmith tracing error: Failed during final cleanup.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
logger.debug("Compressed traces control thread is shutting down")
def _tracing_sub_thread_func(
client_ref: weakref.ref[Client],
use_multipart: bool,
) -> None:
client = client_ref()
if client is None:
return
try:
if not client.info:
return
except BaseException as e:
logger.debug("Error in tracing control thread: %s", e)
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit = batch_ingest_config.get("size_limit", 100)
seen_successive_empty_queues = 0
# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we've seen the queue empty 4 times in a row
and seen_successive_empty_queues
<= batch_ingest_config["scale_down_nempty_trigger"]
):
max_batch_size = (
client._max_batch_size_bytes
or batch_ingest_config.get("size_limit_bytes")
or 0
)
if next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, max_size_bytes=max_batch_size
):
seen_successive_empty_queues = 0
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
if hybrid_otel_and_langsmith:
# Hybrid mode: both OTEL and LangSmith
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only mode
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only mode
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
else:
seen_successive_empty_queues += 1
# drain the queue on exit - apply same logic
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes or batch_ingest_config.get("size_limit_bytes") or 0
)
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode cleanup
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only cleanup
logger.debug("OTEL-only cleanup")
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only cleanup
logger.debug("LangSmith-only cleanup")
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
logger.debug("Tracing control sub-thread is shutting down")

View File

@@ -0,0 +1,21 @@
import functools
import warnings
from typing import Callable
class LangSmithBetaWarning(UserWarning):
"""This is a warning specific to the LangSmithBeta module."""
@functools.lru_cache(maxsize=100)
def _warn_once(message: str, stacklevel: int = 2) -> None:
warnings.warn(message, LangSmithBetaWarning, stacklevel=stacklevel)
def warn_beta(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
_warn_once(f"Function {func.__name__} is in beta.", stacklevel=3)
return func(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,56 @@
import io
import threading
from typing import Optional
from langsmith import utils as ls_utils
try:
from zstandard import ZstdCompressor # type: ignore[import]
ZSTD_AVAILABLE = True
except ImportError:
ZSTD_AVAILABLE = False
compression_level = int(ls_utils.get_env_var("RUN_COMPRESSION_LEVEL") or 1)
compression_threads = int(ls_utils.get_env_var("RUN_COMPRESSION_THREADS") or -1)
DEFAULT_MAX_UNCOMPRESSED_QUEUE_BYTES = 1024 * 1024 * 1024 # 1GB
class CompressedTraces:
def __init__(self, max_uncompressed_size_bytes: Optional[int] = None) -> None:
if not ZSTD_AVAILABLE:
raise ImportError(
"zstandard is required for compressed trace ingestion. "
"Install it with `pip install zstandard` or set the environment "
"variable LANGSMITH_DISABLE_RUN_COMPRESSION=true to disable "
"compression."
)
# Configure the maximum total uncompressed size for the in-memory queue.
if max_uncompressed_size_bytes is None:
max_bytes_str = ls_utils.get_env_var("MAX_INGEST_MEMORY_BYTES")
if max_bytes_str is not None:
max_uncompressed_size_bytes = int(max_bytes_str)
else:
max_uncompressed_size_bytes = DEFAULT_MAX_UNCOMPRESSED_QUEUE_BYTES
self.max_uncompressed_size_bytes = max_uncompressed_size_bytes
self.buffer: io.BytesIO = io.BytesIO()
self.trace_count: int = 0
self.lock = threading.Lock()
self.uncompressed_size: int = 0
self._context: list[str] = []
self.compressor_writer = ZstdCompressor(
level=compression_level, threads=compression_threads
).stream_writer(self.buffer, closefd=False)
def reset(self) -> None:
self.buffer = io.BytesIO()
self.trace_count = 0
self.uncompressed_size = 0
self._context = []
self.compressor_writer = ZstdCompressor(
level=compression_level, threads=-1
).stream_writer(self.buffer, closefd=False)

View File

@@ -0,0 +1,9 @@
import uuid
_SIZE_LIMIT_BYTES = 20_971_520 # 20MB by default
_AUTO_SCALE_UP_QSIZE_TRIGGER = 200
_AUTO_SCALE_UP_NTHREADS_LIMIT = 32
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4
_BLOCKSIZE_BYTES = 1024 * 1024 # 1MB
_BOUNDARY = uuid.uuid4().hex
_TRACING_QUEUE_MAX_SIZE = 10_000

View File

@@ -0,0 +1,29 @@
"""Shared context (ContextVars and global defaults) that configure tracing."""
import contextvars
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING:
from langsmith.client import Client
from langsmith.run_trees import RunTree
else:
Client = Any # type: ignore[assignment]
RunTree = Any # type: ignore[assignment]
_PROJECT_NAME = contextvars.ContextVar[Optional[str]]("_PROJECT_NAME", default=None)
_TAGS = contextvars.ContextVar[Optional[list[str]]]("_TAGS", default=None)
_METADATA = contextvars.ContextVar[Optional[dict[str, Any]]]("_METADATA", default=None)
_TRACING_ENABLED = contextvars.ContextVar[Optional[Union[bool, Literal["local"]]]](
"_TRACING_ENABLED", default=None
)
_CLIENT = contextvars.ContextVar[Optional["Client"]]("_CLIENT", default=None)
_PARENT_RUN_TREE = contextvars.ContextVar[Optional["RunTree"]](
"_PARENT_RUN_TREE", default=None
)
# Not thread-local, so you can set this process-wide (before asyncio.run, etc.)
_GLOBAL_PROJECT_NAME: Optional[str] = None
_GLOBAL_TAGS: Optional[list[str]] = None
_GLOBAL_METADATA: Optional[dict[str, Any]] = None
_GLOBAL_TRACING_ENABLED: Optional[Union[bool, Literal["local"]]] = None
_GLOBAL_CLIENT: Optional["Client"] = None

View File

@@ -0,0 +1,67 @@
from typing import Any, Callable, Literal, Optional
from typing_extensions import TypedDict
METRICS = Literal[
"damerau_levenshtein",
"levenshtein",
"jaro",
"jaro_winkler",
"hamming",
"indel",
]
class EditDistanceConfig(TypedDict, total=False):
metric: METRICS
normalize_score: bool
class EditDistance:
def __init__(
self,
config: Optional[EditDistanceConfig] = None,
):
config = config or {}
metric = config.get("metric") or "damerau_levenshtein"
self.metric = self._get_metric(
metric, normalize_score=config.get("normalize_score", True)
)
def evaluate(
self,
prediction: str,
reference: Optional[str] = None,
) -> float:
return self.metric(prediction, reference)
@staticmethod
def _get_metric(distance: str, normalize_score: bool = True) -> Callable:
try:
from rapidfuzz import ( # type: ignore[import-not-found]
distance as rf_distance,
)
except ImportError:
raise ImportError(
"This operation requires the rapidfuzz library to use."
"Please install it with `pip install -U rapidfuzz`."
)
module_map: dict[str, Any] = {
"damerau_levenshtein": rf_distance.DamerauLevenshtein,
"levenshtein": rf_distance.Levenshtein,
"jaro": rf_distance.Jaro,
"jaro_winkler": rf_distance.JaroWinkler,
"hamming": rf_distance.Hamming,
"indel": rf_distance.Indel,
}
if distance not in module_map:
raise ValueError(
f"Invalid distance metric: {distance}"
f"\nMust be one of: {list(module_map)}"
)
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else:
return module.distance

View File

@@ -0,0 +1,190 @@
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
)
from typing_extensions import TypedDict
if TYPE_CHECKING:
import numpy as np # type: ignore
logger = logging.getLogger(__name__)
Matrix = Union[list[list[float]], list[Any], Any]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices."""
import numpy as np
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd # type: ignore
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = 1 - simd.cdist(X, Y, metric="cosine")
if isinstance(Z, float):
return np.array([Z])
return np.array(Z)
except ImportError:
logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity
def _get_openai_encoder() -> Callable[[Sequence[str]], Sequence[Sequence[float]]]:
"""Get the OpenAI GPT-3 encoder."""
try:
from openai import Client as OpenAIClient
except ImportError:
raise ImportError(
"THe default encoder for the EmbeddingDistance class uses the OpenAI API. "
"Please either install the openai library with `pip install openai` or "
"provide a custom encoder function (Callable[[str], Sequence[float]])."
)
def encode_text(texts: Sequence[str]) -> Sequence[Sequence[float]]:
client = OpenAIClient()
response = client.embeddings.create(
input=list(texts), model="text-embedding-3-small"
)
return [d.embedding for d in response.data]
return encode_text
class EmbeddingConfig(TypedDict, total=False):
encoder: Callable[[list[str]], Sequence[Sequence[float]]]
metric: Literal["cosine", "euclidean", "manhattan", "chebyshev", "hamming"]
class EmbeddingDistance:
def __init__(
self,
config: Optional[EmbeddingConfig] = None,
):
config = config or {}
self.distance = config.get("metric") or "cosine"
self.encoder = config.get("encoder") or _get_openai_encoder()
def evaluate(
self,
prediction: str,
reference: str,
) -> float:
try:
import numpy as np
except ImportError:
raise ImportError(
"The EmbeddingDistance class requires NumPy. Please install it with "
"`pip install numpy`."
)
embeddings = self.encoder([prediction, reference])
vector = np.array(embeddings)
return self._compute_distance(vector[0], vector[1]).item()
def _compute_distance(self, a: np.ndarray, b: np.ndarray) -> np.floating:
if self.distance == "cosine":
return self._cosine_distance(a, b) # type: ignore
elif self.distance == "euclidean":
return self._euclidean_distance(a, b)
elif self.distance == "manhattan":
return self._manhattan_distance(a, b)
elif self.distance == "chebyshev":
return self._chebyshev_distance(a, b)
elif self.distance == "hamming":
return self._hamming_distance(a, b)
else:
raise ValueError(f"Invalid distance metric: {self.distance}")
@staticmethod
def _cosine_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Compute the cosine distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.ndarray: The cosine distance.
"""
return 1.0 - cosine_similarity([a], [b])
@staticmethod
def _euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Euclidean distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Euclidean distance.
"""
return np.linalg.norm(a - b)
@staticmethod
def _manhattan_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Manhattan distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Manhattan distance.
"""
return np.sum(np.abs(a - b))
@staticmethod
def _chebyshev_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Chebyshev distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Chebyshev distance.
"""
return np.max(np.abs(a - b))
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Hamming distance.
"""
return np.mean(a != b)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from collections.abc import Iterable
from io import BufferedReader
from typing import Union
MultipartPart = tuple[
str, tuple[None, Union[bytes, BufferedReader], str, dict[str, str]]
]
class MultipartPartsAndContext:
parts: list[MultipartPart]
context: str
__slots__ = ("parts", "context")
def __init__(self, parts: list[MultipartPart], context: str) -> None:
self.parts = parts
self.context = context
def join_multipart_parts_and_context(
parts_and_contexts: Iterable[MultipartPartsAndContext],
) -> MultipartPartsAndContext:
acc_parts: list[MultipartPart] = []
acc_context: list[str] = []
for parts_and_context in parts_and_contexts:
acc_parts.extend(parts_and_context.parts)
acc_context.append(parts_and_context.context)
return MultipartPartsAndContext(acc_parts, "; ".join(acc_context))

View File

@@ -0,0 +1,446 @@
from __future__ import annotations
import itertools
import logging
import os
import uuid
from collections.abc import Iterable
from io import BufferedReader
from typing import Literal, Optional, Union, cast
from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._compressed_traces import CompressedTraces
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json
logger = logging.getLogger(__name__)
class SerializedRunOperation:
operation: Literal["post", "patch"]
id: uuid.UUID
trace_id: uuid.UUID
# this is the whole object, minus the other fields which
# are popped (inputs/outputs/events/attachments)
_none: bytes
inputs: Optional[bytes]
outputs: Optional[bytes]
events: Optional[bytes]
extra: Optional[bytes]
error: Optional[bytes]
serialized: Optional[bytes]
attachments: Optional[ls_schemas.Attachments]
__slots__ = (
"operation",
"id",
"trace_id",
"_none",
"inputs",
"outputs",
"events",
"extra",
"error",
"serialized",
"attachments",
)
def __init__(
self,
operation: Literal["post", "patch"],
id: uuid.UUID,
trace_id: uuid.UUID,
_none: bytes,
inputs: Optional[bytes] = None,
outputs: Optional[bytes] = None,
events: Optional[bytes] = None,
extra: Optional[bytes] = None,
error: Optional[bytes] = None,
serialized: Optional[bytes] = None,
attachments: Optional[ls_schemas.Attachments] = None,
) -> None:
self.operation = operation
self.id = id
self.trace_id = trace_id
self._none = _none
self.inputs = inputs
self.outputs = outputs
self.events = events
self.extra = extra
self.error = error
self.serialized = serialized
self.attachments = attachments
def calculate_serialized_size(self) -> int:
"""Calculate actual serialized size of this operation."""
size = 0
if self._none:
size += len(self._none)
if self.inputs:
size += len(self.inputs)
if self.outputs:
size += len(self.outputs)
if self.events:
size += len(self.events)
if self.extra:
size += len(self.extra)
if self.error:
size += len(self.error)
if self.serialized:
size += len(self.serialized)
if self.attachments:
for content_type, data_or_path in self.attachments.values():
if isinstance(data_or_path, bytes):
size += len(data_or_path)
return size
def deserialize_run_info(self) -> dict:
"""Deserialize the main run info (_none and extra, error and serialized)."""
run_info = _orjson.loads(self._none)
if self.extra is not None:
run_info["extra"] = _orjson.loads(self.extra)
if self.error is not None:
run_info["error"] = _orjson.loads(self.error)
if self.serialized is not None:
run_info["serialized"] = _orjson.loads(self.serialized)
return run_info
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedRunOperation) and (
self.operation,
self.id,
self.trace_id,
self._none,
self.inputs,
self.outputs,
self.events,
self.extra,
self.error,
self.serialized,
self.attachments,
) == (
other.operation,
other.id,
other.trace_id,
other._none,
other.inputs,
other.outputs,
other.events,
other.extra,
other.error,
other.serialized,
other.attachments,
)
class SerializedFeedbackOperation:
id: uuid.UUID
trace_id: uuid.UUID
feedback: bytes
__slots__ = ("id", "trace_id", "feedback")
def __init__(self, id: uuid.UUID, trace_id: uuid.UUID, feedback: bytes) -> None:
self.id = id
self.trace_id = trace_id
self.feedback = feedback
def calculate_serialized_size(self) -> int:
"""Calculate actual serialized size of this operation."""
return len(self.feedback)
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedFeedbackOperation) and (
self.id,
self.trace_id,
self.feedback,
) == (other.id, other.trace_id, other.feedback)
def serialize_feedback_dict(
feedback: Union[ls_schemas.FeedbackCreate, dict],
) -> SerializedFeedbackOperation:
if hasattr(feedback, "model_dump") and callable(getattr(feedback, "model_dump")):
feedback_create: dict = feedback.model_dump() # type: ignore
else:
feedback_create = cast(dict, feedback)
if "id" not in feedback_create:
feedback_create["id"] = uuid.uuid4()
elif isinstance(feedback_create["id"], str):
feedback_create["id"] = uuid.UUID(feedback_create["id"])
if "trace_id" not in feedback_create:
feedback_create["trace_id"] = uuid.uuid4()
elif isinstance(feedback_create["trace_id"], str):
feedback_create["trace_id"] = uuid.UUID(feedback_create["trace_id"])
return SerializedFeedbackOperation(
id=feedback_create["id"],
trace_id=feedback_create["trace_id"],
feedback=_dumps_json(feedback_create),
)
def serialize_run_dict(
operation: Literal["post", "patch"], payload: dict
) -> SerializedRunOperation:
inputs = payload.pop("inputs", None)
outputs = payload.pop("outputs", None)
events = payload.pop("events", None)
extra = payload.pop("extra", None)
error = payload.pop("error", None)
serialized = payload.pop("serialized", None)
attachments = payload.pop("attachments", None)
return SerializedRunOperation(
operation=operation,
id=payload["id"],
trace_id=payload["trace_id"],
_none=_dumps_json(payload),
inputs=_dumps_json(inputs) if inputs is not None else None,
outputs=_dumps_json(outputs) if outputs is not None else None,
events=_dumps_json(events) if events is not None else None,
extra=_dumps_json(extra) if extra is not None else None,
error=_dumps_json(error) if error is not None else None,
serialized=_dumps_json(serialized) if serialized is not None else None,
attachments=attachments if attachments is not None else None,
)
def combine_serialized_queue_operations(
ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]],
) -> list[Union[SerializedRunOperation, SerializedFeedbackOperation]]:
create_ops_by_id = {
op.id: op
for op in ops
if isinstance(op, SerializedRunOperation) and op.operation == "post"
}
passthrough_ops: list[
Union[SerializedRunOperation, SerializedFeedbackOperation]
] = []
for op in ops:
if isinstance(op, SerializedRunOperation):
if op.operation == "post":
continue
# must be patch
create_op = create_ops_by_id.get(op.id)
if create_op is None:
passthrough_ops.append(op)
continue
if op._none is not None and op._none != create_op._none:
# TODO optimize this more - this would currently be slowest
# for large payloads
create_op_dict = _orjson.loads(create_op._none)
op_dict = {
k: v for k, v in _orjson.loads(op._none).items() if v is not None
}
create_op_dict.update(op_dict)
create_op._none = _orjson.dumps(create_op_dict)
if op.inputs is not None:
create_op.inputs = op.inputs
if op.outputs is not None:
create_op.outputs = op.outputs
if op.events is not None:
create_op.events = op.events
if op.extra is not None:
create_op.extra = op.extra
if op.error is not None:
create_op.error = op.error
if op.serialized is not None:
create_op.serialized = op.serialized
if op.attachments is not None:
if create_op.attachments is None:
create_op.attachments = {}
create_op.attachments.update(op.attachments)
else:
passthrough_ops.append(op)
return list(itertools.chain(create_ops_by_id.values(), passthrough_ops))
def serialized_feedback_operation_to_multipart_parts_and_context(
op: SerializedFeedbackOperation,
) -> MultipartPartsAndContext:
return MultipartPartsAndContext(
[
(
f"feedback.{op.id}",
(
None,
op.feedback,
"application/json",
{"Content-Length": str(len(op.feedback))},
),
)
],
f"trace={op.trace_id},id={op.id}",
)
def serialized_run_operation_to_multipart_parts_and_context(
op: SerializedRunOperation,
) -> tuple[MultipartPartsAndContext, dict[str, BufferedReader]]:
acc_parts: list[MultipartPart] = []
opened_files_dict: dict[str, BufferedReader] = {}
# this is main object, minus inputs/outputs/events/attachments
acc_parts.append(
(
f"{op.operation}.{op.id}",
(
None,
op._none,
"application/json",
{"Content-Length": str(len(op._none))},
),
)
)
for key, value in (
("inputs", op.inputs),
("outputs", op.outputs),
("events", op.events),
("extra", op.extra),
("error", op.error),
("serialized", op.serialized),
):
if value is None:
continue
valb = value
acc_parts.append(
(
f"{op.operation}.{op.id}.{key}",
(
None,
valb,
"application/json",
{"Content-Length": str(len(valb))},
),
),
)
if op.attachments:
for n, (content_type, data_or_path) in op.attachments.items():
if "." in n:
logger.warning(
f"Skipping logging of attachment '{n}' "
f"for run {op.id}:"
" Invalid attachment name. Attachment names must not contain"
" periods ('.'). Please rename the attachment and try again."
)
continue
if isinstance(data_or_path, bytes):
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
data_or_path,
content_type,
{"Content-Length": str(len(data_or_path))},
),
)
)
else:
try:
file_size = os.path.getsize(data_or_path)
file = open(data_or_path, "rb")
except FileNotFoundError:
logger.warning(
"Attachment file not found for run %s: %s", op.id, data_or_path
)
continue
opened_files_dict[str(data_or_path) + str(uuid.uuid4())] = file
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
file,
f"{content_type}; length={file_size}",
{},
),
)
)
return (
MultipartPartsAndContext(acc_parts, f"trace={op.trace_id},id={op.id}"),
opened_files_dict,
)
def encode_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
boundary: str,
) -> Iterable[tuple[bytes, Union[bytes, BufferedReader]]]:
for part_name, (filename, data, content_type, headers) in parts_and_context.parts:
header_parts = [
f"--{boundary}\r\n",
f'Content-Disposition: form-data; name="{part_name}"',
]
if filename:
header_parts.append(f'; filename="{filename}"')
header_parts.extend(
[
f"\r\nContent-Type: {content_type}\r\n",
*[f"{k}: {v}\r\n" for k, v in headers.items()],
"\r\n",
]
)
yield ("".join(header_parts).encode(), data)
def compress_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
compressed_traces: CompressedTraces,
boundary: str,
) -> bool:
"""Compress multipart parts into the shared compressed buffer.
Returns True if the parts were enqueued into the compressed buffer, or False
if they were rejected because the configured in-memory size limit would be
exceeded.
"""
write = compressed_traces.compressor_writer.write
parts: list[tuple[bytes, bytes]] = []
op_uncompressed_size = 0
for headers, data in encode_multipart_parts_and_context(
parts_and_context, boundary
):
# Normalise to bytes
if not isinstance(data, (bytes, bytearray)):
data = (
data.read() if isinstance(data, BufferedReader) else str(data).encode()
)
parts.append((headers, data))
op_uncompressed_size += len(data)
max_bytes = getattr(compressed_traces, "max_uncompressed_size_bytes", None)
if max_bytes is not None and max_bytes > 0:
current_size = compressed_traces.uncompressed_size
if current_size > 0 and current_size + op_uncompressed_size > max_bytes:
from langsmith.client import _log_tracing_drop
_log_tracing_drop(
f"compressed traces buffer full ({current_size}/{max_bytes} bytes)"
)
return False
for headers, data in parts:
write(headers)
compressed_traces.uncompressed_size += len(data)
write(data)
write(b"\r\n") # part terminator
compressed_traces._context.append(parts_and_context.context)
return True

View File

@@ -0,0 +1,88 @@
"""Stubs for orjson operations, compatible with PyPy via a json fallback."""
try:
from orjson import (
OPT_NON_STR_KEYS,
OPT_SERIALIZE_DATACLASS,
OPT_SERIALIZE_NUMPY,
OPT_SERIALIZE_UUID,
Fragment,
JSONDecodeError,
dumps,
loads,
)
except ImportError:
import dataclasses
import json
import uuid
from typing import Any, Callable, Optional, Union
DefaultFunc = Optional[Callable[[Any], Any]]
OPT_NON_STR_KEYS = 1
OPT_SERIALIZE_DATACLASS = 2
OPT_SERIALIZE_NUMPY = 4
OPT_SERIALIZE_UUID = 8
class Fragment: # type: ignore
def __init__(self, payloadb: bytes):
self.payloadb = payloadb
from json import JSONDecodeError # type: ignore
def dumps(
obj: Any,
/,
default: DefaultFunc = None,
option: Optional[int] = None,
) -> bytes:
# for now, don't do anything for this case because `json.dumps`
# automatically encodes non-str keys as str by default, unlike orjson
# enable_non_str_keys = bool(option & OPT_NON_STR_KEYS)
if option is None:
option = 0
enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY)
enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS)
enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID)
class CustomEncoder(json.JSONEncoder): # type: ignore
def encode(self, o: Any) -> str:
if isinstance(o, Fragment):
return o.payloadb.decode("utf-8") # type: ignore
return super().encode(o)
def default(self, o: Any) -> Any:
if enable_serialize_uuid and isinstance(o, uuid.UUID):
return str(o)
if enable_serialize_numpy and hasattr(o, "tolist"):
# even objects like np.uint16(15) have a .tolist() function
return o.tolist()
if (
enable_serialize_dataclass
and dataclasses.is_dataclass(o)
and not isinstance(o, type)
):
return dataclasses.asdict(o)
if default is not None:
return default(o)
return super().default(o)
return json.dumps(obj, cls=CustomEncoder).encode("utf-8")
def loads(payload: Union[bytes, bytearray, memoryview, str], /) -> Any:
return json.loads(payload)
__all__ = [
"loads",
"dumps",
"Fragment",
"JSONDecodeError",
"OPT_SERIALIZE_NUMPY",
"OPT_SERIALIZE_DATACLASS",
"OPT_SERIALIZE_UUID",
"OPT_NON_STR_KEYS",
]

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from uuid import UUID
def get_otel_trace_id_from_uuid(uuid_val: UUID) -> int:
"""Get OpenTelemetry trace ID as integer from UUID.
Args:
uuid_val: The UUID to convert.
Returns:
Integer representation of the trace ID.
"""
trace_id_hex = uuid_val.hex
return int(trace_id_hex, 16)
def get_otel_span_id_from_uuid(uuid_val: UUID) -> int:
"""Get OpenTelemetry span ID as integer from UUID.
Args:
uuid_val: The UUID to convert.
Returns:
Integer representation of the span ID.
"""
uuid_bytes = uuid_val.bytes
span_id_bytes = uuid_bytes[:8]
span_id_hex = span_id_bytes.hex()
return int(span_id_hex, 16)

View File

@@ -0,0 +1,93 @@
import functools
from urllib3 import __version__ as urllib3version # type: ignore[import-untyped]
from urllib3 import connection # type: ignore[import-untyped]
def _ensure_str(s, encoding="utf-8", errors="strict") -> str:
if isinstance(s, str):
return s
if isinstance(s, bytes):
return s.decode(encoding, errors)
return str(s)
# Copied from https://github.com/urllib3/urllib3/blob/1c994dfc8c5d5ecaee8ed3eb585d4785f5febf6e/src/urllib3/connection.py#L231
def request(self, method, url, body=None, headers=None):
"""Make the request.
This function is based on the urllib3 request method, with modifications
to handle potential issues when using vcrpy in concurrent workloads.
Args:
self: The HTTPConnection instance.
method (str): The HTTP method (e.g., 'GET', 'POST').
url (str): The URL for the request.
body (Optional[Any]): The body of the request.
headers (Optional[dict]): Headers to send with the request.
Returns:
The result of calling the parent request method.
"""
# Update the inner socket's timeout value to send the request.
# This only triggers if the connection is re-used.
if getattr(self, "sock", None) is not None:
self.sock.settimeout(self.timeout)
if headers is None:
headers = {}
else:
# Avoid modifying the headers passed into .request()
headers = headers.copy()
if "user-agent" not in (_ensure_str(k.lower()) for k in headers):
headers["User-Agent"] = connection._get_default_user_agent()
# The above is all the same ^^^
# The following is different:
return self._parent_request(method, url, body=body, headers=headers)
_PATCHED = False
def patch_urllib3():
"""Patch the request method of urllib3 to avoid type errors when using vcrpy.
In concurrent workloads (such as the tracing background queue), the
connection pool can get in a state where an HTTPConnection is created
before vcrpy patches the HTTPConnection class. In urllib3 >= 2.0 this isn't
a problem since they use the proper super().request(...) syntax, but in older
versions, super(HTTPConnection, self).request is used, resulting in a TypeError
since self is no longer a subclass of "HTTPConnection" (which at this point
is vcr.stubs.VCRConnection).
This method patches the class to fix the super() syntax to avoid mixed inheritance.
In the case of the LangSmith tracing logic, it doesn't really matter since we always
exclude cache checks for calls to LangSmith.
The patch is only applied for urllib3 versions older than 2.0.
"""
global _PATCHED
if _PATCHED:
return
from packaging import version
if version.parse(urllib3version) >= version.parse("2.0"):
_PATCHED = True
return
# Lookup the parent class and its request method
parent_class = connection.HTTPConnection.__bases__[0]
parent_request = parent_class.request
def new_request(self, *args, **kwargs):
"""Handle parent request.
This method binds the parent's request method to self and then
calls our modified request function.
"""
self._parent_request = functools.partial(parent_request, self)
return request(self, *args, **kwargs)
connection.HTTPConnection.request = new_request
_PATCHED = True

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import base64
import collections
import datetime
import decimal
import ipaddress
import json
import logging
import pathlib
import re
import uuid
from typing import Any
from langsmith._internal import _orjson
try:
from zoneinfo import ZoneInfo # type: ignore[import-not-found]
except ImportError:
class ZoneInfo: # type: ignore[no-redef]
"""Introduced in python 3.9."""
logger = logging.getLogger(__name__)
def _simple_default(obj):
try:
# Only need to handle types that orjson doesn't serialize by default
# https://github.com/ijl/orjson#serialize
if isinstance(obj, datetime.datetime):
return obj.isoformat()
elif isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, BaseException):
return {"error": type(obj).__name__, "message": str(obj)}
elif isinstance(obj, (set, frozenset, collections.deque)):
return list(obj)
elif isinstance(obj, (datetime.timezone, ZoneInfo)):
return obj.tzname(None)
elif isinstance(obj, datetime.timedelta):
return obj.total_seconds()
elif isinstance(obj, decimal.Decimal):
if obj.as_tuple().exponent >= 0:
return int(obj)
else:
return float(obj)
elif isinstance(
obj,
(
ipaddress.IPv4Address,
ipaddress.IPv4Interface,
ipaddress.IPv4Network,
ipaddress.IPv6Address,
ipaddress.IPv6Interface,
ipaddress.IPv6Network,
pathlib.Path,
),
):
return str(obj)
elif isinstance(obj, re.Pattern):
return obj.pattern
elif isinstance(obj, (bytes, bytearray)):
return base64.b64encode(obj).decode()
return str(obj)
except BaseException as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return str(obj)
_serialization_methods = [
(
"model_dump",
{"exclude_none": True, "mode": "json"},
), # Pydantic V2 with non-serializable fields
("dict", {}), # Pydantic V1 with non-serializable field
("to_dict", {}), # dataclasses-json
]
# IMPORTANT: This function is used from Rust code in `langsmith-pyo3` serialization,
# in order to handle serializing these tricky Python types *from Rust*.
# Do not cause this function to become inaccessible (e.g. by deleting
# or renaming it) without also fixing the corresponding Rust code found in:
# rust/crates/langsmith-pyo3/src/serialization/mod.rs
def _serialize_json(obj: Any) -> Any:
try:
if isinstance(obj, (set, tuple)):
if hasattr(obj, "_asdict") and callable(obj._asdict):
# NamedTuple
return obj._asdict()
return list(obj)
for attr, kwargs in _serialization_methods:
if (
hasattr(obj, attr)
and callable(getattr(obj, attr))
and not isinstance(obj, type)
):
try:
method = getattr(obj, attr)
response = method(**kwargs)
if not isinstance(response, dict):
return str(response)
return response
except Exception as e:
logger.debug(
f"Failed to use {attr} to serialize {type(obj)} to"
f" JSON: {repr(e)}"
)
pass
return _simple_default(obj)
except BaseException as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return str(obj)
def _elide_surrogates(s: bytes) -> bytes:
pattern = re.compile(rb"\\ud[89a-f][0-9a-f]{2}", re.IGNORECASE)
result = pattern.sub(b"", s)
return result
def dumps_json(obj: Any) -> bytes:
"""Serialize an object to a JSON formatted string.
Parameters
----------
obj : Any
The object to serialize.
default : Callable[[Any], Any] or None, default=None
The default function to use for serialization.
Returns:
-------
str
The JSON formatted string.
"""
try:
return _orjson.dumps(
obj,
default=_serialize_json,
option=_orjson.OPT_SERIALIZE_NUMPY
| _orjson.OPT_SERIALIZE_DATACLASS
| _orjson.OPT_SERIALIZE_UUID
| _orjson.OPT_NON_STR_KEYS,
)
except TypeError as e:
# Usually caused by UTF surrogate characters
logger.debug(f"Orjson serialization failed: {repr(e)}. Falling back to json.")
result = json.dumps(
obj,
default=_serialize_json,
ensure_ascii=True,
).encode("utf-8")
try:
result = _orjson.dumps(
_orjson.loads(result.decode("utf-8", errors="surrogateescape"))
)
except _orjson.JSONDecodeError:
result = _elide_surrogates(result)
return result

View File

@@ -0,0 +1,155 @@
"""UUID helpers backed by uuid-utils."""
from __future__ import annotations
import time
import uuid
import warnings
from typing import Final
import xxhash
from uuid_utils.compat import uuid7 as _uuid_utils_uuid7
_NANOS_PER_SECOND: Final = 1_000_000_000
def _to_timestamp_and_nanos(nanoseconds: int) -> tuple[int, int]:
"""Split a nanosecond timestamp into seconds and remaining nanoseconds."""
seconds, nanos = divmod(nanoseconds, _NANOS_PER_SECOND)
return seconds, nanos
def uuid7(nanoseconds: int | None = None) -> uuid.UUID:
"""Generate a UUID from a Unix timestamp in nanoseconds and random bits.
UUIDv7 objects feature monotonicity within a millisecond.
Args:
nanoseconds: Optional ns timestamp. If not provided, uses current time.
"""
# --- 48 --- -- 4 -- --- 12 --- -- 2 -- --- 30 --- - 32 -
# unix_ts_ms | version | counter_hi | variant | counter_lo | random
#
# 'counter = counter_hi | counter_lo' is a 42-bit counter constructed
# with Method 1 of RFC 9562, §6.2, and its MSB is set to 0.
#
# 'random' is a 32-bit random value regenerated for every new UUID.
#
# If multiple UUIDs are generated within the same millisecond, the LSB
# of 'counter' is incremented by 1. When overflowing, the timestamp is
# advanced and the counter is reset to a random 42-bit integer with MSB
# set to 0.
# For now, just delegate to the uuid_utils implementation
if nanoseconds is None:
return _uuid_utils_uuid7()
seconds, nanos = _to_timestamp_and_nanos(nanoseconds)
return _uuid_utils_uuid7(timestamp=seconds, nanos=nanos)
def is_uuid_v7(uuid_obj: uuid.UUID) -> bool:
"""Check if a UUID is version 7.
Args:
uuid_obj: The UUID to check.
Returns:
True if the UUID is version 7, False otherwise.
"""
return uuid_obj.version == 7
_UUID_V7_WARNING_EMITTED = False
def warn_if_not_uuid_v7(uuid_obj: uuid.UUID, id_type: str) -> None:
"""Warn if a UUID is not version 7.
Args:
uuid_obj: The UUID to check.
id_type: The type of ID (e.g., "run_id", "trace_id") for the warning message.
"""
global _UUID_V7_WARNING_EMITTED
if not is_uuid_v7(uuid_obj) and not _UUID_V7_WARNING_EMITTED:
_UUID_V7_WARNING_EMITTED = True
warnings.warn(
(
"LangSmith now uses UUID v7 for run and trace identifiers. "
"This warning appears when passing custom IDs. "
"Please use: from langsmith import uuid7\n"
" id = uuid7()\n"
"Future versions will require UUID v7."
),
UserWarning,
stacklevel=3,
)
def uuid7_deterministic(original_id: uuid.UUID, key: str) -> uuid.UUID:
"""Generate a deterministic UUID7 derived from an original UUID and a key.
This function creates a new UUID that:
- Preserves the timestamp from the original UUID if it's UUID v7
- Uses current time if the original is not UUID v7
- Uses deterministic bits derived from hashing the original + key with XXH3-128
- Is valid UUID v7 format
This is used for creating replica IDs that maintain time-ordering properties
while being deterministic across distributed systems.
Args:
original_id: The source UUID (ideally UUID v7 to preserve timestamp).
key: A string key used for deterministic derivation (e.g., project name).
Returns:
A new UUID v7 with preserved timestamp (if original is v7) and
deterministic random bits.
Example:
>>> original = uuid7()
>>> replica_id = uuid7_deterministic(original, "replica-project")
>>> # Same inputs always produce same output
>>> assert uuid7_deterministic(original, "replica-project") == replica_id
"""
# Generate deterministic bytes from XXH3-128 hash of original + key
hash_input = f"{original_id}:{key}".encode()
h = xxhash.xxh3_128(hash_input).digest()
# Build new UUID7:
# UUID7 structure (RFC 9562):
# [0-5] 48 bits: unix_ts_ms (timestamp in milliseconds)
# [6] 4 bits: version (0111 = 7) + 4 bits rand_a
# [7] 8 bits: rand_a (continued)
# [8] 2 bits: variant (10) + 6 bits rand_b
# [9-15] 56 bits: rand_b (continued)
b = bytearray(16)
# Check if original is UUID v7 - if so, preserve its timestamp
# If not, use current time to ensure the derived UUID has a valid timestamp
if is_uuid_v7(original_id):
# Preserve timestamp from original UUID7 (bytes 0-5)
b[0:6] = original_id.bytes[0:6]
else:
# Generate fresh timestamp for non-UUID7 inputs
# This matches CPython 3.14's uuid7() implementation:
# timestamp_ms = time.time_ns() // 1_000_000
# Then convert to big-endian bytes
timestamp_ms = time.time_ns() // 1_000_000
# Mask to 48 bits and convert to big-endian bytes
unix_ts_ms = timestamp_ms & 0xFFFF_FFFF_FFFF
b[0:6] = unix_ts_ms.to_bytes(6, "big")
# Set version 7 (0111) in high nibble + 4 bits from hash
b[6] = 0x70 | (h[0] & 0x0F)
# rand_a continued (8 bits from hash)
b[7] = h[1]
# Set variant (10) in high 2 bits + 6 bits from hash
b[8] = 0x80 | (h[2] & 0x3F)
# rand_b (56 bits = 7 bytes from hash)
b[9:16] = h[3:10]
return uuid.UUID(bytes=bytes(b))

View File

@@ -0,0 +1,106 @@
"""Client configuration for OpenTelemetry integration with LangSmith."""
import os
import warnings
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from opentelemetry.sdk.trace import TracerProvider # type: ignore[import]
except ImportError:
TracerProvider = object # type: ignore[assignment, misc]
from langsmith import utils as ls_utils
def _import_otel_client():
"""Dynamically import OTEL client modules when needed."""
try:
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( # type: ignore[import]
OTLPSpanExporter,
)
from opentelemetry.sdk.resources import ( # type: ignore[import]
SERVICE_NAME,
Resource,
)
from opentelemetry.sdk.trace import TracerProvider # type: ignore[import]
from opentelemetry.sdk.trace.export import ( # type: ignore[import]
BatchSpanProcessor,
)
return (
OTLPSpanExporter,
SERVICE_NAME,
Resource,
TracerProvider,
BatchSpanProcessor,
)
except ImportError as e:
warnings.warn(
f"OTEL_ENABLED is set but OpenTelemetry packages are not installed: {e}"
)
return None
def get_otlp_tracer_provider() -> "TracerProvider":
"""Get the OTLP tracer provider for LangSmith.
This function creates a tracer provider that exports spans using the OTLP protocol
with LangSmith-specific defaults:
- OTEL_EXPORTER_OTLP_ENDPOINT: https://api.smith.langchain.com/otel
- OTEL_EXPORTER_OTLP_HEADERS: Contains x-api-key from LangSmith API key and
Langsmith-Project header if project is configured
These defaults can be overridden by setting the environment variables before
calling this function.
Returns:
TracerProvider: The OTLP tracer provider.
"""
# Import OTEL modules dynamically
otel_imports = _import_otel_client()
if otel_imports is None:
raise ImportError(
"OpenTelemetry packages are required to use this function. "
"Please install with `pip install langsmith[otel]`"
)
(
OTLPSpanExporter,
SERVICE_NAME,
Resource,
TracerProvider,
BatchSpanProcessor,
) = otel_imports
if "OTEL_EXPORTER_OTLP_ENDPOINT" not in os.environ:
ls_endpoint = ls_utils.get_api_url(None)
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{ls_endpoint}/otel"
# Configure headers with API key and project if available
if "OTEL_EXPORTER_OTLP_HEADERS" not in os.environ:
api_key = ls_utils.get_api_key(None)
headers = f"x-api-key={api_key}"
project = ls_utils.get_tracer_project()
if project:
headers += f",Langsmith-Project={project}"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = headers
service_name = os.environ.get("OTEL_SERVICE_NAME", "langsmith")
resource = Resource(
attributes={
SERVICE_NAME: service_name,
# Marker to identify LangSmith's internal provider
"langsmith.internal_provider": True,
}
)
tracer_provider = TracerProvider(resource=resource)
otlp_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(otlp_exporter)
tracer_provider.add_span_processor(span_processor)
return tracer_provider

View File

@@ -0,0 +1,845 @@
"""OpenTelemetry exporter for LangSmith runs."""
from __future__ import annotations
import datetime
import logging
import time
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
try:
from opentelemetry.context.context import Context # type: ignore[import]
from opentelemetry.trace import Span # type: ignore[import]
except ImportError:
Context = Any # type: ignore[assignment, misc]
Span = Any # type: ignore[assignment, misc]
from langsmith import utils as ls_utils
from langsmith._internal import _orjson
from langsmith._internal._operations import (
SerializedRunOperation,
)
from langsmith._internal._otel_utils import (
get_otel_span_id_from_uuid,
get_otel_trace_id_from_uuid,
)
def _import_otel_exporter():
"""Dynamically import OTEL exporter modules when needed."""
try:
from opentelemetry import trace # type: ignore[import]
from opentelemetry.context.context import Context # type: ignore[import]
from opentelemetry.trace import ( # type: ignore[import]
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
)
return (
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
)
except ImportError as e:
warnings.warn(
f"OTEL_ENABLED is set but OpenTelemetry packages are not installed: {e}"
)
return None
logger = logging.getLogger(__name__)
# OpenTelemetry GenAI semconv attribute names
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_FREQUENCY_PENALTY = "gen_ai.request.frequency_penalty"
GEN_AI_REQUEST_PRESENCE_PENALTY = "gen_ai.request.presence_penalty"
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
GENAI_PROMPT = "gen_ai.prompt"
GENAI_COMPLETION = "gen_ai.completion"
GEN_AI_REQUEST_EXTRA_QUERY = "gen_ai.request.extra_query"
GEN_AI_REQUEST_EXTRA_BODY = "gen_ai.request.extra_body"
GEN_AI_SERIALIZED_NAME = "gen_ai.serialized.name"
GEN_AI_SERIALIZED_SIGNATURE = "gen_ai.serialized.signature"
GEN_AI_SERIALIZED_DOC = "gen_ai.serialized.doc"
GEN_AI_RESPONSE_ID = "gen_ai.response.id"
GEN_AI_RESPONSE_SERVICE_TIER = "gen_ai.response.service_tier"
GEN_AI_RESPONSE_SYSTEM_FINGERPRINT = "gen_ai.response.system_fingerprint"
GEN_AI_USAGE_INPUT_TOKEN_DETAILS = "gen_ai.usage.input_token_details"
GEN_AI_USAGE_OUTPUT_TOKEN_DETAILS = "gen_ai.usage.output_token_details"
# LangSmith custom attributes
LANGSMITH_SESSION_ID = "langsmith.trace.session_id"
LANGSMITH_SESSION_NAME = "langsmith.trace.session_name"
LANGSMITH_RUN_TYPE = "langsmith.span.kind"
LANGSMITH_NAME = "langsmith.trace.name"
LANGSMITH_METADATA = "langsmith.metadata"
LANGSMITH_TAGS = "langsmith.span.tags"
LANGSMITH_RUNTIME = "langsmith.span.runtime"
LANGSMITH_REQUEST_STREAMING = "langsmith.request.streaming"
LANGSMITH_REQUEST_HEADERS = "langsmith.request.headers"
# GenAI event names
GEN_AI_SYSTEM_MESSAGE = "gen_ai.system.message"
GEN_AI_USER_MESSAGE = "gen_ai.user.message"
GEN_AI_ASSISTANT_MESSAGE = "gen_ai.assistant.message"
GEN_AI_CHOICE = "gen_ai.choice"
WELL_KNOWN_OPERATION_NAMES = {
"llm": "chat",
"tool": "execute_tool",
"retriever": "embeddings",
"embedding": "embeddings",
"prompt": "chat",
}
def _get_operation_name(run_type: str) -> str:
return WELL_KNOWN_OPERATION_NAMES.get(run_type, run_type)
class OTELExporter:
__slots__ = [
"_tracer",
"_span_info",
"_otel_available",
"_trace",
"_span_ttl_seconds",
"_last_cleanup",
]
"""OpenTelemetry exporter for LangSmith runs."""
def __init__(self, tracer_provider=None, span_ttl_seconds=None):
"""Initialize the OTEL exporter.
Args:
tracer_provider: Optional tracer provider to use. If not provided,
the global tracer provider will be used.
span_ttl_seconds: TTL for incomplete traces in seconds. If None,
uses LANGSMITH_OTEL_SPAN_TTL_SECONDS env var (default: 3600s)
"""
# Set defaults from environment variables if not provided
if span_ttl_seconds is None:
span_ttl_seconds = int(
ls_utils.get_env_var("OTEL_SPAN_TTL_SECONDS", default="3600")
)
otel_imports = _import_otel_exporter()
if otel_imports is None:
self._tracer = None
self._span_info = {}
self._otel_available = False
self._trace = None
self._span_ttl_seconds = span_ttl_seconds
self._last_cleanup = 0.0
else:
(
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
) = otel_imports
self._tracer = trace.get_tracer(
"langsmith", tracer_provider=tracer_provider
)
self._span_info = {}
self._otel_available = True
self._trace = trace
self._span_ttl_seconds = span_ttl_seconds
self._last_cleanup = 0.0
def export_batch(
self,
operations: list[SerializedRunOperation],
otel_context_map: dict[uuid.UUID, Optional[Context]],
) -> None:
"""Export a batch of serialized run operations to OTEL.
Args:
operations: List of serialized run operations to export.
"""
# Proactive cleanup of expired and excess spans before new operations
self._cleanup_stale_spans()
for op in operations:
try:
run_info = self._deserialize_run_info(op)
if not run_info:
continue
if op.operation == "post":
span = self._create_span_for_run(
op, run_info, otel_context_map.get(op.id)
)
if span:
self._span_info[op.id] = {
"span": span,
"created_at": time.time(),
}
else:
self._update_span_for_run(op, run_info)
except Exception as e:
logger.exception(f"Error processing operation {op.id}: {e}")
def _deserialize_run_info(self, op: SerializedRunOperation) -> Optional[dict]:
"""Deserialize the run info from the operation.
Args:
op: The serialized run operation.
Returns:
The deserialized run info as a dictionary, or None if deserialization
failed.
"""
try:
return op.deserialize_run_info()
except Exception as e:
logger.exception(f"Failed to deserialize run info for {op.id}: {e}")
return None
def _create_span_for_run(
self,
op: SerializedRunOperation,
run_info: dict,
otel_context: Optional[Context] = None,
) -> Optional[Span]:
"""Create an OpenTelemetry span for a run operation.
Args:
op: The serialized run operation.
run_info: The deserialized run info.
parent_span: Optional parent span.
Returns:
The created span, or None if creation failed.
"""
try:
start_time = run_info.get("start_time")
start_time_utc_nano = self._as_utc_nano(start_time)
end_time = run_info.get("end_time")
end_time_utc_nano = self._as_utc_nano(end_time)
# Create deterministic trace and span IDs to match user OpenTelemetry spans
trace_id_int = get_otel_trace_id_from_uuid(op.trace_id)
span_id_int = get_otel_span_id_from_uuid(op.id)
# Get OTEL imports for this operation
otel_imports = _import_otel_exporter()
if otel_imports is None:
return None
(
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
) = otel_imports
# Create SpanContext with deterministic IDs
span_context = SpanContext(
trace_id=trace_id_int,
span_id=span_id_int,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
# Create NonRecordingSpan for context setting
non_recording_span = NonRecordingSpan(span_context)
deterministic_context = set_span_in_context(non_recording_span)
# Start the span with appropriate context
parent_run_id = run_info.get("parent_run_id")
if (
parent_run_id is not None
and uuid.UUID(parent_run_id) in self._span_info
):
# Use the parent span context
parent_span = self._span_info[uuid.UUID(parent_run_id)]["span"]
span = self._tracer.start_span(
run_info.get("name"),
context=set_span_in_context(parent_span),
start_time=start_time_utc_nano,
)
else:
# For root spans, check if there's an existing OpenTelemetry context
# If so, inherit from it; otherwise use our deterministic context
current_context = (
otel_context if otel_context else deterministic_context
)
span = self._tracer.start_span(
run_info.get("name"),
context=current_context,
start_time=start_time_utc_nano,
)
# Set all attributes
self._set_span_attributes(span, run_info, op)
# Set status based on error
if run_info.get("error"):
span.set_status(trace.StatusCode.ERROR)
span.record_exception(Exception(run_info.get("error")))
else:
span.set_status(trace.StatusCode.OK)
# End the span if end_time is present
end_time = run_info.get("end_time")
if end_time:
end_time_utc_nano = self._as_utc_nano(end_time)
if end_time_utc_nano:
span.end(end_time=end_time_utc_nano)
else:
span.end()
return span
except Exception as e:
logger.exception(f"Failed to create span for run {op.id}: {e}")
return None
def _update_span_for_run(self, op: SerializedRunOperation, run_info: dict) -> None:
"""Update an OpenTelemetry span for a run operation.
Args:
op: The serialized run operation.
run_info: The deserialized run info.
"""
try:
# Get the span for this run
if op.id not in self._span_info:
logger.debug(f"No span found for run {op.id} during update")
return
span = self._span_info[op.id]["span"]
# Update attributes
self._set_span_attributes(span, run_info, op)
# Update status based on error
if run_info.get("error"):
span.set_status(self._trace.StatusCode.ERROR)
span.record_exception(Exception(run_info.get("error")))
else:
span.set_status(self._trace.StatusCode.OK)
# End the span if end_time is present
end_time = run_info.get("end_time")
if end_time:
end_time_utc_nano = self._as_utc_nano(end_time)
if end_time_utc_nano:
span.end(end_time=end_time_utc_nano)
else:
span.end()
# Remove the span info from our dictionary
del self._span_info[op.id]
logger.debug(f"Completed span, remaining spans: {len(self._span_info)}")
else:
# Span exists but no end_time - this is normal for ongoing operations
logger.debug("Updated span (no end_time yet)")
except Exception as e:
logger.exception(f"Failed to update span for run {op.id}: {e}")
def _cleanup_stale_spans(self) -> None:
"""Clean up spans older than TTL threshold."""
if not self._span_info:
return
current_time = time.time()
# Only run cleanup every 10 seconds to reduce overhead
if current_time - self._last_cleanup < 10.0:
return
self._last_cleanup = current_time
cutoff_time = current_time - self._span_ttl_seconds
# Remove spans older than TTL in one pass
stale_span_ids = [
span_id
for span_id, info in self._span_info.items()
if info["created_at"] < cutoff_time
]
if stale_span_ids:
logger.info(
f" LangSmith OTEL Cleanup: Removing {len(stale_span_ids)} stale spans"
)
for span_id in stale_span_ids:
self._remove_span(span_id)
def _remove_span(self, span_id: uuid.UUID) -> None:
"""Remove a single span and clean up resources.
Note:
We call `span.end()` here because spans in `_span_info` are orphaned -
they never received their patch operation and will never naturally complete.
Ending them gracefully is better than leaving them open indefinitely.
"""
if span_id not in self._span_info:
return
try:
# End the orphaned span gracefully
span = self._span_info[span_id]["span"]
# Check if span is still active before ending it
if (
hasattr(span, "end")
and hasattr(span, "is_recording")
and span.is_recording()
):
span.end()
logger.debug(f"Ended orphaned span {span_id}")
elif hasattr(span, "end"):
# Span already ended, just log it
logger.debug(f"Span {span_id} already ended, skipping end() call")
# Remove from tracking regardless
del self._span_info[span_id]
except Exception as e:
logger.debug(f"Error removing span {span_id}: {e}")
# Still try to remove from tracking even if ending failed
try:
del self._span_info[span_id]
except KeyError:
pass
def _extract_model_name(self, run_info: dict) -> Optional[str]:
"""Extract model name from run info.
Args:
run_info: The run info.
Returns:
The model name, or None if not found.
"""
# Try to get model name from metadata
if run_info.get("extra") and run_info["extra"].get("metadata"):
metadata = run_info["extra"]["metadata"]
# First check for ls_model_name in metadata
if metadata.get("ls_model_name"):
return metadata["ls_model_name"]
# Then check invocation_params for model info
if "invocation_params" in metadata:
invocation_params = metadata["invocation_params"]
# Check model first, then model_name
if invocation_params.get("model"):
return invocation_params["model"]
elif invocation_params.get("model_name"):
return invocation_params["model_name"]
return None
def _set_span_attributes(
self,
span: Span,
run_info: dict,
op: SerializedRunOperation,
) -> None:
"""Set attributes on the span.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
op: The serialized run operation.
"""
# Set LangSmith-specific attributes
if run_info.get("run_type"):
span.set_attribute(LANGSMITH_RUN_TYPE, str(run_info.get("run_type")))
if run_info.get("name"):
span.set_attribute(LANGSMITH_NAME, str(run_info.get("name")))
if run_info.get("session_id"):
span.set_attribute(LANGSMITH_SESSION_ID, str(run_info.get("session_id")))
if run_info.get("session_name"):
span.set_attribute(
LANGSMITH_SESSION_NAME, str(run_info.get("session_name"))
)
# Set GenAI attributes according to OTEL semantic conventions
# Set gen_ai.operation.name
if op.operation == "post":
operation_name = _get_operation_name(run_info.get("run_type", "chain"))
span.set_attribute(GEN_AI_OPERATION_NAME, operation_name)
# Set gen_ai.system
self._set_gen_ai_system(span, run_info)
# Set model name if available
model_name = self._extract_model_name(run_info)
if model_name:
span.set_attribute(GEN_AI_REQUEST_MODEL, model_name)
# Set token usage information
if run_info.get("prompt_tokens") is not None:
prompt_tokens = run_info["prompt_tokens"]
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(prompt_tokens))
if run_info.get("completion_tokens") is not None:
completion_tokens = run_info["completion_tokens"]
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(completion_tokens))
if run_info.get("total_tokens") is not None:
total_tokens = run_info["total_tokens"]
span.set_attribute(GEN_AI_USAGE_TOTAL_TOKENS, int(total_tokens))
# Set other parameters from invocation_params
self._set_invocation_parameters(span, run_info)
# Set metadata and tags if available
extra = run_info.get("extra", {})
metadata = extra.get("metadata", {})
for key, value in metadata.items():
if value is not None:
span.set_attribute(f"{LANGSMITH_METADATA}.{key}", value)
tags = run_info.get("tags")
if tags:
if isinstance(tags, list):
span.set_attribute(LANGSMITH_TAGS, ", ".join(tags))
else:
span.set_attribute(LANGSMITH_TAGS, tags)
# Support additional serialized attributes, if present
if run_info.get("serialized") and isinstance(run_info["serialized"], dict):
serialized = run_info["serialized"]
if "name" in serialized and serialized["name"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_NAME, serialized["name"])
if "signature" in serialized and serialized["signature"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_SIGNATURE, serialized["signature"])
if "doc" in serialized and serialized["doc"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_DOC, serialized["doc"])
# Set inputs/outputs if available
self._set_io_attributes(span, op)
def _set_gen_ai_system(self, span: Span, run_info: dict) -> None:
"""Set the gen_ai.system attribute on the span based on the model provider.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
"""
# Default to "langchain" if we can't determine the system
system = "langchain"
# Extract model name to determine the system
model_name = self._extract_model_name(run_info)
if model_name:
model_lower = model_name.lower()
if "anthropic" in model_lower or model_lower.startswith("claude"):
system = "anthropic"
elif "bedrock" in model_lower:
system = "aws.bedrock"
elif "azure" in model_lower and "openai" in model_lower:
system = "az.ai.openai"
elif "azure" in model_lower and "inference" in model_lower:
system = "az.ai.inference"
elif "cohere" in model_lower:
system = "cohere"
elif "deepseek" in model_lower:
system = "deepseek"
elif "gemini" in model_lower:
system = "gemini"
elif "groq" in model_lower:
system = "groq"
elif "watson" in model_lower or "ibm" in model_lower:
system = "ibm.watsonx.ai"
elif "mistral" in model_lower:
system = "mistral_ai"
elif "gpt" in model_lower or "openai" in model_lower:
system = "openai"
elif "perplexity" in model_lower or "sonar" in model_lower:
system = "perplexity"
elif "vertex" in model_lower:
system = "vertex_ai"
elif "xai" in model_lower or "grok" in model_lower:
system = "xai"
elif "qwen" in model_lower:
system = "qwen"
span.set_attribute(GEN_AI_SYSTEM, system)
setattr(span, "_gen_ai_system", system)
def _set_invocation_parameters(self, span: Span, run_info: dict) -> None:
"""Set invocation parameters on the span.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
"""
if not (run_info.get("extra") and run_info["extra"].get("metadata")):
return
metadata = run_info["extra"]["metadata"]
if "invocation_params" not in metadata:
return
invocation_params = metadata["invocation_params"]
# Set relevant invocation parameters
if "max_tokens" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_MAX_TOKENS, invocation_params["max_tokens"]
)
if "temperature" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_TEMPERATURE, invocation_params["temperature"]
)
if "top_p" in invocation_params:
span.set_attribute(GEN_AI_REQUEST_TOP_P, invocation_params["top_p"])
if "frequency_penalty" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_FREQUENCY_PENALTY, invocation_params["frequency_penalty"]
)
if "presence_penalty" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_PRESENCE_PENALTY, invocation_params["presence_penalty"]
)
def _set_io_attributes(self, span: Span, op: SerializedRunOperation) -> None:
"""Set input/output attributes on the span.
Args:
span: The span to set attributes on.
op: The serialized run operation.
"""
if op.inputs:
try:
inputs = _orjson.loads(op.inputs)
if isinstance(inputs, dict):
if (
"model" in inputs
and isinstance(inputs.get("messages"), list)
and inputs["model"] is not None
):
span.set_attribute(GEN_AI_REQUEST_MODEL, inputs["model"])
# Set additional request attributes if available.
if "stream" in inputs and inputs["stream"] is not None:
span.set_attribute(
LANGSMITH_REQUEST_STREAMING, inputs["stream"]
)
if (
"extra_headers" in inputs
and inputs["extra_headers"] is not None
):
span.set_attribute(
LANGSMITH_REQUEST_HEADERS, inputs["extra_headers"]
)
if "extra_query" in inputs and inputs["extra_query"] is not None:
span.set_attribute(
GEN_AI_REQUEST_EXTRA_QUERY, inputs["extra_query"]
)
if "extra_body" in inputs and inputs["extra_body"] is not None:
span.set_attribute(
GEN_AI_REQUEST_EXTRA_BODY, inputs["extra_body"]
)
span.set_attribute(GENAI_PROMPT, op.inputs)
except Exception:
logger.debug(
"Failed to process inputs for run %s", op.id, exc_info=True
)
if op.outputs:
try:
outputs = _orjson.loads(op.outputs)
# Extract token usage from outputs (for LLM runs)
token_usage = self.get_unified_run_tokens(outputs)
if token_usage:
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, token_usage[0])
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, token_usage[1])
span.set_attribute(
GEN_AI_USAGE_TOTAL_TOKENS, token_usage[0] + token_usage[1]
)
if "model" in outputs:
span.set_attribute(GEN_AI_RESPONSE_MODEL, str(outputs["model"]))
# Extract additional response attributes.
if isinstance(outputs, dict):
if "id" in outputs and outputs["id"] is not None:
span.set_attribute(GEN_AI_RESPONSE_ID, outputs["id"])
if "choices" in outputs and isinstance(outputs["choices"], list):
finish_reasons = []
for choice in outputs["choices"]:
if (
"finish_reason" in choice
and choice["finish_reason"] is not None
):
finish_reasons.append(str(choice["finish_reason"]))
if finish_reasons:
span.set_attribute(
GEN_AI_RESPONSE_FINISH_REASONS,
", ".join(finish_reasons),
)
if (
"service_tier" in outputs
and outputs["service_tier"] is not None
):
span.set_attribute(
GEN_AI_RESPONSE_SERVICE_TIER, outputs["service_tier"]
)
if (
"system_fingerprint" in outputs
and outputs["system_fingerprint"] is not None
):
span.set_attribute(
GEN_AI_RESPONSE_SYSTEM_FINGERPRINT,
outputs["system_fingerprint"],
)
if "usage_metadata" in outputs and isinstance(
outputs["usage_metadata"], dict
):
usage_metadata = outputs["usage_metadata"]
if (
"input_token_details" in usage_metadata
and usage_metadata["input_token_details"] is not None
):
input_token_details = str(
usage_metadata["input_token_details"]
)
span.set_attribute(
GEN_AI_USAGE_INPUT_TOKEN_DETAILS, input_token_details
)
if (
"output_token_details" in usage_metadata
and usage_metadata["output_token_details"] is not None
):
output_token_details = str(
usage_metadata["output_token_details"]
)
span.set_attribute(
GEN_AI_USAGE_OUTPUT_TOKEN_DETAILS, output_token_details
)
span.set_attribute(GENAI_COMPLETION, op.outputs)
except Exception:
logger.debug(
"Failed to process outputs for run %s", op.id, exc_info=True
)
def _as_utc_nano(self, timestamp: Optional[str]) -> Optional[int]:
if not timestamp:
return None
try:
dt = datetime.datetime.fromisoformat(timestamp)
return int(dt.astimezone(datetime.timezone.utc).timestamp() * 1_000_000_000)
except ValueError:
logger.exception(f"Failed to parse timestamp {timestamp}")
return None
def get_unified_run_tokens(
self, outputs: Optional[dict]
) -> Optional[tuple[int, int]]:
if not outputs:
return None
# search in non-generations lists
if output := self._extract_unified_run_tokens(outputs.get("usage_metadata")):
return output
# find if direct kwarg in outputs
keys = outputs.keys()
for key in keys:
haystack = outputs[key]
if not haystack or not isinstance(haystack, dict):
continue
if output := self._extract_unified_run_tokens(
haystack.get("usage_metadata")
):
return output
if (
haystack.get("lc") == 1
and "kwargs" in haystack
and isinstance(haystack["kwargs"], dict)
and (
output := self._extract_unified_run_tokens(
haystack["kwargs"].get("usage_metadata")
)
)
):
return output
# find in generations
generations = outputs.get("generations") or []
if not isinstance(generations, list):
return None
if generations and not isinstance(generations[0], list):
generations = [generations]
for generation in [x for xs in generations for x in xs]:
if (
isinstance(generation, dict)
and "message" in generation
and isinstance(generation["message"], dict)
and "kwargs" in generation["message"]
and isinstance(generation["message"]["kwargs"], dict)
and (
output := self._extract_unified_run_tokens(
generation["message"]["kwargs"].get("usage_metadata")
)
)
):
return output
return None
def _extract_unified_run_tokens(
self, outputs: Optional[Any]
) -> Optional[tuple[int, int]]:
if not outputs or not isinstance(outputs, dict):
return None
if "input_tokens" not in outputs or "output_tokens" not in outputs:
return None
if not isinstance(outputs["input_tokens"], int) or not isinstance(
outputs["output_tokens"], int
):
return None
return outputs["input_tokens"], outputs["output_tokens"]

View File

@@ -0,0 +1,201 @@
import re # noqa
import inspect
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Callable, Optional, TypedDict, Union
class _ExtractOptions(TypedDict):
max_depth: Optional[int]
"""
Maximum depth to traverse to to extract string nodes
"""
class StringNode(TypedDict):
"""String node extracted from the data."""
value: str
"""String value."""
path: list[Union[str, int]]
"""Path to the string node in the data."""
def _extract_string_nodes(data: Any, options: _ExtractOptions) -> list[StringNode]:
max_depth = options.get("max_depth") or 10
queue: list[tuple[Any, int, list[Union[str, int]]]] = [(data, 0, [])]
result: list[StringNode] = []
while queue:
task = queue.pop(0)
if task is None:
continue
value, depth, path = task
if isinstance(value, (dict, defaultdict)):
if depth >= max_depth:
continue
for key, nested_value in value.items():
queue.append((nested_value, depth + 1, path + [key]))
elif isinstance(value, list):
if depth >= max_depth:
continue
for i, item in enumerate(value):
queue.append((item, depth + 1, path + [i]))
elif isinstance(value, str):
result.append(StringNode(value=value, path=path))
return result
class StringNodeProcessor:
"""Processes a list of string nodes for masking."""
@abstractmethod
def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
"""Accept and return a list of string nodes to be masked."""
class ReplacerOptions(TypedDict):
"""Configuration options for replacing sensitive data."""
max_depth: Optional[int]
"""Maximum depth to traverse to to extract string nodes."""
deep_clone: Optional[bool]
"""Deep clone the data before replacing."""
class StringNodeRule(TypedDict):
"""Declarative rule used for replacing sensitive data."""
pattern: re.Pattern
"""Regex pattern to match."""
replace: Optional[str]
"""Replacement value. Defaults to `[redacted]` if not specified."""
class RuleNodeProcessor(StringNodeProcessor):
"""String node processor that uses a list of rules to replace sensitive data."""
rules: list[StringNodeRule]
"""List of rules to apply for replacing sensitive data.
Each rule is a StringNodeRule, which contains a regex pattern to match
and an optional replacement string.
"""
def __init__(self, rules: list[StringNodeRule]):
"""Initialize the processor with a list of rules."""
self.rules = [
{
"pattern": (
rule["pattern"]
if isinstance(rule["pattern"], re.Pattern)
else re.compile(rule["pattern"])
),
"replace": (
rule["replace"]
if isinstance(rule.get("replace"), str)
else "[redacted]"
),
}
for rule in rules
]
def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
"""Mask nodes using the rules."""
result = []
for item in nodes:
new_value = item["value"]
for rule in self.rules:
new_value = rule["pattern"].sub(rule["replace"], new_value)
if new_value != item["value"]:
result.append(StringNode(value=new_value, path=item["path"]))
return result
class CallableNodeProcessor(StringNodeProcessor):
"""String node processor that uses a callable function to replace sensitive data."""
func: Union[Callable[[str], str], Callable[[str, list[Union[str, int]]], str]]
"""The callable function used to replace sensitive data.
It can be either a function that takes a single string argument and returns a string,
or a function that takes a string and a list of path elements (strings or integers)
and returns a string."""
accepts_path: bool
"""Indicates whether the callable function accepts a path argument.
If True, the function expects two arguments: the string to be processed and the path to that string.
If False, the function expects only the string to be processed."""
def __init__(
self,
func: Union[Callable[[str], str], Callable[[str, list[Union[str, int]]], str]],
):
"""Initialize the processor with a callable function."""
self.func = func
self.accepts_path = len(inspect.signature(func).parameters) == 2
def mask_nodes(self, nodes: list[StringNode]) -> list[StringNode]:
"""Mask nodes using the callable function."""
retval: list[StringNode] = []
for node in nodes:
candidate = (
self.func(node["value"], node["path"]) # type: ignore[call-arg]
if self.accepts_path
else self.func(node["value"]) # type: ignore[call-arg]
)
if candidate != node["value"]:
retval.append(StringNode(value=candidate, path=node["path"]))
return retval
ReplacerType = Union[
Callable[[str, list[Union[str, int]]], str],
list[StringNodeRule],
StringNodeProcessor,
]
def _get_node_processor(replacer: ReplacerType) -> StringNodeProcessor:
if isinstance(replacer, list):
return RuleNodeProcessor(rules=replacer)
elif callable(replacer):
return CallableNodeProcessor(func=replacer)
else:
return replacer
def create_anonymizer(
replacer: ReplacerType,
*,
max_depth: Optional[int] = None,
) -> Callable[[Any], Any]:
"""Create an anonymizer function."""
processor = _get_node_processor(replacer)
def anonymizer(data: Any) -> Any:
nodes = _extract_string_nodes(data, {"max_depth": max_depth or 10})
mutate_value = data
to_update = processor.mask_nodes(nodes)
for node in to_update:
if not node["path"]:
mutate_value = node["value"]
else:
temp = mutate_value
for part in node["path"][:-1]:
temp = temp[part]
last_part = node["path"][-1]
temp[last_part] = node["value"]
return mutate_value
return anonymizer

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""Beta functionality prone to change."""
from langsmith._internal._beta_decorator import warn_beta
from langsmith.beta._evals import compute_test_metrics, convert_runs_to_test
__all__ = ["convert_runs_to_test", "compute_test_metrics", "warn_beta"]

View File

@@ -0,0 +1,243 @@
"""Beta utility functions to assist in common eval workflows.
These functions may change in the future.
"""
import collections
import datetime
import itertools
import uuid
from collections.abc import Sequence
from typing import Optional, TypeVar
import langsmith.run_trees as rt
import langsmith.schemas as ls_schemas
from langsmith import evaluation as ls_eval
from langsmith._internal._beta_decorator import warn_beta
from langsmith.client import Client
def _convert_ids(run_dict: dict, id_map: dict) -> dict:
"""Convert the IDs in the run dictionary using the provided ID map.
Parameters:
- run_dict: The dictionary representing a run.
- id_map: The dictionary mapping old IDs to new IDs.
Returns:
- dict: The updated run dictionary.
"""
do = run_dict["dotted_order"]
for k, v in id_map.items():
do = do.replace(str(k), str(v))
run_dict["dotted_order"] = do
if run_dict.get("parent_run_id"):
run_dict["parent_run_id"] = id_map[run_dict["parent_run_id"]]
if not run_dict.get("extra"):
run_dict["extra"] = {}
return run_dict
def _convert_root_run(root: ls_schemas.Run, run_to_example_map: dict) -> list[dict]:
"""Convert the root run and its child runs to a list of dictionaries.
Parameters:
- root: The root run to convert.
- run_to_example_map: The dictionary mapping run IDs to example IDs.
Returns:
- The list of converted run dictionaries.
"""
runs_ = [root]
trace_id = uuid.uuid4()
id_map = {root.trace_id: trace_id}
results = []
while runs_:
src = runs_.pop()
src_dict = src.dict(exclude={"parent_run_ids", "child_run_ids", "session_id"})
id_map[src_dict["id"]] = id_map.get(src_dict["id"], uuid.uuid4())
src_dict["id"] = id_map[src_dict["id"]]
src_dict["trace_id"] = id_map[src_dict["trace_id"]]
if src.child_runs:
runs_.extend(src.child_runs)
results.append(src_dict)
result = [_convert_ids(r, id_map) for r in results]
result[0]["reference_example_id"] = run_to_example_map[root.id]
return result
@warn_beta
def convert_runs_to_test(
runs: Sequence[ls_schemas.Run],
*,
dataset_name: str,
test_project_name: Optional[str] = None,
client: Optional[Client] = None,
load_child_runs: bool = False,
include_outputs: bool = False,
) -> ls_schemas.TracerSession:
"""Convert the following runs to a dataset + test.
This makes it easy to sample prod runs into a new regression testing
workflow and compare against a candidate system.
Internally, this function does the following:
1. Create a dataset from the provided production run inputs.
2. Create a new test project.
3. Clone the production runs and re-upload against the dataset.
Parameters:
- runs: A sequence of runs to be executed as a test.
- dataset_name: The name of the dataset to associate with the test runs.
- client: An optional LangSmith client instance. If not provided, a new client will
be created.
- load_child_runs: Whether to load child runs when copying runs.
Returns:
- The project containing the cloned runs.
Example:
--------
```python
import langsmith
import random
client = langsmith.Client()
# Randomly sample 100 runs from a prod project
runs = list(client.list_runs(project_name="My Project", execution_order=1))
sampled_runs = random.sample(runs, min(len(runs), 100))
runs_as_test(runs, dataset_name="Random Runs")
# Select runs named "extractor" whose root traces received good feedback
runs = client.list_runs(
project_name="<your_project>",
filter='eq(name, "extractor")',
trace_filter='and(eq(feedback_key, "user_score"), eq(feedback_score, 1))',
)
runs_as_test(runs, dataset_name="Extraction Good")
```
"""
if not runs:
raise ValueError(f"""Expected a non-empty sequence of runs. Received: {runs}""")
client = client or rt.get_cached_client()
ds = client.create_dataset(dataset_name=dataset_name)
outputs = [r.outputs for r in runs] if include_outputs else None
client.create_examples(
inputs=[r.inputs for r in runs],
outputs=outputs,
source_run_ids=[r.id for r in runs],
dataset_id=ds.id,
)
if not load_child_runs:
runs_to_copy = runs
else:
runs_to_copy = [
client.read_run(r.id, load_child_runs=load_child_runs) for r in runs
]
test_project_name = test_project_name or f"prod-baseline-{uuid.uuid4().hex[:6]}"
examples = list(client.list_examples(dataset_name=dataset_name))
run_to_example_map = {e.source_run_id: e.id for e in examples}
dataset_version = (
examples[0].modified_at if examples[0].modified_at else examples[0].created_at
)
to_create = [
run_dict
for root_run in runs_to_copy
for run_dict in _convert_root_run(root_run, run_to_example_map)
]
project = client.create_project(
project_name=test_project_name,
reference_dataset_id=ds.id,
metadata={
"which": "prod-baseline",
"dataset_version": dataset_version.isoformat(),
},
)
for new_run in to_create:
latency = new_run["end_time"] - new_run["start_time"]
new_run["start_time"] = datetime.datetime.now(tz=datetime.timezone.utc)
new_run["end_time"] = new_run["start_time"] + latency
client.create_run(**new_run, project_name=test_project_name)
_ = client.update_project(
project.id,
)
return project
def _load_nested_traces(project_name: str, client: Client) -> list[ls_schemas.Run]:
runs = client.list_runs(project_name=project_name)
treemap: collections.defaultdict[uuid.UUID, list[ls_schemas.Run]] = (
collections.defaultdict(list)
)
results = []
all_runs = {}
for run in runs:
if run.parent_run_id is not None:
treemap[run.parent_run_id].append(run)
else:
results.append(run)
all_runs[run.id] = run
for run_id, child_runs in treemap.items():
all_runs[run_id].child_runs = sorted(child_runs, key=lambda r: r.dotted_order)
return results
T = TypeVar("T")
U = TypeVar("U")
def _outer_product(list1: list[T], list2: list[U]) -> list[tuple[T, U]]:
return list(itertools.product(list1, list2))
@warn_beta
def compute_test_metrics(
project_name: str,
*,
evaluators: list,
max_concurrency: Optional[int] = 10,
client: Optional[Client] = None,
) -> None:
"""Compute test metrics for a given test name using a list of evaluators.
Args:
project_name (str): The name of the test project to evaluate.
evaluators (list): A list of evaluators to compute metrics with.
max_concurrency (Optional[int], optional): The maximum number of concurrent
evaluations. Defaults to 10.
client (Optional[Client], optional): The client to use for evaluations.
Defaults to None.
Returns:
None: This function does not return any value.
"""
from langsmith import ContextThreadPoolExecutor
evaluators_: list[ls_eval.RunEvaluator] = []
for func in evaluators:
if isinstance(func, ls_eval.RunEvaluator):
evaluators_.append(func)
elif callable(func):
evaluators_.append(ls_eval.run_evaluator(func))
else:
raise NotImplementedError(
f"Evaluation not yet implemented for evaluator of type {type(func)}"
)
client = client or rt.get_cached_client()
traces = _load_nested_traces(project_name, client)
with ContextThreadPoolExecutor(max_workers=max_concurrency) as executor:
results = executor.map(
client.evaluate_run, *zip(*_outer_product(traces, evaluators_))
)
for _ in results:
pass

View File

@@ -0,0 +1,3 @@
# DOCKER-COMPOSE MOVED
All documentation for `docker-compose` has been moved to the [helm repository](https://github.com/langchain-ai/helm/tree/main/charts/langsmith).

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
"""Utilities to get information about the runtime environment."""
from langsmith.env._git import get_git_info
from langsmith.env._runtime_env import (
get_docker_compose_command,
get_docker_compose_version,
get_docker_environment,
get_docker_version,
get_langchain_env_var_metadata,
get_langchain_env_vars,
get_langchain_environment,
get_release_shas,
get_runtime_and_metrics,
get_runtime_environment,
get_system_metrics,
)
__all__ = [
"get_docker_compose_command",
"get_docker_compose_version",
"get_docker_environment",
"get_docker_version",
"get_langchain_env_var_metadata",
"get_langchain_env_vars",
"get_langchain_environment",
"get_release_shas",
"get_runtime_and_metrics",
"get_runtime_environment",
"get_system_metrics",
"get_git_info",
]

View File

@@ -0,0 +1,64 @@
"""Fetch information about any current git repo."""
import functools
import logging
import subprocess
from typing import List, Optional, TypeVar
from typing_extensions import TypedDict
logger = logging.getLogger(__name__)
T = TypeVar("T")
def exec_git(command: List[str]) -> Optional[str]:
try:
return subprocess.check_output(
["git"] + command, encoding="utf-8", stderr=subprocess.DEVNULL
).strip()
except BaseException:
return None
class GitInfo(TypedDict, total=False):
repo_name: Optional[str]
remote_url: Optional[str]
commit: Optional[str]
branch: Optional[str]
author_name: Optional[str]
author_email: Optional[str]
commit_time: Optional[str]
dirty: Optional[bool]
tags: Optional[str]
@functools.lru_cache(maxsize=1)
def get_git_info(remote: str = "origin") -> GitInfo:
"""Get information about the git repository."""
if not exec_git(["rev-parse", "--is-inside-work-tree"]):
return GitInfo(
remote_url=None,
commit=None,
branch=None,
author_name=None,
author_email=None,
commit_time=None,
dirty=None,
tags=None,
repo_name=None,
)
return {
"remote_url": exec_git(["remote", "get-url", remote]),
"commit": exec_git(["rev-parse", "HEAD"]),
"commit_time": exec_git(["log", "-1", "--format=%ct"]),
"branch": exec_git(["rev-parse", "--abbrev-ref", "HEAD"]),
"tags": exec_git(
["describe", "--tags", "--exact-match", "--always", "--dirty"]
),
"dirty": exec_git(["status", "--porcelain"]) != "",
"author_name": exec_git(["log", "-1", "--format=%an"]),
"author_email": exec_git(["log", "-1", "--format=%ae"]),
"repo_name": (exec_git(["rev-parse", "--show-toplevel"]) or "").split("/")[-1],
}

View File

@@ -0,0 +1,236 @@
"""Environment information."""
import functools
import logging
import os
import platform
import subprocess
from typing import Dict, List, Optional, Union
from langsmith.utils import get_docker_compose_command
from langsmith.env._git import exec_git
try:
# psutil is an optional dependency
import psutil
_PSUTIL_AVAILABLE = True
except ImportError:
_PSUTIL_AVAILABLE = False
logger = logging.getLogger(__name__)
def get_runtime_and_metrics() -> dict:
"""Get the runtime information as well as metrics."""
return {**get_runtime_environment(), **get_system_metrics()}
def get_system_metrics() -> Dict[str, Union[float, dict]]:
"""Get CPU and other performance metrics."""
global _PSUTIL_AVAILABLE
if not _PSUTIL_AVAILABLE:
return {}
try:
process = psutil.Process(os.getpid())
metrics: Dict[str, Union[float, dict]] = {}
with process.oneshot():
mem_info = process.memory_info()
metrics["thread_count"] = float(process.num_threads())
metrics["mem"] = {
"rss": float(mem_info.rss),
}
ctx_switches = process.num_ctx_switches()
cpu_times = process.cpu_times()
metrics["cpu"] = {
"time": {
"sys": cpu_times.system,
"user": cpu_times.user,
},
"ctx_switches": {
"voluntary": float(ctx_switches.voluntary),
"involuntary": float(ctx_switches.involuntary),
},
"percent": process.cpu_percent(),
}
return metrics
except Exception as e:
# If psutil is installed but not compatible with the build,
# we'll just cease further attempts to use it.
_PSUTIL_AVAILABLE = False
logger.debug("Failed to get system metrics: %s", e)
return {}
@functools.lru_cache(maxsize=1)
def get_runtime_environment() -> dict:
"""Get information about the environment."""
# Lazy import to avoid circular imports
from langsmith import __version__
shas = get_release_shas()
return {
"sdk": "langsmith-py",
"sdk_version": __version__,
"library": "langsmith",
"platform": platform.platform(),
"runtime": "python",
"py_implementation": platform.python_implementation(),
"runtime_version": platform.python_version(),
"langchain_version": get_langchain_environment(),
"langchain_core_version": get_langchain_core_version(),
**shas,
}
@functools.lru_cache(maxsize=1)
def get_langchain_environment() -> Optional[str]:
try:
import langchain # type: ignore
return langchain.__version__
except: # noqa
return None
@functools.lru_cache(maxsize=1)
def get_langchain_core_version() -> Optional[str]:
try:
import langchain_core # type: ignore
return langchain_core.__version__
except ImportError:
return None
@functools.lru_cache(maxsize=1)
def get_docker_version() -> Optional[str]:
import subprocess
try:
docker_version = (
subprocess.check_output(["docker", "--version"]).decode("utf-8").strip()
)
except FileNotFoundError:
docker_version = "unknown"
except: # noqa
return None
return docker_version
@functools.lru_cache(maxsize=1)
def get_docker_compose_version() -> Optional[str]:
try:
docker_compose_version = (
subprocess.check_output(["docker-compose", "--version"])
.decode("utf-8")
.strip()
)
except FileNotFoundError:
docker_compose_version = "unknown"
except: # noqa
return None
return docker_compose_version
@functools.lru_cache(maxsize=1)
def _get_compose_command() -> Optional[List[str]]:
try:
compose_command = get_docker_compose_command()
except ValueError as e:
compose_command = [f"NOT INSTALLED: {e}"]
except: # noqa
return None
return compose_command
@functools.lru_cache(maxsize=1)
def get_docker_environment() -> dict:
"""Get information about the environment."""
compose_command = _get_compose_command()
return {
"docker_version": get_docker_version(),
"docker_compose_command": (
" ".join(compose_command) if compose_command is not None else None
),
"docker_compose_version": get_docker_compose_version(),
}
def get_langchain_env_vars() -> dict:
"""Retrieve the langchain environment variables."""
env_vars = {k: v for k, v in os.environ.items() if k.startswith("LANGCHAIN_")}
for key in list(env_vars):
if "key" in key.lower():
v = env_vars[key]
env_vars[key] = v[:2] + "*" * (len(v) - 4) + v[-2:]
return env_vars
@functools.lru_cache(maxsize=1)
def get_langchain_env_var_metadata() -> dict:
"""Retrieve the langchain environment variables."""
excluded = {
"LANGCHAIN_API_KEY",
"LANGCHAIN_ENDPOINT",
"LANGCHAIN_TRACING_V2",
"LANGCHAIN_PROJECT",
"LANGCHAIN_SESSION",
"LANGSMITH_RUNS_ENDPOINTS",
}
langchain_metadata = {
k: v
for k, v in os.environ.items()
if (k.startswith("LANGCHAIN_") or k.startswith("LANGSMITH_"))
and k not in excluded
and "key" not in k.lower()
and "secret" not in k.lower()
and "token" not in k.lower()
}
env_revision_id = langchain_metadata.pop("LANGCHAIN_REVISION_ID", None)
if env_revision_id:
langchain_metadata["revision_id"] = env_revision_id
elif default_revision_id := _get_default_revision_id():
langchain_metadata["revision_id"] = default_revision_id
return langchain_metadata
@functools.lru_cache(maxsize=1)
def _get_default_revision_id() -> Optional[str]:
"""Get the default revision ID based on `git describe`."""
try:
return exec_git(["describe", "--tags", "--always", "--dirty"])
except BaseException:
return None
@functools.lru_cache(maxsize=1)
def get_release_shas() -> Dict[str, str]:
common_release_envs = [
"VERCEL_GIT_COMMIT_SHA",
"NEXT_PUBLIC_VERCEL_GIT_COMMIT_SHA",
"COMMIT_REF",
"RENDER_GIT_COMMIT",
"CI_COMMIT_SHA",
"CIRCLE_SHA1",
"CF_PAGES_COMMIT_SHA",
"REACT_APP_GIT_SHA",
"SOURCE_VERSION",
"GITHUB_SHA",
"TRAVIS_COMMIT",
"GIT_COMMIT",
"BUILD_VCS_NUMBER",
"bamboo_planRepository_revision",
"Build.SourceVersion",
"BITBUCKET_COMMIT",
"DRONE_COMMIT_SHA",
"SEMAPHORE_GIT_SHA",
"BUILDKITE_COMMIT",
]
shas = {}
for env in common_release_envs:
env_var = os.environ.get(env)
if env_var is not None:
shas[env] = env_var
return shas

View File

@@ -0,0 +1,89 @@
"""Evaluation Helpers."""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langsmith.evaluation._arunner import (
aevaluate,
aevaluate_existing,
)
from langsmith.evaluation._runner import (
evaluate,
evaluate_comparative,
evaluate_existing,
)
from langsmith.evaluation.evaluator import (
EvaluationResult,
EvaluationResults,
RunEvaluator,
run_evaluator,
)
def __getattr__(
name: str,
) -> Any:
""".. deprecated:: 0.5.0.
Importing from langsmith.evaluation is deprecated. Use client.evaluate() instead.
"""
if name == "evaluate":
from langsmith.evaluation._runner import evaluate
return evaluate
elif name == "evaluate_existing":
from langsmith.evaluation._runner import evaluate_existing
return evaluate_existing
elif name == "aevaluate":
from langsmith.evaluation._arunner import aevaluate
return aevaluate
elif name == "aevaluate_existing":
from langsmith.evaluation._arunner import aevaluate_existing
return aevaluate_existing
elif name == "evaluate_comparative":
from langsmith.evaluation._runner import evaluate_comparative
return evaluate_comparative
elif name == "EvaluationResult":
from langsmith.evaluation.evaluator import EvaluationResult
return EvaluationResult
elif name == "EvaluationResults":
from langsmith.evaluation.evaluator import EvaluationResults
return EvaluationResults
elif name == "RunEvaluator":
from langsmith.evaluation.evaluator import RunEvaluator
return RunEvaluator
elif name == "run_evaluator":
from langsmith.evaluation.evaluator import run_evaluator
return run_evaluator
elif name == "StringEvaluator":
from langsmith.evaluation.string_evaluator import StringEvaluator
return StringEvaluator
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"run_evaluator",
"EvaluationResult",
"EvaluationResults",
"RunEvaluator",
"StringEvaluator",
"aevaluate",
"aevaluate_existing",
"evaluate",
"evaluate_existing",
"evaluate_comparative",
]
def __dir__() -> list[str]:
return __all__

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,727 @@
import random
adjectives = [
"abandoned",
"aching",
"advanced",
"ample",
"artistic",
"back",
"best",
"bold",
"brief",
"clear",
"cold",
"complicated",
"cooked",
"crazy",
"crushing",
"damp",
"dear",
"definite",
"dependable",
"diligent",
"drab",
"earnest",
"elderly",
"enchanted",
"essential",
"excellent",
"extraneous",
"fixed",
"flowery",
"formal",
"fresh",
"frosty",
"giving",
"glossy",
"healthy",
"helpful",
"impressionable",
"kind",
"large",
"left",
"long",
"loyal",
"mealy",
"memorable",
"monthly",
"new",
"notable",
"only",
"ordinary",
"passionate",
"perfect",
"pertinent",
"proper",
"puzzled",
"reflecting",
"respectful",
"roasted",
"scholarly",
"shiny",
"slight",
"sparkling",
"spotless",
"stupendous",
"sunny",
"tart",
"terrific",
"timely",
"unique",
"upbeat",
"vacant",
"virtual",
"warm",
"weary",
"whispered",
"worthwhile",
"yellow",
]
nouns = [
"account",
"acknowledgment",
"address",
"advertising",
"airplane",
"animal",
"appointment",
"arrival",
"artist",
"attachment",
"attitude",
"availability",
"backpack",
"bag",
"balance",
"bass",
"bean",
"beauty",
"bibliography",
"bill",
"bite",
"blossom",
"boat",
"book",
"box",
"boy",
"bread",
"bridge",
"broccoli",
"building",
"butter",
"button",
"cabbage",
"cake",
"camera",
"camp",
"candle",
"candy",
"canvas",
"car",
"card",
"carrot",
"cart",
"case",
"cat",
"chain",
"chair",
"chalk",
"chance",
"change",
"channel",
"character",
"charge",
"charm",
"chart",
"check",
"cheek",
"cheese",
"chef",
"cherry",
"chicken",
"child",
"church",
"circle",
"class",
"clay",
"click",
"clock",
"cloth",
"cloud",
"clove",
"club",
"coach",
"coal",
"coast",
"coat",
"cod",
"coffee",
"collar",
"color",
"comb",
"comfort",
"comic",
"committee",
"community",
"company",
"comparison",
"competition",
"condition",
"connection",
"control",
"cook",
"copper",
"copy",
"corn",
"cough",
"country",
"cover",
"crate",
"crayon",
"cream",
"creator",
"crew",
"crown",
"current",
"curtain",
"curve",
"cushion",
"dad",
"daughter",
"day",
"death",
"debt",
"decision",
"deer",
"degree",
"design",
"desire",
"desk",
"detail",
"development",
"digestion",
"dime",
"dinner",
"direction",
"dirt",
"discovery",
"discussion",
"disease",
"disgust",
"distance",
"distribution",
"division",
"doctor",
"dog",
"door",
"drain",
"drawer",
"dress",
"drink",
"driving",
"dust",
"ear",
"earth",
"edge",
"education",
"effect",
"egg",
"end",
"energy",
"engine",
"error",
"event",
"example",
"exchange",
"existence",
"expansion",
"experience",
"expert",
"eye",
"face",
"fact",
"fall",
"family",
"farm",
"father",
"fear",
"feeling",
"field",
"finger",
"fire",
"fish",
"flag",
"flight",
"floor",
"flower",
"fold",
"food",
"football",
"force",
"form",
"frame",
"friend",
"frog",
"fruit",
"fuel",
"furniture",
"game",
"garden",
"gate",
"girl",
"glass",
"glove",
"goat",
"gold",
"government",
"grade",
"grain",
"grass",
"green",
"grip",
"group",
"growth",
"guide",
"guitar",
"hair",
"hall",
"hand",
"harbor",
"harmony",
"hat",
"head",
"health",
"heart",
"heat",
"hill",
"history",
"hobbies",
"hole",
"hope",
"horn",
"horse",
"hospital",
"hour",
"house",
"humor",
"idea",
"impulse",
"income",
"increase",
"industry",
"ink",
"insect",
"instrument",
"insurance",
"interest",
"invention",
"iron",
"island",
"jelly",
"jet",
"jewel",
"join",
"judge",
"juice",
"jump",
"kettle",
"key",
"kick",
"kiss",
"kitten",
"knee",
"knife",
"knowledge",
"land",
"language",
"laugh",
"law",
"lead",
"learning",
"leather",
"leg",
"lettuce",
"level",
"library",
"lift",
"light",
"limit",
"line",
"linen",
"lip",
"liquid",
"list",
"look",
"loss",
"love",
"lunch",
"machine",
"man",
"manager",
"map",
"marble",
"mark",
"market",
"mass",
"match",
"meal",
"measure",
"meat",
"meeting",
"memory",
"metal",
"middle",
"milk",
"mind",
"mine",
"minute",
"mist",
"mitten",
"mom",
"money",
"monkey",
"month",
"moon",
"morning",
"mother",
"motion",
"mountain",
"mouth",
"muscle",
"music",
"nail",
"name",
"nation",
"neck",
"need",
"news",
"night",
"noise",
"note",
"number",
"nut",
"observation",
"offer",
"oil",
"operation",
"opinion",
"orange",
"order",
"organization",
"ornament",
"oven",
"page",
"pail",
"pain",
"paint",
"pan",
"pancake",
"paper",
"parcel",
"parent",
"part",
"passenger",
"paste",
"payment",
"peace",
"pear",
"pen",
"pencil",
"person",
"pest",
"pet",
"picture",
"pie",
"pin",
"pipe",
"pizza",
"place",
"plane",
"plant",
"plastic",
"plate",
"play",
"pleasure",
"plot",
"plough",
"pocket",
"point",
"poison",
"police",
"pollution",
"popcorn",
"porter",
"position",
"pot",
"potato",
"powder",
"power",
"price",
"print",
"process",
"produce",
"product",
"profit",
"property",
"prose",
"protest",
"pull",
"pump",
"punishment",
"purpose",
"push",
"quarter",
"question",
"quiet",
"quill",
"quilt",
"quince",
"rabbit",
"rail",
"rain",
"range",
"rat",
"rate",
"ray",
"reaction",
"reading",
"reason",
"record",
"regret",
"relation",
"religion",
"representative",
"request",
"respect",
"rest",
"reward",
"rhythm",
"rice",
"river",
"road",
"roll",
"room",
"root",
"rose",
"route",
"rub",
"rule",
"run",
"sack",
"sail",
"salt",
"sand",
"scale",
"scarecrow",
"scarf",
"scene",
"scent",
"school",
"science",
"scissors",
"screw",
"sea",
"seat",
"secretary",
"seed",
"selection",
"self",
"sense",
"servant",
"shade",
"shake",
"shame",
"shape",
"sheep",
"sheet",
"shelf",
"ship",
"shirt",
"shock",
"shoe",
"shop",
"show",
"side",
"sign",
"silk",
"sink",
"sister",
"size",
"sky",
"sleep",
"smash",
"smell",
"smile",
"smoke",
"snail",
"snake",
"sneeze",
"snow",
"soap",
"society",
"sock",
"soda",
"sofa",
"son",
"song",
"sort",
"sound",
"soup",
"space",
"spark",
"speed",
"sponge",
"spoon",
"spray",
"spring",
"spy",
"square",
"stamp",
"star",
"start",
"statement",
"station",
"steam",
"steel",
"stem",
"step",
"stew",
"stick",
"stitch",
"stocking",
"stomach",
"stone",
"stop",
"store",
"story",
"stove",
"stranger",
"straw",
"stream",
"street",
"stretch",
"string",
"structure",
"substance",
"sugar",
"suggestion",
"suit",
"summer",
"sun",
"support",
"surprise",
"sweater",
"swim",
"system",
"table",
"tail",
"talk",
"tank",
"taste",
"tax",
"tea",
"teaching",
"team",
"tendency",
"test",
"texture",
"theory",
"thing",
"thought",
"thread",
"throat",
"thumb",
"thunder",
"ticket",
"time",
"tin",
"title",
"toad",
"toe",
"tooth",
"toothpaste",
"touch",
"town",
"toy",
"trade",
"train",
"transport",
"tray",
"treatment",
"tree",
"trick",
"trip",
"trouble",
"trousers",
"truck",
"tub",
"turkey",
"turn",
"twist",
"umbrella",
"uncle",
"underwear",
"unit",
"use",
"vacation",
"value",
"van",
"vase",
"vegetable",
"veil",
"vein",
"verse",
"vessel",
"view",
"visitor",
"voice",
"volcano",
"walk",
"wall",
"war",
"wash",
"waste",
"watch",
"water",
"wave",
"wax",
"way",
"wealth",
"weather",
"week",
"weight",
"wheel",
"whip",
"whistle",
"window",
"wine",
"wing",
"winter",
"wire",
"wish",
"woman",
"wood",
"wool",
"word",
"work",
"worm",
"wound",
"wrist",
"writer",
"yard",
"yoke",
"zebra",
"zinc",
"zipper",
"zone",
]
def random_name() -> str:
"""Generate a random name."""
adjective = random.choice(adjectives)
noun = random.choice(nouns)
number = random.randint(1, 100)
return f"{adjective}-{noun}-{number}"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,302 @@
"""Contains the `LLMEvaluator` class for building LLM-as-a-judge evaluators."""
from typing import Any, Callable, Optional, Union, cast
from pydantic import BaseModel
from langsmith._internal._beta_decorator import warn_beta
from langsmith.evaluation import EvaluationResult, EvaluationResults, RunEvaluator
from langsmith.schemas import Example, Run
class CategoricalScoreConfig(BaseModel):
"""Configuration for a categorical score."""
key: str
choices: list[str]
description: str
include_explanation: bool = False
explanation_description: Optional[str] = None
class ContinuousScoreConfig(BaseModel):
"""Configuration for a continuous score."""
key: str
min: float = 0
max: float = 1
description: str
include_explanation: bool = False
explanation_description: Optional[str] = None
def _create_score_json_schema(
score_config: Union[CategoricalScoreConfig, ContinuousScoreConfig],
) -> dict:
properties: dict[str, Any] = {}
if isinstance(score_config, CategoricalScoreConfig):
properties["score"] = {
"type": "string",
"enum": score_config.choices,
"description": f"The score for the evaluation, one of "
f"{', '.join(score_config.choices)}.",
}
elif isinstance(score_config, ContinuousScoreConfig):
properties["score"] = {
"type": "number",
"minimum": score_config.min,
"maximum": score_config.max,
"description": f"The score for the evaluation, between "
f"{score_config.min} and {score_config.max}, inclusive.",
}
else:
raise ValueError("Invalid score type. Must be 'categorical' or 'continuous'")
if score_config.include_explanation:
properties["explanation"] = {
"type": "string",
"description": (
"The explanation for the score."
if score_config.explanation_description is None
else score_config.explanation_description
),
}
return {
"title": score_config.key,
"description": score_config.description,
"type": "object",
"properties": properties,
"required": (
["score", "explanation"] if score_config.include_explanation else ["score"]
),
}
class LLMEvaluator(RunEvaluator):
"""A class for building LLM-as-a-judge evaluators.
.. deprecated:: 0.5.0
LLMEvaluator is deprecated. Use openevals instead: https://github.com/langchain-ai/openevals
"""
def __init__(
self,
*,
prompt_template: Union[str, list[tuple[str, str]]],
score_config: Union[CategoricalScoreConfig, ContinuousScoreConfig],
map_variables: Optional[Callable[[Run, Optional[Example]], dict]] = None,
model_name: str = "gpt-4o",
model_provider: str = "openai",
**kwargs,
):
"""Initialize the `LLMEvaluator`.
Args:
prompt_template (Union[str, List[Tuple[str, str]]): The prompt
template to use for the evaluation. If a string is provided, it is
assumed to be a human / user message.
score_config (Union[CategoricalScoreConfig, ContinuousScoreConfig]):
The configuration for the score, either categorical or continuous.
map_variables (Optional[Callable[[Run, Example], dict]], optional):
A function that maps the run and example to the variables in the
prompt.
If `None`, it is assumed that the prompt only requires 'input',
'output', and 'expected'.
model_name (Optional[str], optional): The model to use for the evaluation.
model_provider (Optional[str], optional): The model provider to use
for the evaluation.
"""
try:
from langchain.chat_models import ( # type: ignore[import-not-found]
init_chat_model,
)
except ImportError as e:
raise ImportError(
"LLMEvaluator requires langchain to be installed. "
"Please install langchain by running `pip install langchain`."
) from e
chat_model = init_chat_model(
model=model_name, model_provider=model_provider, **kwargs
)
self._initialize(prompt_template, score_config, map_variables, chat_model)
@classmethod
def from_model(
cls,
model: Any,
*,
prompt_template: Union[str, list[tuple[str, str]]],
score_config: Union[CategoricalScoreConfig, ContinuousScoreConfig],
map_variables: Optional[Callable[[Run, Optional[Example]], dict]] = None,
):
"""Create an `LLMEvaluator` instance from a `BaseChatModel` instance.
Args:
model (BaseChatModel): The chat model instance to use for the evaluation.
prompt_template (Union[str, List[Tuple[str, str]]): The prompt
template to use for the evaluation. If a string is provided, it is
assumed to be a system message.
score_config (Union[CategoricalScoreConfig, ContinuousScoreConfig]):
The configuration for the score, either categorical or continuous.
map_variables (Optional[Callable[[Run, Example]], dict]], optional):
A function that maps the run and example to the variables in the
prompt.
If `None`, it is assumed that the prompt only requires 'input',
'output', and 'expected'.
Returns:
LLMEvaluator: An instance of `LLMEvaluator`.
"""
instance = cls.__new__(cls)
instance._initialize(prompt_template, score_config, map_variables, model)
return instance
def _initialize(
self,
prompt_template: Union[str, list[tuple[str, str]]],
score_config: Union[CategoricalScoreConfig, ContinuousScoreConfig],
map_variables: Optional[Callable[[Run, Optional[Example]], dict]],
chat_model: Any,
):
"""Shared initialization code for `__init__` and `from_model`.
Args:
prompt_template (Union[str, List[Tuple[str, str]]): The prompt template.
score_config (Union[CategoricalScoreConfig, ContinuousScoreConfig]):
The score configuration.
map_variables (Optional[Callable[[Run, Example]], dict]]):
Function to map variables.
chat_model (BaseChatModel): The chat model instance.
"""
try:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
except ImportError as e:
raise ImportError(
"LLMEvaluator requires langchain-core to be installed. "
"Please install langchain-core by running `pip install langchain-core`."
) from e
if not (
isinstance(chat_model, BaseChatModel)
and hasattr(chat_model, "with_structured_output")
):
raise ValueError(
"chat_model must be an instance of "
"BaseLanguageModel and support structured output."
)
if isinstance(prompt_template, str):
self.prompt = ChatPromptTemplate.from_messages([("human", prompt_template)])
else:
self.prompt = ChatPromptTemplate.from_messages(prompt_template)
if set(self.prompt.input_variables) - {"input", "output", "expected"}:
if not map_variables:
raise ValueError(
"map_inputs must be provided if the prompt template contains "
"variables other than 'input', 'output', and 'expected'"
)
self.map_variables = map_variables
self.score_config = score_config
self.score_schema = _create_score_json_schema(self.score_config)
chat_model = chat_model.with_structured_output(self.score_schema)
self.runnable = self.prompt | chat_model
@warn_beta
def evaluate_run(
self, run: Run, example: Optional[Example] = None
) -> Union[EvaluationResult, EvaluationResults]:
"""Evaluate a run."""
variables = self._prepare_variables(run, example)
output: dict = cast(dict, self.runnable.invoke(variables))
return self._parse_output(output)
@warn_beta
async def aevaluate_run(
self, run: Run, example: Optional[Example] = None
) -> Union[EvaluationResult, EvaluationResults]:
"""Asynchronously evaluate a run."""
variables = self._prepare_variables(run, example)
output: dict = cast(dict, await self.runnable.ainvoke(variables))
return self._parse_output(output)
def _prepare_variables(self, run: Run, example: Optional[Example]) -> dict:
"""Prepare variables for model invocation."""
if self.map_variables:
return self.map_variables(run, example)
variables = {}
if "input" in self.prompt.input_variables:
if len(run.inputs) == 0:
raise ValueError(
"No input keys are present in run.inputs but the prompt "
"requires 'input'."
)
if len(run.inputs) != 1:
raise ValueError(
"Multiple input keys are present in run.inputs. Please provide "
"a map_variables function."
)
variables["input"] = list(run.inputs.values())[0]
if "output" in self.prompt.input_variables:
if not run.outputs:
raise ValueError(
"No output keys are present in run.outputs but the prompt "
"requires 'output'."
)
if len(run.outputs) == 0:
raise ValueError(
"No output keys are present in run.outputs but the prompt "
"requires 'output'."
)
if len(run.outputs) != 1:
raise ValueError(
"Multiple output keys are present in run.outputs. Please "
"provide a map_variables function."
)
variables["output"] = list(run.outputs.values())[0]
if "expected" in self.prompt.input_variables:
if not example or not example.outputs:
raise ValueError(
"No example or example outputs is provided but the prompt "
"requires 'expected'."
)
if len(example.outputs) == 0:
raise ValueError(
"No output keys are present in example.outputs but the prompt "
"requires 'expected'."
)
if len(example.outputs) != 1:
raise ValueError(
"Multiple output keys are present in example.outputs. Please "
"provide a map_variables function."
)
variables["expected"] = list(example.outputs.values())[0]
return variables
def _parse_output(self, output: dict) -> Union[EvaluationResult, EvaluationResults]:
"""Parse the model output into an evaluation result."""
if isinstance(self.score_config, CategoricalScoreConfig):
value = output["score"]
explanation = output.get("explanation", None)
return EvaluationResult(
key=self.score_config.key, value=value, comment=explanation
)
elif isinstance(self.score_config, ContinuousScoreConfig):
score = output["score"]
explanation = output.get("explanation", None)
return EvaluationResult(
key=self.score_config.key, score=score, comment=explanation
)

View File

@@ -0,0 +1,47 @@
"""This module contains the StringEvaluator class."""
import uuid
from typing import Callable, Optional
from pydantic import BaseModel
from langsmith.evaluation.evaluator import EvaluationResult, RunEvaluator
from langsmith.schemas import Example, Run
class StringEvaluator(RunEvaluator, BaseModel):
"""Grades the run's string input, output, and optional answer.
.. deprecated:: 0.5.0
StringEvaluator is deprecated. Use openevals instead: https://github.com/langchain-ai/openevals
"""
evaluation_name: Optional[str] = None
"""The name evaluation, such as `'Accuracy'` or `'Salience'`."""
input_key: str = "input"
"""The key in the run inputs to extract the input string."""
prediction_key: str = "output"
"""The key in the run outputs to extra the prediction string."""
answer_key: Optional[str] = "output"
"""The key in the example outputs the answer string."""
grading_function: Callable[[str, str, Optional[str]], dict]
"""Function that grades the run output against the example output."""
def evaluate_run(
self,
run: Run,
example: Optional[Example] = None,
evaluator_run_id: Optional[uuid.UUID] = None,
) -> EvaluationResult:
"""Evaluate a single run."""
if run.outputs is None:
raise ValueError("Run outputs cannot be None.")
if not example or example.outputs is None or self.answer_key is None:
answer = None
else:
answer = example.outputs.get(self.answer_key)
run_input = run.inputs[self.input_key]
run_output = run.outputs[self.prediction_key]
grading_results = self.grading_function(run_input, run_output, answer)
return EvaluationResult(**{"key": self.evaluation_name, **grading_results})

View File

@@ -0,0 +1,83 @@
"""LangSmith integration for Claude Agent SDK.
This module provides automatic tracing for the Claude Agent SDK by instrumenting
`ClaudeSDKClient` and injecting hooks to trace all tool calls.
"""
import logging
import sys
from typing import Optional
from ._client import instrument_claude_client
from ._config import set_tracing_config
logger = logging.getLogger(__name__)
__all__ = ["configure_claude_agent_sdk"]
def configure_claude_agent_sdk(
name: Optional[str] = None,
project_name: Optional[str] = None,
metadata: Optional[dict] = None,
tags: Optional[list[str]] = None,
) -> bool:
"""Enable LangSmith tracing for the Claude Agent SDK by patching entry points.
This function instruments the Claude Agent SDK to automatically trace:
- Chain runs for each conversation stream (via `ClaudeSDKClient`)
- Model runs for each assistant turn
- All tool calls including built-in tools, external MCP tools, and SDK MCP tools
Tool tracing is implemented via `PreToolUse` and `PostToolUse` hooks
Args:
name: Name of the root trace.
project_name: LangSmith project to trace to.
metadata: Metadata to associate with all traces.
tags: Tags to associate with all traces.
Returns:
`True` if configuration was successful, `False` otherwise.
Example:
>>> from langsmith.integrations.claude_agent_sdk import (
... configure_claude_agent_sdk,
... )
>>> configure_claude_agent_sdk(
... project_name="my-project", tags=["production"]
... ) # doctest: +SKIP
>>> # Now use claude_agent_sdk as normal - tracing is automatic
""" # noqa: E501
try:
import claude_agent_sdk # type: ignore[import-not-found]
except ImportError:
logger.warning("Claude Agent SDK not installed.")
return False
if not hasattr(claude_agent_sdk, "ClaudeSDKClient"):
logger.warning("Claude Agent SDK missing ClaudeSDKClient.")
return False
set_tracing_config(
name=name,
project_name=project_name,
metadata=metadata,
tags=tags,
)
original = getattr(claude_agent_sdk, "ClaudeSDKClient", None)
if not original:
return False
wrapped = instrument_claude_client(original)
setattr(claude_agent_sdk, "ClaudeSDKClient", wrapped)
for module in list(sys.modules.values()):
try:
if module and getattr(module, "ClaudeSDKClient", None) is original:
setattr(module, "ClaudeSDKClient", wrapped)
except Exception:
continue
return True

View File

@@ -0,0 +1,497 @@
"""Client instrumentation for Claude Agent SDK."""
import logging
import time
from collections.abc import AsyncGenerator, AsyncIterable
from datetime import datetime, timezone
from functools import cache
from typing import Any, Optional
from langsmith.run_helpers import get_current_run_tree, trace
from ._hooks import (
clear_active_tool_runs,
post_tool_use_failure_hook,
post_tool_use_hook,
pre_tool_use_hook,
)
from ._messages import (
build_llm_input,
extract_usage_from_result_message,
flatten_content_blocks,
)
from ._tools import clear_parent_run_tree, get_parent_run_tree, set_parent_run_tree
logger = logging.getLogger(__name__)
TRACE_CHAIN_NAME = "claude.conversation"
@cache
def _get_package_version(package_name: str) -> str | None:
try:
from importlib.metadata import version
return version(package_name)
except Exception:
return None
LLM_RUN_NAME = "claude.assistant.turn"
class TurnLifecycle:
"""Track ongoing model runs so consecutive messages are recorded correctly."""
def __init__(self, query_start_time: Optional[float] = None):
self.current_run: Optional[Any] = None
self.next_start_time: Optional[float] = query_start_time
def start_llm_run(
self,
message: Any,
prompt: Any,
history: list[dict[str, Any]],
parent: Optional[Any] = None,
) -> Optional[dict[str, Any]]:
"""Begin a new model run, ending any existing one."""
start = self.next_start_time or time.time()
if self.current_run:
self.current_run.end()
self.current_run.patch()
final_output, run = begin_llm_run_from_assistant_messages(
[message], prompt, history, start_time=start, parent=parent
)
self.current_run = run
self.next_start_time = None
return final_output
def mark_next_start(self) -> None:
"""Mark when the next assistant message will start."""
self.next_start_time = time.time()
def add_usage(self, metrics: dict[str, Any]) -> None:
"""Attach token usage details to the current run."""
if not (self.current_run and metrics):
return
meta = self.current_run.extra.setdefault("metadata", {}).setdefault(
"usage_metadata", {}
)
meta.update(metrics)
def close(self) -> None:
"""End any open run gracefully."""
if self.current_run:
self.current_run.end()
self.current_run.patch()
self.current_run = None
def begin_llm_run_from_assistant_messages(
messages: list[Any],
prompt: Any,
history: list[dict[str, Any]],
start_time: Optional[float] = None,
parent: Optional[Any] = None,
) -> tuple[Optional[dict[str, Any]], Optional[Any]]:
"""Create a traced model run from assistant messages."""
if not messages or type(messages[-1]).__name__ != "AssistantMessage":
return None, None
last_msg = messages[-1]
model = getattr(last_msg, "model", None)
if parent is None:
parent = get_parent_run_tree() or get_current_run_tree()
if not parent:
return None, None
inputs = build_llm_input(prompt, history)
outputs = [
{"content": flatten_content_blocks(m.content), "role": "assistant"}
for m in messages
if hasattr(m, "content")
]
llm_run = parent.create_child(
name=LLM_RUN_NAME,
run_type="llm",
inputs={"messages": inputs} if inputs else {},
extra={"metadata": {"ls_model_name": model}} if model else {},
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc)
if start_time
else None,
)
try:
llm_run.post()
except Exception as e:
logger.warning(f"Failed to post LLM run: {e}")
# Set outputs after posting so they are sent with end_time on the patch.
llm_run.outputs = outputs[-1] if len(outputs) == 1 else {"content": outputs}
final_content = (
{"content": flatten_content_blocks(last_msg.content), "role": "assistant"}
if hasattr(last_msg, "content")
else None
)
return final_content, llm_run
def _inject_tracing_hooks(options: Any) -> None:
"""Inject LangSmith tracing hooks into ClaudeAgentOptions."""
if not hasattr(options, "hooks"):
return
# Initialize hooks dict if not present
if options.hooks is None:
options.hooks = {}
for event in ("PreToolUse", "PostToolUse", "PostToolUseFailure"):
if event not in options.hooks:
options.hooks[event] = []
try:
from claude_agent_sdk import HookMatcher # type: ignore[import-not-found]
langsmith_pre_matcher = HookMatcher(matcher=None, hooks=[pre_tool_use_hook])
langsmith_post_matcher = HookMatcher(matcher=None, hooks=[post_tool_use_hook])
langsmith_failure_matcher = HookMatcher(
matcher=None, hooks=[post_tool_use_failure_hook]
)
options.hooks["PreToolUse"].insert(0, langsmith_pre_matcher)
options.hooks["PostToolUse"].insert(0, langsmith_post_matcher)
options.hooks["PostToolUseFailure"].insert(0, langsmith_failure_matcher)
logger.debug("Injected LangSmith tracing hooks into ClaudeAgentOptions")
except ImportError:
logger.warning("Failed to import HookMatcher from claude_agent_sdk")
except Exception as e:
logger.warning(f"Failed to inject tracing hooks: {e}")
def _unwrap_streamed_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Unwrap streaming input messages for trace display."""
if not messages:
return []
formatted = []
for msg in messages:
if not isinstance(msg, dict):
formatted.append(msg)
continue
if "message" in msg:
inner = msg["message"]
if isinstance(inner, dict):
formatted.append(
{
"role": inner.get("role", "user"),
"content": inner.get("content", ""),
}
)
else:
formatted.append(msg)
else:
formatted.append(msg)
return formatted
def instrument_claude_client(original_class: Any) -> Any:
"""Wrap `ClaudeSDKClient` to trace both `query()` and `receive_response()`."""
if getattr(original_class, "_langsmith_instrumented", False):
return original_class # Already wrapped, avoid double-tracing
class TracedClaudeSDKClient(original_class):
_langsmith_instrumented = True
def __init__(self, *args: Any, **kwargs: Any):
# Inject LangSmith tracing hooks into options before initialization
options = kwargs.get("options") or (args[0] if args else None)
if options:
_inject_tracing_hooks(options)
super().__init__(*args, **kwargs)
self._prompt: Optional[str] = None
self._start_time: Optional[float] = None
self._streamed_input: Optional[list[dict[str, Any]]] = None
async def query(self, *args: Any, **kwargs: Any) -> Any:
"""Capture prompt and start time, wrapping generators if needed."""
self._start_time = time.time()
self._streamed_input = None
prompt = args[0] if args else kwargs.get("prompt")
if prompt is None:
pass
elif isinstance(prompt, str):
self._prompt = prompt
elif isinstance(prompt, AsyncIterable):
collector: list[dict[str, Any]] = []
self._streamed_input = collector
self._prompt = None
async def _gen_wrapper() -> AsyncGenerator[dict[str, Any], None]:
async for msg in prompt:
collector.append(msg)
yield msg
if args:
args = (_gen_wrapper(),) + args[1:]
else:
kwargs["prompt"] = _gen_wrapper()
else:
self._prompt = str(prompt)
return await super().query(*args, **kwargs)
def _handle_assistant_tool_uses(
self,
msg: Any,
run: Any,
subagent_sessions: dict[str, Any],
) -> None:
"""Process tool uses for an assistant message."""
if not hasattr(msg, "content"):
return
from ._hooks import _client_managed_runs
parent_tool_use_id = getattr(msg, "parent_tool_use_id", None)
for block in msg.content:
if type(block).__name__ != "ToolUseBlock":
continue
try:
tool_use_id = getattr(block, "id", None)
tool_name = getattr(block, "name", "unknown_tool")
tool_input = getattr(block, "input", {})
if not tool_use_id:
continue
start_time = time.time()
# Check if this is a Task tool (subagent)
if tool_name == "Task" and not parent_tool_use_id:
# Extract subagent name
subagent_name = (
tool_input.get("subagent_type")
or (
tool_input.get("description", "").split()[0]
if tool_input.get("description")
else None
)
or "unknown-agent"
)
subagent_session = run.create_child(
name=subagent_name,
run_type="chain",
inputs=tool_input,
start_time=datetime.fromtimestamp(
start_time, tz=timezone.utc
),
)
subagent_session.post()
subagent_sessions[tool_use_id] = subagent_session
_client_managed_runs[tool_use_id] = subagent_session
# Check if tool use is within a subagent
elif parent_tool_use_id and parent_tool_use_id in subagent_sessions:
subagent_session = subagent_sessions[parent_tool_use_id]
# Create tool run as child of subagent
tool_run = subagent_session.create_child(
name=tool_name,
run_type="tool",
inputs={"input": tool_input} if tool_input else {},
start_time=datetime.fromtimestamp(
start_time,
tz=timezone.utc,
),
)
tool_run.post()
_client_managed_runs[tool_use_id] = tool_run
except Exception as e:
logger.warning(f"Failed to create client-managed tool run: {e}")
async def receive_response(self) -> AsyncGenerator[Any, None]:
"""Intercept message stream and record chain run activity."""
messages = super().receive_response()
# Capture configuration in inputs and metadata
trace_inputs: dict[str, Any] = {}
trace_metadata: dict[str, Any] = {
"ls_integration": "claude-agent-sdk",
"ls_integration_version": _get_package_version("claude_agent_sdk"),
}
# Track if we need to update input from captured streaming messages
awaiting_streamed_input = self._streamed_input is not None
# Add prompt to inputs (for string prompts)
if self._prompt:
trace_inputs["prompt"] = self._prompt
# Add system_prompt to inputs if available
if hasattr(self, "options") and self.options:
if (
hasattr(self.options, "system_prompt")
and self.options.system_prompt
):
system_prompt = self.options.system_prompt
if isinstance(system_prompt, str):
trace_inputs["system"] = system_prompt
elif isinstance(system_prompt, dict):
# Handle SystemPromptPreset format
if system_prompt.get("type") == "preset":
preset_text = (
f"preset: {system_prompt.get('preset', 'claude_code')}"
)
if "append" in system_prompt:
preset_text += f"\nappend: {system_prompt['append']}"
trace_inputs["system"] = preset_text
else:
trace_inputs["system"] = system_prompt
# Add other config to metadata
for attr in ["model", "permission_mode", "max_turns"]:
if hasattr(self.options, attr):
val = getattr(self.options, attr)
if val is not None:
trace_metadata[attr] = val
async with trace(
name=TRACE_CHAIN_NAME,
run_type="chain",
inputs=trace_inputs,
metadata=trace_metadata,
) as run:
set_parent_run_tree(run)
tracker = TurnLifecycle(self._start_time)
collected: list[dict[str, Any]] = []
# Track subagent sessions by Task tool_use_id
subagent_sessions: dict[str, Any] = {}
prompt_for_llm: Any = self._prompt
try:
async for msg in messages:
if awaiting_streamed_input and self._streamed_input:
unwrapped_messages = _unwrap_streamed_messages(
self._streamed_input
)
if unwrapped_messages:
run.inputs["messages"] = unwrapped_messages
prompt_for_llm = self._streamed_input
awaiting_streamed_input = False
msg_type = type(msg).__name__
if msg_type == "AssistantMessage":
# Check if this message belongs to a subagent
parent_tool_use_id = getattr(
msg, "parent_tool_use_id", None
)
llm_parent = (
subagent_sessions.get(parent_tool_use_id)
if parent_tool_use_id
else None
)
content = tracker.start_llm_run(
msg, prompt_for_llm, collected, parent=llm_parent
)
if content:
collected.append(content)
# Process tool uses in this AssistantMessage
self._handle_assistant_tool_uses(
msg,
run,
subagent_sessions,
)
elif msg_type == "UserMessage":
if hasattr(msg, "content"):
# Check if this is a tool result message
flattened = flatten_content_blocks(msg.content)
if (
isinstance(flattened, list)
and flattened
and isinstance(flattened[0], dict)
and flattened[0].get("type") == "tool_result"
):
# Format each tool result as a separate message
for block in flattened:
collected.append(
{
"role": "tool",
"content": block.get("content", ""),
"tool_call_id": block.get(
"tool_use_id"
),
}
)
else:
collected.append(
{
"content": flattened,
"role": "user",
}
)
tracker.mark_next_start()
elif msg_type == "ResultMessage":
# Add usage metrics including cost
if hasattr(msg, "usage"):
usage = extract_usage_from_result_message(msg)
# Add total_cost to usage_metadata if available
if (
hasattr(msg, "total_cost_usd")
and msg.total_cost_usd is not None
):
usage["total_cost"] = msg.total_cost_usd
tracker.add_usage(usage)
# Add conversation-level metadata
meta = {
k: v
for k, v in {
"num_turns": getattr(msg, "num_turns", None),
"session_id": getattr(msg, "session_id", None),
"duration_ms": getattr(msg, "duration_ms", None),
"duration_api_ms": getattr(
msg, "duration_api_ms", None
),
"is_error": getattr(msg, "is_error", None),
}.items()
if v is not None
}
if meta:
run.metadata.update(meta)
yield msg
run.end(outputs=collected[-1] if collected else None)
except Exception:
logger.exception("Error while tracing Claude Agent stream")
finally:
tracker.close()
clear_parent_run_tree()
clear_active_tool_runs()
async def __aenter__(self) -> "TracedClaudeSDKClient":
await super().__aenter__()
return self
async def __aexit__(self, *args: Any) -> None:
await super().__aexit__(*args)
return TracedClaudeSDKClient

View File

@@ -0,0 +1,39 @@
"""Configuration management for Claude Agent SDK tracing."""
from typing import Any, Optional
# Global configuration for tracing
_tracing_config: dict[str, Any] = {
"name": None,
"project_name": None,
"metadata": None,
"tags": None,
}
def set_tracing_config(
name: Optional[str] = None,
project_name: Optional[str] = None,
metadata: Optional[dict] = None,
tags: Optional[list[str]] = None,
) -> None:
"""Set the global tracing configuration for Claude Agent SDK.
Args:
name: Name of the root trace.
project_name: LangSmith project to trace to.
metadata: Metadata to associate with all traces.
tags: Tags to associate with all traces.
"""
global _tracing_config
_tracing_config = {
"name": name,
"project_name": project_name,
"metadata": metadata,
"tags": tags,
}
def get_tracing_config() -> dict[str, Any]:
"""Get the current tracing configuration."""
return _tracing_config.copy()

View File

@@ -0,0 +1,253 @@
"""Hook-based tool tracing for Claude Agent SDK."""
import logging
import time
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional
from langsmith.run_helpers import get_current_run_tree
from langsmith.run_trees import RunTree
from ._tools import get_parent_run_tree
if TYPE_CHECKING:
from claude_agent_sdk import (
HookContext,
HookInput,
HookJSONOutput,
)
logger = logging.getLogger(__name__)
# Key: tool_use_id, Value: (run_tree, start_time)
_active_tool_runs: dict[str, tuple[Any, float]] = {}
# Storage for tool or subagent runs managed by client
# Key: tool_use_id, Value: run_tree
_client_managed_runs: dict[str, RunTree] = {}
async def pre_tool_use_hook(
input_data: "HookInput",
tool_use_id: Optional[str],
context: "HookContext",
) -> "HookJSONOutput":
"""Trace tool execution before it starts.
Args:
input_data: Contains `tool_name`, `tool_input`, `session_id`
tool_use_id: Unique identifier for this tool invocation
context: Hook context (currently contains only signal)
Returns:
Hook output (empty dict allows execution to proceed)
"""
if not tool_use_id:
return {}
# Skip if this tool run is already managed by the client
if tool_use_id in _client_managed_runs:
return {}
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
tool_input = input_data.get("tool_input", {})
try:
parent = get_parent_run_tree() or get_current_run_tree()
if not parent:
return {}
start_time = time.time()
tool_run = parent.create_child(
name=tool_name,
run_type="tool",
inputs={"input": tool_input} if tool_input else {},
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
)
try:
tool_run.post()
except Exception as e:
logger.warning(f"Failed to post tool run for {tool_name}: {e}")
_active_tool_runs[tool_use_id] = (tool_run, start_time)
except Exception as e:
logger.warning(f"Error in PreToolUse hook for {tool_name}: {e}", exc_info=True)
return {}
async def post_tool_use_hook(
input_data: "HookInput",
tool_use_id: Optional[str],
context: "HookContext",
) -> "HookJSONOutput":
"""Trace tool execution after it completes.
Args:
input_data: Contains `tool_name`, `tool_input`, `tool_response`, `session_id`, etc.
tool_use_id: Unique identifier for this tool invocation
context: Hook context (currently contains only signal)
Returns:
Hook output (empty `dict` by default)
""" # noqa: E501
if not tool_use_id:
return {}
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
tool_response = input_data.get("tool_response")
# Check if this is a client-managed run
run_tree = _client_managed_runs.pop(tool_use_id, None)
if run_tree:
# This run is managed by the client (subagent session or its tools)
try:
if isinstance(tool_response, dict):
outputs = tool_response
elif isinstance(tool_response, list):
outputs = {"content": tool_response}
else:
outputs = {"output": str(tool_response)} if tool_response else {}
is_error = False
if isinstance(tool_response, dict):
is_error = tool_response.get("is_error", False)
run_tree.end(
outputs=outputs,
error=outputs.get("output") if is_error else None,
)
run_tree.patch()
except Exception as e:
logger.warning(f"Failed to update client-managed run: {e}")
return {}
try:
run_info = _active_tool_runs.pop(tool_use_id, None)
if not run_info:
return {}
tool_run, start_time = run_info
if isinstance(tool_response, dict):
outputs = tool_response
elif isinstance(tool_response, list):
outputs = {"content": tool_response}
else:
outputs = {"output": str(tool_response)} if tool_response else {}
# Check if the tool execution was an error
is_error = False
if isinstance(tool_response, dict):
is_error = tool_response.get("is_error", False)
tool_run.end(
outputs=outputs,
error=outputs.get("output") if is_error else None,
)
try:
tool_run.patch()
except Exception as e:
logger.warning(f"Failed to patch tool run for {tool_name}: {e}")
except Exception as e:
logger.warning(f"Error in PostToolUse hook for {tool_name}: {e}", exc_info=True)
return {}
async def post_tool_use_failure_hook(
input_data: "HookInput",
tool_use_id: Optional[str],
context: "HookContext",
) -> "HookJSONOutput":
"""Trace tool execution when it fails.
This hook fires for built-in tool failures (Bash, Read, Write, etc.)
and is mutually exclusive with :func:`post_tool_use_hook` — when a
built-in tool fails, only ``PostToolUseFailure`` fires.
Args:
input_data: Contains ``tool_name``, ``tool_input``, ``error``,
and optionally ``is_interrupt``.
tool_use_id: Unique identifier for this tool invocation
context: Hook context (currently contains only signal)
Returns:
Hook output (empty dict)
"""
if not tool_use_id:
return {}
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
error: str = str(input_data.get("error", "Unknown error"))
# Check if this is a client-managed run (subagent or its tools)
run_tree = _client_managed_runs.pop(tool_use_id, None)
if run_tree:
try:
run_tree.end(
outputs={"error": error},
error=error,
)
run_tree.patch()
except Exception as e:
logger.warning(f"Failed to update client-managed run on failure: {e}")
return {}
try:
run_info = _active_tool_runs.pop(tool_use_id, None)
if not run_info:
return {}
tool_run, start_time = run_info
tool_run.end(
outputs={"error": error},
error=error,
)
try:
tool_run.patch()
except Exception as e:
logger.warning(f"Failed to patch failed tool run for {tool_name}: {e}")
except Exception as e:
logger.warning(
f"Error in PostToolUseFailure hook for {tool_name}: {e}", exc_info=True
)
return {}
def clear_active_tool_runs() -> None:
"""Clear all active tool runs.
This should be called when a conversation ends to avoid memory leaks
and to clean up any orphaned tool runs.
"""
global _active_tool_runs, _client_managed_runs
# End any orphaned client-managed runs
for tool_use_id, run_tree in _client_managed_runs.items():
try:
run_tree.end(error="Client-managed run not completed (conversation ended)")
run_tree.patch()
except Exception as e:
logger.debug(
f"Failed to clean up orphaned client-managed run {tool_use_id}: {e}"
)
# End any orphaned tool runs
for tool_use_id, (tool_run, _) in _active_tool_runs.items():
try:
tool_run.end(error="Tool run not completed (conversation ended)")
tool_run.patch()
except Exception as e:
logger.debug(f"Failed to clean up orphaned tool run {tool_use_id}: {e}")
_active_tool_runs.clear()
_client_managed_runs.clear()

View File

@@ -0,0 +1,116 @@
"""Message processing and content serialization for Claude Agent SDK."""
from typing import Any
def _extract_tool_result_text(content: Any) -> str:
"""Extract text content from tool result content blocks."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
texts.append(item.get("text", ""))
elif hasattr(item, "text"):
texts.append(getattr(item, "text", ""))
return "\n".join(texts) if texts else str(content)
return str(content)
def flatten_content_blocks(content: Any) -> Any:
"""Convert SDK content blocks into serializable dicts using explicit type checks."""
if not isinstance(content, list):
return content
result = []
for block in content:
block_type = type(block).__name__
# Handle known Claude SDK block types
if block_type == "TextBlock":
result.append(
{
"type": "text",
"text": getattr(block, "text", ""),
}
)
elif block_type == "ThinkingBlock":
result.append(
{
"type": "thinking",
"thinking": getattr(block, "thinking", ""),
"signature": getattr(block, "signature", ""),
}
)
elif block_type == "ToolUseBlock":
result.append(
{
"type": "tool_use",
"id": getattr(block, "id", None),
"name": getattr(block, "name", None),
"input": getattr(block, "input", None),
}
)
elif block_type == "ToolResultBlock":
# Extract text from nested content for tool results
tool_content = getattr(block, "content", None)
content_text = _extract_tool_result_text(tool_content)
result.append(
{
"type": "tool_result",
"tool_use_id": getattr(block, "tool_use_id", None),
"content": content_text,
"is_error": getattr(block, "is_error", False),
}
)
else:
result.append(block)
return result
def build_llm_input(prompt: Any, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Construct a combined prompt + history message list."""
if isinstance(prompt, str):
entry = {"content": prompt, "role": "user"}
return [entry, *history] if history else [entry]
if isinstance(prompt, list):
formatted = []
for msg in prompt:
if not isinstance(msg, dict):
formatted.append(msg)
continue
if "message" in msg:
inner = msg["message"]
if isinstance(inner, dict):
formatted.append(
{
"role": inner.get("role", "user"),
"content": inner.get("content", ""),
}
)
else:
formatted.append(msg)
elif "role" in msg and "content" in msg:
formatted.append(msg)
else:
formatted.append(msg)
return [*formatted, *history] if history else formatted
return history or []
def extract_usage_from_result_message(msg: Any) -> dict[str, Any]:
"""Normalize and merge token usage metrics from a `ResultMessage`."""
from ._usage import extract_usage_metadata, sum_anthropic_tokens
if not getattr(msg, "usage", None):
return {}
metrics = extract_usage_metadata(msg.usage)
return sum_anthropic_tokens(metrics) if metrics else {}

View File

@@ -0,0 +1,34 @@
"""Thread-local storage utilities for Claude Agent SDK tracing.
This module provides thread-local storage for the parent run tree,
which is used by hooks to maintain trace context when async context
propagation is broken.
"""
import logging
import threading
from typing import Any
logger = logging.getLogger(__name__)
# Thread-local store for passing the parent run tree into hooks.
# Claude's async event loop by default breaks tracing.
# contextvars start empty within new anyio threads. The parent run tree is threaded
# via thread-local as a fallback when context propagation isn't available.
_thread_local = threading.local()
def set_parent_run_tree(run_tree: Any) -> None:
"""Set the parent run tree in thread-local storage."""
_thread_local.parent_run_tree = run_tree
def clear_parent_run_tree() -> None:
"""Clear the parent run tree from thread-local storage."""
if hasattr(_thread_local, "parent_run_tree"):
delattr(_thread_local, "parent_run_tree")
def get_parent_run_tree() -> Any:
"""Get the parent run tree from thread-local storage."""
return getattr(_thread_local, "parent_run_tree", None)

View File

@@ -0,0 +1,63 @@
"""Token usage utilities for Claude Agent SDK."""
from typing import Any
def extract_usage_metadata(usage: Any) -> dict[str, Any]:
"""Extract and normalize usage metrics from a Claude usage object or dict."""
if not usage:
return {}
get = usage.get if isinstance(usage, dict) else lambda k: getattr(usage, k, None)
def to_int(value):
try:
return int(value)
except (ValueError, TypeError):
return None
def to_float(value):
try:
return float(value)
except (ValueError, TypeError):
return None
meta: dict[str, Any] = {}
if (v := to_int(get("input_tokens"))) is not None:
meta["input_tokens"] = v
if (v := to_int(get("output_tokens"))) is not None:
meta["output_tokens"] = v
cache_read = to_float(get("cache_read_input_tokens"))
cache_create = to_float(get("cache_creation_input_tokens"))
if cache_read is not None or cache_create is not None:
meta["input_token_details"] = {}
if cache_read is not None:
meta["input_token_details"]["cache_read"] = cache_read
if cache_create is not None:
meta["input_token_details"]["cache_creation"] = cache_create
return meta
def sum_anthropic_tokens(usage_metadata: dict[str, Any]) -> dict[str, int]:
"""Sum Anthropic cache tokens into `input_tokens` and add `total_tokens`."""
details = usage_metadata.get("input_token_details") or {}
cache_read = details.get(
"cache_read", usage_metadata.get("cache_read_input_tokens")
)
cache_create = details.get(
"cache_creation", usage_metadata.get("cache_creation_input_tokens")
)
input_tokens = usage_metadata.get("input_tokens") or 0
cache_read_val = cache_read or 0
cache_create_val = cache_create or 0
total_prompt = input_tokens + cache_read_val + cache_create_val
output_tokens = usage_metadata.get("output_tokens") or 0
return {
**usage_metadata,
"input_tokens": total_prompt,
"total_tokens": total_prompt + output_tokens,
}

View File

@@ -0,0 +1,127 @@
"""LangSmith integration for Google ADK (Agent Development Kit)."""
from __future__ import annotations
import logging
from typing import Optional
from ._config import set_tracing_config
logger = logging.getLogger(__name__)
__all__ = ["configure_google_adk", "create_traced_session_context"]
_patched = False
def configure_google_adk(
name: Optional[str] = None,
project_name: Optional[str] = None,
metadata: Optional[dict] = None,
tags: Optional[list[str]] = None,
) -> bool:
"""Enable LangSmith tracing for Google ADK.
Can be called before or after importing Runner (import-order agnostic).
Args:
name: Name of the root trace. Defaults to "google_adk.session".
project_name: LangSmith project to trace to.
metadata: Metadata to associate with all traces.
tags: Tags to associate with all traces.
Returns:
True if configuration was successful, False otherwise.
"""
global _patched
if _patched:
set_tracing_config(
name=name, project_name=project_name, metadata=metadata, tags=tags
)
return True
try:
import google.adk # noqa: F401
from wrapt import wrap_function_wrapper
except ImportError as e:
logger.warning(f"Missing dependency: {e}")
return False
set_tracing_config(
name=name, project_name=project_name, metadata=metadata, tags=tags
)
from ._client import (
wrap_agent_run_async,
wrap_flow_call_llm_async,
wrap_runner_run,
wrap_runner_run_async,
wrap_tool_run_async,
)
_wraps = [
(
"google.adk.runners",
"Runner.run",
wrap_runner_run,
),
(
"google.adk.runners",
"Runner.run_async",
wrap_runner_run_async,
),
(
"google.adk.agents.base_agent",
"BaseAgent.run_async",
wrap_agent_run_async,
),
(
"google.adk.flows.llm_flows.base_llm_flow",
"BaseLlmFlow._call_llm_async",
wrap_flow_call_llm_async,
),
(
"google.adk.tools.base_tool",
"BaseTool.run_async",
wrap_tool_run_async,
),
(
"google.adk.tools.function_tool",
"FunctionTool.run_async",
wrap_tool_run_async,
),
(
"google.adk.tools.mcp_tool.mcp_tool",
"McpTool.run_async",
wrap_tool_run_async,
),
]
for module, name, wrapper in _wraps:
try:
wrap_function_wrapper(module, name, wrapper)
except Exception as e:
logger.warning(f"Failed to wrap {name}: {e}")
_patched = True
return True
def create_traced_session_context(
name: Optional[str] = None,
project_name: Optional[str] = None,
metadata: Optional[dict] = None,
tags: Optional[list[str]] = None,
inputs: Optional[dict] = None,
):
"""Create a trace context for manual session tracing."""
from ._client import create_traced_session_context as _create_context
return _create_context(
name=name,
project_name=project_name,
metadata=metadata,
tags=tags,
inputs=inputs,
)

View File

@@ -0,0 +1,489 @@
"""Client instrumentation for Google ADK using wrapt."""
from __future__ import annotations
import json
import logging
import time
from collections.abc import AsyncGenerator
from contextlib import aclosing
from datetime import datetime, timezone
from functools import cache
from typing import Any, Optional
from langsmith.run_helpers import get_current_run_tree, set_tracing_parent, trace
from ._config import get_tracing_config
from ._messages import convert_llm_request_to_messages, has_function_calls
from ._usage import extract_model_name, extract_usage_from_response
_LS_PROVIDER_VERTEXAI = "google_vertexai"
_LS_PROVIDER_GOOGLE_AI = "google_ai"
def extract_tools_from_llm_request(llm_request: Any) -> list[dict[str, Any]]:
"""Extract tool definitions from LlmRequest and convert to OpenAI format."""
config = getattr(llm_request, "config", None)
if not config:
return []
tools_list = getattr(config, "tools", None)
if not tools_list:
return []
result = []
for tool in tools_list:
for func_decl in getattr(tool, "function_declarations", None) or []:
try:
dumped = func_decl.model_dump(exclude_none=True)
result.append(
{
"type": "function",
"function": dumped,
}
)
except Exception:
pass
return result
def _get_ls_provider() -> str:
"""Detect provider based on GOOGLE_GENAI_USE_VERTEXAI env var."""
import os
use_vertexai = os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "0").lower() in (
"1",
"true",
"yes",
)
return _LS_PROVIDER_VERTEXAI if use_vertexai else _LS_PROVIDER_GOOGLE_AI
logger = logging.getLogger(__name__)
TRACE_CHAIN_NAME = "google_adk.session"
@cache
def _get_package_version(package_name: str) -> str | None:
try:
from importlib.metadata import version
return version(package_name)
except Exception:
return None
# Attribute name used to bridge the root run from Runner.run (sync) into the
# background thread where Runner.run_async executes. Runner.run spins up a
# new thread for its internal asyncio event loop, so context vars don't
# propagate automatically. Storing the run on the instance (a plain object
# attribute) crosses the thread boundary, and wrap_runner_run_async picks it
# up and re-establishes it as a context var.
_SYNC_ROOT_RUN_ATTR = "_langsmith_root_run"
def _extract_text_from_content(content: Any) -> Optional[str]:
if content is None:
return None
parts = getattr(content, "parts", None)
if not parts:
return None
text_parts = [str(p.text) for p in parts if getattr(p, "text", None)]
return " ".join(text_parts) if text_parts else None
def _iter_invocation_events(ctx: Any) -> list[Any]:
"""Get session events for the current invocation."""
session = getattr(ctx, "session", None)
if session is None:
return []
invocation_id = getattr(ctx, "invocation_id", None)
events = getattr(session, "events", None) or []
if invocation_id is None:
return list(events)
return [e for e in events if getattr(e, "invocation_id", None) == invocation_id]
def _extract_latest_invocation_text(ctx: Any) -> Optional[str]:
"""Get the latest text from session events for the current invocation."""
for event in reversed(_iter_invocation_events(ctx)):
text = _extract_text_from_content(getattr(event, "content", None))
if text:
return text
return None
def wrap_runner_run(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any:
"""Wrap Runner.run to create a root trace for synchronous execution.
Runner.run internally starts a new thread to run its async event loop, so
context vars set here would not be visible to code running in that thread.
We bridge the gap by storing the root run on the instance (a plain object
attribute that IS visible across threads) so that wrap_runner_run_async can
re-establish it as a context var inside the async event loop.
"""
config = get_tracing_config()
trace_name = config.get("name") or TRACE_CHAIN_NAME
trace_inputs: dict[str, Any] = {}
if new_message := kwargs.get("new_message"):
if text := _extract_text_from_content(new_message):
trace_inputs["input"] = text
trace_metadata: dict[str, Any] = {
"ls_provider": _get_ls_provider(),
"ls_integration": "google-adk",
"ls_integration_version": _get_package_version("google-adk"),
**(config.get("metadata") or {}),
}
if app_name := getattr(instance, "app_name", None):
trace_metadata["app_name"] = app_name
if user_id := kwargs.get("user_id"):
trace_metadata["user_id"] = user_id
if session_id := kwargs.get("session_id"):
trace_metadata["session_id"] = session_id
def _trace_run():
with trace(
name=trace_name,
run_type="chain",
inputs=trace_inputs,
project_name=config.get("project_name"),
tags=config.get("tags"),
metadata=trace_metadata,
) as root_run:
setattr(instance, _SYNC_ROOT_RUN_ATTR, root_run)
try:
events = list(wrapped(*args, **kwargs))
final_output = None
for event in reversed(events):
if content := getattr(event, "content", None):
if text := _extract_text_from_content(content):
final_output = text
break
root_run.end(outputs={"output": final_output} if final_output else None)
yield from events
except Exception as e:
root_run.end(error=str(e))
raise
finally:
setattr(instance, _SYNC_ROOT_RUN_ATTR, None)
return _trace_run()
async def wrap_runner_run_async(
wrapped: Any, instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap Runner.run_async to create a root trace for asynchronous execution.
When called from the background thread spawned by Runner.run, the root run
stored on the instance is re-established as a context var so that
wrap_agent_run_async and wrap_flow_call_llm_async can find the parent via
get_current_run_tree().
"""
root_run = getattr(instance, _SYNC_ROOT_RUN_ATTR, None)
if root_run is not None:
# sync bridge: re-establish root run as context var in this thread
with set_tracing_parent(root_run):
async with aclosing(wrapped(*args, **kwargs)) as agen:
async for event in agen:
yield event
return
config = get_tracing_config()
trace_name = config.get("name") or TRACE_CHAIN_NAME
trace_inputs: dict[str, Any] = {}
if new_message := kwargs.get("new_message"):
if text := _extract_text_from_content(new_message):
trace_inputs["input"] = text
trace_metadata: dict[str, Any] = {
"ls_provider": _get_ls_provider(),
"ls_integration": "google-adk",
"ls_integration_version": _get_package_version("google-adk"),
**(config.get("metadata") or {}),
}
if app_name := getattr(instance, "app_name", None):
trace_metadata["app_name"] = app_name
if user_id := kwargs.get("user_id"):
trace_metadata["user_id"] = user_id
if session_id := kwargs.get("session_id"):
trace_metadata["session_id"] = session_id
async def _trace_run_async() -> AsyncGenerator[Any, None]:
async with trace(
name=trace_name,
run_type="chain",
inputs=trace_inputs,
project_name=config.get("project_name"),
tags=config.get("tags"),
metadata=trace_metadata,
) as run:
try:
final_output: Optional[str] = None
async with aclosing(wrapped(*args, **kwargs)) as agen:
async for event in agen:
if content := getattr(event, "content", None):
if text := _extract_text_from_content(content):
final_output = text
yield event
run.end(outputs={"output": final_output} if final_output else None)
except Exception as e:
run.end(error=str(e))
raise
async for event in _trace_run_async():
yield event
async def wrap_agent_run_async(
wrapped: Any, instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap BaseAgent.run_async to create a chain span for each agent invocation."""
parent = get_current_run_tree()
if not parent:
async with aclosing(wrapped(*args, **kwargs)) as agen:
async for event in agen:
yield event
return
ctx = args[0] if args else kwargs.get("parent_context")
agent_name = getattr(instance, "name", None) or type(instance).__name__
inputs: dict[str, Any] = {}
if ctx is not None:
if latest := _extract_latest_invocation_text(ctx):
inputs["input"] = latest
async with trace(name=agent_name, run_type="chain", inputs=inputs) as agent_run:
try:
final_output: Optional[str] = None
async with aclosing(wrapped(*args, **kwargs)) as agen:
async for event in agen:
if content := getattr(event, "content", None):
if text := _extract_text_from_content(content):
final_output = text
yield event
agent_run.end(outputs={"output": final_output} if final_output else None)
except Exception as e:
agent_run.end(error=str(e))
raise
async def wrap_tool_run_async(
wrapped: Any, instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap BaseTool.run_async (all tool subclasses) to trace tool invocations."""
parent = get_current_run_tree()
if not parent:
return await wrapped(*args, **kwargs)
tool_name = getattr(instance, "name", None) or type(instance).__name__
tool_args = kwargs.get("args") or (args[0] if args else {})
inputs = tool_args if isinstance(tool_args, dict) else {"args": tool_args}
start_time = time.time()
tool_run = parent.create_child(
name=tool_name,
run_type="tool",
inputs=inputs,
extra={"metadata": {"ls_provider": _get_ls_provider()}},
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
)
try:
tool_run.post()
except Exception as e:
logger.debug(f"Failed to post tool run: {e}")
try:
result = await wrapped(*args, **kwargs)
if isinstance(result, dict):
outputs = result
elif isinstance(result, list):
outputs = {"content": result}
elif result is not None:
outputs = {"output": str(result)}
else:
outputs = {}
tool_run.end(outputs=outputs)
try:
tool_run.patch()
except Exception as e:
logger.debug(f"Failed to patch tool run: {e}")
return result
except Exception as e:
tool_run.end(error=str(e))
try:
tool_run.patch()
except Exception as patch_e:
logger.debug(f"Failed to patch tool run on error: {patch_e}")
raise
def _determine_llm_call_type(llm_request: Any, llm_response: Any) -> str:
try:
for content in getattr(llm_request, "contents", None) or []:
for part in getattr(content, "parts", None) or []:
if hasattr(part, "function_response") and part.function_response:
return "response_generation"
if has_function_calls(llm_response):
return "tool_selection"
return "direct_response"
except Exception:
return "unknown"
async def wrap_flow_call_llm_async(
wrapped: Any, instance: Any, args: Any, kwargs: Any
) -> Any:
"""Wrap BaseLlmFlow._call_llm_async to capture LLM calls with TTFT tracking."""
parent = get_current_run_tree()
if not parent:
async for event in wrapped(*args, **kwargs):
yield event
return
llm_request = args[1] if len(args) > 1 else kwargs.get("llm_request")
model_name = extract_model_name(llm_request) if llm_request else None
messages = convert_llm_request_to_messages(llm_request) if llm_request else None
tools = extract_tools_from_llm_request(llm_request) if llm_request else []
inputs: dict[str, Any] = {}
if messages:
inputs["messages"] = messages
metadata: dict[str, Any] = {"ls_provider": _get_ls_provider()}
if model_name:
metadata["ls_model_name"] = model_name
# Build extra dict with invocation_params if tools exist
extra: dict[str, Any] = {"metadata": metadata}
if tools:
extra["invocation_params"] = {"tools": tools}
start_time = time.time()
llm_run = parent.create_child(
name=model_name or "google_adk_llm",
run_type="llm",
inputs=inputs,
extra=extra,
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
)
try:
llm_run.post()
except Exception as e:
logger.debug(f"Failed to post LLM run: {e}")
first_token_time: Optional[float] = None
last_event = None
event_with_content = None
try:
async with aclosing(wrapped(*args, **kwargs)) as agen:
async for event in agen:
is_partial = getattr(event, "partial", False)
if first_token_time is None and is_partial:
first_token_time = time.time()
try:
llm_run.add_event(
{
"name": "new_token",
"time": datetime.fromtimestamp(
first_token_time, tz=timezone.utc
).isoformat(),
}
)
except Exception as e:
logger.debug(f"Failed to add new_token event: {e}")
last_event = event
if hasattr(event, "content") and event.content is not None:
event_with_content = event
yield event
outputs: dict[str, Any] = {"role": "assistant"}
content_source = event_with_content or last_event
if (
content_source
and hasattr(content_source, "content")
and content_source.content
):
parts = getattr(content_source.content, "parts", None) or []
text_parts, tool_calls = [], []
for i, part in enumerate(parts):
if hasattr(part, "text") and part.text:
text_parts.append(str(part.text))
elif hasattr(part, "function_call") and part.function_call:
fc = part.function_call
tool_calls.append(
{
"id": f"call_{i}",
"type": "function",
"function": {
"name": getattr(fc, "name", ""),
"arguments": json.dumps(
dict(fc.args) if getattr(fc, "args", None) else {}
),
},
}
)
outputs["content"] = " ".join(text_parts) if text_parts else None
if tool_calls:
outputs["tool_calls"] = tool_calls
if last_event:
if usage := extract_usage_from_response(last_event):
llm_run.extra.setdefault("metadata", {})["usage_metadata"] = usage
if first_token_time is not None:
llm_run.extra.setdefault("metadata", {})["time_to_first_token"] = (
first_token_time - start_time
)
if last_event and llm_request:
llm_run.extra.setdefault("metadata", {})["llm_call_type"] = (
_determine_llm_call_type(llm_request, last_event)
)
llm_run.end(outputs=outputs)
try:
llm_run.patch()
except Exception as e:
logger.debug(f"Failed to patch LLM run: {e}")
except Exception as e:
llm_run.end(error=str(e))
try:
llm_run.patch()
except Exception as patch_e:
logger.debug(f"Failed to patch LLM run on error: {patch_e}")
raise
def create_traced_session_context(
name: Optional[str] = None,
project_name: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
tags: Optional[list[str]] = None,
inputs: Optional[dict[str, Any]] = None,
):
"""Create a trace context for manual session tracing."""
config = get_tracing_config()
return trace(
name=name or config.get("name") or TRACE_CHAIN_NAME,
run_type="chain",
inputs=inputs or {},
project_name=project_name or config.get("project_name"),
tags=tags or config.get("tags"),
metadata={**(config.get("metadata") or {}), **(metadata or {})},
)

Some files were not shown because too many files have changed in this diff Show More