from __future__ import annotations import functools import logging from collections.abc import AsyncIterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, ) from pydantic import TypeAdapter from typing_extensions import Self, TypedDict from langsmith import client as ls_client from langsmith import run_helpers from langsmith.schemas import InputTokenDetails, UsageMetadata if TYPE_CHECKING: import httpx from anthropic import Anthropic, AsyncAnthropic from anthropic.lib.streaming import AsyncMessageStream, MessageStream from anthropic.types import Completion, Message, MessageStreamEvent C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any]) logger = logging.getLogger(__name__) @functools.lru_cache def _get_not_given() -> Optional[tuple[type, ...]]: try: from anthropic._types import NotGiven, Omit return (NotGiven, Omit) except ImportError: return None def _strip_not_given(d: dict) -> dict: try: if not_given := _get_not_given(): d = { k: v for k, v in d.items() if not any(isinstance(v, t) for t in not_given) } except Exception as e: logger.error(f"Error stripping NotGiven: {e}") if "system" in d: d["messages"] = [{"role": "system", "content": d["system"]}] + d.get( "messages", [] ) d.pop("system") return {k: v for k, v in d.items() if v is not None} def _infer_ls_params(prepopulated_invocation_params: dict, kwargs: dict): stripped = _strip_not_given(kwargs) stop = stripped.get("stop") if stop and isinstance(stop, str): stop = [stop] # Allowlist of safe invocation parameters to include # Only include known, non-sensitive parameters allowed_invocation_keys = { "mcp_servers", "service_tier", "top_k", "top_p", "stream", "thinking", } # Only include allowlisted parameters invocation_params = { k: v for k, v in stripped.items() if k in allowed_invocation_keys } return { "ls_provider": "anthropic", "ls_model_type": "chat", "ls_model_name": stripped.get("model", None), "ls_temperature": stripped.get("temperature", None), "ls_max_tokens": stripped.get("max_tokens", None), "ls_stop": stop, "ls_invocation_params": { **prepopulated_invocation_params, **invocation_params, }, } def _accumulate_event( *, event: MessageStreamEvent, current_snapshot: Message | None ) -> Message | None: try: from anthropic.types import ContentBlock except ImportError: logger.debug("Error importing ContentBlock") return current_snapshot if current_snapshot is None: if event.type == "message_start": return event.message raise RuntimeError( f'Unexpected event order, got {event.type} before "message_start"' ) if event.type == "content_block_start": # TODO: check index <-- from anthropic SDK :) adapter: TypeAdapter = TypeAdapter(ContentBlock) content_block_instance = adapter.validate_python( event.content_block.model_dump() ) current_snapshot.content.append( content_block_instance, # type: ignore[attr-defined] ) elif event.type == "content_block_delta": content = current_snapshot.content[event.index] if content.type == "text" and event.delta.type == "text_delta": content.text += event.delta.text elif event.type == "message_delta": current_snapshot.stop_reason = event.delta.stop_reason current_snapshot.stop_sequence = event.delta.stop_sequence current_snapshot.usage.output_tokens = event.usage.output_tokens return current_snapshot def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata: input_tokens = anthropic_token_usage.get("input_tokens") or 0 output_tokens = anthropic_token_usage.get("output_tokens") or 0 total_tokens = input_tokens + output_tokens input_token_details: dict = {} cache_read = anthropic_token_usage.get("cache_read_input_tokens") or 0 if cache_read: input_token_details["cache_read"] = cache_read cache_creation_obj = anthropic_token_usage.get("cache_creation") or {} if cache_creation_obj: ephemeral_5m = cache_creation_obj.get("ephemeral_5m_input_tokens") or 0 ephemeral_1h = cache_creation_obj.get("ephemeral_1h_input_tokens") or 0 if ephemeral_5m: input_token_details["ephemeral_5m_input_tokens"] = ephemeral_5m if ephemeral_1h: input_token_details["ephemeral_1h_input_tokens"] = ephemeral_1h else: cache_creation = anthropic_token_usage.get("cache_creation_input_tokens") or 0 if cache_creation: input_token_details["cache_creation"] = cache_creation result = UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, ) if input_token_details: result["input_token_details"] = InputTokenDetails(**input_token_details) return result def _message_to_outputs(message: Any) -> dict: """Convert an Anthropic Message to a flat outputs dict with usage_metadata.""" rdict = message.model_dump() anthropic_token_usage = rdict.pop("usage", None) if anthropic_token_usage: rdict["usage_metadata"] = _create_usage_metadata(anthropic_token_usage) rdict.pop("type", None) return rdict def _reduce_chat_chunks(all_chunks: Sequence) -> dict: full_message = None for chunk in all_chunks: try: full_message = _accumulate_event(event=chunk, current_snapshot=full_message) except RuntimeError as e: logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}") return {"output": all_chunks} if full_message is None: return {"output": all_chunks} return _message_to_outputs(full_message) def _reduce_completions(all_chunks: list[Completion]) -> dict: all_content = [] for chunk in all_chunks: content = chunk.completion if content is not None: all_content.append(content) content = "".join(all_content) if all_chunks: d = all_chunks[-1].model_dump() d["choices"] = [{"text": content}] else: d = {"choices": [{"text": content}]} return d def _process_chat_completion(outputs: Any): try: # Check if outputs is a LegacyAPIResponse wrapper (from with_raw_response). # The Anthropic SDK's LegacyAPIResponse wraps the actual response object. # Call .parse() to extract the Message for tracing. # See: anthropics/anthropic-sdk-python _legacy_response.py#L102 if hasattr(outputs, "parse") and callable(outputs.parse): try: outputs = outputs.parse() except Exception: pass return _message_to_outputs(outputs) except BaseException as e: logger.debug(f"Error processing chat completion: {e}") return {"output": outputs} def _get_wrapper( original_create: Callable, name: str, reduce_fn: Callable, prepopulated_invocation_params: dict, tracing_extra: TracingExtra, ) -> Callable: @functools.wraps(original_create) def create(*args, **kwargs): stream = kwargs.get("stream") decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=reduce_fn if stream else None, process_inputs=_strip_not_given, process_outputs=_process_chat_completion, _invocation_params_fn=functools.partial( _infer_ls_params, prepopulated_invocation_params ), **tracing_extra, ) result = decorator(original_create)(*args, **kwargs) return result @functools.wraps(original_create) async def acreate(*args, **kwargs): stream = kwargs.get("stream") decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=reduce_fn if stream else None, process_inputs=_strip_not_given, process_outputs=_process_chat_completion, _invocation_params_fn=functools.partial( _infer_ls_params, prepopulated_invocation_params ), **tracing_extra, ) result = await decorator(original_create)(*args, **kwargs) return result return acreate if run_helpers.is_async(original_create) else create def _get_stream_wrapper( original_stream: Callable, name: str, prepopulated_invocation_params: dict, tracing_extra: TracingExtra, ) -> Callable: """Create a wrapper for Anthropic's streaming context manager.""" is_async = "async" in str(original_stream).lower() configured_traceable = run_helpers.traceable( name=name, reduce_fn=_reduce_chat_chunks, run_type="llm", process_inputs=_strip_not_given, _invocation_params_fn=functools.partial( _infer_ls_params, prepopulated_invocation_params ), **tracing_extra, ) configured_traceable_text = run_helpers.traceable( name=name, run_type="llm", process_inputs=_strip_not_given, process_outputs=_process_chat_completion, _invocation_params_fn=functools.partial( _infer_ls_params, prepopulated_invocation_params ), **tracing_extra, ) if is_async: class AsyncMessageStreamWrapper: def __init__( self, wrapped: AsyncMessageStream, **kwargs, ) -> None: self._wrapped = wrapped self._kwargs = kwargs @property def text_stream(self): @configured_traceable_text async def _text_stream(**_): async for chunk in self._wrapped.text_stream: yield chunk run_tree = run_helpers.get_current_run_tree() final_message = await self._wrapped.get_final_message() outputs = _message_to_outputs(final_message) run_tree.outputs = outputs if usage := outputs.get("usage_metadata"): run_tree.metadata["usage_metadata"] = usage return _text_stream(**self._kwargs) @property def response(self) -> httpx.Response: return self._wrapped.response @property def request_id(self) -> str | None: return self._wrapped.request_id async def __anext__(self) -> MessageStreamEvent: aiter = self.__aiter__() return await aiter.__anext__() async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]: @configured_traceable def traced_iter(**_): return self._wrapped.__aiter__() async for chunk in traced_iter(**self._kwargs): yield chunk async def __aenter__(self) -> Self: await self._wrapped.__aenter__() return self async def __aexit__(self, *exc) -> None: await self._wrapped.__aexit__(*exc) async def close(self) -> None: await self._wrapped.close() async def get_final_message(self) -> Message: return await self._wrapped.get_final_message() async def get_final_text(self) -> str: return await self._wrapped.get_final_text() async def until_done(self) -> None: await self._wrapped.until_done() @property def current_message_snapshot(self) -> Message: return self._wrapped.current_message_snapshot class AsyncMessagesStreamManagerWrapper: def __init__(self, **kwargs): self._kwargs = kwargs async def __aenter__(self): self._manager = original_stream(**self._kwargs) stream = await self._manager.__aenter__() return AsyncMessageStreamWrapper(stream, **self._kwargs) async def __aexit__(self, *exc): await self._manager.__aexit__(*exc) return AsyncMessagesStreamManagerWrapper else: class MessageStreamWrapper: def __init__( self, wrapped: MessageStream, **kwargs, ) -> None: self._wrapped = wrapped self._kwargs = kwargs @property def response(self) -> Any: return self._wrapped.response @property def request_id(self) -> str | None: return self._wrapped.request_id # type: ignore[no-any-return] @property def text_stream(self): @configured_traceable_text def _text_stream(**_): yield from self._wrapped.text_stream run_tree = run_helpers.get_current_run_tree() final_message = self._wrapped.get_final_message() outputs = _message_to_outputs(final_message) run_tree.outputs = outputs if usage := outputs.get("usage_metadata"): run_tree.metadata["usage_metadata"] = usage return _text_stream(**self._kwargs) def __next__(self) -> MessageStreamEvent: return self.__iter__().__next__() def __iter__(self): @configured_traceable def traced_iter(**_): return self._wrapped.__iter__() return traced_iter(**self._kwargs) def __enter__(self) -> Self: self._wrapped.__enter__() return self def __exit__(self, *exc) -> None: self._wrapped.__exit__(*exc) def close(self) -> None: self._wrapped.close() def get_final_message(self) -> Message: return self._wrapped.get_final_message() def get_final_text(self) -> str: return self._wrapped.get_final_text() def until_done(self) -> None: return self._wrapped.until_done() @property def current_message_snapshot(self) -> Message: return self._wrapped.current_message_snapshot class MessagesStreamManagerWrapper: def __init__(self, **kwargs): self._kwargs = kwargs def __enter__(self): self._manager = original_stream(**self._kwargs) return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs) def __exit__(self, *exc): self._manager.__exit__(*exc) return MessagesStreamManagerWrapper class TracingExtra(TypedDict, total=False): metadata: Optional[Mapping[str, Any]] tags: Optional[list[str]] client: Optional[ls_client.Client] def wrap_anthropic( client: C, *, tracing_extra: Optional[TracingExtra] = None, chat_name: str = "ChatAnthropic", completions_name: str = "Anthropic", ) -> C: """Patch the Anthropic client to make it traceable. Args: client: The client to patch. tracing_extra: Extra tracing information. chat_name: The run name for the messages endpoint. completions_name: The run name for the completions endpoint. Returns: The patched client. Example: ```python import anthropic from langsmith import wrappers client = wrappers.wrap_anthropic(anthropic.Anthropic()) # Use Anthropic client same as you normally would: system = "You are a helpful assistant." messages = [ { "role": "user", "content": "What physics breakthroughs do you predict will happen by 2300?", } ] completion = client.messages.create( model="claude-3-5-sonnet-latest", messages=messages, max_tokens=1000, system=system, ) print(completion.content) # With raw response to access headers: raw_response = client.messages.with_raw_response.create( model="claude-3-5-sonnet-latest", messages=messages, max_tokens=1000, system=system, ) print(raw_response.headers) # Access HTTP headers message = raw_response.parse() # Get parsed response # You can also use the streaming context manager: with client.messages.stream( model="claude-3-5-sonnet-latest", messages=messages, max_tokens=1000, system=system, ) as stream: for text in stream.text_stream: print(text, end="", flush=True) message = stream.get_final_message() ``` """ # noqa: E501 tracing_extra = tracing_extra or {} # Extract ls_invocation_params from metadata metadata = dict(tracing_extra.get("metadata") or {}) prepopulated_invocation_params = metadata.pop("ls_invocation_params", {}) # Create new tracing_extra without ls_invocation_params in metadata tracing_extra_rest: TracingExtra = { # type: ignore[assignment] k: v for k, v in tracing_extra.items() if k != "metadata" } if metadata: tracing_extra_rest["metadata"] = metadata # type: ignore[typeddict-item] client.messages.create = _get_wrapper( # type: ignore[method-assign] client.messages.create, chat_name, _reduce_chat_chunks, prepopulated_invocation_params, tracing_extra_rest, ) client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign] client.messages.stream, chat_name, prepopulated_invocation_params, tracing_extra_rest, ) client.completions.create = _get_wrapper( # type: ignore[method-assign] client.completions.create, completions_name, _reduce_completions, prepopulated_invocation_params, tracing_extra_rest, ) if ( hasattr(client, "beta") and hasattr(client.beta, "messages") and hasattr(client.beta.messages, "create") ): client.beta.messages.create = _get_wrapper( # type: ignore[method-assign] client.beta.messages.create, # type: ignore chat_name, _reduce_chat_chunks, prepopulated_invocation_params, tracing_extra_rest, ) return client