initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,358 @@
"""Adapted.
Original source:
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
"""
import asyncio
import contextvars
import functools
import inspect
from collections import deque
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Coroutine,
Iterable,
Iterator,
)
from contextlib import AbstractAsyncContextManager
from typing import (
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
cast,
overload,
)
T = TypeVar("T")
_no_default = object()
# https://github.com/python/cpython/blob/main/Lib/test/test_asyncgen.py#L54
# before 3.10, the builtin anext() was not available
def py_anext(
iterator: AsyncIterator[T], default: Union[T, Any] = _no_default
) -> Awaitable[Union[T, None, Any]]:
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
)
except AttributeError:
raise TypeError(f"{iterator!r} is not an async iterator")
if default is _no_default:
return __anext__(iterator)
async def anext_impl() -> Union[T, Any]:
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
class NoLock:
"""Dummy lock that provides the proper interface but no protection."""
async def __aenter__(self) -> None:
pass
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
return False
async def tee_peer(
iterator: AsyncIterator[T],
# the buffer specific to this peer
buffer: deque[T],
# the buffers of all peers, including our own
peers: list[deque[T]],
lock: AbstractAsyncContextManager[Any],
) -> AsyncGenerator[T, None]:
"""Iterate over :py:func:`~.tee`."""
try:
while True:
if not buffer:
async with lock:
# Another peer produced an item while we were waiting for the lock.
# Proceed with the next loop iteration to yield the item.
if buffer:
continue
try:
item = await iterator.__anext__()
except StopAsyncIteration:
break
else:
# Append to all buffers, including our own. We'll fetch our
# item from the buffer again, instead of yielding it directly.
# This ensures the proper item ordering if any of our peers
# are fetching items concurrently. They may have buffered their
# item already.
for peer_buffer in peers:
peer_buffer.append(item)
yield buffer.popleft()
finally:
async with lock:
# this peer is done remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)
break
# if we are the last peer, try and close the iterator
if not peers and hasattr(iterator, "aclose"):
await iterator.aclose()
class Tee(Generic[T]):
"""Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.
All child iterators may advance separately but pare the same items
from ``iterable`` -- when the most advanced iterator retrieves an item,
it is buffered until the least advanced iterator has yielded it as well.
A ``tee`` works lazily and can handle an infinite ``iterable``, provided
that all iterators advance.
```python
async def derivative(sensor_data):
previous, current = a.tee(sensor_data, n=2)
await a.anext(previous) # advance one iterator
return a.map(operator.sub, previous, current)
```
Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
immediately closes all children, and it can be used in an ``async with`` context
for the same effect.
If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
provide these items. Also, ``tee`` must internally buffer each item until the
last iterator has yielded it; if the most and least advanced iterator differ
by most data, using a :py:class:`list` is more efficient (but not lazy).
If the underlying iterable is concurrency safe (``anext`` may be awaited
concurrently) the resulting iterators are concurrency safe as well. Otherwise,
the iterators are safe if there is only ever one single "most advanced" iterator.
To enforce sequential use of ``anext``, provide a ``lock``
- e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
and access is automatically synchronised.
"""
def __init__(
self,
iterable: AsyncIterator[T],
n: int = 2,
*,
lock: Optional[AbstractAsyncContextManager[Any]] = None,
):
self._iterator = iterable.__aiter__() # before 3.10 aiter() doesn't exist
self._buffers: list[deque[T]] = [deque() for _ in range(n)]
self._children = tuple(
tee_peer(
iterator=self._iterator,
buffer=buffer,
peers=self._buffers,
lock=lock if lock is not None else NoLock(),
)
for buffer in self._buffers
)
def __len__(self) -> int:
return len(self._children)
@overload
def __getitem__(self, item: int) -> AsyncIterator[T]: ...
@overload
def __getitem__(self, item: slice) -> tuple[AsyncIterator[T], ...]: ...
def __getitem__(
self, item: Union[int, slice]
) -> Union[AsyncIterator[T], tuple[AsyncIterator[T], ...]]:
return self._children[item]
def __iter__(self) -> Iterator[AsyncIterator[T]]:
yield from self._children
async def __aenter__(self) -> "Tee[T]":
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
await self.aclose()
return False
async def aclose(self) -> None:
for child in self._children:
await child.aclose()
atee = Tee
async def async_zip(*async_iterables):
"""Async version of zip."""
# Before Python 3.10, aiter() was not available
iterators = [iterable.__aiter__() for iterable in async_iterables]
while True:
try:
items = await asyncio.gather(
*(py_anext(iterator) for iterator in iterators)
)
yield tuple(items)
except StopAsyncIteration:
break
def ensure_async_iterator(
iterable: Union[Iterable, AsyncIterable],
) -> AsyncIterator:
if hasattr(iterable, "__anext__"):
return cast(AsyncIterator, iterable)
elif hasattr(iterable, "__aiter__"):
return cast(AsyncIterator, iterable.__aiter__())
else:
class AsyncIteratorWrapper:
def __init__(self, iterable: Iterable):
self._iterator = iter(iterable)
async def __anext__(self):
try:
return next(self._iterator)
except StopIteration:
raise StopAsyncIteration
def __aiter__(self):
return self
return AsyncIteratorWrapper(iterable)
def aiter_with_concurrency(
n: Optional[int],
generator: AsyncIterator[Coroutine[None, None, T]],
*,
_eager_consumption_timeout: float = 0,
) -> AsyncGenerator[T, None]:
"""Process async generator with max parallelism.
Args:
n: The number of tasks to run concurrently.
generator: The async generator to process.
_eager_consumption_timeout: If set, check for completed tasks after
each iteration and yield their results. This can be used to
consume the generator eagerly while still respecting the concurrency
limit.
Yields:
The processed items yielded by the async generator.
"""
if n == 0:
async def consume():
async for item in generator:
yield await item
return consume()
semaphore = cast(
asyncio.Semaphore, asyncio.Semaphore(n) if n is not None else NoLock()
)
async def process_item(ix: int, item):
async with semaphore:
res = await item
return (ix, res)
async def process_generator():
tasks = {}
accepts_context = asyncio_accepts_context()
ix = 0
async for item in generator:
if accepts_context:
context = contextvars.copy_context()
task = asyncio.create_task(process_item(ix, item), context=context)
else:
task = asyncio.create_task(process_item(ix, item))
tasks[ix] = task
ix += 1
if _eager_consumption_timeout > 0:
try:
for _fut in asyncio.as_completed(
tasks.values(),
timeout=_eager_consumption_timeout,
):
task_idx, res = await _fut
yield res
del tasks[task_idx]
except asyncio.TimeoutError:
pass
if n is not None and len(tasks) >= n:
done, _ = await asyncio.wait(
tasks.values(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
task_idx, res = task.result()
yield res
del tasks[task_idx]
for task in asyncio.as_completed(tasks.values()):
_, res = await task
yield res
return process_generator()
def accepts_context(callable: Callable[..., Any]) -> bool:
"""Check if a callable accepts a context argument."""
try:
return inspect.signature(callable).parameters.get("context") is not None
except ValueError:
return False
# Ported from Python 3.9+ to support Python 3.8
async def aio_to_thread(
func, /, *args, __ctx: Optional[contextvars.Context] = None, **kwargs
):
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = asyncio.get_running_loop()
ctx = __ctx or contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
@functools.lru_cache(maxsize=1)
def asyncio_accepts_context():
"""Check if the current asyncio event loop accepts a context argument."""
return accepts_context(asyncio.create_task)

View File

@@ -0,0 +1,964 @@
from __future__ import annotations
import concurrent.futures as cf
import copy
import functools
import io
import logging
import sys
import threading
import time
import weakref
from multiprocessing import cpu_count
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal._compressed_traces import ZSTD_AVAILABLE, CompressedTraces
from langsmith._internal._constants import (
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
_AUTO_SCALE_UP_NTHREADS_LIMIT,
_AUTO_SCALE_UP_QSIZE_TRIGGER,
_BOUNDARY,
)
from langsmith._internal._operations import (
SerializedFeedbackOperation,
SerializedRunOperation,
combine_serialized_queue_operations,
)
if TYPE_CHECKING:
from opentelemetry.context.context import Context # type: ignore[import]
from langsmith.client import Client
logger = logging.getLogger("langsmith.client")
LANGSMITH_CLIENT_THREAD_POOL = cf.ThreadPoolExecutor(max_workers=cpu_count())
def _group_batch_by_api_endpoint(
batch: list[TracingQueueItem],
) -> dict[
tuple[
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
],
list[TracingQueueItem],
]:
"""Group batch items by endpoint and auth combination."""
from collections import defaultdict
grouped = defaultdict(list)
for item in batch:
key = (
item.api_url,
item.api_key,
item.service_key,
item.tenant_id,
item.authorization,
item.cookie,
)
grouped[key].append(item)
return grouped
@functools.total_ordering
class TracingQueueItem:
"""An item in the tracing queue.
Attributes:
priority (str): The priority of the item.
item (Any): The item itself.
otel_context (Optional[Context]): The OTEL context of the item.
"""
priority: str
item: Union[SerializedRunOperation, SerializedFeedbackOperation]
api_url: Optional[str]
api_key: Optional[str]
service_key: Optional[str]
tenant_id: Optional[str]
authorization: Optional[str]
cookie: Optional[str]
otel_context: Optional[Context]
__slots__ = (
"priority",
"item",
"api_key",
"api_url",
"service_key",
"tenant_id",
"authorization",
"cookie",
"otel_context",
)
def __init__(
self,
priority: str,
item: Union[SerializedRunOperation, SerializedFeedbackOperation],
api_key: Optional[str] = None,
api_url: Optional[str] = None,
service_key: Optional[str] = None,
tenant_id: Optional[str] = None,
authorization: Optional[str] = None,
cookie: Optional[str] = None,
otel_context: Optional[Context] = None,
) -> None:
self.priority = priority
self.item = item
self.api_key = api_key
self.api_url = api_url
self.service_key = service_key
self.tenant_id = tenant_id
self.authorization = authorization
self.cookie = cookie
self.otel_context = otel_context
def __lt__(self, other: TracingQueueItem) -> bool:
return (self.priority, self.item.__class__) < (
other.priority,
other.item.__class__,
)
def __eq__(self, other: object) -> bool:
return isinstance(other, TracingQueueItem) and (
self.priority,
self.item.__class__,
) == (other.priority, other.item.__class__)
def _tracing_thread_drain_queue(
tracing_queue: Queue, limit: int = 100, block: bool = True, max_size_bytes: int = 0
) -> list[TracingQueueItem]:
next_batch: list[TracingQueueItem] = []
current_size = 0
try:
# wait 250ms for the first item, then
# - drain the queue with a 50ms block timeout
# - stop draining if we hit either count or size limit
# shorter drain timeout is used instead of non-blocking calls to
# avoid creating too many small batches
if item := tracing_queue.get(block=block, timeout=0.25):
next_batch.append(item)
if max_size_bytes > 0:
current_size += item.item.calculate_serialized_size()
# If first item already exceeds limit, return just this item
if current_size > max_size_bytes:
return next_batch
# Continue draining until we hit count limit OR size limit
while True:
try:
item = tracing_queue.get(block=block, timeout=0.05)
except Empty:
break
# Add the item first
next_batch.append(item)
# Then check size limit AFTER adding the item
if max_size_bytes > 0:
current_size += item.item.calculate_serialized_size()
# If we've exceeded size limit, stop here
# (item is included in this batch)
if current_size > max_size_bytes:
break
# Check count limit AFTER adding the item
if limit and len(next_batch) >= limit:
break
except Empty:
pass
return next_batch
def _tracing_thread_drain_compressed_buffer(
client: Client, size_limit: int = 100, size_limit_bytes: int | None = 20_971_520
) -> tuple[Optional[io.BytesIO], Optional[tuple[int, int]]]:
try:
if client.compressed_traces is None:
return None, None
with client.compressed_traces.lock:
pre_compressed_size = client.compressed_traces.uncompressed_size
size_limit_bytes = client._max_batch_size_bytes or size_limit_bytes
if size_limit is not None and size_limit <= 0:
raise ValueError(f"size_limit must be positive; got {size_limit}")
if size_limit_bytes is not None and size_limit_bytes < 0:
raise ValueError(
f"size_limit_bytes must be nonnegative; got {size_limit_bytes}"
)
if (
size_limit_bytes is None or pre_compressed_size < size_limit_bytes
) and (
size_limit is None or client.compressed_traces.trace_count < size_limit
):
return None, None
# Write final boundary and close compression stream
client.compressed_traces.compressor_writer.write(
f"--{_BOUNDARY}--\r\n".encode()
)
client.compressed_traces.compressor_writer.close()
current_size = client.compressed_traces.buffer.tell()
filled_buffer = client.compressed_traces.buffer
setattr(
cast(Any, filled_buffer),
"context",
client.compressed_traces._context,
)
compressed_traces_info = (pre_compressed_size, current_size)
client.compressed_traces.reset()
filled_buffer.seek(0)
return (filled_buffer, compressed_traces_info)
except Exception:
logger.error(
"LangSmith tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
# exceptions are logged elsewhere, but we need to make sure the
# background thread continues to run
return None, None
def _process_buffered_run_ops_batch(
client: Client,
batch_to_process: list[tuple[str, dict, dict[str, Optional[str]]]],
) -> None:
"""Process a batch of run operations asynchronously."""
try:
# Extract just the run dictionaries for process_buffered_run_ops
run_dicts = [run_data for _, run_data, _ in batch_to_process]
original_ids = [run.get("id") for run in run_dicts]
# Apply process_buffered_run_ops transformation
if client._process_buffered_run_ops is None:
raise RuntimeError(
"process_buffered_run_ops should not be None when processing batch"
)
processed_runs = list(client._process_buffered_run_ops(run_dicts))
# Validate that the transformation preserves run count and IDs
if len(processed_runs) != len(run_dicts):
raise ValueError(
f"process_buffered_run_ops must return the same number of runs. "
f"Expected {len(run_dicts)}, got {len(processed_runs)}"
)
processed_ids = [run.get("id") for run in processed_runs]
if processed_ids != original_ids:
raise ValueError(
f"process_buffered_run_ops must preserve run IDs in the same order. "
f"Expected {original_ids}, got {processed_ids}"
)
# Process each run and add to compressed traces
for (operation, _, write_ctx), processed_run in zip(
batch_to_process, processed_runs
):
if operation == "post":
client._create_run(processed_run, **write_ctx)
elif operation == "patch":
client._update_run(processed_run, **write_ctx)
# Trigger data available event
if client._data_available_event:
client._data_available_event.set()
except Exception:
# Log errors but don't crash the background thread
logger.error(
"LangSmith buffered run ops processing error: Failed to process batch.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
def _tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
use_multipart: bool,
mark_task_done: bool = True,
ops: Optional[
list[Union[SerializedRunOperation, SerializedFeedbackOperation]]
] = None,
) -> None:
"""Handle a batch of tracing queue items by sending them to LangSmith.
Args:
client: The LangSmith client to use for sending data.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
use_multipart: Whether to use multipart endpoint for sending data.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False when called from parallel execution to avoid double counting.
ops: Pre-combined serialized operations to use instead of combining from batch.
If None, operations will be combined from the batch items.
"""
try:
# Group batch items by (api_url, auth) combination
grouped_batches = _group_batch_by_api_endpoint(batch)
for (
api_url,
api_key,
service_key,
tenant_id,
authorization,
cookie,
), group_batch in grouped_batches.items():
if not ops:
group_ops = combine_serialized_queue_operations(
[item.item for item in group_batch]
)
else:
group_ids = {item.item.id for item in group_batch}
group_ops = [op for op in ops if op.id in group_ids]
if use_multipart:
client._multipart_ingest_ops(
group_ops,
api_url=api_url,
api_key=api_key,
service_key=service_key,
tenant_id=tenant_id,
authorization=authorization,
cookie=cookie,
)
else:
if any(isinstance(op, SerializedFeedbackOperation) for op in group_ops):
logger.warning(
"Feedback operations are not supported in non-multipart mode"
)
group_ops = [
op
for op in group_ops
if not isinstance(op, SerializedFeedbackOperation)
]
client._batch_ingest_run_ops(
cast(list[SerializedRunOperation], group_ops),
api_url=api_url,
api_key=api_key,
service_key=service_key,
tenant_id=tenant_id,
authorization=authorization,
cookie=cookie,
)
except Exception as e:
logger.error(
"LangSmith tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
client._invoke_tracing_error_callback(e)
finally:
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def _otel_tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
mark_task_done: bool = True,
ops: Optional[
list[Union[SerializedRunOperation, SerializedFeedbackOperation]]
] = None,
) -> None:
"""Handle a batch of tracing queue items by exporting them to OTEL.
Args:
client: The LangSmith client containing the OTEL exporter.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False when called from parallel execution to avoid double counting.
ops: Pre-combined serialized operations to use instead of combining from batch.
If None, operations will be combined from the batch items.
"""
try:
if ops is None:
ops = combine_serialized_queue_operations([item.item for item in batch])
run_ops = [op for op in ops if isinstance(op, SerializedRunOperation)]
otel_context_map = {
item.item.id: item.otel_context
for item in batch
if isinstance(item.item, SerializedRunOperation)
}
if run_ops:
if client.otel_exporter is not None:
client.otel_exporter.export_batch(run_ops, otel_context_map)
else:
logger.error(
"LangSmith tracing error: Failed to submit OTEL trace data.\n"
"This does not affect your application's runtime.\n"
"Error details: client.otel_exporter is None"
)
except Exception as e:
logger.error(
"OTEL tracing error: Failed to submit trace data.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
client._invoke_tracing_error_callback(e)
finally:
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def _hybrid_tracing_thread_handle_batch(
client: Client,
tracing_queue: Queue,
batch: list[TracingQueueItem],
use_multipart: bool,
mark_task_done: bool = True,
) -> None:
"""Handle a batch of tracing queue items by sending to both both LangSmith and OTEL.
Args:
client: The LangSmith client to use for sending data.
tracing_queue: The queue containing tracing items (used for task_done calls).
batch: List of tracing queue items to process.
use_multipart: Whether to use multipart endpoint for LangSmith.
mark_task_done: Whether to mark queue tasks as done after processing.
Set to False primarily for testing when items weren't actually queued.
"""
# Combine operations once to avoid race conditions
ops = combine_serialized_queue_operations([item.item for item in batch])
# Create copies for each thread to avoid shared mutation
langsmith_ops = copy.deepcopy(ops)
otel_ops = copy.deepcopy(ops)
try:
# Use ThreadPoolExecutor for parallel execution
with cf.ThreadPoolExecutor(max_workers=2) as executor:
# Submit both tasks
future_langsmith = executor.submit(
_tracing_thread_handle_batch,
client,
tracing_queue,
batch,
use_multipart,
False, # Don't mark tasks done - we'll do it once at the end
langsmith_ops,
)
future_otel = executor.submit(
_otel_tracing_thread_handle_batch,
client,
tracing_queue,
batch,
False, # Don't mark tasks done - we'll do it once at the end
otel_ops,
)
# Wait for both to complete
future_langsmith.result()
future_otel.result()
except RuntimeError as e:
if "cannot schedule new futures after interpreter shutdown" in str(e):
# During interpreter shutdown, ThreadPoolExecutor is blocked,
# fall back to sequential processing
logger.debug(
"Interpreter shutting down, falling back to sequential processing"
)
_tracing_thread_handle_batch(
client, tracing_queue, batch, use_multipart, False, langsmith_ops
)
_otel_tracing_thread_handle_batch(
client, tracing_queue, batch, False, otel_ops
)
else:
raise
# Mark all tasks as done once, only if requested
if mark_task_done and tracing_queue is not None:
for _ in batch:
try:
tracing_queue.task_done()
except ValueError as e:
if "task_done() called too many times" in str(e):
# This can happen during shutdown when multiple threads
# process the same queue items. It's harmless.
logger.debug(
f"Ignoring harmless task_done error during shutdown: {e}"
)
else:
raise
def get_size_limit_from_env() -> Optional[int]:
size_limit_str = ls_utils.get_env_var(
"BATCH_INGEST_SIZE_LIMIT",
)
if size_limit_str is not None:
try:
return int(size_limit_str)
except ValueError:
logger.warning(
f"Invalid value for BATCH_INGEST_SIZE_LIMIT: {size_limit_str}, "
"continuing with default"
)
return None
def _ensure_ingest_config(
info: ls_schemas.LangSmithInfo,
) -> ls_schemas.BatchIngestConfig:
default_config = ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit_bytes=None, # Note this field is not used here
size_limit=100,
scale_up_nthreads_limit=_AUTO_SCALE_UP_NTHREADS_LIMIT,
scale_up_qsize_trigger=_AUTO_SCALE_UP_QSIZE_TRIGGER,
scale_down_nempty_trigger=_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
)
if not info:
return default_config
try:
if not info.batch_ingest_config:
return default_config
env_size_limit = get_size_limit_from_env()
if env_size_limit is not None:
info.batch_ingest_config["size_limit"] = env_size_limit
return info.batch_ingest_config
except BaseException:
return default_config
def get_tracing_mode() -> tuple[bool, bool]:
"""Get the current tracing mode configuration.
Returns:
tuple[bool, bool]:
- hybrid_otel_and_langsmith: True if both OTEL and LangSmith tracing
are enabled, which is default behavior if OTEL_ENABLED is set to
true and OTEL_ONLY is not set to true
- is_otel_only: True if only OTEL tracing is enabled
"""
otel_enabled = ls_utils.is_env_var_truish("OTEL_ENABLED")
otel_only = ls_utils.is_env_var_truish("OTEL_ONLY")
# If OTEL is not enabled, neither mode should be active
if not otel_enabled:
return False, False
hybrid_otel_and_langsmith = not otel_only
is_otel_only = otel_only
return hybrid_otel_and_langsmith, is_otel_only
def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"]
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]
use_multipart = not client._multipart_disabled and batch_ingest_config.get(
"use_multipart_endpoint", True
)
sub_threads: list[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3
# Disable compression if explicitly set, using OpenTelemetry, or zstd unavailable
if not ZSTD_AVAILABLE:
logger.debug(
"zstandard package is not installed. "
"Falling back to uncompressed multipart ingestion."
)
disable_compression = (
ls_utils.is_env_var_truish("DISABLE_RUN_COMPRESSION")
or client.otel_exporter is not None
or not ZSTD_AVAILABLE
)
if not disable_compression and use_multipart:
if not (client.info.instance_flags or {}).get(
"zstd_compression_enabled", False
):
logger.warning(
"Run compression is not enabled. Please update to the latest "
"version of LangSmith. Falling back to regular multipart ingestion."
)
else:
client._futures = weakref.WeakSet()
client.compressed_traces = CompressedTraces()
client._data_available_event = threading.Event()
threading.Thread(
target=tracing_control_thread_func_compress_parallel,
args=(weakref.ref(client),),
daemon=client._use_daemon_threads,
).start()
num_known_refs += 1
def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
hasattr(client, "_manual_cleanup") and client._manual_cleanup
):
logger.debug("Client is being cleaned up, stopping tracing thread")
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
logger.debug("Main thread is dead, stopping tracing thread")
return False
if hasattr(sys, "getrefcount"):
# check if client refs count indicates we're the only remaining
# reference to the client
should_keep_thread = sys.getrefcount(client) > num_known_refs + len(
sub_threads
)
if not should_keep_thread:
logger.debug(
"Client refs count indicates we're the only remaining reference "
"to the client, stopping tracing thread",
)
return should_keep_thread
else:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True
# loop until
while keep_thread_active():
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
if (
len(sub_threads) < scale_up_nthreads_limit
and tracing_queue.qsize() > scale_up_qsize_trigger
):
new_thread = threading.Thread(
target=_tracing_sub_thread_func,
args=(weakref.ref(client), use_multipart),
daemon=client._use_daemon_threads,
)
sub_threads.append(new_thread)
new_thread.start()
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes
or batch_ingest_config.get("size_limit_bytes")
or 0
)
if next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode: both OTEL and LangSmith
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only mode
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only mode
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
# drain the queue on exit - apply same logic
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes or batch_ingest_config.get("size_limit_bytes") or 0
)
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode cleanup
logger.debug("Hybrid mode cleanup")
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only cleanup
logger.debug("OTEL-only cleanup")
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only cleanup
logger.debug("LangSmith-only cleanup")
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
logger.debug("Tracing control thread is shutting down")
def tracing_control_thread_func_compress_parallel(
client_ref: weakref.ref[Client], flush_interval: float = 0.5
) -> None:
client = client_ref()
if client is None:
return
logger.debug("Tracing control thread func compress parallel called")
if (
client.compressed_traces is None
or client._data_available_event is None
or client._futures is None
):
logger.error(
"LangSmith tracing error: Required compression attributes not "
"initialized.\nThis may affect trace submission but does not "
"impact your application's runtime."
)
return
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
size_limit_bytes = client._max_batch_size_bytes or batch_ingest_config.get(
"size_limit_bytes", 20_971_520
)
# One for this func, one for the parent thread, one for getrefcount,
# one for _get_data_type_cached
num_known_refs = 4
def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
hasattr(client, "_manual_cleanup") and client._manual_cleanup
):
logger.debug("Client is being cleaned up, stopping compression thread")
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
logger.debug("Main thread is dead, stopping compression thread")
return False
if hasattr(sys, "getrefcount"):
# check if client refs count indicates we're the only remaining
# reference to the client
should_keep_thread = sys.getrefcount(client) > num_known_refs
if not should_keep_thread:
logger.debug(
"Client refs count indicates we're the only remaining reference "
"to the client, stopping compression thread",
)
return should_keep_thread
else:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True
last_flush_time = time.monotonic()
while True:
triggered = client._data_available_event.wait(timeout=0.05)
if not keep_thread_active():
break
# If data arrived, clear the event and attempt a drain
if triggered:
client._data_available_event.clear()
data_stream, compressed_traces_info = (
_tracing_thread_drain_compressed_buffer
)(client, size_limit, size_limit_bytes)
# If we have data, submit the send request
if data_stream is not None:
try:
future = LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_traces_info,
)
client._futures.add(future)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_traces_info,
)
last_flush_time = time.monotonic()
else:
if (time.monotonic() - last_flush_time) >= flush_interval:
(
data_stream,
compressed_traces_info,
) = _tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
if data_stream is not None:
try:
cf.wait(
[
LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
data_stream,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
data_stream,
compressed_traces_info,
)
last_flush_time = time.monotonic()
# Drain the buffer on exit (final flush)
try:
(
final_data_stream,
compressed_traces_info,
) = _tracing_thread_drain_compressed_buffer(
client, size_limit=1, size_limit_bytes=1
)
if final_data_stream is not None:
try:
cf.wait(
[
LANGSMITH_CLIENT_THREAD_POOL.submit(
client._send_compressed_multipart_req,
final_data_stream,
compressed_traces_info,
)
]
)
except RuntimeError:
client._send_compressed_multipart_req(
final_data_stream,
compressed_traces_info,
)
except Exception:
logger.error(
"LangSmith tracing error: Failed during final cleanup.\n"
"This does not affect your application's runtime.\n"
"Error details:",
exc_info=True,
)
logger.debug("Compressed traces control thread is shutting down")
def _tracing_sub_thread_func(
client_ref: weakref.ref[Client],
use_multipart: bool,
) -> None:
client = client_ref()
if client is None:
return
try:
if not client.info:
return
except BaseException as e:
logger.debug("Error in tracing control thread: %s", e)
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit = batch_ingest_config.get("size_limit", 100)
seen_successive_empty_queues = 0
# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we've seen the queue empty 4 times in a row
and seen_successive_empty_queues
<= batch_ingest_config["scale_down_nempty_trigger"]
):
max_batch_size = (
client._max_batch_size_bytes
or batch_ingest_config.get("size_limit_bytes")
or 0
)
if next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, max_size_bytes=max_batch_size
):
seen_successive_empty_queues = 0
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
if hybrid_otel_and_langsmith:
# Hybrid mode: both OTEL and LangSmith
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only mode
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only mode
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
else:
seen_successive_empty_queues += 1
# drain the queue on exit - apply same logic
hybrid_otel_and_langsmith, is_otel_only = get_tracing_mode()
max_batch_size = (
client._max_batch_size_bytes or batch_ingest_config.get("size_limit_bytes") or 0
)
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False, max_size_bytes=max_batch_size
):
if hybrid_otel_and_langsmith:
# Hybrid mode cleanup
_hybrid_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
elif is_otel_only:
# OTEL-only cleanup
logger.debug("OTEL-only cleanup")
_otel_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
# LangSmith-only cleanup
logger.debug("LangSmith-only cleanup")
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, use_multipart
)
logger.debug("Tracing control sub-thread is shutting down")

