initial commit
This commit is contained in:
13
venv/Lib/site-packages/langsmith/wrappers/__init__.py
Normal file
13
venv/Lib/site-packages/langsmith/wrappers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""This module provides convenient tracing wrappers for popular libraries."""
|
||||
|
||||
from langsmith.wrappers._anthropic import wrap_anthropic
|
||||
from langsmith.wrappers._gemini import wrap_gemini # BETA
|
||||
from langsmith.wrappers._openai import wrap_openai
|
||||
from langsmith.wrappers._openai_agents import OpenAIAgentsTracingProcessor
|
||||
|
||||
__all__ = [
|
||||
"wrap_anthropic",
|
||||
"wrap_gemini", # BETA
|
||||
"wrap_openai",
|
||||
"OpenAIAgentsTracingProcessor",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
577
venv/Lib/site-packages/langsmith/wrappers/_anthropic.py
Normal file
577
venv/Lib/site-packages/langsmith/wrappers/_anthropic.py
Normal file
@@ -0,0 +1,577 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Mapping, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langsmith import client as ls_client
|
||||
from langsmith import run_helpers
|
||||
from langsmith.schemas import InputTokenDetails, UsageMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.lib.streaming import AsyncMessageStream, MessageStream
|
||||
from anthropic.types import Completion, Message, MessageStreamEvent
|
||||
|
||||
C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_not_given() -> Optional[tuple[type, ...]]:
|
||||
try:
|
||||
from anthropic._types import NotGiven, Omit
|
||||
|
||||
return (NotGiven, Omit)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _strip_not_given(d: dict) -> dict:
|
||||
try:
|
||||
if not_given := _get_not_given():
|
||||
d = {
|
||||
k: v
|
||||
for k, v in d.items()
|
||||
if not any(isinstance(v, t) for t in not_given)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error stripping NotGiven: {e}")
|
||||
|
||||
if "system" in d:
|
||||
d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
|
||||
"messages", []
|
||||
)
|
||||
d.pop("system")
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
def _infer_ls_params(prepopulated_invocation_params: dict, kwargs: dict):
|
||||
stripped = _strip_not_given(kwargs)
|
||||
|
||||
stop = stripped.get("stop")
|
||||
if stop and isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
# Allowlist of safe invocation parameters to include
|
||||
# Only include known, non-sensitive parameters
|
||||
allowed_invocation_keys = {
|
||||
"mcp_servers",
|
||||
"service_tier",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"stream",
|
||||
"thinking",
|
||||
}
|
||||
|
||||
# Only include allowlisted parameters
|
||||
invocation_params = {
|
||||
k: v for k, v in stripped.items() if k in allowed_invocation_keys
|
||||
}
|
||||
|
||||
return {
|
||||
"ls_provider": "anthropic",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": stripped.get("model", None),
|
||||
"ls_temperature": stripped.get("temperature", None),
|
||||
"ls_max_tokens": stripped.get("max_tokens", None),
|
||||
"ls_stop": stop,
|
||||
"ls_invocation_params": {
|
||||
**prepopulated_invocation_params,
|
||||
**invocation_params,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _accumulate_event(
|
||||
*, event: MessageStreamEvent, current_snapshot: Message | None
|
||||
) -> Message | None:
|
||||
try:
|
||||
from anthropic.types import ContentBlock
|
||||
except ImportError:
|
||||
logger.debug("Error importing ContentBlock")
|
||||
return current_snapshot
|
||||
|
||||
if current_snapshot is None:
|
||||
if event.type == "message_start":
|
||||
return event.message
|
||||
|
||||
raise RuntimeError(
|
||||
f'Unexpected event order, got {event.type} before "message_start"'
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
# TODO: check index <-- from anthropic SDK :)
|
||||
adapter: TypeAdapter = TypeAdapter(ContentBlock)
|
||||
content_block_instance = adapter.validate_python(
|
||||
event.content_block.model_dump()
|
||||
)
|
||||
current_snapshot.content.append(
|
||||
content_block_instance, # type: ignore[attr-defined]
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
content = current_snapshot.content[event.index]
|
||||
if content.type == "text" and event.delta.type == "text_delta":
|
||||
content.text += event.delta.text
|
||||
elif event.type == "message_delta":
|
||||
current_snapshot.stop_reason = event.delta.stop_reason
|
||||
current_snapshot.stop_sequence = event.delta.stop_sequence
|
||||
current_snapshot.usage.output_tokens = event.usage.output_tokens
|
||||
|
||||
return current_snapshot
|
||||
|
||||
|
||||
def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
|
||||
input_tokens = anthropic_token_usage.get("input_tokens") or 0
|
||||
output_tokens = anthropic_token_usage.get("output_tokens") or 0
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
input_token_details: dict = {}
|
||||
cache_read = anthropic_token_usage.get("cache_read_input_tokens") or 0
|
||||
if cache_read:
|
||||
input_token_details["cache_read"] = cache_read
|
||||
|
||||
cache_creation_obj = anthropic_token_usage.get("cache_creation") or {}
|
||||
if cache_creation_obj:
|
||||
ephemeral_5m = cache_creation_obj.get("ephemeral_5m_input_tokens") or 0
|
||||
ephemeral_1h = cache_creation_obj.get("ephemeral_1h_input_tokens") or 0
|
||||
if ephemeral_5m:
|
||||
input_token_details["ephemeral_5m_input_tokens"] = ephemeral_5m
|
||||
if ephemeral_1h:
|
||||
input_token_details["ephemeral_1h_input_tokens"] = ephemeral_1h
|
||||
else:
|
||||
cache_creation = anthropic_token_usage.get("cache_creation_input_tokens") or 0
|
||||
if cache_creation:
|
||||
input_token_details["cache_creation"] = cache_creation
|
||||
|
||||
result = UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
if input_token_details:
|
||||
result["input_token_details"] = InputTokenDetails(**input_token_details)
|
||||
return result
|
||||
|
||||
|
||||
def _message_to_outputs(message: Any) -> dict:
|
||||
"""Convert an Anthropic Message to a flat outputs dict with usage_metadata."""
|
||||
rdict = message.model_dump()
|
||||
anthropic_token_usage = rdict.pop("usage", None)
|
||||
if anthropic_token_usage:
|
||||
rdict["usage_metadata"] = _create_usage_metadata(anthropic_token_usage)
|
||||
rdict.pop("type", None)
|
||||
return rdict
|
||||
|
||||
|
||||
def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
|
||||
full_message = None
|
||||
for chunk in all_chunks:
|
||||
try:
|
||||
full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
|
||||
return {"output": all_chunks}
|
||||
if full_message is None:
|
||||
return {"output": all_chunks}
|
||||
return _message_to_outputs(full_message)
|
||||
|
||||
|
||||
def _reduce_completions(all_chunks: list[Completion]) -> dict:
|
||||
all_content = []
|
||||
for chunk in all_chunks:
|
||||
content = chunk.completion
|
||||
if content is not None:
|
||||
all_content.append(content)
|
||||
content = "".join(all_content)
|
||||
if all_chunks:
|
||||
d = all_chunks[-1].model_dump()
|
||||
d["choices"] = [{"text": content}]
|
||||
else:
|
||||
d = {"choices": [{"text": content}]}
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def _process_chat_completion(outputs: Any):
|
||||
try:
|
||||
# Check if outputs is a LegacyAPIResponse wrapper (from with_raw_response).
|
||||
# The Anthropic SDK's LegacyAPIResponse wraps the actual response object.
|
||||
# Call .parse() to extract the Message for tracing.
|
||||
# See: anthropics/anthropic-sdk-python _legacy_response.py#L102
|
||||
if hasattr(outputs, "parse") and callable(outputs.parse):
|
||||
try:
|
||||
outputs = outputs.parse()
|
||||
except Exception:
|
||||
pass
|
||||
return _message_to_outputs(outputs)
|
||||
except BaseException as e:
|
||||
logger.debug(f"Error processing chat completion: {e}")
|
||||
return {"output": outputs}
|
||||
|
||||
|
||||
def _get_wrapper(
|
||||
original_create: Callable,
|
||||
name: str,
|
||||
reduce_fn: Callable,
|
||||
prepopulated_invocation_params: dict,
|
||||
tracing_extra: TracingExtra,
|
||||
) -> Callable:
|
||||
@functools.wraps(original_create)
|
||||
def create(*args, **kwargs):
|
||||
stream = kwargs.get("stream")
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if stream else None,
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
|
||||
result = decorator(original_create)(*args, **kwargs)
|
||||
return result
|
||||
|
||||
@functools.wraps(original_create)
|
||||
async def acreate(*args, **kwargs):
|
||||
stream = kwargs.get("stream")
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if stream else None,
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
result = await decorator(original_create)(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return acreate if run_helpers.is_async(original_create) else create
|
||||
|
||||
|
||||
def _get_stream_wrapper(
|
||||
original_stream: Callable,
|
||||
name: str,
|
||||
prepopulated_invocation_params: dict,
|
||||
tracing_extra: TracingExtra,
|
||||
) -> Callable:
|
||||
"""Create a wrapper for Anthropic's streaming context manager."""
|
||||
is_async = "async" in str(original_stream).lower()
|
||||
configured_traceable = run_helpers.traceable(
|
||||
name=name,
|
||||
reduce_fn=_reduce_chat_chunks,
|
||||
run_type="llm",
|
||||
process_inputs=_strip_not_given,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
configured_traceable_text = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
|
||||
if is_async:
|
||||
|
||||
class AsyncMessageStreamWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
wrapped: AsyncMessageStream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._wrapped = wrapped
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def text_stream(self):
|
||||
@configured_traceable_text
|
||||
async def _text_stream(**_):
|
||||
async for chunk in self._wrapped.text_stream:
|
||||
yield chunk
|
||||
run_tree = run_helpers.get_current_run_tree()
|
||||
final_message = await self._wrapped.get_final_message()
|
||||
outputs = _message_to_outputs(final_message)
|
||||
run_tree.outputs = outputs
|
||||
if usage := outputs.get("usage_metadata"):
|
||||
run_tree.metadata["usage_metadata"] = usage
|
||||
|
||||
return _text_stream(**self._kwargs)
|
||||
|
||||
@property
|
||||
def response(self) -> httpx.Response:
|
||||
return self._wrapped.response
|
||||
|
||||
@property
|
||||
def request_id(self) -> str | None:
|
||||
return self._wrapped.request_id
|
||||
|
||||
async def __anext__(self) -> MessageStreamEvent:
|
||||
aiter = self.__aiter__()
|
||||
return await aiter.__anext__()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
|
||||
@configured_traceable
|
||||
def traced_iter(**_):
|
||||
return self._wrapped.__aiter__()
|
||||
|
||||
async for chunk in traced_iter(**self._kwargs):
|
||||
yield chunk
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self._wrapped.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
await self._wrapped.__aexit__(*exc)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._wrapped.close()
|
||||
|
||||
async def get_final_message(self) -> Message:
|
||||
return await self._wrapped.get_final_message()
|
||||
|
||||
async def get_final_text(self) -> str:
|
||||
return await self._wrapped.get_final_text()
|
||||
|
||||
async def until_done(self) -> None:
|
||||
await self._wrapped.until_done()
|
||||
|
||||
@property
|
||||
def current_message_snapshot(self) -> Message:
|
||||
return self._wrapped.current_message_snapshot
|
||||
|
||||
class AsyncMessagesStreamManagerWrapper:
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def __aenter__(self):
|
||||
self._manager = original_stream(**self._kwargs)
|
||||
stream = await self._manager.__aenter__()
|
||||
return AsyncMessageStreamWrapper(stream, **self._kwargs)
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self._manager.__aexit__(*exc)
|
||||
|
||||
return AsyncMessagesStreamManagerWrapper
|
||||
else:
|
||||
|
||||
class MessageStreamWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
wrapped: MessageStream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._wrapped = wrapped
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def response(self) -> Any:
|
||||
return self._wrapped.response
|
||||
|
||||
@property
|
||||
def request_id(self) -> str | None:
|
||||
return self._wrapped.request_id # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def text_stream(self):
|
||||
@configured_traceable_text
|
||||
def _text_stream(**_):
|
||||
yield from self._wrapped.text_stream
|
||||
run_tree = run_helpers.get_current_run_tree()
|
||||
final_message = self._wrapped.get_final_message()
|
||||
outputs = _message_to_outputs(final_message)
|
||||
run_tree.outputs = outputs
|
||||
if usage := outputs.get("usage_metadata"):
|
||||
run_tree.metadata["usage_metadata"] = usage
|
||||
|
||||
return _text_stream(**self._kwargs)
|
||||
|
||||
def __next__(self) -> MessageStreamEvent:
|
||||
return self.__iter__().__next__()
|
||||
|
||||
def __iter__(self):
|
||||
@configured_traceable
|
||||
def traced_iter(**_):
|
||||
return self._wrapped.__iter__()
|
||||
|
||||
return traced_iter(**self._kwargs)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._wrapped.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._wrapped.__exit__(*exc)
|
||||
|
||||
def close(self) -> None:
|
||||
self._wrapped.close()
|
||||
|
||||
def get_final_message(self) -> Message:
|
||||
return self._wrapped.get_final_message()
|
||||
|
||||
def get_final_text(self) -> str:
|
||||
return self._wrapped.get_final_text()
|
||||
|
||||
def until_done(self) -> None:
|
||||
return self._wrapped.until_done()
|
||||
|
||||
@property
|
||||
def current_message_snapshot(self) -> Message:
|
||||
return self._wrapped.current_message_snapshot
|
||||
|
||||
class MessagesStreamManagerWrapper:
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __enter__(self):
|
||||
self._manager = original_stream(**self._kwargs)
|
||||
return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self._manager.__exit__(*exc)
|
||||
|
||||
return MessagesStreamManagerWrapper
|
||||
|
||||
|
||||
class TracingExtra(TypedDict, total=False):
|
||||
metadata: Optional[Mapping[str, Any]]
|
||||
tags: Optional[list[str]]
|
||||
client: Optional[ls_client.Client]
|
||||
|
||||
|
||||
def wrap_anthropic(
|
||||
client: C,
|
||||
*,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
chat_name: str = "ChatAnthropic",
|
||||
completions_name: str = "Anthropic",
|
||||
) -> C:
|
||||
"""Patch the Anthropic client to make it traceable.
|
||||
|
||||
Args:
|
||||
client: The client to patch.
|
||||
tracing_extra: Extra tracing information.
|
||||
chat_name: The run name for the messages endpoint.
|
||||
completions_name: The run name for the completions endpoint.
|
||||
|
||||
Returns:
|
||||
The patched client.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import anthropic
|
||||
from langsmith import wrappers
|
||||
|
||||
client = wrappers.wrap_anthropic(anthropic.Anthropic())
|
||||
|
||||
# Use Anthropic client same as you normally would:
|
||||
system = "You are a helpful assistant."
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What physics breakthroughs do you predict will happen by 2300?",
|
||||
}
|
||||
]
|
||||
completion = client.messages.create(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
)
|
||||
print(completion.content)
|
||||
|
||||
# With raw response to access headers:
|
||||
raw_response = client.messages.with_raw_response.create(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
)
|
||||
print(raw_response.headers) # Access HTTP headers
|
||||
message = raw_response.parse() # Get parsed response
|
||||
|
||||
# You can also use the streaming context manager:
|
||||
with client.messages.stream(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end="", flush=True)
|
||||
message = stream.get_final_message()
|
||||
```
|
||||
""" # noqa: E501
|
||||
tracing_extra = tracing_extra or {}
|
||||
|
||||
# Extract ls_invocation_params from metadata
|
||||
metadata = dict(tracing_extra.get("metadata") or {})
|
||||
prepopulated_invocation_params = metadata.pop("ls_invocation_params", {})
|
||||
|
||||
# Create new tracing_extra without ls_invocation_params in metadata
|
||||
tracing_extra_rest: TracingExtra = { # type: ignore[assignment]
|
||||
k: v for k, v in tracing_extra.items() if k != "metadata"
|
||||
}
|
||||
if metadata:
|
||||
tracing_extra_rest["metadata"] = metadata # type: ignore[typeddict-item]
|
||||
|
||||
client.messages.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.messages.create,
|
||||
chat_name,
|
||||
_reduce_chat_chunks,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
|
||||
client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
|
||||
client.messages.stream,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
client.completions.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.completions.create,
|
||||
completions_name,
|
||||
_reduce_completions,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(client, "beta")
|
||||
and hasattr(client.beta, "messages")
|
||||
and hasattr(client.beta.messages, "create")
|
||||
):
|
||||
client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.beta.messages.create, # type: ignore
|
||||
chat_name,
|
||||
_reduce_chat_chunks,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
return client
|
||||
689
venv/Lib/site-packages/langsmith/wrappers/_gemini.py
Normal file
689
venv/Lib/site-packages/langsmith/wrappers/_gemini.py
Normal file
@@ -0,0 +1,689 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langsmith import client as ls_client
|
||||
from langsmith import run_helpers
|
||||
from langsmith._internal._beta_decorator import warn_beta
|
||||
from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai # type: ignore[import-untyped, attr-defined]
|
||||
|
||||
C = TypeVar("C", bound=Union["genai.Client", Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strip_none(d: dict) -> dict:
|
||||
"""Remove `None` values from dictionary."""
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
def _convert_config_for_tracing(kwargs: dict) -> None:
|
||||
"""Convert `GenerateContentConfig` to `dict` for LangSmith compatibility."""
|
||||
if "config" in kwargs and not isinstance(kwargs["config"], dict):
|
||||
kwargs["config"] = vars(kwargs["config"])
|
||||
|
||||
|
||||
def _process_gemini_inputs(inputs: dict) -> dict:
|
||||
r"""Process Gemini inputs to normalize them for LangSmith tracing.
|
||||
|
||||
Example:
|
||||
```txt
|
||||
{"contents": "Hello", "model": "gemini-pro"}
|
||||
→ {"messages": [{"role": "user", "content": "Hello"}], "model": "gemini-pro"}
|
||||
{"contents": [{"role": "user", "parts": [{"text": "What is AI?"}]}], "model": "gemini-pro"}
|
||||
→ {"messages": [{"role": "user", "content": "What is AI?"}], "model": "gemini-pro"}
|
||||
```
|
||||
""" # noqa: E501
|
||||
# If contents is not present or not in list format, return as-is
|
||||
contents = inputs.get("contents")
|
||||
if not contents:
|
||||
return inputs
|
||||
|
||||
# Handle string input (simple case)
|
||||
if isinstance(contents, str):
|
||||
return {
|
||||
"messages": [{"role": "user", "content": contents}],
|
||||
"model": inputs.get("model"),
|
||||
**({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
|
||||
}
|
||||
|
||||
# Handle list of content objects (multimodal case)
|
||||
if isinstance(contents, list):
|
||||
# Check if it's a simple list of strings
|
||||
if all(isinstance(item, str) for item in contents):
|
||||
# Each string becomes a separate user message (matches Gemini's behavior)
|
||||
return {
|
||||
"messages": [{"role": "user", "content": item} for item in contents],
|
||||
"model": inputs.get("model"),
|
||||
**({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
|
||||
}
|
||||
# Handle complex multimodal case
|
||||
messages = []
|
||||
for content in contents:
|
||||
if isinstance(content, dict):
|
||||
role = content.get("role", "user")
|
||||
parts = content.get("parts", [])
|
||||
|
||||
# Extract text and other parts
|
||||
text_parts = []
|
||||
content_parts = []
|
||||
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
# Handle text parts
|
||||
if "text" in part and part["text"]:
|
||||
text_parts.append(part["text"])
|
||||
content_parts.append({"type": "text", "text": part["text"]})
|
||||
# Handle inline data (images)
|
||||
elif "inline_data" in part:
|
||||
inline_data = part["inline_data"]
|
||||
mime_type = inline_data.get("mime_type", "image/jpeg")
|
||||
data = inline_data.get("data", b"")
|
||||
|
||||
# Convert bytes to base64 string if needed
|
||||
if isinstance(data, bytes):
|
||||
data_b64 = base64.b64encode(data).decode("utf-8")
|
||||
else:
|
||||
data_b64 = data # Already a string
|
||||
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{data_b64}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Handle function responses
|
||||
elif "functionResponse" in part:
|
||||
function_response = part["functionResponse"]
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "function_response",
|
||||
"function_response": {
|
||||
"name": function_response.get("name"),
|
||||
"response": function_response.get(
|
||||
"response", {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
# Handle function calls (for conversation history)
|
||||
elif "function_call" in part or "functionCall" in part:
|
||||
function_call = part.get("function_call") or part.get(
|
||||
"functionCall"
|
||||
)
|
||||
|
||||
if function_call is not None:
|
||||
# Normalize to dict (FunctionCall is a Pydantic model)
|
||||
if not isinstance(function_call, dict):
|
||||
function_call = function_call.to_dict()
|
||||
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"function_call": {
|
||||
"id": function_call.get("id"),
|
||||
"name": function_call.get("name"),
|
||||
"arguments": function_call.get("args", {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(part, str):
|
||||
# Handle simple string parts
|
||||
text_parts.append(part)
|
||||
content_parts.append({"type": "text", "text": part})
|
||||
|
||||
# If only text parts, use simple string format
|
||||
if content_parts and all(
|
||||
p.get("type") == "text" for p in content_parts
|
||||
):
|
||||
message_content: Union[str, list[dict[str, Any]]] = "\n".join(
|
||||
text_parts
|
||||
)
|
||||
else:
|
||||
message_content = content_parts if content_parts else ""
|
||||
|
||||
messages.append({"role": role, "content": message_content})
|
||||
return {
|
||||
"messages": messages,
|
||||
"model": inputs.get("model"),
|
||||
**({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
|
||||
}
|
||||
|
||||
# Fallback: return original inputs
|
||||
return inputs
|
||||
|
||||
|
||||
def _infer_invocation_params(
|
||||
prepopulated_invocation_params: dict, kwargs: dict
|
||||
) -> dict:
|
||||
"""Extract invocation parameters for tracing."""
|
||||
stripped = _strip_none(kwargs)
|
||||
config = stripped.get("config", {})
|
||||
|
||||
# Handle both dict config and GenerateContentConfig object
|
||||
if hasattr(config, "temperature"):
|
||||
temperature = config.temperature
|
||||
max_tokens = getattr(config, "max_output_tokens", None)
|
||||
stop = getattr(config, "stop_sequences", None)
|
||||
else:
|
||||
temperature = config.get("temperature")
|
||||
max_tokens = config.get("max_output_tokens")
|
||||
stop = config.get("stop_sequences")
|
||||
|
||||
return {
|
||||
"ls_provider": "google",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": stripped.get("model"),
|
||||
"ls_temperature": temperature,
|
||||
"ls_max_tokens": max_tokens,
|
||||
"ls_stop": stop,
|
||||
"ls_invocation_params": prepopulated_invocation_params,
|
||||
}
|
||||
|
||||
|
||||
def _create_usage_metadata(gemini_usage_metadata: dict) -> UsageMetadata:
|
||||
"""Convert Gemini usage metadata to LangSmith format."""
|
||||
prompt_token_count = gemini_usage_metadata.get("prompt_token_count") or 0
|
||||
candidates_token_count = gemini_usage_metadata.get("candidates_token_count") or 0
|
||||
cached_content_token_count = (
|
||||
gemini_usage_metadata.get("cached_content_token_count") or 0
|
||||
)
|
||||
thoughts_token_count = gemini_usage_metadata.get("thoughts_token_count") or 0
|
||||
total_token_count = (
|
||||
gemini_usage_metadata.get("total_token_count")
|
||||
or prompt_token_count + candidates_token_count
|
||||
)
|
||||
|
||||
input_token_details: dict = {}
|
||||
if cached_content_token_count:
|
||||
input_token_details["cache_read"] = cached_content_token_count
|
||||
input_token_details["cache_read_over_200k"] = max(
|
||||
0, cached_content_token_count - 200000
|
||||
)
|
||||
input_token_details["over_200k"] = max(0, prompt_token_count - 200000)
|
||||
|
||||
output_token_details: dict = {}
|
||||
if thoughts_token_count:
|
||||
output_token_details["reasoning"] = thoughts_token_count
|
||||
|
||||
if candidates_token_count:
|
||||
output_token_details["over_200k"] = max(0, candidates_token_count - 200000)
|
||||
|
||||
return UsageMetadata(
|
||||
input_tokens=prompt_token_count,
|
||||
output_tokens=candidates_token_count,
|
||||
total_tokens=total_token_count,
|
||||
input_token_details=InputTokenDetails(
|
||||
**{k: v for k, v in input_token_details.items() if v is not None}
|
||||
),
|
||||
output_token_details=OutputTokenDetails(
|
||||
**{k: v for k, v in output_token_details.items() if v is not None}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _process_generate_content_response(response: Any) -> dict:
|
||||
"""Process Gemini response for tracing."""
|
||||
try:
|
||||
# Convert response to dictionary
|
||||
if hasattr(response, "to_dict"):
|
||||
rdict = response.to_dict()
|
||||
elif hasattr(response, "model_dump"):
|
||||
rdict = response.model_dump()
|
||||
else:
|
||||
rdict = {"text": getattr(response, "text", str(response))}
|
||||
|
||||
# Extract content from candidates if available
|
||||
content_result = ""
|
||||
content_parts = []
|
||||
finish_reason: Optional[str] = None
|
||||
if "candidates" in rdict and rdict["candidates"]:
|
||||
candidate = rdict["candidates"][0]
|
||||
if "content" in candidate:
|
||||
content = candidate["content"]
|
||||
if "parts" in content and content["parts"]:
|
||||
for part in content["parts"]:
|
||||
# Handle text parts
|
||||
if "text" in part and part["text"]:
|
||||
content_result += part["text"]
|
||||
content_parts.append({"type": "text", "text": part["text"]})
|
||||
# Handle inline data (images) in response
|
||||
elif "inline_data" in part and part["inline_data"] is not None:
|
||||
inline_data = part["inline_data"]
|
||||
mime_type = inline_data.get("mime_type", "image/jpeg")
|
||||
data = inline_data.get("data", b"")
|
||||
|
||||
# Convert bytes to base64 string if needed
|
||||
if isinstance(data, bytes):
|
||||
data_b64 = base64.b64encode(data).decode("utf-8")
|
||||
else:
|
||||
data_b64 = data # Already a string
|
||||
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{data_b64}",
|
||||
"detail": "high",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Handle function calls in response
|
||||
elif "function_call" in part or "functionCall" in part:
|
||||
function_call = part.get("function_call") or part.get(
|
||||
"functionCall"
|
||||
)
|
||||
|
||||
if function_call is not None:
|
||||
# Normalize to dict (FunctionCall is a Pydantic model)
|
||||
if not isinstance(function_call, dict):
|
||||
function_call = function_call.to_dict()
|
||||
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"function_call": {
|
||||
"id": function_call.get("id"),
|
||||
"name": function_call.get("name"),
|
||||
"arguments": function_call.get("args", {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
if "finish_reason" in candidate and candidate["finish_reason"]:
|
||||
finish_reason = candidate["finish_reason"]
|
||||
elif "text" in rdict:
|
||||
content_result = rdict["text"]
|
||||
content_parts.append({"type": "text", "text": content_result})
|
||||
|
||||
# Build chat-like response format - use OpenAI-compatible format for tool calls
|
||||
tool_calls = [p for p in content_parts if p.get("type") == "function_call"]
|
||||
if tool_calls:
|
||||
# OpenAI-compatible format for LangSmith UI
|
||||
result = {
|
||||
"content": content_result or None,
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc["function_call"].get("id") or f"call_{i}",
|
||||
"type": "function",
|
||||
"index": i,
|
||||
"function": {
|
||||
"name": tc["function_call"]["name"],
|
||||
"arguments": json.dumps(tc["function_call"]["arguments"]),
|
||||
},
|
||||
}
|
||||
for i, tc in enumerate(tool_calls)
|
||||
],
|
||||
}
|
||||
elif len(content_parts) > 1 or (
|
||||
content_parts and content_parts[0]["type"] != "text"
|
||||
):
|
||||
# Use structured format for mixed non-tool content
|
||||
result = {
|
||||
"content": content_parts,
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
else:
|
||||
# Use simple string format for text-only responses
|
||||
result = {
|
||||
"content": content_result,
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
# Extract and convert usage metadata
|
||||
usage_metadata = rdict.get("usage_metadata")
|
||||
usage_dict: UsageMetadata = UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
if usage_metadata:
|
||||
usage_dict = _create_usage_metadata(usage_metadata)
|
||||
|
||||
# Return in a format that avoids stringification by LangSmith
|
||||
if result.get("tool_calls"):
|
||||
# For responses with tool calls, return structured format
|
||||
return {
|
||||
"content": result["content"],
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
"tool_calls": result["tool_calls"],
|
||||
"usage_metadata": usage_dict,
|
||||
}
|
||||
else:
|
||||
# For simple text responses, return minimal structure with usage metadata
|
||||
if isinstance(result["content"], str):
|
||||
return {
|
||||
"content": result["content"],
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
"usage_metadata": usage_dict,
|
||||
}
|
||||
else:
|
||||
# For multimodal content, return structured format with usage metadata
|
||||
return {
|
||||
"content": result["content"],
|
||||
"role": "assistant",
|
||||
"finish_reason": finish_reason,
|
||||
"usage_metadata": usage_dict,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing Gemini response: {e}")
|
||||
return {"output": response}
|
||||
|
||||
|
||||
def _reduce_generate_content_chunks(all_chunks: list) -> dict:
|
||||
"""Reduce streaming chunks into a single response."""
|
||||
if not all_chunks:
|
||||
return {
|
||||
"content": "",
|
||||
"usage_metadata": UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
),
|
||||
}
|
||||
|
||||
# Accumulate text from all chunks
|
||||
full_text = ""
|
||||
last_chunk = None
|
||||
|
||||
for chunk in all_chunks:
|
||||
try:
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_text += chunk.text
|
||||
last_chunk = chunk
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing chunk: {e}")
|
||||
|
||||
# Extract usage metadata from the last chunk
|
||||
usage_metadata: UsageMetadata = UsageMetadata(
|
||||
input_tokens=0, output_tokens=0, total_tokens=0
|
||||
)
|
||||
if last_chunk:
|
||||
try:
|
||||
if hasattr(last_chunk, "usage_metadata") and last_chunk.usage_metadata:
|
||||
if hasattr(last_chunk.usage_metadata, "to_dict"):
|
||||
usage_dict = last_chunk.usage_metadata.to_dict()
|
||||
elif hasattr(last_chunk.usage_metadata, "model_dump"):
|
||||
usage_dict = last_chunk.usage_metadata.model_dump()
|
||||
else:
|
||||
usage_dict = {
|
||||
"prompt_token_count": getattr(
|
||||
last_chunk.usage_metadata, "prompt_token_count", 0
|
||||
),
|
||||
"candidates_token_count": getattr(
|
||||
last_chunk.usage_metadata, "candidates_token_count", 0
|
||||
),
|
||||
"cached_content_token_count": getattr(
|
||||
last_chunk.usage_metadata, "cached_content_token_count", 0
|
||||
),
|
||||
"thoughts_token_count": getattr(
|
||||
last_chunk.usage_metadata, "thoughts_token_count", 0
|
||||
),
|
||||
"total_token_count": getattr(
|
||||
last_chunk.usage_metadata, "total_token_count", 0
|
||||
),
|
||||
}
|
||||
# Add usage_metadata to both run.extra AND outputs
|
||||
usage_metadata = _create_usage_metadata(usage_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting metadata from last chunk: {e}")
|
||||
|
||||
# Return minimal structure with usage_metadata in outputs
|
||||
return {
|
||||
"content": full_text,
|
||||
"usage_metadata": usage_metadata,
|
||||
}
|
||||
|
||||
|
||||
def _get_wrapper(
|
||||
original_generate: Callable,
|
||||
name: str,
|
||||
prepopulated_invocation_params: dict,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
is_streaming: bool = False,
|
||||
) -> Callable:
|
||||
"""Create a wrapper for Gemini's `generate_content` methods."""
|
||||
textra = tracing_extra or {}
|
||||
|
||||
@functools.wraps(original_generate)
|
||||
def generate(*args, **kwargs):
|
||||
# Handle config object before tracing setup
|
||||
_convert_config_for_tracing(kwargs)
|
||||
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
|
||||
process_inputs=_process_gemini_inputs,
|
||||
process_outputs=(
|
||||
_process_generate_content_response if not is_streaming else None
|
||||
),
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params, prepopulated_invocation_params
|
||||
),
|
||||
**textra,
|
||||
)
|
||||
|
||||
return decorator(original_generate)(*args, **kwargs)
|
||||
|
||||
@functools.wraps(original_generate)
|
||||
async def agenerate(*args, **kwargs):
|
||||
# Handle config object before tracing setup
|
||||
_convert_config_for_tracing(kwargs)
|
||||
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
|
||||
process_inputs=_process_gemini_inputs,
|
||||
process_outputs=(
|
||||
_process_generate_content_response if not is_streaming else None
|
||||
),
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params, prepopulated_invocation_params
|
||||
),
|
||||
**textra,
|
||||
)
|
||||
|
||||
return await decorator(original_generate)(*args, **kwargs)
|
||||
|
||||
return agenerate if run_helpers.is_async(original_generate) else generate
|
||||
|
||||
|
||||
class TracingExtra(TypedDict, total=False):
|
||||
metadata: Optional[Mapping[str, Any]]
|
||||
tags: Optional[list[str]]
|
||||
client: Optional[ls_client.Client]
|
||||
|
||||
|
||||
@warn_beta
|
||||
def wrap_gemini(
|
||||
client: C,
|
||||
*,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
chat_name: str = "ChatGoogleGenerativeAI",
|
||||
) -> C:
|
||||
"""Patch the Google Gen AI client to make it traceable.
|
||||
|
||||
!!! warning
|
||||
|
||||
**BETA**: This wrapper is in beta.
|
||||
|
||||
Supports:
|
||||
- `generate_content` and `generate_content_stream` methods
|
||||
- Sync and async clients
|
||||
- Streaming and non-streaming responses
|
||||
- Tool/function calling with proper UI rendering
|
||||
- Multimodal inputs (text + images)
|
||||
- Image generation with `inline_data` support
|
||||
- Token usage tracking including reasoning tokens
|
||||
|
||||
Args:
|
||||
client: The Google Gen AI client to patch.
|
||||
tracing_extra: Extra tracing information.
|
||||
chat_name: The run name for the chat endpoint.
|
||||
|
||||
Returns:
|
||||
The patched client.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from langsmith import wrappers
|
||||
|
||||
# Use Google Gen AI client same as you normally would.
|
||||
client = wrappers.wrap_gemini(genai.Client(api_key="your-api-key"))
|
||||
|
||||
# Basic text generation:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents="Why is the sky blue?",
|
||||
)
|
||||
print(response.text)
|
||||
|
||||
# Streaming:
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model="gemini-2.5-flash",
|
||||
contents="Tell me a story",
|
||||
):
|
||||
print(chunk.text, end="")
|
||||
|
||||
# Tool/Function calling:
|
||||
schedule_meeting_function = {
|
||||
"name": "schedule_meeting",
|
||||
"description": "Schedules a meeting with specified attendees.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"attendees": {"type": "array", "items": {"type": "string"}},
|
||||
"date": {"type": "string"},
|
||||
"time": {"type": "string"},
|
||||
"topic": {"type": "string"},
|
||||
},
|
||||
"required": ["attendees", "date", "time", "topic"],
|
||||
},
|
||||
}
|
||||
|
||||
tools = types.Tool(function_declarations=[schedule_meeting_function])
|
||||
config = types.GenerateContentConfig(tools=[tools])
|
||||
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents="Schedule a meeting with Bob and Alice tomorrow at 2 PM.",
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Image generation:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash-image",
|
||||
contents=["Create a picture of a futuristic city"],
|
||||
)
|
||||
|
||||
# Save generated image
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.inline_data is not None:
|
||||
image = Image.open(BytesIO(part.inline_data.data))
|
||||
image.save("generated_image.png")
|
||||
```
|
||||
|
||||
!!! version-added "Added in `langsmith` 0.4.33"
|
||||
|
||||
Initial beta release of Google Gemini wrapper.
|
||||
|
||||
"""
|
||||
tracing_extra = tracing_extra or {}
|
||||
|
||||
# Extract ls_invocation_params from metadata
|
||||
metadata = dict(tracing_extra.get("metadata") or {})
|
||||
prepopulated_invocation_params = metadata.pop("ls_invocation_params", {})
|
||||
|
||||
# Create new tracing_extra without ls_invocation_params in metadata
|
||||
tracing_extra_rest: TracingExtra = { # type: ignore[assignment]
|
||||
k: v for k, v in tracing_extra.items() if k != "metadata"
|
||||
}
|
||||
if metadata:
|
||||
tracing_extra_rest["metadata"] = metadata # type: ignore[typeddict-item]
|
||||
|
||||
# Check if already wrapped to prevent double-wrapping
|
||||
if (
|
||||
hasattr(client, "models")
|
||||
and hasattr(client.models, "generate_content")
|
||||
and hasattr(client.models.generate_content, "__wrapped__")
|
||||
):
|
||||
raise ValueError(
|
||||
"This Google Gen AI client has already been wrapped. "
|
||||
"Wrapping a client multiple times is not supported."
|
||||
)
|
||||
|
||||
# Wrap synchronous methods
|
||||
if hasattr(client, "models") and hasattr(client.models, "generate_content"):
|
||||
client.models.generate_content = _get_wrapper( # type: ignore[method-assign]
|
||||
client.models.generate_content,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
is_streaming=False,
|
||||
)
|
||||
|
||||
if hasattr(client, "models") and hasattr(client.models, "generate_content_stream"):
|
||||
client.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
|
||||
client.models.generate_content_stream,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
is_streaming=True,
|
||||
)
|
||||
|
||||
# Wrap async methods (aio namespace)
|
||||
if (
|
||||
hasattr(client, "aio")
|
||||
and hasattr(client.aio, "models")
|
||||
and hasattr(client.aio.models, "generate_content")
|
||||
):
|
||||
client.aio.models.generate_content = _get_wrapper( # type: ignore[method-assign]
|
||||
client.aio.models.generate_content,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
is_streaming=False,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(client, "aio")
|
||||
and hasattr(client.aio, "models")
|
||||
and hasattr(client.aio.models, "generate_content_stream")
|
||||
):
|
||||
client.aio.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
|
||||
client.aio.models.generate_content_stream,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
is_streaming=True,
|
||||
)
|
||||
|
||||
return client
|
||||
648
venv/Lib/site-packages/langsmith/wrappers/_openai.py
Normal file
648
venv/Lib/site-packages/langsmith/wrappers/_openai.py
Normal file
@@ -0,0 +1,648 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langsmith import client as ls_client
|
||||
from langsmith import run_helpers
|
||||
from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk,
|
||||
Choice,
|
||||
ChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.completion import Completion
|
||||
from openai.types.responses import ResponseStreamEvent # type: ignore
|
||||
|
||||
# Any is used since it may work with Azure or other providers
|
||||
C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_omit_types() -> tuple[type, ...]:
|
||||
"""Get NotGiven/Omit sentinel types used by OpenAI SDK."""
|
||||
types: list[type[Any]] = []
|
||||
try:
|
||||
from openai._types import NotGiven, Omit
|
||||
|
||||
types.append(NotGiven)
|
||||
types.append(Omit)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return tuple(types)
|
||||
|
||||
|
||||
def _strip_not_given(d: dict) -> dict:
|
||||
try:
|
||||
omit_types = _get_omit_types()
|
||||
if not omit_types:
|
||||
return d
|
||||
return {
|
||||
k: v
|
||||
for k, v in d.items()
|
||||
if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error stripping NotGiven: {e}")
|
||||
return d
|
||||
|
||||
|
||||
def _process_inputs(d: dict) -> dict:
|
||||
"""Strip `NotGiven` values and serialize `text_format` to JSON schema."""
|
||||
d = _strip_not_given(d)
|
||||
|
||||
# Convert text_format (Pydantic model) to JSON schema if present
|
||||
if "text_format" in d:
|
||||
text_format = d["text_format"]
|
||||
if hasattr(text_format, "model_json_schema"):
|
||||
try:
|
||||
return {
|
||||
**d,
|
||||
"text_format": text_format.model_json_schema(),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return d
|
||||
|
||||
|
||||
def _infer_invocation_params(
|
||||
model_type: str,
|
||||
provider: str,
|
||||
prepopulated_invocation_params: dict,
|
||||
use_responses_api: bool,
|
||||
kwargs: dict,
|
||||
):
|
||||
stripped = _strip_not_given(kwargs)
|
||||
|
||||
stop = stripped.get("stop")
|
||||
if stop and isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
# Allowlist of safe invocation parameters to include
|
||||
# Only include known, non-sensitive parameters
|
||||
allowed_invocation_keys = {
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"modalities",
|
||||
"parallel_tool_calls",
|
||||
"prediction",
|
||||
"presence_penalty",
|
||||
"prompt_cache_key",
|
||||
"reasoning",
|
||||
"reasoning_effort",
|
||||
"response_format",
|
||||
"seed",
|
||||
"service_tier",
|
||||
"stream_options",
|
||||
"top_logprobs",
|
||||
"top_p",
|
||||
"truncation",
|
||||
"user",
|
||||
"verbosity",
|
||||
"web_search_options",
|
||||
}
|
||||
|
||||
# Only include allowlisted parameters
|
||||
invocation_params = {
|
||||
k: v for k, v in stripped.items() if k in allowed_invocation_keys
|
||||
}
|
||||
|
||||
if use_responses_api:
|
||||
invocation_params["use_responses_api"] = True
|
||||
|
||||
return {
|
||||
"ls_provider": provider,
|
||||
"ls_model_type": model_type,
|
||||
"ls_model_name": stripped.get("model"),
|
||||
"ls_temperature": stripped.get("temperature"),
|
||||
"ls_max_tokens": stripped.get("max_tokens")
|
||||
or stripped.get("max_completion_tokens")
|
||||
or stripped.get("max_output_tokens"),
|
||||
"ls_stop": stop,
|
||||
"ls_invocation_params": {
|
||||
**prepopulated_invocation_params,
|
||||
**invocation_params,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _reduce_choices(choices: list[Choice]) -> dict:
|
||||
reversed_choices = list(reversed(choices))
|
||||
message: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
for c in reversed_choices:
|
||||
if hasattr(c, "delta") and getattr(c.delta, "role", None):
|
||||
message["role"] = c.delta.role
|
||||
break
|
||||
tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list)
|
||||
for c in choices:
|
||||
if hasattr(c, "delta"):
|
||||
if getattr(c.delta, "content", None):
|
||||
message["content"] += c.delta.content
|
||||
if getattr(c.delta, "function_call", None):
|
||||
if not message.get("function_call"):
|
||||
message["function_call"] = {"name": "", "arguments": ""}
|
||||
name_ = getattr(c.delta.function_call, "name", None)
|
||||
if name_:
|
||||
message["function_call"]["name"] += name_
|
||||
arguments_ = getattr(c.delta.function_call, "arguments", None)
|
||||
if arguments_:
|
||||
message["function_call"]["arguments"] += arguments_
|
||||
if getattr(c.delta, "tool_calls", None):
|
||||
tool_calls_list = c.delta.tool_calls
|
||||
if tool_calls_list is not None:
|
||||
for tool_call in tool_calls_list:
|
||||
tool_calls[tool_call.index].append(tool_call)
|
||||
if tool_calls:
|
||||
message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)]
|
||||
for index, tool_call_chunks in tool_calls.items():
|
||||
message["tool_calls"][index] = {
|
||||
"index": index,
|
||||
"id": next((c.id for c in tool_call_chunks if c.id), None),
|
||||
"type": next((c.type for c in tool_call_chunks if c.type), None),
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
for chunk in tool_call_chunks:
|
||||
if getattr(chunk, "function", None):
|
||||
name_ = getattr(chunk.function, "name", None)
|
||||
if name_:
|
||||
message["tool_calls"][index]["function"]["name"] += name_
|
||||
arguments_ = getattr(chunk.function, "arguments", None)
|
||||
if arguments_:
|
||||
message["tool_calls"][index]["function"]["arguments"] += (
|
||||
arguments_
|
||||
)
|
||||
return {
|
||||
"index": getattr(choices[0], "index", 0) if choices else 0,
|
||||
"finish_reason": next(
|
||||
(
|
||||
c.finish_reason
|
||||
for c in reversed_choices
|
||||
if getattr(c, "finish_reason", None)
|
||||
),
|
||||
None,
|
||||
),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict:
|
||||
choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list)
|
||||
for chunk in all_chunks:
|
||||
for choice in chunk.choices:
|
||||
choices_by_index[choice.index].append(choice)
|
||||
if all_chunks:
|
||||
d = all_chunks[-1].model_dump()
|
||||
d["choices"] = [
|
||||
_reduce_choices(choices) for choices in choices_by_index.values()
|
||||
]
|
||||
else:
|
||||
d = {"choices": [{"message": {"role": "assistant", "content": ""}}]}
|
||||
# streamed outputs don't go through `process_outputs`
|
||||
# so we need to flatten metadata here
|
||||
oai_token_usage = d.pop("usage", None)
|
||||
d["usage_metadata"] = (
|
||||
_create_usage_metadata(oai_token_usage) if oai_token_usage else None
|
||||
)
|
||||
return d
|
||||
|
||||
|
||||
def _reduce_completions(all_chunks: list[Completion]) -> dict:
|
||||
all_content = []
|
||||
for chunk in all_chunks:
|
||||
content = chunk.choices[0].text
|
||||
if content is not None:
|
||||
all_content.append(content)
|
||||
content = "".join(all_content)
|
||||
if all_chunks:
|
||||
d = all_chunks[-1].model_dump()
|
||||
d["choices"] = [{"text": content}]
|
||||
else:
|
||||
d = {"choices": [{"text": content}]}
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def _create_usage_metadata(
|
||||
oai_token_usage: dict, service_tier: Optional[str] = None
|
||||
) -> UsageMetadata:
|
||||
recognized_service_tier = (
|
||||
service_tier if service_tier in ["priority", "flex"] else None
|
||||
)
|
||||
service_tier_prefix = (
|
||||
f"{recognized_service_tier}_" if recognized_service_tier else ""
|
||||
)
|
||||
|
||||
input_tokens = (
|
||||
oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0
|
||||
)
|
||||
output_tokens = (
|
||||
oai_token_usage.get("completion_tokens")
|
||||
or oai_token_usage.get("output_tokens")
|
||||
or 0
|
||||
)
|
||||
total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
|
||||
input_token_details: dict = {
|
||||
"audio": (
|
||||
oai_token_usage.get("prompt_tokens_details")
|
||||
or oai_token_usage.get("input_tokens_details")
|
||||
or {}
|
||||
).get("audio_tokens"),
|
||||
f"{service_tier_prefix}cache_read": (
|
||||
oai_token_usage.get("prompt_tokens_details")
|
||||
or oai_token_usage.get("input_tokens_details")
|
||||
or {}
|
||||
).get("cached_tokens"),
|
||||
}
|
||||
output_token_details: dict = {
|
||||
"audio": (
|
||||
oai_token_usage.get("completion_tokens_details")
|
||||
or oai_token_usage.get("output_tokens_details")
|
||||
or {}
|
||||
).get("audio_tokens"),
|
||||
f"{service_tier_prefix}reasoning": (
|
||||
oai_token_usage.get("completion_tokens_details")
|
||||
or oai_token_usage.get("output_tokens_details")
|
||||
or {}
|
||||
).get("reasoning_tokens"),
|
||||
}
|
||||
|
||||
if recognized_service_tier:
|
||||
# Avoid counting cache read and reasoning tokens towards the
|
||||
# service tier token count since service tier tokens are already
|
||||
# priced differently
|
||||
input_token_details[recognized_service_tier] = input_tokens - (
|
||||
input_token_details.get(f"{service_tier_prefix}cache_read") or 0
|
||||
)
|
||||
output_token_details[recognized_service_tier] = output_tokens - (
|
||||
output_token_details.get(f"{service_tier_prefix}reasoning") or 0
|
||||
)
|
||||
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
input_token_details=InputTokenDetails(
|
||||
**{k: v for k, v in input_token_details.items() if v is not None}
|
||||
),
|
||||
output_token_details=OutputTokenDetails(
|
||||
**{k: v for k, v in output_token_details.items() if v is not None}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _process_chat_completion(outputs: Any):
|
||||
try:
|
||||
# Check if outputs is an APIResponse wrapper (from with_raw_response).
|
||||
# The OpenAI SDK's APIResponse wraps the actual response object.
|
||||
# Call .parse() to extract the ChatCompletion/Completion for tracing.
|
||||
# See: github.com/openai/openai-python/blob/main/src/openai/_response.py#L285
|
||||
if hasattr(outputs, "parse") and callable(outputs.parse):
|
||||
try:
|
||||
outputs = outputs.parse()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
rdict = outputs.model_dump()
|
||||
oai_token_usage = rdict.pop("usage", None)
|
||||
rdict["usage_metadata"] = (
|
||||
_create_usage_metadata(oai_token_usage, rdict.get("service_tier"))
|
||||
if oai_token_usage
|
||||
else None
|
||||
)
|
||||
return rdict
|
||||
except BaseException as e:
|
||||
logger.debug(f"Error processing chat completion: {e}")
|
||||
return {"output": outputs}
|
||||
|
||||
|
||||
def _get_wrapper(
|
||||
original_create: Callable,
|
||||
name: str,
|
||||
reduce_fn: Callable,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
invocation_params_fn: Optional[Callable] = None,
|
||||
process_outputs: Optional[Callable] = None,
|
||||
) -> Callable:
|
||||
textra = tracing_extra or {}
|
||||
|
||||
@functools.wraps(original_create)
|
||||
def create(*args, **kwargs):
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
|
||||
process_inputs=_process_inputs,
|
||||
_invocation_params_fn=invocation_params_fn,
|
||||
process_outputs=process_outputs,
|
||||
**textra,
|
||||
)
|
||||
|
||||
return decorator(original_create)(*args, **kwargs)
|
||||
|
||||
@functools.wraps(original_create)
|
||||
async def acreate(*args, **kwargs):
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
|
||||
process_inputs=_process_inputs,
|
||||
_invocation_params_fn=invocation_params_fn,
|
||||
process_outputs=process_outputs,
|
||||
**textra,
|
||||
)
|
||||
return await decorator(original_create)(*args, **kwargs)
|
||||
|
||||
return acreate if run_helpers.is_async(original_create) else create
|
||||
|
||||
|
||||
def _get_parse_wrapper(
|
||||
original_parse: Callable,
|
||||
name: str,
|
||||
process_outputs: Callable,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
invocation_params_fn: Optional[Callable] = None,
|
||||
) -> Callable:
|
||||
textra = tracing_extra or {}
|
||||
|
||||
@functools.wraps(original_parse)
|
||||
def parse(*args, **kwargs):
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=None,
|
||||
process_inputs=_process_inputs,
|
||||
_invocation_params_fn=invocation_params_fn,
|
||||
process_outputs=process_outputs,
|
||||
**textra,
|
||||
)
|
||||
return decorator(original_parse)(*args, **kwargs)
|
||||
|
||||
@functools.wraps(original_parse)
|
||||
async def aparse(*args, **kwargs):
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=None,
|
||||
process_inputs=_process_inputs,
|
||||
_invocation_params_fn=invocation_params_fn,
|
||||
process_outputs=process_outputs,
|
||||
**textra,
|
||||
)
|
||||
return await decorator(original_parse)(*args, **kwargs)
|
||||
|
||||
return aparse if run_helpers.is_async(original_parse) else parse
|
||||
|
||||
|
||||
def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict:
|
||||
for event in events:
|
||||
if event.type == "response.completed":
|
||||
return _process_responses_api_output(event.response)
|
||||
return {}
|
||||
|
||||
|
||||
class TracingExtra(TypedDict, total=False):
|
||||
metadata: Optional[Mapping[str, Any]]
|
||||
tags: Optional[list[str]]
|
||||
client: Optional[ls_client.Client]
|
||||
|
||||
|
||||
def wrap_openai(
|
||||
client: C,
|
||||
*,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
chat_name: str = "ChatOpenAI",
|
||||
completions_name: str = "OpenAI",
|
||||
) -> C:
|
||||
"""Patch the OpenAI client to make it traceable.
|
||||
|
||||
Supports:
|
||||
- Chat and Responses API's
|
||||
- Sync and async OpenAI clients
|
||||
- `create` and `parse` methods
|
||||
- With and without streaming
|
||||
- `with_raw_response` API for accessing HTTP headers
|
||||
|
||||
Args:
|
||||
client: The client to patch.
|
||||
tracing_extra: Extra tracing information.
|
||||
chat_name: The run name for the chat completions endpoint.
|
||||
completions_name: The run name for the completions endpoint.
|
||||
|
||||
Returns:
|
||||
The patched client.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import openai
|
||||
from langsmith import wrappers
|
||||
|
||||
# Use OpenAI client same as you normally would.
|
||||
client = wrappers.wrap_openai(openai.OpenAI())
|
||||
|
||||
# Chat API:
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What physics breakthroughs do you predict will happen by 2300?",
|
||||
},
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
model="gpt-4o-mini", messages=messages
|
||||
)
|
||||
print(completion.choices[0].message.content)
|
||||
|
||||
# Responses API:
|
||||
response = client.responses.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=messages,
|
||||
)
|
||||
print(response.output_text)
|
||||
|
||||
# With raw response to access headers:
|
||||
raw_response = client.chat.completions.with_raw_response.create(
|
||||
model="gpt-4o-mini", messages=messages
|
||||
)
|
||||
print(raw_response.headers) # Access HTTP headers
|
||||
completion = raw_response.parse() # Get parsed response
|
||||
```
|
||||
|
||||
!!! warning "Behavior changed in `langsmith` 0.3.16"
|
||||
|
||||
Support for Responses API added.
|
||||
|
||||
!!! warning "Behavior changed in `langsmith` 0.3.x"
|
||||
|
||||
Support for `with_raw_response` API added.
|
||||
""" # noqa: E501
|
||||
tracing_extra = tracing_extra or {}
|
||||
|
||||
# Extract ls_invocation_params from metadata
|
||||
metadata = dict(tracing_extra.get("metadata") or {})
|
||||
prepopulated_invocation_params = metadata.pop("ls_invocation_params", {})
|
||||
|
||||
# Create new tracing_extra without ls_invocation_params in metadata
|
||||
tracing_extra_rest: TracingExtra = { # type: ignore[assignment]
|
||||
k: v for k, v in tracing_extra.items() if k != "metadata"
|
||||
}
|
||||
if metadata:
|
||||
tracing_extra_rest["metadata"] = metadata # type: ignore[typeddict-item]
|
||||
|
||||
ls_provider = "openai"
|
||||
try:
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI):
|
||||
ls_provider = "azure"
|
||||
chat_name = "AzureChatOpenAI"
|
||||
completions_name = "AzureOpenAI"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# First wrap the create methods - these handle non-streaming cases
|
||||
client.chat.completions.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.chat.completions.create,
|
||||
chat_name,
|
||||
_reduce_chat,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"chat",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
False,
|
||||
),
|
||||
process_outputs=_process_chat_completion,
|
||||
)
|
||||
|
||||
client.completions.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.completions.create,
|
||||
completions_name,
|
||||
_reduce_completions,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"llm",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
# Wrap beta.chat.completions.parse if it exists
|
||||
if (
|
||||
hasattr(client, "beta")
|
||||
and hasattr(client.beta, "chat")
|
||||
and hasattr(client.beta.chat, "completions")
|
||||
and hasattr(client.beta.chat.completions, "parse")
|
||||
):
|
||||
client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
|
||||
client.beta.chat.completions.parse, # type: ignore
|
||||
chat_name,
|
||||
_process_chat_completion,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"chat",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
# Wrap chat.completions.parse if it exists
|
||||
if (
|
||||
hasattr(client, "chat")
|
||||
and hasattr(client.chat, "completions")
|
||||
and hasattr(client.chat.completions, "parse")
|
||||
):
|
||||
client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
|
||||
client.chat.completions.parse, # type: ignore
|
||||
chat_name,
|
||||
_process_chat_completion,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"chat",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
# For the responses API: "client.responses.create(**kwargs)"
|
||||
if hasattr(client, "responses"):
|
||||
if hasattr(client.responses, "create"):
|
||||
client.responses.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.responses.create,
|
||||
chat_name,
|
||||
_reduce_response_events,
|
||||
process_outputs=_process_responses_api_output,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"chat",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
True,
|
||||
),
|
||||
)
|
||||
if hasattr(client.responses, "parse"):
|
||||
client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign]
|
||||
client.responses.parse,
|
||||
chat_name,
|
||||
_process_responses_api_output,
|
||||
tracing_extra=tracing_extra_rest,
|
||||
invocation_params_fn=functools.partial(
|
||||
_infer_invocation_params,
|
||||
"chat",
|
||||
ls_provider,
|
||||
prepopulated_invocation_params,
|
||||
True,
|
||||
),
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def _process_responses_api_output(response: Any) -> dict:
|
||||
if response:
|
||||
try:
|
||||
# Unwrap APIResponse from with_raw_response for tracing
|
||||
if hasattr(response, "parse") and callable(response.parse):
|
||||
try:
|
||||
response = response.parse()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output = response.model_dump(exclude_none=True, mode="json")
|
||||
if usage := output.pop("usage", None):
|
||||
output["usage_metadata"] = _create_usage_metadata(
|
||||
usage, output.get("service_tier")
|
||||
)
|
||||
return output
|
||||
except Exception:
|
||||
return {"output": response}
|
||||
return {}
|
||||
19
venv/Lib/site-packages/langsmith/wrappers/_openai_agents.py
Normal file
19
venv/Lib/site-packages/langsmith/wrappers/_openai_agents.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Tombstone module for backward compatibility.
|
||||
|
||||
This module has been moved to `langsmith.integrations.openai_agents`.
|
||||
|
||||
Imports from this location are deprecated but will continue to work.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from langsmith.integrations.openai_agents_sdk import OpenAIAgentsTracingProcessor
|
||||
|
||||
warnings.warn(
|
||||
"langsmith.wrappers._openai_agents is deprecated and has been moved to "
|
||||
"langsmith.integrations.openai_agents_sdk. Please update your imports.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
__all__ = ["OpenAIAgentsTracingProcessor"]
|
||||
Reference in New Issue
Block a user