initial commit
This commit is contained in:
577
venv/Lib/site-packages/langsmith/wrappers/_anthropic.py
Normal file
577
venv/Lib/site-packages/langsmith/wrappers/_anthropic.py
Normal file
@@ -0,0 +1,577 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Mapping, Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langsmith import client as ls_client
|
||||
from langsmith import run_helpers
|
||||
from langsmith.schemas import InputTokenDetails, UsageMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.lib.streaming import AsyncMessageStream, MessageStream
|
||||
from anthropic.types import Completion, Message, MessageStreamEvent
|
||||
|
||||
C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_not_given() -> Optional[tuple[type, ...]]:
|
||||
try:
|
||||
from anthropic._types import NotGiven, Omit
|
||||
|
||||
return (NotGiven, Omit)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _strip_not_given(d: dict) -> dict:
|
||||
try:
|
||||
if not_given := _get_not_given():
|
||||
d = {
|
||||
k: v
|
||||
for k, v in d.items()
|
||||
if not any(isinstance(v, t) for t in not_given)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error stripping NotGiven: {e}")
|
||||
|
||||
if "system" in d:
|
||||
d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
|
||||
"messages", []
|
||||
)
|
||||
d.pop("system")
|
||||
return {k: v for k, v in d.items() if v is not None}
|
||||
|
||||
|
||||
def _infer_ls_params(prepopulated_invocation_params: dict, kwargs: dict):
|
||||
stripped = _strip_not_given(kwargs)
|
||||
|
||||
stop = stripped.get("stop")
|
||||
if stop and isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
# Allowlist of safe invocation parameters to include
|
||||
# Only include known, non-sensitive parameters
|
||||
allowed_invocation_keys = {
|
||||
"mcp_servers",
|
||||
"service_tier",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"stream",
|
||||
"thinking",
|
||||
}
|
||||
|
||||
# Only include allowlisted parameters
|
||||
invocation_params = {
|
||||
k: v for k, v in stripped.items() if k in allowed_invocation_keys
|
||||
}
|
||||
|
||||
return {
|
||||
"ls_provider": "anthropic",
|
||||
"ls_model_type": "chat",
|
||||
"ls_model_name": stripped.get("model", None),
|
||||
"ls_temperature": stripped.get("temperature", None),
|
||||
"ls_max_tokens": stripped.get("max_tokens", None),
|
||||
"ls_stop": stop,
|
||||
"ls_invocation_params": {
|
||||
**prepopulated_invocation_params,
|
||||
**invocation_params,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _accumulate_event(
|
||||
*, event: MessageStreamEvent, current_snapshot: Message | None
|
||||
) -> Message | None:
|
||||
try:
|
||||
from anthropic.types import ContentBlock
|
||||
except ImportError:
|
||||
logger.debug("Error importing ContentBlock")
|
||||
return current_snapshot
|
||||
|
||||
if current_snapshot is None:
|
||||
if event.type == "message_start":
|
||||
return event.message
|
||||
|
||||
raise RuntimeError(
|
||||
f'Unexpected event order, got {event.type} before "message_start"'
|
||||
)
|
||||
|
||||
if event.type == "content_block_start":
|
||||
# TODO: check index <-- from anthropic SDK :)
|
||||
adapter: TypeAdapter = TypeAdapter(ContentBlock)
|
||||
content_block_instance = adapter.validate_python(
|
||||
event.content_block.model_dump()
|
||||
)
|
||||
current_snapshot.content.append(
|
||||
content_block_instance, # type: ignore[attr-defined]
|
||||
)
|
||||
elif event.type == "content_block_delta":
|
||||
content = current_snapshot.content[event.index]
|
||||
if content.type == "text" and event.delta.type == "text_delta":
|
||||
content.text += event.delta.text
|
||||
elif event.type == "message_delta":
|
||||
current_snapshot.stop_reason = event.delta.stop_reason
|
||||
current_snapshot.stop_sequence = event.delta.stop_sequence
|
||||
current_snapshot.usage.output_tokens = event.usage.output_tokens
|
||||
|
||||
return current_snapshot
|
||||
|
||||
|
||||
def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
|
||||
input_tokens = anthropic_token_usage.get("input_tokens") or 0
|
||||
output_tokens = anthropic_token_usage.get("output_tokens") or 0
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
input_token_details: dict = {}
|
||||
cache_read = anthropic_token_usage.get("cache_read_input_tokens") or 0
|
||||
if cache_read:
|
||||
input_token_details["cache_read"] = cache_read
|
||||
|
||||
cache_creation_obj = anthropic_token_usage.get("cache_creation") or {}
|
||||
if cache_creation_obj:
|
||||
ephemeral_5m = cache_creation_obj.get("ephemeral_5m_input_tokens") or 0
|
||||
ephemeral_1h = cache_creation_obj.get("ephemeral_1h_input_tokens") or 0
|
||||
if ephemeral_5m:
|
||||
input_token_details["ephemeral_5m_input_tokens"] = ephemeral_5m
|
||||
if ephemeral_1h:
|
||||
input_token_details["ephemeral_1h_input_tokens"] = ephemeral_1h
|
||||
else:
|
||||
cache_creation = anthropic_token_usage.get("cache_creation_input_tokens") or 0
|
||||
if cache_creation:
|
||||
input_token_details["cache_creation"] = cache_creation
|
||||
|
||||
result = UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
if input_token_details:
|
||||
result["input_token_details"] = InputTokenDetails(**input_token_details)
|
||||
return result
|
||||
|
||||
|
||||
def _message_to_outputs(message: Any) -> dict:
|
||||
"""Convert an Anthropic Message to a flat outputs dict with usage_metadata."""
|
||||
rdict = message.model_dump()
|
||||
anthropic_token_usage = rdict.pop("usage", None)
|
||||
if anthropic_token_usage:
|
||||
rdict["usage_metadata"] = _create_usage_metadata(anthropic_token_usage)
|
||||
rdict.pop("type", None)
|
||||
return rdict
|
||||
|
||||
|
||||
def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
|
||||
full_message = None
|
||||
for chunk in all_chunks:
|
||||
try:
|
||||
full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
|
||||
except RuntimeError as e:
|
||||
logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
|
||||
return {"output": all_chunks}
|
||||
if full_message is None:
|
||||
return {"output": all_chunks}
|
||||
return _message_to_outputs(full_message)
|
||||
|
||||
|
||||
def _reduce_completions(all_chunks: list[Completion]) -> dict:
|
||||
all_content = []
|
||||
for chunk in all_chunks:
|
||||
content = chunk.completion
|
||||
if content is not None:
|
||||
all_content.append(content)
|
||||
content = "".join(all_content)
|
||||
if all_chunks:
|
||||
d = all_chunks[-1].model_dump()
|
||||
d["choices"] = [{"text": content}]
|
||||
else:
|
||||
d = {"choices": [{"text": content}]}
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def _process_chat_completion(outputs: Any):
|
||||
try:
|
||||
# Check if outputs is a LegacyAPIResponse wrapper (from with_raw_response).
|
||||
# The Anthropic SDK's LegacyAPIResponse wraps the actual response object.
|
||||
# Call .parse() to extract the Message for tracing.
|
||||
# See: anthropics/anthropic-sdk-python _legacy_response.py#L102
|
||||
if hasattr(outputs, "parse") and callable(outputs.parse):
|
||||
try:
|
||||
outputs = outputs.parse()
|
||||
except Exception:
|
||||
pass
|
||||
return _message_to_outputs(outputs)
|
||||
except BaseException as e:
|
||||
logger.debug(f"Error processing chat completion: {e}")
|
||||
return {"output": outputs}
|
||||
|
||||
|
||||
def _get_wrapper(
|
||||
original_create: Callable,
|
||||
name: str,
|
||||
reduce_fn: Callable,
|
||||
prepopulated_invocation_params: dict,
|
||||
tracing_extra: TracingExtra,
|
||||
) -> Callable:
|
||||
@functools.wraps(original_create)
|
||||
def create(*args, **kwargs):
|
||||
stream = kwargs.get("stream")
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if stream else None,
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
|
||||
result = decorator(original_create)(*args, **kwargs)
|
||||
return result
|
||||
|
||||
@functools.wraps(original_create)
|
||||
async def acreate(*args, **kwargs):
|
||||
stream = kwargs.get("stream")
|
||||
decorator = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
reduce_fn=reduce_fn if stream else None,
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
result = await decorator(original_create)(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return acreate if run_helpers.is_async(original_create) else create
|
||||
|
||||
|
||||
def _get_stream_wrapper(
|
||||
original_stream: Callable,
|
||||
name: str,
|
||||
prepopulated_invocation_params: dict,
|
||||
tracing_extra: TracingExtra,
|
||||
) -> Callable:
|
||||
"""Create a wrapper for Anthropic's streaming context manager."""
|
||||
is_async = "async" in str(original_stream).lower()
|
||||
configured_traceable = run_helpers.traceable(
|
||||
name=name,
|
||||
reduce_fn=_reduce_chat_chunks,
|
||||
run_type="llm",
|
||||
process_inputs=_strip_not_given,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
configured_traceable_text = run_helpers.traceable(
|
||||
name=name,
|
||||
run_type="llm",
|
||||
process_inputs=_strip_not_given,
|
||||
process_outputs=_process_chat_completion,
|
||||
_invocation_params_fn=functools.partial(
|
||||
_infer_ls_params, prepopulated_invocation_params
|
||||
),
|
||||
**tracing_extra,
|
||||
)
|
||||
|
||||
if is_async:
|
||||
|
||||
class AsyncMessageStreamWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
wrapped: AsyncMessageStream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._wrapped = wrapped
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def text_stream(self):
|
||||
@configured_traceable_text
|
||||
async def _text_stream(**_):
|
||||
async for chunk in self._wrapped.text_stream:
|
||||
yield chunk
|
||||
run_tree = run_helpers.get_current_run_tree()
|
||||
final_message = await self._wrapped.get_final_message()
|
||||
outputs = _message_to_outputs(final_message)
|
||||
run_tree.outputs = outputs
|
||||
if usage := outputs.get("usage_metadata"):
|
||||
run_tree.metadata["usage_metadata"] = usage
|
||||
|
||||
return _text_stream(**self._kwargs)
|
||||
|
||||
@property
|
||||
def response(self) -> httpx.Response:
|
||||
return self._wrapped.response
|
||||
|
||||
@property
|
||||
def request_id(self) -> str | None:
|
||||
return self._wrapped.request_id
|
||||
|
||||
async def __anext__(self) -> MessageStreamEvent:
|
||||
aiter = self.__aiter__()
|
||||
return await aiter.__anext__()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
|
||||
@configured_traceable
|
||||
def traced_iter(**_):
|
||||
return self._wrapped.__aiter__()
|
||||
|
||||
async for chunk in traced_iter(**self._kwargs):
|
||||
yield chunk
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self._wrapped.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
await self._wrapped.__aexit__(*exc)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._wrapped.close()
|
||||
|
||||
async def get_final_message(self) -> Message:
|
||||
return await self._wrapped.get_final_message()
|
||||
|
||||
async def get_final_text(self) -> str:
|
||||
return await self._wrapped.get_final_text()
|
||||
|
||||
async def until_done(self) -> None:
|
||||
await self._wrapped.until_done()
|
||||
|
||||
@property
|
||||
def current_message_snapshot(self) -> Message:
|
||||
return self._wrapped.current_message_snapshot
|
||||
|
||||
class AsyncMessagesStreamManagerWrapper:
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def __aenter__(self):
|
||||
self._manager = original_stream(**self._kwargs)
|
||||
stream = await self._manager.__aenter__()
|
||||
return AsyncMessageStreamWrapper(stream, **self._kwargs)
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self._manager.__aexit__(*exc)
|
||||
|
||||
return AsyncMessagesStreamManagerWrapper
|
||||
else:
|
||||
|
||||
class MessageStreamWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
wrapped: MessageStream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._wrapped = wrapped
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def response(self) -> Any:
|
||||
return self._wrapped.response
|
||||
|
||||
@property
|
||||
def request_id(self) -> str | None:
|
||||
return self._wrapped.request_id # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def text_stream(self):
|
||||
@configured_traceable_text
|
||||
def _text_stream(**_):
|
||||
yield from self._wrapped.text_stream
|
||||
run_tree = run_helpers.get_current_run_tree()
|
||||
final_message = self._wrapped.get_final_message()
|
||||
outputs = _message_to_outputs(final_message)
|
||||
run_tree.outputs = outputs
|
||||
if usage := outputs.get("usage_metadata"):
|
||||
run_tree.metadata["usage_metadata"] = usage
|
||||
|
||||
return _text_stream(**self._kwargs)
|
||||
|
||||
def __next__(self) -> MessageStreamEvent:
|
||||
return self.__iter__().__next__()
|
||||
|
||||
def __iter__(self):
|
||||
@configured_traceable
|
||||
def traced_iter(**_):
|
||||
return self._wrapped.__iter__()
|
||||
|
||||
return traced_iter(**self._kwargs)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._wrapped.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._wrapped.__exit__(*exc)
|
||||
|
||||
def close(self) -> None:
|
||||
self._wrapped.close()
|
||||
|
||||
def get_final_message(self) -> Message:
|
||||
return self._wrapped.get_final_message()
|
||||
|
||||
def get_final_text(self) -> str:
|
||||
return self._wrapped.get_final_text()
|
||||
|
||||
def until_done(self) -> None:
|
||||
return self._wrapped.until_done()
|
||||
|
||||
@property
|
||||
def current_message_snapshot(self) -> Message:
|
||||
return self._wrapped.current_message_snapshot
|
||||
|
||||
class MessagesStreamManagerWrapper:
|
||||
def __init__(self, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __enter__(self):
|
||||
self._manager = original_stream(**self._kwargs)
|
||||
return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self._manager.__exit__(*exc)
|
||||
|
||||
return MessagesStreamManagerWrapper
|
||||
|
||||
|
||||
class TracingExtra(TypedDict, total=False):
|
||||
metadata: Optional[Mapping[str, Any]]
|
||||
tags: Optional[list[str]]
|
||||
client: Optional[ls_client.Client]
|
||||
|
||||
|
||||
def wrap_anthropic(
|
||||
client: C,
|
||||
*,
|
||||
tracing_extra: Optional[TracingExtra] = None,
|
||||
chat_name: str = "ChatAnthropic",
|
||||
completions_name: str = "Anthropic",
|
||||
) -> C:
|
||||
"""Patch the Anthropic client to make it traceable.
|
||||
|
||||
Args:
|
||||
client: The client to patch.
|
||||
tracing_extra: Extra tracing information.
|
||||
chat_name: The run name for the messages endpoint.
|
||||
completions_name: The run name for the completions endpoint.
|
||||
|
||||
Returns:
|
||||
The patched client.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import anthropic
|
||||
from langsmith import wrappers
|
||||
|
||||
client = wrappers.wrap_anthropic(anthropic.Anthropic())
|
||||
|
||||
# Use Anthropic client same as you normally would:
|
||||
system = "You are a helpful assistant."
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What physics breakthroughs do you predict will happen by 2300?",
|
||||
}
|
||||
]
|
||||
completion = client.messages.create(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
)
|
||||
print(completion.content)
|
||||
|
||||
# With raw response to access headers:
|
||||
raw_response = client.messages.with_raw_response.create(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
)
|
||||
print(raw_response.headers) # Access HTTP headers
|
||||
message = raw_response.parse() # Get parsed response
|
||||
|
||||
# You can also use the streaming context manager:
|
||||
with client.messages.stream(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
system=system,
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end="", flush=True)
|
||||
message = stream.get_final_message()
|
||||
```
|
||||
""" # noqa: E501
|
||||
tracing_extra = tracing_extra or {}
|
||||
|
||||
# Extract ls_invocation_params from metadata
|
||||
metadata = dict(tracing_extra.get("metadata") or {})
|
||||
prepopulated_invocation_params = metadata.pop("ls_invocation_params", {})
|
||||
|
||||
# Create new tracing_extra without ls_invocation_params in metadata
|
||||
tracing_extra_rest: TracingExtra = { # type: ignore[assignment]
|
||||
k: v for k, v in tracing_extra.items() if k != "metadata"
|
||||
}
|
||||
if metadata:
|
||||
tracing_extra_rest["metadata"] = metadata # type: ignore[typeddict-item]
|
||||
|
||||
client.messages.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.messages.create,
|
||||
chat_name,
|
||||
_reduce_chat_chunks,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
|
||||
client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
|
||||
client.messages.stream,
|
||||
chat_name,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
client.completions.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.completions.create,
|
||||
completions_name,
|
||||
_reduce_completions,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(client, "beta")
|
||||
and hasattr(client.beta, "messages")
|
||||
and hasattr(client.beta.messages, "create")
|
||||
):
|
||||
client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
|
||||
client.beta.messages.create, # type: ignore
|
||||
chat_name,
|
||||
_reduce_chat_chunks,
|
||||
prepopulated_invocation_params,
|
||||
tracing_extra_rest,
|
||||
)
|
||||
return client
|
||||
Reference in New Issue
Block a user