View File

@@ -0,0 +1,21 @@
import functools
import warnings
from typing import Callable
class LangSmithBetaWarning(UserWarning):
"""This is a warning specific to the LangSmithBeta module."""
@functools.lru_cache(maxsize=100)
def _warn_once(message: str, stacklevel: int = 2) -> None:
warnings.warn(message, LangSmithBetaWarning, stacklevel=stacklevel)
def warn_beta(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
_warn_once(f"Function {func.__name__} is in beta.", stacklevel=3)
return func(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,56 @@
import io
import threading
from typing import Optional
from langsmith import utils as ls_utils
try:
from zstandard import ZstdCompressor # type: ignore[import]
ZSTD_AVAILABLE = True
except ImportError:
ZSTD_AVAILABLE = False
compression_level = int(ls_utils.get_env_var("RUN_COMPRESSION_LEVEL") or 1)
compression_threads = int(ls_utils.get_env_var("RUN_COMPRESSION_THREADS") or -1)
DEFAULT_MAX_UNCOMPRESSED_QUEUE_BYTES = 1024 * 1024 * 1024 # 1GB
class CompressedTraces:
def __init__(self, max_uncompressed_size_bytes: Optional[int] = None) -> None:
if not ZSTD_AVAILABLE:
raise ImportError(
"zstandard is required for compressed trace ingestion. "
"Install it with `pip install zstandard` or set the environment "
"variable LANGSMITH_DISABLE_RUN_COMPRESSION=true to disable "
"compression."
)
# Configure the maximum total uncompressed size for the in-memory queue.
if max_uncompressed_size_bytes is None:
max_bytes_str = ls_utils.get_env_var("MAX_INGEST_MEMORY_BYTES")
if max_bytes_str is not None:
max_uncompressed_size_bytes = int(max_bytes_str)
else:
max_uncompressed_size_bytes = DEFAULT_MAX_UNCOMPRESSED_QUEUE_BYTES
self.max_uncompressed_size_bytes = max_uncompressed_size_bytes
self.buffer: io.BytesIO = io.BytesIO()
self.trace_count: int = 0
self.lock = threading.Lock()
self.uncompressed_size: int = 0
self._context: list[str] = []
self.compressor_writer = ZstdCompressor(
level=compression_level, threads=compression_threads
).stream_writer(self.buffer, closefd=False)
def reset(self) -> None:
self.buffer = io.BytesIO()
self.trace_count = 0
self.uncompressed_size = 0
self._context = []
self.compressor_writer = ZstdCompressor(
level=compression_level, threads=-1
).stream_writer(self.buffer, closefd=False)

View File

@@ -0,0 +1,9 @@
import uuid
_SIZE_LIMIT_BYTES = 20_971_520 # 20MB by default
_AUTO_SCALE_UP_QSIZE_TRIGGER = 200
_AUTO_SCALE_UP_NTHREADS_LIMIT = 32
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4
_BLOCKSIZE_BYTES = 1024 * 1024 # 1MB
_BOUNDARY = uuid.uuid4().hex
_TRACING_QUEUE_MAX_SIZE = 10_000

View File

@@ -0,0 +1,29 @@
"""Shared context (ContextVars and global defaults) that configure tracing."""
import contextvars
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
if TYPE_CHECKING:
from langsmith.client import Client
from langsmith.run_trees import RunTree
else:
Client = Any # type: ignore[assignment]
RunTree = Any # type: ignore[assignment]
_PROJECT_NAME = contextvars.ContextVar[Optional[str]]("_PROJECT_NAME", default=None)
_TAGS = contextvars.ContextVar[Optional[list[str]]]("_TAGS", default=None)
_METADATA = contextvars.ContextVar[Optional[dict[str, Any]]]("_METADATA", default=None)
_TRACING_ENABLED = contextvars.ContextVar[Optional[Union[bool, Literal["local"]]]](
"_TRACING_ENABLED", default=None
)
_CLIENT = contextvars.ContextVar[Optional["Client"]]("_CLIENT", default=None)
_PARENT_RUN_TREE = contextvars.ContextVar[Optional["RunTree"]](
"_PARENT_RUN_TREE", default=None
)
# Not thread-local, so you can set this process-wide (before asyncio.run, etc.)
_GLOBAL_PROJECT_NAME: Optional[str] = None
_GLOBAL_TAGS: Optional[list[str]] = None
_GLOBAL_METADATA: Optional[dict[str, Any]] = None
_GLOBAL_TRACING_ENABLED: Optional[Union[bool, Literal["local"]]] = None
_GLOBAL_CLIENT: Optional["Client"] = None

View File

@@ -0,0 +1,67 @@
from typing import Any, Callable, Literal, Optional
from typing_extensions import TypedDict
METRICS = Literal[
"damerau_levenshtein",
"levenshtein",
"jaro",
"jaro_winkler",
"hamming",
"indel",
]
class EditDistanceConfig(TypedDict, total=False):
metric: METRICS
normalize_score: bool
class EditDistance:
def __init__(
self,
config: Optional[EditDistanceConfig] = None,
):
config = config or {}
metric = config.get("metric") or "damerau_levenshtein"
self.metric = self._get_metric(
metric, normalize_score=config.get("normalize_score", True)
)
def evaluate(
self,
prediction: str,
reference: Optional[str] = None,
) -> float:
return self.metric(prediction, reference)
@staticmethod
def _get_metric(distance: str, normalize_score: bool = True) -> Callable:
try:
from rapidfuzz import ( # type: ignore[import-not-found]
distance as rf_distance,
)
except ImportError:
raise ImportError(
"This operation requires the rapidfuzz library to use."
"Please install it with `pip install -U rapidfuzz`."
)
module_map: dict[str, Any] = {
"damerau_levenshtein": rf_distance.DamerauLevenshtein,
"levenshtein": rf_distance.Levenshtein,
"jaro": rf_distance.Jaro,
"jaro_winkler": rf_distance.JaroWinkler,
"hamming": rf_distance.Hamming,
"indel": rf_distance.Indel,
}
if distance not in module_map:
raise ValueError(
f"Invalid distance metric: {distance}"
f"\nMust be one of: {list(module_map)}"
)
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else:
return module.distance

View File

@@ -0,0 +1,190 @@
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
)
from typing_extensions import TypedDict
if TYPE_CHECKING:
import numpy as np # type: ignore
logger = logging.getLogger(__name__)
Matrix = Union[list[list[float]], list[Any], Any]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices."""
import numpy as np
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd # type: ignore
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = 1 - simd.cdist(X, Y, metric="cosine")
if isinstance(Z, float):
return np.array([Z])
return np.array(Z)
except ImportError:
logger.debug(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity
def _get_openai_encoder() -> Callable[[Sequence[str]], Sequence[Sequence[float]]]:
"""Get the OpenAI GPT-3 encoder."""
try:
from openai import Client as OpenAIClient
except ImportError:
raise ImportError(
"THe default encoder for the EmbeddingDistance class uses the OpenAI API. "
"Please either install the openai library with `pip install openai` or "
"provide a custom encoder function (Callable[[str], Sequence[float]])."
)
def encode_text(texts: Sequence[str]) -> Sequence[Sequence[float]]:
client = OpenAIClient()
response = client.embeddings.create(
input=list(texts), model="text-embedding-3-small"
)
return [d.embedding for d in response.data]
return encode_text
class EmbeddingConfig(TypedDict, total=False):
encoder: Callable[[list[str]], Sequence[Sequence[float]]]
metric: Literal["cosine", "euclidean", "manhattan", "chebyshev", "hamming"]
class EmbeddingDistance:
def __init__(
self,
config: Optional[EmbeddingConfig] = None,
):
config = config or {}
self.distance = config.get("metric") or "cosine"
self.encoder = config.get("encoder") or _get_openai_encoder()
def evaluate(
self,
prediction: str,
reference: str,
) -> float:
try:
import numpy as np
except ImportError:
raise ImportError(
"The EmbeddingDistance class requires NumPy. Please install it with "
"`pip install numpy`."
)
embeddings = self.encoder([prediction, reference])
vector = np.array(embeddings)
return self._compute_distance(vector[0], vector[1]).item()
def _compute_distance(self, a: np.ndarray, b: np.ndarray) -> np.floating:
if self.distance == "cosine":
return self._cosine_distance(a, b) # type: ignore
elif self.distance == "euclidean":
return self._euclidean_distance(a, b)
elif self.distance == "manhattan":
return self._manhattan_distance(a, b)
elif self.distance == "chebyshev":
return self._chebyshev_distance(a, b)
elif self.distance == "hamming":
return self._hamming_distance(a, b)
else:
raise ValueError(f"Invalid distance metric: {self.distance}")
@staticmethod
def _cosine_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Compute the cosine distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.ndarray: The cosine distance.
"""
return 1.0 - cosine_similarity([a], [b])
@staticmethod
def _euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Euclidean distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Euclidean distance.
"""
return np.linalg.norm(a - b)
@staticmethod
def _manhattan_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Manhattan distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Manhattan distance.
"""
return np.sum(np.abs(a - b))
@staticmethod
def _chebyshev_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Chebyshev distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Chebyshev distance.
"""
return np.max(np.abs(a - b))
@staticmethod
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
"""Compute the Hamming distance between two vectors.
Args:
a (np.ndarray): The first vector.
b (np.ndarray): The second vector.
Returns:
np.floating: The Hamming distance.
"""
return np.mean(a != b)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from collections.abc import Iterable
from io import BufferedReader
from typing import Union
MultipartPart = tuple[
str, tuple[None, Union[bytes, BufferedReader], str, dict[str, str]]
]
class MultipartPartsAndContext:
parts: list[MultipartPart]
context: str
__slots__ = ("parts", "context")
def __init__(self, parts: list[MultipartPart], context: str) -> None:
self.parts = parts
self.context = context
def join_multipart_parts_and_context(
parts_and_contexts: Iterable[MultipartPartsAndContext],
) -> MultipartPartsAndContext:
acc_parts: list[MultipartPart] = []
acc_context: list[str] = []
for parts_and_context in parts_and_contexts:
acc_parts.extend(parts_and_context.parts)
acc_context.append(parts_and_context.context)
return MultipartPartsAndContext(acc_parts, "; ".join(acc_context))

View File

@@ -0,0 +1,446 @@
from __future__ import annotations
import itertools
import logging
import os
import uuid
from collections.abc import Iterable
from io import BufferedReader
from typing import Literal, Optional, Union, cast
from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._compressed_traces import CompressedTraces
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json
logger = logging.getLogger(__name__)
class SerializedRunOperation:
operation: Literal["post", "patch"]
id: uuid.UUID
trace_id: uuid.UUID
# this is the whole object, minus the other fields which
# are popped (inputs/outputs/events/attachments)
_none: bytes
inputs: Optional[bytes]
outputs: Optional[bytes]
events: Optional[bytes]
extra: Optional[bytes]
error: Optional[bytes]
serialized: Optional[bytes]
attachments: Optional[ls_schemas.Attachments]
__slots__ = (
"operation",
"id",
"trace_id",
"_none",
"inputs",
"outputs",
"events",
"extra",
"error",
"serialized",
"attachments",
)
def __init__(
self,
operation: Literal["post", "patch"],
id: uuid.UUID,
trace_id: uuid.UUID,
_none: bytes,
inputs: Optional[bytes] = None,
outputs: Optional[bytes] = None,
events: Optional[bytes] = None,
extra: Optional[bytes] = None,
error: Optional[bytes] = None,
serialized: Optional[bytes] = None,
attachments: Optional[ls_schemas.Attachments] = None,
) -> None:
self.operation = operation
self.id = id
self.trace_id = trace_id
self._none = _none
self.inputs = inputs
self.outputs = outputs
self.events = events
self.extra = extra
self.error = error
self.serialized = serialized
self.attachments = attachments
def calculate_serialized_size(self) -> int:
"""Calculate actual serialized size of this operation."""
size = 0
if self._none:
size += len(self._none)
if self.inputs:
size += len(self.inputs)
if self.outputs:
size += len(self.outputs)
if self.events:
size += len(self.events)
if self.extra:
size += len(self.extra)
if self.error:
size += len(self.error)
if self.serialized:
size += len(self.serialized)
if self.attachments:
for content_type, data_or_path in self.attachments.values():
if isinstance(data_or_path, bytes):
size += len(data_or_path)
return size
def deserialize_run_info(self) -> dict:
"""Deserialize the main run info (_none and extra, error and serialized)."""
run_info = _orjson.loads(self._none)
if self.extra is not None:
run_info["extra"] = _orjson.loads(self.extra)
if self.error is not None:
run_info["error"] = _orjson.loads(self.error)
if self.serialized is not None:
run_info["serialized"] = _orjson.loads(self.serialized)
return run_info
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedRunOperation) and (
self.operation,
self.id,
self.trace_id,
self._none,
self.inputs,
self.outputs,
self.events,
self.extra,
self.error,
self.serialized,
self.attachments,
) == (
other.operation,
other.id,
other.trace_id,
other._none,
other.inputs,
other.outputs,
other.events,
other.extra,
other.error,
other.serialized,
other.attachments,
)
class SerializedFeedbackOperation:
id: uuid.UUID
trace_id: uuid.UUID
feedback: bytes
__slots__ = ("id", "trace_id", "feedback")
def __init__(self, id: uuid.UUID, trace_id: uuid.UUID, feedback: bytes) -> None:
self.id = id
self.trace_id = trace_id
self.feedback = feedback
def calculate_serialized_size(self) -> int:
"""Calculate actual serialized size of this operation."""
return len(self.feedback)
def __eq__(self, other: object) -> bool:
return isinstance(other, SerializedFeedbackOperation) and (
self.id,
self.trace_id,
self.feedback,
) == (other.id, other.trace_id, other.feedback)
def serialize_feedback_dict(
feedback: Union[ls_schemas.FeedbackCreate, dict],
) -> SerializedFeedbackOperation:
if hasattr(feedback, "model_dump") and callable(getattr(feedback, "model_dump")):
feedback_create: dict = feedback.model_dump() # type: ignore
else:
feedback_create = cast(dict, feedback)
if "id" not in feedback_create:
feedback_create["id"] = uuid.uuid4()
elif isinstance(feedback_create["id"], str):
feedback_create["id"] = uuid.UUID(feedback_create["id"])
if "trace_id" not in feedback_create:
feedback_create["trace_id"] = uuid.uuid4()
elif isinstance(feedback_create["trace_id"], str):
feedback_create["trace_id"] = uuid.UUID(feedback_create["trace_id"])
return SerializedFeedbackOperation(
id=feedback_create["id"],
trace_id=feedback_create["trace_id"],
feedback=_dumps_json(feedback_create),
)
def serialize_run_dict(
operation: Literal["post", "patch"], payload: dict
) -> SerializedRunOperation:
inputs = payload.pop("inputs", None)
outputs = payload.pop("outputs", None)
events = payload.pop("events", None)
extra = payload.pop("extra", None)
error = payload.pop("error", None)
serialized = payload.pop("serialized", None)
attachments = payload.pop("attachments", None)
return SerializedRunOperation(
operation=operation,
id=payload["id"],
trace_id=payload["trace_id"],
_none=_dumps_json(payload),
inputs=_dumps_json(inputs) if inputs is not None else None,
outputs=_dumps_json(outputs) if outputs is not None else None,
events=_dumps_json(events) if events is not None else None,
extra=_dumps_json(extra) if extra is not None else None,
error=_dumps_json(error) if error is not None else None,
serialized=_dumps_json(serialized) if serialized is not None else None,
attachments=attachments if attachments is not None else None,
)
def combine_serialized_queue_operations(
ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]],
) -> list[Union[SerializedRunOperation, SerializedFeedbackOperation]]:
create_ops_by_id = {
op.id: op
for op in ops
if isinstance(op, SerializedRunOperation) and op.operation == "post"
}
passthrough_ops: list[
Union[SerializedRunOperation, SerializedFeedbackOperation]
] = []
for op in ops:
if isinstance(op, SerializedRunOperation):
if op.operation == "post":
continue
# must be patch
create_op = create_ops_by_id.get(op.id)
if create_op is None:
passthrough_ops.append(op)
continue
if op._none is not None and op._none != create_op._none:
# TODO optimize this more - this would currently be slowest
# for large payloads
create_op_dict = _orjson.loads(create_op._none)
op_dict = {
k: v for k, v in _orjson.loads(op._none).items() if v is not None
}
create_op_dict.update(op_dict)
create_op._none = _orjson.dumps(create_op_dict)
if op.inputs is not None:
create_op.inputs = op.inputs
if op.outputs is not None:
create_op.outputs = op.outputs
if op.events is not None:
create_op.events = op.events
if op.extra is not None:
create_op.extra = op.extra
if op.error is not None:
create_op.error = op.error
if op.serialized is not None:
create_op.serialized = op.serialized
if op.attachments is not None:
if create_op.attachments is None:
create_op.attachments = {}
create_op.attachments.update(op.attachments)
else:
passthrough_ops.append(op)
return list(itertools.chain(create_ops_by_id.values(), passthrough_ops))
def serialized_feedback_operation_to_multipart_parts_and_context(
op: SerializedFeedbackOperation,
) -> MultipartPartsAndContext:
return MultipartPartsAndContext(
[
(
f"feedback.{op.id}",
(
None,
op.feedback,
"application/json",
{"Content-Length": str(len(op.feedback))},
),
)
],
f"trace={op.trace_id},id={op.id}",
)
def serialized_run_operation_to_multipart_parts_and_context(
op: SerializedRunOperation,
) -> tuple[MultipartPartsAndContext, dict[str, BufferedReader]]:
acc_parts: list[MultipartPart] = []
opened_files_dict: dict[str, BufferedReader] = {}
# this is main object, minus inputs/outputs/events/attachments
acc_parts.append(
(
f"{op.operation}.{op.id}",
(
None,
op._none,
"application/json",
{"Content-Length": str(len(op._none))},
),
)
)
for key, value in (
("inputs", op.inputs),
("outputs", op.outputs),
("events", op.events),
("extra", op.extra),
("error", op.error),
("serialized", op.serialized),
):
if value is None:
continue
valb = value
acc_parts.append(
(
f"{op.operation}.{op.id}.{key}",
(
None,
valb,
"application/json",
{"Content-Length": str(len(valb))},
),
),
)
if op.attachments:
for n, (content_type, data_or_path) in op.attachments.items():
if "." in n:
logger.warning(
f"Skipping logging of attachment '{n}' "
f"for run {op.id}:"
" Invalid attachment name. Attachment names must not contain"
" periods ('.'). Please rename the attachment and try again."
)
continue
if isinstance(data_or_path, bytes):
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
data_or_path,
content_type,
{"Content-Length": str(len(data_or_path))},
),
)
)
else:
try:
file_size = os.path.getsize(data_or_path)
file = open(data_or_path, "rb")
except FileNotFoundError:
logger.warning(
"Attachment file not found for run %s: %s", op.id, data_or_path
)
continue
opened_files_dict[str(data_or_path) + str(uuid.uuid4())] = file
acc_parts.append(
(
f"attachment.{op.id}.{n}",
(
None,
file,
f"{content_type}; length={file_size}",
{},
),
)
)
return (
MultipartPartsAndContext(acc_parts, f"trace={op.trace_id},id={op.id}"),
opened_files_dict,
)
def encode_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
boundary: str,
) -> Iterable[tuple[bytes, Union[bytes, BufferedReader]]]:
for part_name, (filename, data, content_type, headers) in parts_and_context.parts:
header_parts = [
f"--{boundary}\r\n",
f'Content-Disposition: form-data; name="{part_name}"',
]
if filename:
header_parts.append(f'; filename="{filename}"')
header_parts.extend(
[
f"\r\nContent-Type: {content_type}\r\n",
*[f"{k}: {v}\r\n" for k, v in headers.items()],
"\r\n",
]
)
yield ("".join(header_parts).encode(), data)
def compress_multipart_parts_and_context(
parts_and_context: MultipartPartsAndContext,
compressed_traces: CompressedTraces,
boundary: str,
) -> bool:
"""Compress multipart parts into the shared compressed buffer.
Returns True if the parts were enqueued into the compressed buffer, or False
if they were rejected because the configured in-memory size limit would be
exceeded.
"""
write = compressed_traces.compressor_writer.write
parts: list[tuple[bytes, bytes]] = []
op_uncompressed_size = 0
for headers, data in encode_multipart_parts_and_context(
parts_and_context, boundary
):
# Normalise to bytes
if not isinstance(data, (bytes, bytearray)):
data = (
data.read() if isinstance(data, BufferedReader) else str(data).encode()
)
parts.append((headers, data))
op_uncompressed_size += len(data)
max_bytes = getattr(compressed_traces, "max_uncompressed_size_bytes", None)
if max_bytes is not None and max_bytes > 0:
current_size = compressed_traces.uncompressed_size
if current_size > 0 and current_size + op_uncompressed_size > max_bytes:
from langsmith.client import _log_tracing_drop
_log_tracing_drop(
f"compressed traces buffer full ({current_size}/{max_bytes} bytes)"
)
return False
for headers, data in parts:
write(headers)
compressed_traces.uncompressed_size += len(data)
write(data)
write(b"\r\n") # part terminator
compressed_traces._context.append(parts_and_context.context)
return True

View File

@@ -0,0 +1,88 @@
"""Stubs for orjson operations, compatible with PyPy via a json fallback."""
try:
from orjson import (
OPT_NON_STR_KEYS,
OPT_SERIALIZE_DATACLASS,
OPT_SERIALIZE_NUMPY,
OPT_SERIALIZE_UUID,
Fragment,
JSONDecodeError,
dumps,
loads,
)
except ImportError:
import dataclasses
import json
import uuid
from typing import Any, Callable, Optional, Union
DefaultFunc = Optional[Callable[[Any], Any]]
OPT_NON_STR_KEYS = 1
OPT_SERIALIZE_DATACLASS = 2
OPT_SERIALIZE_NUMPY = 4
OPT_SERIALIZE_UUID = 8
class Fragment: # type: ignore
def __init__(self, payloadb: bytes):
self.payloadb = payloadb
from json import JSONDecodeError # type: ignore
def dumps(
obj: Any,
/,
default: DefaultFunc = None,
option: Optional[int] = None,
) -> bytes:
# for now, don't do anything for this case because `json.dumps`
# automatically encodes non-str keys as str by default, unlike orjson
# enable_non_str_keys = bool(option & OPT_NON_STR_KEYS)
if option is None:
option = 0
enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY)
enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS)
enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID)
class CustomEncoder(json.JSONEncoder): # type: ignore
def encode(self, o: Any) -> str:
if isinstance(o, Fragment):
return o.payloadb.decode("utf-8") # type: ignore
return super().encode(o)
def default(self, o: Any) -> Any:
if enable_serialize_uuid and isinstance(o, uuid.UUID):
return str(o)
if enable_serialize_numpy and hasattr(o, "tolist"):
# even objects like np.uint16(15) have a .tolist() function
return o.tolist()
if (
enable_serialize_dataclass
and dataclasses.is_dataclass(o)
and not isinstance(o, type)
):
return dataclasses.asdict(o)
if default is not None:
return default(o)
return super().default(o)
return json.dumps(obj, cls=CustomEncoder).encode("utf-8")
def loads(payload: Union[bytes, bytearray, memoryview, str], /) -> Any:
return json.loads(payload)
__all__ = [
"loads",
"dumps",
"Fragment",
"JSONDecodeError",
"OPT_SERIALIZE_NUMPY",
"OPT_SERIALIZE_DATACLASS",
"OPT_SERIALIZE_UUID",
"OPT_NON_STR_KEYS",
]

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from uuid import UUID
def get_otel_trace_id_from_uuid(uuid_val: UUID) -> int:
"""Get OpenTelemetry trace ID as integer from UUID.
Args:
uuid_val: The UUID to convert.
Returns:
Integer representation of the trace ID.
"""
trace_id_hex = uuid_val.hex
return int(trace_id_hex, 16)
def get_otel_span_id_from_uuid(uuid_val: UUID) -> int:
"""Get OpenTelemetry span ID as integer from UUID.
Args:
uuid_val: The UUID to convert.
Returns:
Integer representation of the span ID.
"""
uuid_bytes = uuid_val.bytes
span_id_bytes = uuid_bytes[:8]
span_id_hex = span_id_bytes.hex()
return int(span_id_hex, 16)

View File

@@ -0,0 +1,93 @@
import functools
from urllib3 import __version__ as urllib3version # type: ignore[import-untyped]
from urllib3 import connection # type: ignore[import-untyped]
def _ensure_str(s, encoding="utf-8", errors="strict") -> str:
if isinstance(s, str):
return s
if isinstance(s, bytes):
return s.decode(encoding, errors)
return str(s)
# Copied from https://github.com/urllib3/urllib3/blob/1c994dfc8c5d5ecaee8ed3eb585d4785f5febf6e/src/urllib3/connection.py#L231
def request(self, method, url, body=None, headers=None):
"""Make the request.
This function is based on the urllib3 request method, with modifications
to handle potential issues when using vcrpy in concurrent workloads.
Args:
self: The HTTPConnection instance.
method (str): The HTTP method (e.g., 'GET', 'POST').
url (str): The URL for the request.
body (Optional[Any]): The body of the request.
headers (Optional[dict]): Headers to send with the request.
Returns:
The result of calling the parent request method.
"""
# Update the inner socket's timeout value to send the request.
# This only triggers if the connection is re-used.
if getattr(self, "sock", None) is not None:
self.sock.settimeout(self.timeout)
if headers is None:
headers = {}
else:
# Avoid modifying the headers passed into .request()
headers = headers.copy()
if "user-agent" not in (_ensure_str(k.lower()) for k in headers):
headers["User-Agent"] = connection._get_default_user_agent()
# The above is all the same ^^^
# The following is different:
return self._parent_request(method, url, body=body, headers=headers)
_PATCHED = False
def patch_urllib3():
"""Patch the request method of urllib3 to avoid type errors when using vcrpy.
In concurrent workloads (such as the tracing background queue), the
connection pool can get in a state where an HTTPConnection is created
before vcrpy patches the HTTPConnection class. In urllib3 >= 2.0 this isn't
a problem since they use the proper super().request(...) syntax, but in older
versions, super(HTTPConnection, self).request is used, resulting in a TypeError
since self is no longer a subclass of "HTTPConnection" (which at this point
is vcr.stubs.VCRConnection).
This method patches the class to fix the super() syntax to avoid mixed inheritance.
In the case of the LangSmith tracing logic, it doesn't really matter since we always
exclude cache checks for calls to LangSmith.
The patch is only applied for urllib3 versions older than 2.0.
"""
global _PATCHED
if _PATCHED:
return
from packaging import version
if version.parse(urllib3version) >= version.parse("2.0"):
_PATCHED = True
return
# Lookup the parent class and its request method
parent_class = connection.HTTPConnection.__bases__[0]
parent_request = parent_class.request
def new_request(self, *args, **kwargs):
"""Handle parent request.
This method binds the parent's request method to self and then
calls our modified request function.
"""
self._parent_request = functools.partial(parent_request, self)
return request(self, *args, **kwargs)
connection.HTTPConnection.request = new_request
_PATCHED = True

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import base64
import collections
import datetime
import decimal
import ipaddress
import json
import logging
import pathlib
import re
import uuid
from typing import Any
from langsmith._internal import _orjson
try:
from zoneinfo import ZoneInfo # type: ignore[import-not-found]
except ImportError:
class ZoneInfo: # type: ignore[no-redef]
"""Introduced in python 3.9."""
logger = logging.getLogger(__name__)
def _simple_default(obj):
try:
# Only need to handle types that orjson doesn't serialize by default
# https://github.com/ijl/orjson#serialize
if isinstance(obj, datetime.datetime):
return obj.isoformat()
elif isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, BaseException):
return {"error": type(obj).__name__, "message": str(obj)}
elif isinstance(obj, (set, frozenset, collections.deque)):
return list(obj)
elif isinstance(obj, (datetime.timezone, ZoneInfo)):
return obj.tzname(None)
elif isinstance(obj, datetime.timedelta):
return obj.total_seconds()
elif isinstance(obj, decimal.Decimal):
if obj.as_tuple().exponent >= 0:
return int(obj)
else:
return float(obj)
elif isinstance(
obj,
(
ipaddress.IPv4Address,
ipaddress.IPv4Interface,
ipaddress.IPv4Network,
ipaddress.IPv6Address,
ipaddress.IPv6Interface,
ipaddress.IPv6Network,
pathlib.Path,
),
):
return str(obj)
elif isinstance(obj, re.Pattern):
return obj.pattern
elif isinstance(obj, (bytes, bytearray)):
return base64.b64encode(obj).decode()
return str(obj)
except BaseException as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return str(obj)
_serialization_methods = [
(
"model_dump",
{"exclude_none": True, "mode": "json"},
), # Pydantic V2 with non-serializable fields
("dict", {}), # Pydantic V1 with non-serializable field
("to_dict", {}), # dataclasses-json
]
# IMPORTANT: This function is used from Rust code in `langsmith-pyo3` serialization,
# in order to handle serializing these tricky Python types *from Rust*.
# Do not cause this function to become inaccessible (e.g. by deleting
# or renaming it) without also fixing the corresponding Rust code found in:
# rust/crates/langsmith-pyo3/src/serialization/mod.rs
def _serialize_json(obj: Any) -> Any:
try:
if isinstance(obj, (set, tuple)):
if hasattr(obj, "_asdict") and callable(obj._asdict):
# NamedTuple
return obj._asdict()
return list(obj)
for attr, kwargs in _serialization_methods:
if (
hasattr(obj, attr)
and callable(getattr(obj, attr))
and not isinstance(obj, type)
):
try:
method = getattr(obj, attr)
response = method(**kwargs)
if not isinstance(response, dict):
return str(response)
return response
except Exception as e:
logger.debug(
f"Failed to use {attr} to serialize {type(obj)} to"
f" JSON: {repr(e)}"
)
pass
return _simple_default(obj)
except BaseException as e:
logger.debug(f"Failed to serialize {type(obj)} to JSON: {e}")
return str(obj)
def _elide_surrogates(s: bytes) -> bytes:
pattern = re.compile(rb"\\ud[89a-f][0-9a-f]{2}", re.IGNORECASE)
result = pattern.sub(b"", s)
return result
def dumps_json(obj: Any) -> bytes:
"""Serialize an object to a JSON formatted string.
Parameters
----------
obj : Any
The object to serialize.
default : Callable[[Any], Any] or None, default=None
The default function to use for serialization.
Returns:
-------
str
The JSON formatted string.
"""
try:
return _orjson.dumps(
obj,
default=_serialize_json,
option=_orjson.OPT_SERIALIZE_NUMPY
| _orjson.OPT_SERIALIZE_DATACLASS
| _orjson.OPT_SERIALIZE_UUID
| _orjson.OPT_NON_STR_KEYS,
)
except TypeError as e:
# Usually caused by UTF surrogate characters
logger.debug(f"Orjson serialization failed: {repr(e)}. Falling back to json.")
result = json.dumps(
obj,
default=_serialize_json,
ensure_ascii=True,
).encode("utf-8")
try:
result = _orjson.dumps(
_orjson.loads(result.decode("utf-8", errors="surrogateescape"))
)
except _orjson.JSONDecodeError:
result = _elide_surrogates(result)
return result

View File

@@ -0,0 +1,155 @@
"""UUID helpers backed by uuid-utils."""
from __future__ import annotations
import time
import uuid
import warnings
from typing import Final
import xxhash
from uuid_utils.compat import uuid7 as _uuid_utils_uuid7
_NANOS_PER_SECOND: Final = 1_000_000_000
def _to_timestamp_and_nanos(nanoseconds: int) -> tuple[int, int]:
"""Split a nanosecond timestamp into seconds and remaining nanoseconds."""
seconds, nanos = divmod(nanoseconds, _NANOS_PER_SECOND)
return seconds, nanos
def uuid7(nanoseconds: int | None = None) -> uuid.UUID:
"""Generate a UUID from a Unix timestamp in nanoseconds and random bits.
UUIDv7 objects feature monotonicity within a millisecond.
Args:
nanoseconds: Optional ns timestamp. If not provided, uses current time.
"""
# --- 48 --- -- 4 -- --- 12 --- -- 2 -- --- 30 --- - 32 -
# unix_ts_ms | version | counter_hi | variant | counter_lo | random
#
# 'counter = counter_hi | counter_lo' is a 42-bit counter constructed
# with Method 1 of RFC 9562, §6.2, and its MSB is set to 0.
#
# 'random' is a 32-bit random value regenerated for every new UUID.
#
# If multiple UUIDs are generated within the same millisecond, the LSB
# of 'counter' is incremented by 1. When overflowing, the timestamp is
# advanced and the counter is reset to a random 42-bit integer with MSB
# set to 0.
# For now, just delegate to the uuid_utils implementation
if nanoseconds is None:
return _uuid_utils_uuid7()
seconds, nanos = _to_timestamp_and_nanos(nanoseconds)
return _uuid_utils_uuid7(timestamp=seconds, nanos=nanos)
def is_uuid_v7(uuid_obj: uuid.UUID) -> bool:
"""Check if a UUID is version 7.
Args:
uuid_obj: The UUID to check.
Returns:
True if the UUID is version 7, False otherwise.
"""
return uuid_obj.version == 7
_UUID_V7_WARNING_EMITTED = False
def warn_if_not_uuid_v7(uuid_obj: uuid.UUID, id_type: str) -> None:
"""Warn if a UUID is not version 7.
Args:
uuid_obj: The UUID to check.
id_type: The type of ID (e.g., "run_id", "trace_id") for the warning message.
"""
global _UUID_V7_WARNING_EMITTED
if not is_uuid_v7(uuid_obj) and not _UUID_V7_WARNING_EMITTED:
_UUID_V7_WARNING_EMITTED = True
warnings.warn(
(
"LangSmith now uses UUID v7 for run and trace identifiers. "
"This warning appears when passing custom IDs. "
"Please use: from langsmith import uuid7\n"
" id = uuid7()\n"
"Future versions will require UUID v7."
),
UserWarning,
stacklevel=3,
)
def uuid7_deterministic(original_id: uuid.UUID, key: str) -> uuid.UUID:
"""Generate a deterministic UUID7 derived from an original UUID and a key.
This function creates a new UUID that:
- Preserves the timestamp from the original UUID if it's UUID v7
- Uses current time if the original is not UUID v7
- Uses deterministic bits derived from hashing the original + key with XXH3-128
- Is valid UUID v7 format
This is used for creating replica IDs that maintain time-ordering properties
while being deterministic across distributed systems.
Args:
original_id: The source UUID (ideally UUID v7 to preserve timestamp).
key: A string key used for deterministic derivation (e.g., project name).
Returns:
A new UUID v7 with preserved timestamp (if original is v7) and
deterministic random bits.
Example:
>>> original = uuid7()
>>> replica_id = uuid7_deterministic(original, "replica-project")
>>> # Same inputs always produce same output
>>> assert uuid7_deterministic(original, "replica-project") == replica_id
"""
# Generate deterministic bytes from XXH3-128 hash of original + key
hash_input = f"{original_id}:{key}".encode()
h = xxhash.xxh3_128(hash_input).digest()
# Build new UUID7:
# UUID7 structure (RFC 9562):
# [0-5] 48 bits: unix_ts_ms (timestamp in milliseconds)
# [6] 4 bits: version (0111 = 7) + 4 bits rand_a
# [7] 8 bits: rand_a (continued)
# [8] 2 bits: variant (10) + 6 bits rand_b
# [9-15] 56 bits: rand_b (continued)
b = bytearray(16)
# Check if original is UUID v7 - if so, preserve its timestamp
# If not, use current time to ensure the derived UUID has a valid timestamp
if is_uuid_v7(original_id):
# Preserve timestamp from original UUID7 (bytes 0-5)
b[0:6] = original_id.bytes[0:6]
else:
# Generate fresh timestamp for non-UUID7 inputs
# This matches CPython 3.14's uuid7() implementation:
# timestamp_ms = time.time_ns() // 1_000_000
# Then convert to big-endian bytes
timestamp_ms = time.time_ns() // 1_000_000
# Mask to 48 bits and convert to big-endian bytes
unix_ts_ms = timestamp_ms & 0xFFFF_FFFF_FFFF
b[0:6] = unix_ts_ms.to_bytes(6, "big")
# Set version 7 (0111) in high nibble + 4 bits from hash
b[6] = 0x70 | (h[0] & 0x0F)
# rand_a continued (8 bits from hash)
b[7] = h[1]
# Set variant (10) in high 2 bits + 6 bits from hash
b[8] = 0x80 | (h[2] & 0x3F)
# rand_b (56 bits = 7 bytes from hash)
b[9:16] = h[3:10]
return uuid.UUID(bytes=bytes(b))

View File

@@ -0,0 +1,106 @@
"""Client configuration for OpenTelemetry integration with LangSmith."""
import os
import warnings
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from opentelemetry.sdk.trace import TracerProvider # type: ignore[import]
except ImportError:
TracerProvider = object # type: ignore[assignment, misc]
from langsmith import utils as ls_utils
def _import_otel_client():
"""Dynamically import OTEL client modules when needed."""
try:
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( # type: ignore[import]
OTLPSpanExporter,
)
from opentelemetry.sdk.resources import ( # type: ignore[import]
SERVICE_NAME,
Resource,
)
from opentelemetry.sdk.trace import TracerProvider # type: ignore[import]
from opentelemetry.sdk.trace.export import ( # type: ignore[import]
BatchSpanProcessor,
)
return (
OTLPSpanExporter,
SERVICE_NAME,
Resource,
TracerProvider,
BatchSpanProcessor,
)
except ImportError as e:
warnings.warn(
f"OTEL_ENABLED is set but OpenTelemetry packages are not installed: {e}"
)
return None
def get_otlp_tracer_provider() -> "TracerProvider":
"""Get the OTLP tracer provider for LangSmith.
This function creates a tracer provider that exports spans using the OTLP protocol
with LangSmith-specific defaults:
- OTEL_EXPORTER_OTLP_ENDPOINT: https://api.smith.langchain.com/otel
- OTEL_EXPORTER_OTLP_HEADERS: Contains x-api-key from LangSmith API key and
Langsmith-Project header if project is configured
These defaults can be overridden by setting the environment variables before
calling this function.
Returns:
TracerProvider: The OTLP tracer provider.
"""
# Import OTEL modules dynamically
otel_imports = _import_otel_client()
if otel_imports is None:
raise ImportError(
"OpenTelemetry packages are required to use this function. "
"Please install with `pip install langsmith[otel]`"
)
(
OTLPSpanExporter,
SERVICE_NAME,
Resource,
TracerProvider,
BatchSpanProcessor,
) = otel_imports
if "OTEL_EXPORTER_OTLP_ENDPOINT" not in os.environ:
ls_endpoint = ls_utils.get_api_url(None)
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = f"{ls_endpoint}/otel"
# Configure headers with API key and project if available
if "OTEL_EXPORTER_OTLP_HEADERS" not in os.environ:
api_key = ls_utils.get_api_key(None)
headers = f"x-api-key={api_key}"
project = ls_utils.get_tracer_project()
if project:
headers += f",Langsmith-Project={project}"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = headers
service_name = os.environ.get("OTEL_SERVICE_NAME", "langsmith")
resource = Resource(
attributes={
SERVICE_NAME: service_name,
# Marker to identify LangSmith's internal provider
"langsmith.internal_provider": True,
}
)
tracer_provider = TracerProvider(resource=resource)
otlp_exporter = OTLPSpanExporter()
span_processor = BatchSpanProcessor(otlp_exporter)
tracer_provider.add_span_processor(span_processor)
return tracer_provider

View File

@@ -0,0 +1,845 @@
"""OpenTelemetry exporter for LangSmith runs."""
from __future__ import annotations
import datetime
import logging
import time
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
try:
from opentelemetry.context.context import Context # type: ignore[import]
from opentelemetry.trace import Span # type: ignore[import]
except ImportError:
Context = Any # type: ignore[assignment, misc]
Span = Any # type: ignore[assignment, misc]
from langsmith import utils as ls_utils
from langsmith._internal import _orjson
from langsmith._internal._operations import (
SerializedRunOperation,
)
from langsmith._internal._otel_utils import (
get_otel_span_id_from_uuid,
get_otel_trace_id_from_uuid,
)
def _import_otel_exporter():
"""Dynamically import OTEL exporter modules when needed."""
try:
from opentelemetry import trace # type: ignore[import]
from opentelemetry.context.context import Context # type: ignore[import]
from opentelemetry.trace import ( # type: ignore[import]
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
)
return (
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
)
except ImportError as e:
warnings.warn(
f"OTEL_ENABLED is set but OpenTelemetry packages are not installed: {e}"
)
return None
logger = logging.getLogger(__name__)
# OpenTelemetry GenAI semconv attribute names
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_FREQUENCY_PENALTY = "gen_ai.request.frequency_penalty"
GEN_AI_REQUEST_PRESENCE_PENALTY = "gen_ai.request.presence_penalty"
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
GENAI_PROMPT = "gen_ai.prompt"
GENAI_COMPLETION = "gen_ai.completion"
GEN_AI_REQUEST_EXTRA_QUERY = "gen_ai.request.extra_query"
GEN_AI_REQUEST_EXTRA_BODY = "gen_ai.request.extra_body"
GEN_AI_SERIALIZED_NAME = "gen_ai.serialized.name"
GEN_AI_SERIALIZED_SIGNATURE = "gen_ai.serialized.signature"
GEN_AI_SERIALIZED_DOC = "gen_ai.serialized.doc"
GEN_AI_RESPONSE_ID = "gen_ai.response.id"
GEN_AI_RESPONSE_SERVICE_TIER = "gen_ai.response.service_tier"
GEN_AI_RESPONSE_SYSTEM_FINGERPRINT = "gen_ai.response.system_fingerprint"
GEN_AI_USAGE_INPUT_TOKEN_DETAILS = "gen_ai.usage.input_token_details"
GEN_AI_USAGE_OUTPUT_TOKEN_DETAILS = "gen_ai.usage.output_token_details"
# LangSmith custom attributes
LANGSMITH_SESSION_ID = "langsmith.trace.session_id"
LANGSMITH_SESSION_NAME = "langsmith.trace.session_name"
LANGSMITH_RUN_TYPE = "langsmith.span.kind"
LANGSMITH_NAME = "langsmith.trace.name"
LANGSMITH_METADATA = "langsmith.metadata"
LANGSMITH_TAGS = "langsmith.span.tags"
LANGSMITH_RUNTIME = "langsmith.span.runtime"
LANGSMITH_REQUEST_STREAMING = "langsmith.request.streaming"
LANGSMITH_REQUEST_HEADERS = "langsmith.request.headers"
# GenAI event names
GEN_AI_SYSTEM_MESSAGE = "gen_ai.system.message"
GEN_AI_USER_MESSAGE = "gen_ai.user.message"
GEN_AI_ASSISTANT_MESSAGE = "gen_ai.assistant.message"
GEN_AI_CHOICE = "gen_ai.choice"
WELL_KNOWN_OPERATION_NAMES = {
"llm": "chat",
"tool": "execute_tool",
"retriever": "embeddings",
"embedding": "embeddings",
"prompt": "chat",
}
def _get_operation_name(run_type: str) -> str:
return WELL_KNOWN_OPERATION_NAMES.get(run_type, run_type)
class OTELExporter:
__slots__ = [
"_tracer",
"_span_info",
"_otel_available",
"_trace",
"_span_ttl_seconds",
"_last_cleanup",
]
"""OpenTelemetry exporter for LangSmith runs."""
def __init__(self, tracer_provider=None, span_ttl_seconds=None):
"""Initialize the OTEL exporter.
Args:
tracer_provider: Optional tracer provider to use. If not provided,
the global tracer provider will be used.
span_ttl_seconds: TTL for incomplete traces in seconds. If None,
uses LANGSMITH_OTEL_SPAN_TTL_SECONDS env var (default: 3600s)
"""
# Set defaults from environment variables if not provided
if span_ttl_seconds is None:
span_ttl_seconds = int(
ls_utils.get_env_var("OTEL_SPAN_TTL_SECONDS", default="3600")
)
otel_imports = _import_otel_exporter()
if otel_imports is None:
self._tracer = None
self._span_info = {}
self._otel_available = False
self._trace = None
self._span_ttl_seconds = span_ttl_seconds
self._last_cleanup = 0.0
else:
(
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
) = otel_imports
self._tracer = trace.get_tracer(
"langsmith", tracer_provider=tracer_provider
)
self._span_info = {}
self._otel_available = True
self._trace = trace
self._span_ttl_seconds = span_ttl_seconds
self._last_cleanup = 0.0
def export_batch(
self,
operations: list[SerializedRunOperation],
otel_context_map: dict[uuid.UUID, Optional[Context]],
) -> None:
"""Export a batch of serialized run operations to OTEL.
Args:
operations: List of serialized run operations to export.
"""
# Proactive cleanup of expired and excess spans before new operations
self._cleanup_stale_spans()
for op in operations:
try:
run_info = self._deserialize_run_info(op)
if not run_info:
continue
if op.operation == "post":
span = self._create_span_for_run(
op, run_info, otel_context_map.get(op.id)
)
if span:
self._span_info[op.id] = {
"span": span,
"created_at": time.time(),
}
else:
self._update_span_for_run(op, run_info)
except Exception as e:
logger.exception(f"Error processing operation {op.id}: {e}")
def _deserialize_run_info(self, op: SerializedRunOperation) -> Optional[dict]:
"""Deserialize the run info from the operation.
Args:
op: The serialized run operation.
Returns:
The deserialized run info as a dictionary, or None if deserialization
failed.
"""
try:
return op.deserialize_run_info()
except Exception as e:
logger.exception(f"Failed to deserialize run info for {op.id}: {e}")
return None
def _create_span_for_run(
self,
op: SerializedRunOperation,
run_info: dict,
otel_context: Optional[Context] = None,
) -> Optional[Span]:
"""Create an OpenTelemetry span for a run operation.
Args:
op: The serialized run operation.
run_info: The deserialized run info.
parent_span: Optional parent span.
Returns:
The created span, or None if creation failed.
"""
try:
start_time = run_info.get("start_time")
start_time_utc_nano = self._as_utc_nano(start_time)
end_time = run_info.get("end_time")
end_time_utc_nano = self._as_utc_nano(end_time)
# Create deterministic trace and span IDs to match user OpenTelemetry spans
trace_id_int = get_otel_trace_id_from_uuid(op.trace_id)
span_id_int = get_otel_span_id_from_uuid(op.id)
# Get OTEL imports for this operation
otel_imports = _import_otel_exporter()
if otel_imports is None:
return None
(
trace,
Context,
NonRecordingSpan,
Span,
SpanContext,
TraceFlags,
TraceState,
set_span_in_context,
) = otel_imports
# Create SpanContext with deterministic IDs
span_context = SpanContext(
trace_id=trace_id_int,
span_id=span_id_int,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
# Create NonRecordingSpan for context setting
non_recording_span = NonRecordingSpan(span_context)
deterministic_context = set_span_in_context(non_recording_span)
# Start the span with appropriate context
parent_run_id = run_info.get("parent_run_id")
if (
parent_run_id is not None
and uuid.UUID(parent_run_id) in self._span_info
):
# Use the parent span context
parent_span = self._span_info[uuid.UUID(parent_run_id)]["span"]
span = self._tracer.start_span(
run_info.get("name"),
context=set_span_in_context(parent_span),
start_time=start_time_utc_nano,
)
else:
# For root spans, check if there's an existing OpenTelemetry context
# If so, inherit from it; otherwise use our deterministic context
current_context = (
otel_context if otel_context else deterministic_context
)
span = self._tracer.start_span(
run_info.get("name"),
context=current_context,
start_time=start_time_utc_nano,
)
# Set all attributes
self._set_span_attributes(span, run_info, op)
# Set status based on error
if run_info.get("error"):
span.set_status(trace.StatusCode.ERROR)
span.record_exception(Exception(run_info.get("error")))
else:
span.set_status(trace.StatusCode.OK)
# End the span if end_time is present
end_time = run_info.get("end_time")
if end_time:
end_time_utc_nano = self._as_utc_nano(end_time)
if end_time_utc_nano:
span.end(end_time=end_time_utc_nano)
else:
span.end()
return span
except Exception as e:
logger.exception(f"Failed to create span for run {op.id}: {e}")
return None
def _update_span_for_run(self, op: SerializedRunOperation, run_info: dict) -> None:
"""Update an OpenTelemetry span for a run operation.
Args:
op: The serialized run operation.
run_info: The deserialized run info.
"""
try:
# Get the span for this run
if op.id not in self._span_info:
logger.debug(f"No span found for run {op.id} during update")
return
span = self._span_info[op.id]["span"]
# Update attributes
self._set_span_attributes(span, run_info, op)
# Update status based on error
if run_info.get("error"):
span.set_status(self._trace.StatusCode.ERROR)
span.record_exception(Exception(run_info.get("error")))
else:
span.set_status(self._trace.StatusCode.OK)
# End the span if end_time is present
end_time = run_info.get("end_time")
if end_time:
end_time_utc_nano = self._as_utc_nano(end_time)
if end_time_utc_nano:
span.end(end_time=end_time_utc_nano)
else:
span.end()
# Remove the span info from our dictionary
del self._span_info[op.id]
logger.debug(f"Completed span, remaining spans: {len(self._span_info)}")
else:
# Span exists but no end_time - this is normal for ongoing operations
logger.debug("Updated span (no end_time yet)")
except Exception as e:
logger.exception(f"Failed to update span for run {op.id}: {e}")
def _cleanup_stale_spans(self) -> None:
"""Clean up spans older than TTL threshold."""
if not self._span_info:
return
current_time = time.time()
# Only run cleanup every 10 seconds to reduce overhead
if current_time - self._last_cleanup < 10.0:
return
self._last_cleanup = current_time
cutoff_time = current_time - self._span_ttl_seconds
# Remove spans older than TTL in one pass
stale_span_ids = [
span_id
for span_id, info in self._span_info.items()
if info["created_at"] < cutoff_time
]
if stale_span_ids:
logger.info(
f" LangSmith OTEL Cleanup: Removing {len(stale_span_ids)} stale spans"
)
for span_id in stale_span_ids:
self._remove_span(span_id)
def _remove_span(self, span_id: uuid.UUID) -> None:
"""Remove a single span and clean up resources.
Note:
We call `span.end()` here because spans in `_span_info` are orphaned -
they never received their patch operation and will never naturally complete.
Ending them gracefully is better than leaving them open indefinitely.
"""
if span_id not in self._span_info:
return
try:
# End the orphaned span gracefully
span = self._span_info[span_id]["span"]
# Check if span is still active before ending it
if (
hasattr(span, "end")
and hasattr(span, "is_recording")
and span.is_recording()
):
span.end()
logger.debug(f"Ended orphaned span {span_id}")
elif hasattr(span, "end"):
# Span already ended, just log it
logger.debug(f"Span {span_id} already ended, skipping end() call")
# Remove from tracking regardless
del self._span_info[span_id]
except Exception as e:
logger.debug(f"Error removing span {span_id}: {e}")
# Still try to remove from tracking even if ending failed
try:
del self._span_info[span_id]
except KeyError:
pass
def _extract_model_name(self, run_info: dict) -> Optional[str]:
"""Extract model name from run info.
Args:
run_info: The run info.
Returns:
The model name, or None if not found.
"""
# Try to get model name from metadata
if run_info.get("extra") and run_info["extra"].get("metadata"):
metadata = run_info["extra"]["metadata"]
# First check for ls_model_name in metadata
if metadata.get("ls_model_name"):
return metadata["ls_model_name"]
# Then check invocation_params for model info
if "invocation_params" in metadata:
invocation_params = metadata["invocation_params"]
# Check model first, then model_name
if invocation_params.get("model"):
return invocation_params["model"]
elif invocation_params.get("model_name"):
return invocation_params["model_name"]
return None
def _set_span_attributes(
self,
span: Span,
run_info: dict,
op: SerializedRunOperation,
) -> None:
"""Set attributes on the span.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
op: The serialized run operation.
"""
# Set LangSmith-specific attributes
if run_info.get("run_type"):
span.set_attribute(LANGSMITH_RUN_TYPE, str(run_info.get("run_type")))
if run_info.get("name"):
span.set_attribute(LANGSMITH_NAME, str(run_info.get("name")))
if run_info.get("session_id"):
span.set_attribute(LANGSMITH_SESSION_ID, str(run_info.get("session_id")))
if run_info.get("session_name"):
span.set_attribute(
LANGSMITH_SESSION_NAME, str(run_info.get("session_name"))
)
# Set GenAI attributes according to OTEL semantic conventions
# Set gen_ai.operation.name
if op.operation == "post":
operation_name = _get_operation_name(run_info.get("run_type", "chain"))
span.set_attribute(GEN_AI_OPERATION_NAME, operation_name)
# Set gen_ai.system
self._set_gen_ai_system(span, run_info)
# Set model name if available
model_name = self._extract_model_name(run_info)
if model_name:
span.set_attribute(GEN_AI_REQUEST_MODEL, model_name)
# Set token usage information
if run_info.get("prompt_tokens") is not None:
prompt_tokens = run_info["prompt_tokens"]
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(prompt_tokens))
if run_info.get("completion_tokens") is not None:
completion_tokens = run_info["completion_tokens"]
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(completion_tokens))
if run_info.get("total_tokens") is not None:
total_tokens = run_info["total_tokens"]
span.set_attribute(GEN_AI_USAGE_TOTAL_TOKENS, int(total_tokens))
# Set other parameters from invocation_params
self._set_invocation_parameters(span, run_info)
# Set metadata and tags if available
extra = run_info.get("extra", {})
metadata = extra.get("metadata", {})
for key, value in metadata.items():
if value is not None:
span.set_attribute(f"{LANGSMITH_METADATA}.{key}", value)
tags = run_info.get("tags")
if tags:
if isinstance(tags, list):
span.set_attribute(LANGSMITH_TAGS, ", ".join(tags))
else:
span.set_attribute(LANGSMITH_TAGS, tags)
# Support additional serialized attributes, if present
if run_info.get("serialized") and isinstance(run_info["serialized"], dict):
serialized = run_info["serialized"]
if "name" in serialized and serialized["name"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_NAME, serialized["name"])
if "signature" in serialized and serialized["signature"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_SIGNATURE, serialized["signature"])
if "doc" in serialized and serialized["doc"] is not None:
span.set_attribute(GEN_AI_SERIALIZED_DOC, serialized["doc"])
# Set inputs/outputs if available
self._set_io_attributes(span, op)
def _set_gen_ai_system(self, span: Span, run_info: dict) -> None:
"""Set the gen_ai.system attribute on the span based on the model provider.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
"""
# Default to "langchain" if we can't determine the system
system = "langchain"
# Extract model name to determine the system
model_name = self._extract_model_name(run_info)
if model_name:
model_lower = model_name.lower()
if "anthropic" in model_lower or model_lower.startswith("claude"):
system = "anthropic"
elif "bedrock" in model_lower:
system = "aws.bedrock"
elif "azure" in model_lower and "openai" in model_lower:
system = "az.ai.openai"
elif "azure" in model_lower and "inference" in model_lower:
system = "az.ai.inference"
elif "cohere" in model_lower:
system = "cohere"
elif "deepseek" in model_lower:
system = "deepseek"
elif "gemini" in model_lower:
system = "gemini"
elif "groq" in model_lower:
system = "groq"
elif "watson" in model_lower or "ibm" in model_lower:
system = "ibm.watsonx.ai"
elif "mistral" in model_lower:
system = "mistral_ai"
elif "gpt" in model_lower or "openai" in model_lower:
system = "openai"
elif "perplexity" in model_lower or "sonar" in model_lower:
system = "perplexity"
elif "vertex" in model_lower:
system = "vertex_ai"
elif "xai" in model_lower or "grok" in model_lower:
system = "xai"
elif "qwen" in model_lower:
system = "qwen"
span.set_attribute(GEN_AI_SYSTEM, system)
setattr(span, "_gen_ai_system", system)
def _set_invocation_parameters(self, span: Span, run_info: dict) -> None:
"""Set invocation parameters on the span.
Args:
span: The span to set attributes on.
run_info: The deserialized run info.
"""
if not (run_info.get("extra") and run_info["extra"].get("metadata")):
return
metadata = run_info["extra"]["metadata"]
if "invocation_params" not in metadata:
return
invocation_params = metadata["invocation_params"]
# Set relevant invocation parameters
if "max_tokens" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_MAX_TOKENS, invocation_params["max_tokens"]
)
if "temperature" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_TEMPERATURE, invocation_params["temperature"]
)
if "top_p" in invocation_params:
span.set_attribute(GEN_AI_REQUEST_TOP_P, invocation_params["top_p"])
if "frequency_penalty" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_FREQUENCY_PENALTY, invocation_params["frequency_penalty"]
)
if "presence_penalty" in invocation_params:
span.set_attribute(
GEN_AI_REQUEST_PRESENCE_PENALTY, invocation_params["presence_penalty"]
)
def _set_io_attributes(self, span: Span, op: SerializedRunOperation) -> None:
"""Set input/output attributes on the span.
Args:
span: The span to set attributes on.
op: The serialized run operation.
"""
if op.inputs:
try:
inputs = _orjson.loads(op.inputs)
if isinstance(inputs, dict):
if (
"model" in inputs
and isinstance(inputs.get("messages"), list)
and inputs["model"] is not None
):
span.set_attribute(GEN_AI_REQUEST_MODEL, inputs["model"])
# Set additional request attributes if available.
if "stream" in inputs and inputs["stream"] is not None:
span.set_attribute(
LANGSMITH_REQUEST_STREAMING, inputs["stream"]
)
if (
"extra_headers" in inputs
and inputs["extra_headers"] is not None
):
span.set_attribute(
LANGSMITH_REQUEST_HEADERS, inputs["extra_headers"]
)
if "extra_query" in inputs and inputs["extra_query"] is not None:
span.set_attribute(
GEN_AI_REQUEST_EXTRA_QUERY, inputs["extra_query"]
)
if "extra_body" in inputs and inputs["extra_body"] is not None:
span.set_attribute(
GEN_AI_REQUEST_EXTRA_BODY, inputs["extra_body"]
)
span.set_attribute(GENAI_PROMPT, op.inputs)
except Exception:
logger.debug(
"Failed to process inputs for run %s", op.id, exc_info=True
)
if op.outputs:
try:
outputs = _orjson.loads(op.outputs)
# Extract token usage from outputs (for LLM runs)
token_usage = self.get_unified_run_tokens(outputs)
if token_usage:
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, token_usage[0])
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, token_usage[1])
span.set_attribute(
GEN_AI_USAGE_TOTAL_TOKENS, token_usage[0] + token_usage[1]
)
if "model" in outputs:
span.set_attribute(GEN_AI_RESPONSE_MODEL, str(outputs["model"]))
# Extract additional response attributes.
if isinstance(outputs, dict):
if "id" in outputs and outputs["id"] is not None:
span.set_attribute(GEN_AI_RESPONSE_ID, outputs["id"])
if "choices" in outputs and isinstance(outputs["choices"], list):
finish_reasons = []
for choice in outputs["choices"]:
if (
"finish_reason" in choice
and choice["finish_reason"] is not None
):
finish_reasons.append(str(choice["finish_reason"]))
if finish_reasons:
span.set_attribute(
GEN_AI_RESPONSE_FINISH_REASONS,
", ".join(finish_reasons),
)
if (
"service_tier" in outputs
and outputs["service_tier"] is not None
):
span.set_attribute(
GEN_AI_RESPONSE_SERVICE_TIER, outputs["service_tier"]
)
if (
"system_fingerprint" in outputs
and outputs["system_fingerprint"] is not None
):
span.set_attribute(
GEN_AI_RESPONSE_SYSTEM_FINGERPRINT,
outputs["system_fingerprint"],
)
if "usage_metadata" in outputs and isinstance(
outputs["usage_metadata"], dict
):
usage_metadata = outputs["usage_metadata"]
if (
"input_token_details" in usage_metadata
and usage_metadata["input_token_details"] is not None
):
input_token_details = str(
usage_metadata["input_token_details"]
)
span.set_attribute(
GEN_AI_USAGE_INPUT_TOKEN_DETAILS, input_token_details
)
if (
"output_token_details" in usage_metadata
and usage_metadata["output_token_details"] is not None
):
output_token_details = str(
usage_metadata["output_token_details"]
)
span.set_attribute(
GEN_AI_USAGE_OUTPUT_TOKEN_DETAILS, output_token_details
)
span.set_attribute(GENAI_COMPLETION, op.outputs)
except Exception:
logger.debug(
"Failed to process outputs for run %s", op.id, exc_info=True
)
def _as_utc_nano(self, timestamp: Optional[str]) -> Optional[int]:
if not timestamp:
return None
try:
dt = datetime.datetime.fromisoformat(timestamp)
return int(dt.astimezone(datetime.timezone.utc).timestamp() * 1_000_000_000)
except ValueError:
logger.exception(f"Failed to parse timestamp {timestamp}")
return None
def get_unified_run_tokens(
self, outputs: Optional[dict]
) -> Optional[tuple[int, int]]:
if not outputs:
return None
# search in non-generations lists
if output := self._extract_unified_run_tokens(outputs.get("usage_metadata")):
return output
# find if direct kwarg in outputs
keys = outputs.keys()
for key in keys:
haystack = outputs[key]
if not haystack or not isinstance(haystack, dict):
continue
if output := self._extract_unified_run_tokens(
haystack.get("usage_metadata")
):
return output
if (
haystack.get("lc") == 1
and "kwargs" in haystack
and isinstance(haystack["kwargs"], dict)
and (
output := self._extract_unified_run_tokens(
haystack["kwargs"].get("usage_metadata")
)
)
):
return output
# find in generations
generations = outputs.get("generations") or []
if not isinstance(generations, list):
return None
if generations and not isinstance(generations[0], list):
generations = [generations]
for generation in [x for xs in generations for x in xs]:
if (
isinstance(generation, dict)
and "message" in generation
and isinstance(generation["message"], dict)
and "kwargs" in generation["message"]
and isinstance(generation["message"]["kwargs"], dict)
and (
output := self._extract_unified_run_tokens(
generation["message"]["kwargs"].get("usage_metadata")
)
)
):
return output
return None
def _extract_unified_run_tokens(
self, outputs: Optional[Any]
) -> Optional[tuple[int, int]]:
if not outputs or not isinstance(outputs, dict):
return None
if "input_tokens" not in outputs or "output_tokens" not in outputs:
return None
if not isinstance(outputs["input_tokens"], int) or not isinstance(
outputs["output_tokens"], int
):
return None
return outputs["input_tokens"], outputs["output_tokens"]