initial commit
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
"""LangSmith integration for Claude Agent SDK.
|
||||
|
||||
This module provides automatic tracing for the Claude Agent SDK by instrumenting
|
||||
`ClaudeSDKClient` and injecting hooks to trace all tool calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from ._client import instrument_claude_client
|
||||
from ._config import set_tracing_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["configure_claude_agent_sdk"]
|
||||
|
||||
|
||||
def configure_claude_agent_sdk(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
) -> bool:
|
||||
"""Enable LangSmith tracing for the Claude Agent SDK by patching entry points.
|
||||
|
||||
This function instruments the Claude Agent SDK to automatically trace:
|
||||
- Chain runs for each conversation stream (via `ClaudeSDKClient`)
|
||||
- Model runs for each assistant turn
|
||||
- All tool calls including built-in tools, external MCP tools, and SDK MCP tools
|
||||
|
||||
Tool tracing is implemented via `PreToolUse` and `PostToolUse` hooks
|
||||
|
||||
Args:
|
||||
name: Name of the root trace.
|
||||
project_name: LangSmith project to trace to.
|
||||
metadata: Metadata to associate with all traces.
|
||||
tags: Tags to associate with all traces.
|
||||
|
||||
Returns:
|
||||
`True` if configuration was successful, `False` otherwise.
|
||||
|
||||
Example:
|
||||
>>> from langsmith.integrations.claude_agent_sdk import (
|
||||
... configure_claude_agent_sdk,
|
||||
... )
|
||||
>>> configure_claude_agent_sdk(
|
||||
... project_name="my-project", tags=["production"]
|
||||
... ) # doctest: +SKIP
|
||||
>>> # Now use claude_agent_sdk as normal - tracing is automatic
|
||||
""" # noqa: E501
|
||||
try:
|
||||
import claude_agent_sdk # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
logger.warning("Claude Agent SDK not installed.")
|
||||
return False
|
||||
|
||||
if not hasattr(claude_agent_sdk, "ClaudeSDKClient"):
|
||||
logger.warning("Claude Agent SDK missing ClaudeSDKClient.")
|
||||
return False
|
||||
|
||||
set_tracing_config(
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
metadata=metadata,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
original = getattr(claude_agent_sdk, "ClaudeSDKClient", None)
|
||||
if not original:
|
||||
return False
|
||||
|
||||
wrapped = instrument_claude_client(original)
|
||||
setattr(claude_agent_sdk, "ClaudeSDKClient", wrapped)
|
||||
|
||||
for module in list(sys.modules.values()):
|
||||
try:
|
||||
if module and getattr(module, "ClaudeSDKClient", None) is original:
|
||||
setattr(module, "ClaudeSDKClient", wrapped)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return True
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,497 @@
|
||||
"""Client instrumentation for Claude Agent SDK."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterable
|
||||
from datetime import datetime, timezone
|
||||
from functools import cache
|
||||
from typing import Any, Optional
|
||||
|
||||
from langsmith.run_helpers import get_current_run_tree, trace
|
||||
|
||||
from ._hooks import (
|
||||
clear_active_tool_runs,
|
||||
post_tool_use_failure_hook,
|
||||
post_tool_use_hook,
|
||||
pre_tool_use_hook,
|
||||
)
|
||||
from ._messages import (
|
||||
build_llm_input,
|
||||
extract_usage_from_result_message,
|
||||
flatten_content_blocks,
|
||||
)
|
||||
from ._tools import clear_parent_run_tree, get_parent_run_tree, set_parent_run_tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRACE_CHAIN_NAME = "claude.conversation"
|
||||
|
||||
|
||||
@cache
|
||||
def _get_package_version(package_name: str) -> str | None:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
LLM_RUN_NAME = "claude.assistant.turn"
|
||||
|
||||
|
||||
class TurnLifecycle:
|
||||
"""Track ongoing model runs so consecutive messages are recorded correctly."""
|
||||
|
||||
def __init__(self, query_start_time: Optional[float] = None):
|
||||
self.current_run: Optional[Any] = None
|
||||
self.next_start_time: Optional[float] = query_start_time
|
||||
|
||||
def start_llm_run(
|
||||
self,
|
||||
message: Any,
|
||||
prompt: Any,
|
||||
history: list[dict[str, Any]],
|
||||
parent: Optional[Any] = None,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Begin a new model run, ending any existing one."""
|
||||
start = self.next_start_time or time.time()
|
||||
|
||||
if self.current_run:
|
||||
self.current_run.end()
|
||||
self.current_run.patch()
|
||||
|
||||
final_output, run = begin_llm_run_from_assistant_messages(
|
||||
[message], prompt, history, start_time=start, parent=parent
|
||||
)
|
||||
self.current_run = run
|
||||
self.next_start_time = None
|
||||
return final_output
|
||||
|
||||
def mark_next_start(self) -> None:
|
||||
"""Mark when the next assistant message will start."""
|
||||
self.next_start_time = time.time()
|
||||
|
||||
def add_usage(self, metrics: dict[str, Any]) -> None:
|
||||
"""Attach token usage details to the current run."""
|
||||
if not (self.current_run and metrics):
|
||||
return
|
||||
meta = self.current_run.extra.setdefault("metadata", {}).setdefault(
|
||||
"usage_metadata", {}
|
||||
)
|
||||
meta.update(metrics)
|
||||
|
||||
def close(self) -> None:
|
||||
"""End any open run gracefully."""
|
||||
if self.current_run:
|
||||
self.current_run.end()
|
||||
self.current_run.patch()
|
||||
self.current_run = None
|
||||
|
||||
|
||||
def begin_llm_run_from_assistant_messages(
|
||||
messages: list[Any],
|
||||
prompt: Any,
|
||||
history: list[dict[str, Any]],
|
||||
start_time: Optional[float] = None,
|
||||
parent: Optional[Any] = None,
|
||||
) -> tuple[Optional[dict[str, Any]], Optional[Any]]:
|
||||
"""Create a traced model run from assistant messages."""
|
||||
if not messages or type(messages[-1]).__name__ != "AssistantMessage":
|
||||
return None, None
|
||||
|
||||
last_msg = messages[-1]
|
||||
model = getattr(last_msg, "model", None)
|
||||
if parent is None:
|
||||
parent = get_parent_run_tree() or get_current_run_tree()
|
||||
if not parent:
|
||||
return None, None
|
||||
|
||||
inputs = build_llm_input(prompt, history)
|
||||
outputs = [
|
||||
{"content": flatten_content_blocks(m.content), "role": "assistant"}
|
||||
for m in messages
|
||||
if hasattr(m, "content")
|
||||
]
|
||||
|
||||
llm_run = parent.create_child(
|
||||
name=LLM_RUN_NAME,
|
||||
run_type="llm",
|
||||
inputs={"messages": inputs} if inputs else {},
|
||||
extra={"metadata": {"ls_model_name": model}} if model else {},
|
||||
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc)
|
||||
if start_time
|
||||
else None,
|
||||
)
|
||||
|
||||
try:
|
||||
llm_run.post()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post LLM run: {e}")
|
||||
|
||||
# Set outputs after posting so they are sent with end_time on the patch.
|
||||
llm_run.outputs = outputs[-1] if len(outputs) == 1 else {"content": outputs}
|
||||
|
||||
final_content = (
|
||||
{"content": flatten_content_blocks(last_msg.content), "role": "assistant"}
|
||||
if hasattr(last_msg, "content")
|
||||
else None
|
||||
)
|
||||
return final_content, llm_run
|
||||
|
||||
|
||||
def _inject_tracing_hooks(options: Any) -> None:
|
||||
"""Inject LangSmith tracing hooks into ClaudeAgentOptions."""
|
||||
if not hasattr(options, "hooks"):
|
||||
return
|
||||
|
||||
# Initialize hooks dict if not present
|
||||
if options.hooks is None:
|
||||
options.hooks = {}
|
||||
|
||||
for event in ("PreToolUse", "PostToolUse", "PostToolUseFailure"):
|
||||
if event not in options.hooks:
|
||||
options.hooks[event] = []
|
||||
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher # type: ignore[import-not-found]
|
||||
|
||||
langsmith_pre_matcher = HookMatcher(matcher=None, hooks=[pre_tool_use_hook])
|
||||
langsmith_post_matcher = HookMatcher(matcher=None, hooks=[post_tool_use_hook])
|
||||
langsmith_failure_matcher = HookMatcher(
|
||||
matcher=None, hooks=[post_tool_use_failure_hook]
|
||||
)
|
||||
|
||||
options.hooks["PreToolUse"].insert(0, langsmith_pre_matcher)
|
||||
options.hooks["PostToolUse"].insert(0, langsmith_post_matcher)
|
||||
options.hooks["PostToolUseFailure"].insert(0, langsmith_failure_matcher)
|
||||
|
||||
logger.debug("Injected LangSmith tracing hooks into ClaudeAgentOptions")
|
||||
except ImportError:
|
||||
logger.warning("Failed to import HookMatcher from claude_agent_sdk")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to inject tracing hooks: {e}")
|
||||
|
||||
|
||||
def _unwrap_streamed_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Unwrap streaming input messages for trace display."""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
formatted.append(msg)
|
||||
continue
|
||||
|
||||
if "message" in msg:
|
||||
inner = msg["message"]
|
||||
if isinstance(inner, dict):
|
||||
formatted.append(
|
||||
{
|
||||
"role": inner.get("role", "user"),
|
||||
"content": inner.get("content", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted.append(msg)
|
||||
else:
|
||||
formatted.append(msg)
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
def instrument_claude_client(original_class: Any) -> Any:
|
||||
"""Wrap `ClaudeSDKClient` to trace both `query()` and `receive_response()`."""
|
||||
if getattr(original_class, "_langsmith_instrumented", False):
|
||||
return original_class # Already wrapped, avoid double-tracing
|
||||
|
||||
class TracedClaudeSDKClient(original_class):
|
||||
_langsmith_instrumented = True
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
# Inject LangSmith tracing hooks into options before initialization
|
||||
options = kwargs.get("options") or (args[0] if args else None)
|
||||
if options:
|
||||
_inject_tracing_hooks(options)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self._prompt: Optional[str] = None
|
||||
self._start_time: Optional[float] = None
|
||||
self._streamed_input: Optional[list[dict[str, Any]]] = None
|
||||
|
||||
async def query(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Capture prompt and start time, wrapping generators if needed."""
|
||||
self._start_time = time.time()
|
||||
self._streamed_input = None
|
||||
prompt = args[0] if args else kwargs.get("prompt")
|
||||
|
||||
if prompt is None:
|
||||
pass
|
||||
elif isinstance(prompt, str):
|
||||
self._prompt = prompt
|
||||
elif isinstance(prompt, AsyncIterable):
|
||||
collector: list[dict[str, Any]] = []
|
||||
self._streamed_input = collector
|
||||
self._prompt = None
|
||||
|
||||
async def _gen_wrapper() -> AsyncGenerator[dict[str, Any], None]:
|
||||
async for msg in prompt:
|
||||
collector.append(msg)
|
||||
yield msg
|
||||
|
||||
if args:
|
||||
args = (_gen_wrapper(),) + args[1:]
|
||||
else:
|
||||
kwargs["prompt"] = _gen_wrapper()
|
||||
else:
|
||||
self._prompt = str(prompt)
|
||||
|
||||
return await super().query(*args, **kwargs)
|
||||
|
||||
def _handle_assistant_tool_uses(
|
||||
self,
|
||||
msg: Any,
|
||||
run: Any,
|
||||
subagent_sessions: dict[str, Any],
|
||||
) -> None:
|
||||
"""Process tool uses for an assistant message."""
|
||||
if not hasattr(msg, "content"):
|
||||
return
|
||||
|
||||
from ._hooks import _client_managed_runs
|
||||
|
||||
parent_tool_use_id = getattr(msg, "parent_tool_use_id", None)
|
||||
|
||||
for block in msg.content:
|
||||
if type(block).__name__ != "ToolUseBlock":
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_use_id = getattr(block, "id", None)
|
||||
tool_name = getattr(block, "name", "unknown_tool")
|
||||
tool_input = getattr(block, "input", {})
|
||||
|
||||
if not tool_use_id:
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Check if this is a Task tool (subagent)
|
||||
if tool_name == "Task" and not parent_tool_use_id:
|
||||
# Extract subagent name
|
||||
subagent_name = (
|
||||
tool_input.get("subagent_type")
|
||||
or (
|
||||
tool_input.get("description", "").split()[0]
|
||||
if tool_input.get("description")
|
||||
else None
|
||||
)
|
||||
or "unknown-agent"
|
||||
)
|
||||
|
||||
subagent_session = run.create_child(
|
||||
name=subagent_name,
|
||||
run_type="chain",
|
||||
inputs=tool_input,
|
||||
start_time=datetime.fromtimestamp(
|
||||
start_time, tz=timezone.utc
|
||||
),
|
||||
)
|
||||
subagent_session.post()
|
||||
subagent_sessions[tool_use_id] = subagent_session
|
||||
|
||||
_client_managed_runs[tool_use_id] = subagent_session
|
||||
|
||||
# Check if tool use is within a subagent
|
||||
elif parent_tool_use_id and parent_tool_use_id in subagent_sessions:
|
||||
subagent_session = subagent_sessions[parent_tool_use_id]
|
||||
# Create tool run as child of subagent
|
||||
tool_run = subagent_session.create_child(
|
||||
name=tool_name,
|
||||
run_type="tool",
|
||||
inputs={"input": tool_input} if tool_input else {},
|
||||
start_time=datetime.fromtimestamp(
|
||||
start_time,
|
||||
tz=timezone.utc,
|
||||
),
|
||||
)
|
||||
tool_run.post()
|
||||
_client_managed_runs[tool_use_id] = tool_run
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create client-managed tool run: {e}")
|
||||
|
||||
async def receive_response(self) -> AsyncGenerator[Any, None]:
|
||||
"""Intercept message stream and record chain run activity."""
|
||||
messages = super().receive_response()
|
||||
|
||||
# Capture configuration in inputs and metadata
|
||||
trace_inputs: dict[str, Any] = {}
|
||||
trace_metadata: dict[str, Any] = {
|
||||
"ls_integration": "claude-agent-sdk",
|
||||
"ls_integration_version": _get_package_version("claude_agent_sdk"),
|
||||
}
|
||||
|
||||
# Track if we need to update input from captured streaming messages
|
||||
awaiting_streamed_input = self._streamed_input is not None
|
||||
|
||||
# Add prompt to inputs (for string prompts)
|
||||
if self._prompt:
|
||||
trace_inputs["prompt"] = self._prompt
|
||||
|
||||
# Add system_prompt to inputs if available
|
||||
if hasattr(self, "options") and self.options:
|
||||
if (
|
||||
hasattr(self.options, "system_prompt")
|
||||
and self.options.system_prompt
|
||||
):
|
||||
system_prompt = self.options.system_prompt
|
||||
if isinstance(system_prompt, str):
|
||||
trace_inputs["system"] = system_prompt
|
||||
elif isinstance(system_prompt, dict):
|
||||
# Handle SystemPromptPreset format
|
||||
if system_prompt.get("type") == "preset":
|
||||
preset_text = (
|
||||
f"preset: {system_prompt.get('preset', 'claude_code')}"
|
||||
)
|
||||
if "append" in system_prompt:
|
||||
preset_text += f"\nappend: {system_prompt['append']}"
|
||||
trace_inputs["system"] = preset_text
|
||||
else:
|
||||
trace_inputs["system"] = system_prompt
|
||||
|
||||
# Add other config to metadata
|
||||
for attr in ["model", "permission_mode", "max_turns"]:
|
||||
if hasattr(self.options, attr):
|
||||
val = getattr(self.options, attr)
|
||||
if val is not None:
|
||||
trace_metadata[attr] = val
|
||||
|
||||
async with trace(
|
||||
name=TRACE_CHAIN_NAME,
|
||||
run_type="chain",
|
||||
inputs=trace_inputs,
|
||||
metadata=trace_metadata,
|
||||
) as run:
|
||||
set_parent_run_tree(run)
|
||||
tracker = TurnLifecycle(self._start_time)
|
||||
collected: list[dict[str, Any]] = []
|
||||
|
||||
# Track subagent sessions by Task tool_use_id
|
||||
subagent_sessions: dict[str, Any] = {}
|
||||
|
||||
prompt_for_llm: Any = self._prompt
|
||||
|
||||
try:
|
||||
async for msg in messages:
|
||||
if awaiting_streamed_input and self._streamed_input:
|
||||
unwrapped_messages = _unwrap_streamed_messages(
|
||||
self._streamed_input
|
||||
)
|
||||
if unwrapped_messages:
|
||||
run.inputs["messages"] = unwrapped_messages
|
||||
prompt_for_llm = self._streamed_input
|
||||
awaiting_streamed_input = False
|
||||
|
||||
msg_type = type(msg).__name__
|
||||
if msg_type == "AssistantMessage":
|
||||
# Check if this message belongs to a subagent
|
||||
parent_tool_use_id = getattr(
|
||||
msg, "parent_tool_use_id", None
|
||||
)
|
||||
llm_parent = (
|
||||
subagent_sessions.get(parent_tool_use_id)
|
||||
if parent_tool_use_id
|
||||
else None
|
||||
)
|
||||
|
||||
content = tracker.start_llm_run(
|
||||
msg, prompt_for_llm, collected, parent=llm_parent
|
||||
)
|
||||
if content:
|
||||
collected.append(content)
|
||||
|
||||
# Process tool uses in this AssistantMessage
|
||||
self._handle_assistant_tool_uses(
|
||||
msg,
|
||||
run,
|
||||
subagent_sessions,
|
||||
)
|
||||
elif msg_type == "UserMessage":
|
||||
if hasattr(msg, "content"):
|
||||
# Check if this is a tool result message
|
||||
flattened = flatten_content_blocks(msg.content)
|
||||
if (
|
||||
isinstance(flattened, list)
|
||||
and flattened
|
||||
and isinstance(flattened[0], dict)
|
||||
and flattened[0].get("type") == "tool_result"
|
||||
):
|
||||
# Format each tool result as a separate message
|
||||
for block in flattened:
|
||||
collected.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": block.get("content", ""),
|
||||
"tool_call_id": block.get(
|
||||
"tool_use_id"
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
collected.append(
|
||||
{
|
||||
"content": flattened,
|
||||
"role": "user",
|
||||
}
|
||||
)
|
||||
tracker.mark_next_start()
|
||||
elif msg_type == "ResultMessage":
|
||||
# Add usage metrics including cost
|
||||
if hasattr(msg, "usage"):
|
||||
usage = extract_usage_from_result_message(msg)
|
||||
# Add total_cost to usage_metadata if available
|
||||
if (
|
||||
hasattr(msg, "total_cost_usd")
|
||||
and msg.total_cost_usd is not None
|
||||
):
|
||||
usage["total_cost"] = msg.total_cost_usd
|
||||
tracker.add_usage(usage)
|
||||
|
||||
# Add conversation-level metadata
|
||||
meta = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"num_turns": getattr(msg, "num_turns", None),
|
||||
"session_id": getattr(msg, "session_id", None),
|
||||
"duration_ms": getattr(msg, "duration_ms", None),
|
||||
"duration_api_ms": getattr(
|
||||
msg, "duration_api_ms", None
|
||||
),
|
||||
"is_error": getattr(msg, "is_error", None),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if meta:
|
||||
run.metadata.update(meta)
|
||||
|
||||
yield msg
|
||||
run.end(outputs=collected[-1] if collected else None)
|
||||
except Exception:
|
||||
logger.exception("Error while tracing Claude Agent stream")
|
||||
finally:
|
||||
tracker.close()
|
||||
clear_parent_run_tree()
|
||||
clear_active_tool_runs()
|
||||
|
||||
async def __aenter__(self) -> "TracedClaudeSDKClient":
|
||||
await super().__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await super().__aexit__(*args)
|
||||
|
||||
return TracedClaudeSDKClient
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Configuration management for Claude Agent SDK tracing."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
# Global configuration for tracing
|
||||
_tracing_config: dict[str, Any] = {
|
||||
"name": None,
|
||||
"project_name": None,
|
||||
"metadata": None,
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
def set_tracing_config(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""Set the global tracing configuration for Claude Agent SDK.
|
||||
|
||||
Args:
|
||||
name: Name of the root trace.
|
||||
project_name: LangSmith project to trace to.
|
||||
metadata: Metadata to associate with all traces.
|
||||
tags: Tags to associate with all traces.
|
||||
"""
|
||||
global _tracing_config
|
||||
_tracing_config = {
|
||||
"name": name,
|
||||
"project_name": project_name,
|
||||
"metadata": metadata,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
|
||||
def get_tracing_config() -> dict[str, Any]:
|
||||
"""Get the current tracing configuration."""
|
||||
return _tracing_config.copy()
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Hook-based tool tracing for Claude Agent SDK."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langsmith.run_helpers import get_current_run_tree
|
||||
from langsmith.run_trees import RunTree
|
||||
|
||||
from ._tools import get_parent_run_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from claude_agent_sdk import (
|
||||
HookContext,
|
||||
HookInput,
|
||||
HookJSONOutput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Key: tool_use_id, Value: (run_tree, start_time)
|
||||
_active_tool_runs: dict[str, tuple[Any, float]] = {}
|
||||
|
||||
# Storage for tool or subagent runs managed by client
|
||||
# Key: tool_use_id, Value: run_tree
|
||||
_client_managed_runs: dict[str, RunTree] = {}
|
||||
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: "HookInput",
|
||||
tool_use_id: Optional[str],
|
||||
context: "HookContext",
|
||||
) -> "HookJSONOutput":
|
||||
"""Trace tool execution before it starts.
|
||||
|
||||
Args:
|
||||
input_data: Contains `tool_name`, `tool_input`, `session_id`
|
||||
tool_use_id: Unique identifier for this tool invocation
|
||||
context: Hook context (currently contains only signal)
|
||||
|
||||
Returns:
|
||||
Hook output (empty dict allows execution to proceed)
|
||||
"""
|
||||
if not tool_use_id:
|
||||
return {}
|
||||
|
||||
# Skip if this tool run is already managed by the client
|
||||
if tool_use_id in _client_managed_runs:
|
||||
return {}
|
||||
|
||||
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
|
||||
tool_input = input_data.get("tool_input", {})
|
||||
|
||||
try:
|
||||
parent = get_parent_run_tree() or get_current_run_tree()
|
||||
if not parent:
|
||||
return {}
|
||||
|
||||
start_time = time.time()
|
||||
tool_run = parent.create_child(
|
||||
name=tool_name,
|
||||
run_type="tool",
|
||||
inputs={"input": tool_input} if tool_input else {},
|
||||
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
tool_run.post()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post tool run for {tool_name}: {e}")
|
||||
|
||||
_active_tool_runs[tool_use_id] = (tool_run, start_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in PreToolUse hook for {tool_name}: {e}", exc_info=True)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: "HookInput",
|
||||
tool_use_id: Optional[str],
|
||||
context: "HookContext",
|
||||
) -> "HookJSONOutput":
|
||||
"""Trace tool execution after it completes.
|
||||
|
||||
Args:
|
||||
input_data: Contains `tool_name`, `tool_input`, `tool_response`, `session_id`, etc.
|
||||
tool_use_id: Unique identifier for this tool invocation
|
||||
context: Hook context (currently contains only signal)
|
||||
|
||||
Returns:
|
||||
Hook output (empty `dict` by default)
|
||||
""" # noqa: E501
|
||||
if not tool_use_id:
|
||||
return {}
|
||||
|
||||
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
|
||||
tool_response = input_data.get("tool_response")
|
||||
|
||||
# Check if this is a client-managed run
|
||||
run_tree = _client_managed_runs.pop(tool_use_id, None)
|
||||
if run_tree:
|
||||
# This run is managed by the client (subagent session or its tools)
|
||||
try:
|
||||
if isinstance(tool_response, dict):
|
||||
outputs = tool_response
|
||||
elif isinstance(tool_response, list):
|
||||
outputs = {"content": tool_response}
|
||||
else:
|
||||
outputs = {"output": str(tool_response)} if tool_response else {}
|
||||
|
||||
is_error = False
|
||||
if isinstance(tool_response, dict):
|
||||
is_error = tool_response.get("is_error", False)
|
||||
|
||||
run_tree.end(
|
||||
outputs=outputs,
|
||||
error=outputs.get("output") if is_error else None,
|
||||
)
|
||||
run_tree.patch()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update client-managed run: {e}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
run_info = _active_tool_runs.pop(tool_use_id, None)
|
||||
if not run_info:
|
||||
return {}
|
||||
|
||||
tool_run, start_time = run_info
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
outputs = tool_response
|
||||
elif isinstance(tool_response, list):
|
||||
outputs = {"content": tool_response}
|
||||
else:
|
||||
outputs = {"output": str(tool_response)} if tool_response else {}
|
||||
|
||||
# Check if the tool execution was an error
|
||||
is_error = False
|
||||
if isinstance(tool_response, dict):
|
||||
is_error = tool_response.get("is_error", False)
|
||||
|
||||
tool_run.end(
|
||||
outputs=outputs,
|
||||
error=outputs.get("output") if is_error else None,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_run.patch()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to patch tool run for {tool_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in PostToolUse hook for {tool_name}: {e}", exc_info=True)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def post_tool_use_failure_hook(
|
||||
input_data: "HookInput",
|
||||
tool_use_id: Optional[str],
|
||||
context: "HookContext",
|
||||
) -> "HookJSONOutput":
|
||||
"""Trace tool execution when it fails.
|
||||
|
||||
This hook fires for built-in tool failures (Bash, Read, Write, etc.)
|
||||
and is mutually exclusive with :func:`post_tool_use_hook` — when a
|
||||
built-in tool fails, only ``PostToolUseFailure`` fires.
|
||||
|
||||
Args:
|
||||
input_data: Contains ``tool_name``, ``tool_input``, ``error``,
|
||||
and optionally ``is_interrupt``.
|
||||
tool_use_id: Unique identifier for this tool invocation
|
||||
context: Hook context (currently contains only signal)
|
||||
|
||||
Returns:
|
||||
Hook output (empty dict)
|
||||
"""
|
||||
if not tool_use_id:
|
||||
return {}
|
||||
|
||||
tool_name: str = str(input_data.get("tool_name", "unknown_tool"))
|
||||
error: str = str(input_data.get("error", "Unknown error"))
|
||||
|
||||
# Check if this is a client-managed run (subagent or its tools)
|
||||
run_tree = _client_managed_runs.pop(tool_use_id, None)
|
||||
if run_tree:
|
||||
try:
|
||||
run_tree.end(
|
||||
outputs={"error": error},
|
||||
error=error,
|
||||
)
|
||||
run_tree.patch()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update client-managed run on failure: {e}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
run_info = _active_tool_runs.pop(tool_use_id, None)
|
||||
if not run_info:
|
||||
return {}
|
||||
|
||||
tool_run, start_time = run_info
|
||||
|
||||
tool_run.end(
|
||||
outputs={"error": error},
|
||||
error=error,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_run.patch()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to patch failed tool run for {tool_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error in PostToolUseFailure hook for {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def clear_active_tool_runs() -> None:
|
||||
"""Clear all active tool runs.
|
||||
|
||||
This should be called when a conversation ends to avoid memory leaks
|
||||
and to clean up any orphaned tool runs.
|
||||
"""
|
||||
global _active_tool_runs, _client_managed_runs
|
||||
|
||||
# End any orphaned client-managed runs
|
||||
for tool_use_id, run_tree in _client_managed_runs.items():
|
||||
try:
|
||||
run_tree.end(error="Client-managed run not completed (conversation ended)")
|
||||
run_tree.patch()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to clean up orphaned client-managed run {tool_use_id}: {e}"
|
||||
)
|
||||
|
||||
# End any orphaned tool runs
|
||||
for tool_use_id, (tool_run, _) in _active_tool_runs.items():
|
||||
try:
|
||||
tool_run.end(error="Tool run not completed (conversation ended)")
|
||||
tool_run.patch()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to clean up orphaned tool run {tool_use_id}: {e}")
|
||||
|
||||
_active_tool_runs.clear()
|
||||
_client_managed_runs.clear()
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Message processing and content serialization for Claude Agent SDK."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _extract_tool_result_text(content: Any) -> str:
|
||||
"""Extract text content from tool result content blocks."""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
texts.append(item.get("text", ""))
|
||||
elif hasattr(item, "text"):
|
||||
texts.append(getattr(item, "text", ""))
|
||||
return "\n".join(texts) if texts else str(content)
|
||||
return str(content)
|
||||
|
||||
|
||||
def flatten_content_blocks(content: Any) -> Any:
|
||||
"""Convert SDK content blocks into serializable dicts using explicit type checks."""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
result = []
|
||||
for block in content:
|
||||
block_type = type(block).__name__
|
||||
|
||||
# Handle known Claude SDK block types
|
||||
if block_type == "TextBlock":
|
||||
result.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": getattr(block, "text", ""),
|
||||
}
|
||||
)
|
||||
elif block_type == "ThinkingBlock":
|
||||
result.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": getattr(block, "thinking", ""),
|
||||
"signature": getattr(block, "signature", ""),
|
||||
}
|
||||
)
|
||||
elif block_type == "ToolUseBlock":
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": getattr(block, "id", None),
|
||||
"name": getattr(block, "name", None),
|
||||
"input": getattr(block, "input", None),
|
||||
}
|
||||
)
|
||||
elif block_type == "ToolResultBlock":
|
||||
# Extract text from nested content for tool results
|
||||
tool_content = getattr(block, "content", None)
|
||||
content_text = _extract_tool_result_text(tool_content)
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": getattr(block, "tool_use_id", None),
|
||||
"content": content_text,
|
||||
"is_error": getattr(block, "is_error", False),
|
||||
}
|
||||
)
|
||||
else:
|
||||
result.append(block)
|
||||
return result
|
||||
|
||||
|
||||
def build_llm_input(prompt: Any, history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Construct a combined prompt + history message list."""
|
||||
if isinstance(prompt, str):
|
||||
entry = {"content": prompt, "role": "user"}
|
||||
return [entry, *history] if history else [entry]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
formatted = []
|
||||
for msg in prompt:
|
||||
if not isinstance(msg, dict):
|
||||
formatted.append(msg)
|
||||
continue
|
||||
|
||||
if "message" in msg:
|
||||
inner = msg["message"]
|
||||
if isinstance(inner, dict):
|
||||
formatted.append(
|
||||
{
|
||||
"role": inner.get("role", "user"),
|
||||
"content": inner.get("content", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted.append(msg)
|
||||
elif "role" in msg and "content" in msg:
|
||||
formatted.append(msg)
|
||||
else:
|
||||
formatted.append(msg)
|
||||
|
||||
return [*formatted, *history] if history else formatted
|
||||
|
||||
return history or []
|
||||
|
||||
|
||||
def extract_usage_from_result_message(msg: Any) -> dict[str, Any]:
|
||||
"""Normalize and merge token usage metrics from a `ResultMessage`."""
|
||||
from ._usage import extract_usage_metadata, sum_anthropic_tokens
|
||||
|
||||
if not getattr(msg, "usage", None):
|
||||
return {}
|
||||
metrics = extract_usage_metadata(msg.usage)
|
||||
return sum_anthropic_tokens(metrics) if metrics else {}
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Thread-local storage utilities for Claude Agent SDK tracing.
|
||||
|
||||
This module provides thread-local storage for the parent run tree,
|
||||
which is used by hooks to maintain trace context when async context
|
||||
propagation is broken.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread-local store for passing the parent run tree into hooks.
|
||||
# Claude's async event loop by default breaks tracing.
|
||||
# contextvars start empty within new anyio threads. The parent run tree is threaded
|
||||
# via thread-local as a fallback when context propagation isn't available.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def set_parent_run_tree(run_tree: Any) -> None:
|
||||
"""Set the parent run tree in thread-local storage."""
|
||||
_thread_local.parent_run_tree = run_tree
|
||||
|
||||
|
||||
def clear_parent_run_tree() -> None:
|
||||
"""Clear the parent run tree from thread-local storage."""
|
||||
if hasattr(_thread_local, "parent_run_tree"):
|
||||
delattr(_thread_local, "parent_run_tree")
|
||||
|
||||
|
||||
def get_parent_run_tree() -> Any:
|
||||
"""Get the parent run tree from thread-local storage."""
|
||||
return getattr(_thread_local, "parent_run_tree", None)
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Token usage utilities for Claude Agent SDK."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_usage_metadata(usage: Any) -> dict[str, Any]:
|
||||
"""Extract and normalize usage metrics from a Claude usage object or dict."""
|
||||
if not usage:
|
||||
return {}
|
||||
|
||||
get = usage.get if isinstance(usage, dict) else lambda k: getattr(usage, k, None)
|
||||
|
||||
def to_int(value):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def to_float(value):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
if (v := to_int(get("input_tokens"))) is not None:
|
||||
meta["input_tokens"] = v
|
||||
if (v := to_int(get("output_tokens"))) is not None:
|
||||
meta["output_tokens"] = v
|
||||
|
||||
cache_read = to_float(get("cache_read_input_tokens"))
|
||||
cache_create = to_float(get("cache_creation_input_tokens"))
|
||||
if cache_read is not None or cache_create is not None:
|
||||
meta["input_token_details"] = {}
|
||||
if cache_read is not None:
|
||||
meta["input_token_details"]["cache_read"] = cache_read
|
||||
if cache_create is not None:
|
||||
meta["input_token_details"]["cache_creation"] = cache_create
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def sum_anthropic_tokens(usage_metadata: dict[str, Any]) -> dict[str, int]:
|
||||
"""Sum Anthropic cache tokens into `input_tokens` and add `total_tokens`."""
|
||||
details = usage_metadata.get("input_token_details") or {}
|
||||
cache_read = details.get(
|
||||
"cache_read", usage_metadata.get("cache_read_input_tokens")
|
||||
)
|
||||
cache_create = details.get(
|
||||
"cache_creation", usage_metadata.get("cache_creation_input_tokens")
|
||||
)
|
||||
|
||||
input_tokens = usage_metadata.get("input_tokens") or 0
|
||||
cache_read_val = cache_read or 0
|
||||
cache_create_val = cache_create or 0
|
||||
total_prompt = input_tokens + cache_read_val + cache_create_val
|
||||
|
||||
output_tokens = usage_metadata.get("output_tokens") or 0
|
||||
return {
|
||||
**usage_metadata,
|
||||
"input_tokens": total_prompt,
|
||||
"total_tokens": total_prompt + output_tokens,
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
"""LangSmith integration for Google ADK (Agent Development Kit)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from ._config import set_tracing_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["configure_google_adk", "create_traced_session_context"]
|
||||
|
||||
_patched = False
|
||||
|
||||
|
||||
def configure_google_adk(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
) -> bool:
|
||||
"""Enable LangSmith tracing for Google ADK.
|
||||
|
||||
Can be called before or after importing Runner (import-order agnostic).
|
||||
|
||||
Args:
|
||||
name: Name of the root trace. Defaults to "google_adk.session".
|
||||
project_name: LangSmith project to trace to.
|
||||
metadata: Metadata to associate with all traces.
|
||||
tags: Tags to associate with all traces.
|
||||
|
||||
Returns:
|
||||
True if configuration was successful, False otherwise.
|
||||
"""
|
||||
global _patched
|
||||
|
||||
if _patched:
|
||||
set_tracing_config(
|
||||
name=name, project_name=project_name, metadata=metadata, tags=tags
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
import google.adk # noqa: F401
|
||||
from wrapt import wrap_function_wrapper
|
||||
except ImportError as e:
|
||||
logger.warning(f"Missing dependency: {e}")
|
||||
return False
|
||||
|
||||
set_tracing_config(
|
||||
name=name, project_name=project_name, metadata=metadata, tags=tags
|
||||
)
|
||||
|
||||
from ._client import (
|
||||
wrap_agent_run_async,
|
||||
wrap_flow_call_llm_async,
|
||||
wrap_runner_run,
|
||||
wrap_runner_run_async,
|
||||
wrap_tool_run_async,
|
||||
)
|
||||
|
||||
_wraps = [
|
||||
(
|
||||
"google.adk.runners",
|
||||
"Runner.run",
|
||||
wrap_runner_run,
|
||||
),
|
||||
(
|
||||
"google.adk.runners",
|
||||
"Runner.run_async",
|
||||
wrap_runner_run_async,
|
||||
),
|
||||
(
|
||||
"google.adk.agents.base_agent",
|
||||
"BaseAgent.run_async",
|
||||
wrap_agent_run_async,
|
||||
),
|
||||
(
|
||||
"google.adk.flows.llm_flows.base_llm_flow",
|
||||
"BaseLlmFlow._call_llm_async",
|
||||
wrap_flow_call_llm_async,
|
||||
),
|
||||
(
|
||||
"google.adk.tools.base_tool",
|
||||
"BaseTool.run_async",
|
||||
wrap_tool_run_async,
|
||||
),
|
||||
(
|
||||
"google.adk.tools.function_tool",
|
||||
"FunctionTool.run_async",
|
||||
wrap_tool_run_async,
|
||||
),
|
||||
(
|
||||
"google.adk.tools.mcp_tool.mcp_tool",
|
||||
"McpTool.run_async",
|
||||
wrap_tool_run_async,
|
||||
),
|
||||
]
|
||||
|
||||
for module, name, wrapper in _wraps:
|
||||
try:
|
||||
wrap_function_wrapper(module, name, wrapper)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to wrap {name}: {e}")
|
||||
|
||||
_patched = True
|
||||
return True
|
||||
|
||||
|
||||
def create_traced_session_context(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inputs: Optional[dict] = None,
|
||||
):
|
||||
"""Create a trace context for manual session tracing."""
|
||||
from ._client import create_traced_session_context as _create_context
|
||||
|
||||
return _create_context(
|
||||
name=name,
|
||||
project_name=project_name,
|
||||
metadata=metadata,
|
||||
tags=tags,
|
||||
inputs=inputs,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,489 @@
|
||||
"""Client instrumentation for Google ADK using wrapt."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import aclosing
|
||||
from datetime import datetime, timezone
|
||||
from functools import cache
|
||||
from typing import Any, Optional
|
||||
|
||||
from langsmith.run_helpers import get_current_run_tree, set_tracing_parent, trace
|
||||
|
||||
from ._config import get_tracing_config
|
||||
from ._messages import convert_llm_request_to_messages, has_function_calls
|
||||
from ._usage import extract_model_name, extract_usage_from_response
|
||||
|
||||
_LS_PROVIDER_VERTEXAI = "google_vertexai"
|
||||
_LS_PROVIDER_GOOGLE_AI = "google_ai"
|
||||
|
||||
|
||||
def extract_tools_from_llm_request(llm_request: Any) -> list[dict[str, Any]]:
|
||||
"""Extract tool definitions from LlmRequest and convert to OpenAI format."""
|
||||
config = getattr(llm_request, "config", None)
|
||||
if not config:
|
||||
return []
|
||||
|
||||
tools_list = getattr(config, "tools", None)
|
||||
if not tools_list:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for tool in tools_list:
|
||||
for func_decl in getattr(tool, "function_declarations", None) or []:
|
||||
try:
|
||||
dumped = func_decl.model_dump(exclude_none=True)
|
||||
result.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": dumped,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_ls_provider() -> str:
|
||||
"""Detect provider based on GOOGLE_GENAI_USE_VERTEXAI env var."""
|
||||
import os
|
||||
|
||||
use_vertexai = os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "0").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
return _LS_PROVIDER_VERTEXAI if use_vertexai else _LS_PROVIDER_GOOGLE_AI
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRACE_CHAIN_NAME = "google_adk.session"
|
||||
|
||||
|
||||
@cache
|
||||
def _get_package_version(package_name: str) -> str | None:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# Attribute name used to bridge the root run from Runner.run (sync) into the
|
||||
# background thread where Runner.run_async executes. Runner.run spins up a
|
||||
# new thread for its internal asyncio event loop, so context vars don't
|
||||
# propagate automatically. Storing the run on the instance (a plain object
|
||||
# attribute) crosses the thread boundary, and wrap_runner_run_async picks it
|
||||
# up and re-establishes it as a context var.
|
||||
_SYNC_ROOT_RUN_ATTR = "_langsmith_root_run"
|
||||
|
||||
|
||||
def _extract_text_from_content(content: Any) -> Optional[str]:
|
||||
if content is None:
|
||||
return None
|
||||
parts = getattr(content, "parts", None)
|
||||
if not parts:
|
||||
return None
|
||||
text_parts = [str(p.text) for p in parts if getattr(p, "text", None)]
|
||||
return " ".join(text_parts) if text_parts else None
|
||||
|
||||
|
||||
def _iter_invocation_events(ctx: Any) -> list[Any]:
|
||||
"""Get session events for the current invocation."""
|
||||
session = getattr(ctx, "session", None)
|
||||
if session is None:
|
||||
return []
|
||||
invocation_id = getattr(ctx, "invocation_id", None)
|
||||
events = getattr(session, "events", None) or []
|
||||
if invocation_id is None:
|
||||
return list(events)
|
||||
return [e for e in events if getattr(e, "invocation_id", None) == invocation_id]
|
||||
|
||||
|
||||
def _extract_latest_invocation_text(ctx: Any) -> Optional[str]:
|
||||
"""Get the latest text from session events for the current invocation."""
|
||||
for event in reversed(_iter_invocation_events(ctx)):
|
||||
text = _extract_text_from_content(getattr(event, "content", None))
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
|
||||
def wrap_runner_run(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any:
|
||||
"""Wrap Runner.run to create a root trace for synchronous execution.
|
||||
|
||||
Runner.run internally starts a new thread to run its async event loop, so
|
||||
context vars set here would not be visible to code running in that thread.
|
||||
We bridge the gap by storing the root run on the instance (a plain object
|
||||
attribute that IS visible across threads) so that wrap_runner_run_async can
|
||||
re-establish it as a context var inside the async event loop.
|
||||
"""
|
||||
config = get_tracing_config()
|
||||
trace_name = config.get("name") or TRACE_CHAIN_NAME
|
||||
|
||||
trace_inputs: dict[str, Any] = {}
|
||||
if new_message := kwargs.get("new_message"):
|
||||
if text := _extract_text_from_content(new_message):
|
||||
trace_inputs["input"] = text
|
||||
|
||||
trace_metadata: dict[str, Any] = {
|
||||
"ls_provider": _get_ls_provider(),
|
||||
"ls_integration": "google-adk",
|
||||
"ls_integration_version": _get_package_version("google-adk"),
|
||||
**(config.get("metadata") or {}),
|
||||
}
|
||||
if app_name := getattr(instance, "app_name", None):
|
||||
trace_metadata["app_name"] = app_name
|
||||
if user_id := kwargs.get("user_id"):
|
||||
trace_metadata["user_id"] = user_id
|
||||
if session_id := kwargs.get("session_id"):
|
||||
trace_metadata["session_id"] = session_id
|
||||
|
||||
def _trace_run():
|
||||
with trace(
|
||||
name=trace_name,
|
||||
run_type="chain",
|
||||
inputs=trace_inputs,
|
||||
project_name=config.get("project_name"),
|
||||
tags=config.get("tags"),
|
||||
metadata=trace_metadata,
|
||||
) as root_run:
|
||||
setattr(instance, _SYNC_ROOT_RUN_ATTR, root_run)
|
||||
try:
|
||||
events = list(wrapped(*args, **kwargs))
|
||||
final_output = None
|
||||
for event in reversed(events):
|
||||
if content := getattr(event, "content", None):
|
||||
if text := _extract_text_from_content(content):
|
||||
final_output = text
|
||||
break
|
||||
root_run.end(outputs={"output": final_output} if final_output else None)
|
||||
yield from events
|
||||
except Exception as e:
|
||||
root_run.end(error=str(e))
|
||||
raise
|
||||
finally:
|
||||
setattr(instance, _SYNC_ROOT_RUN_ATTR, None)
|
||||
|
||||
return _trace_run()
|
||||
|
||||
|
||||
async def wrap_runner_run_async(
|
||||
wrapped: Any, instance: Any, args: Any, kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrap Runner.run_async to create a root trace for asynchronous execution.
|
||||
|
||||
When called from the background thread spawned by Runner.run, the root run
|
||||
stored on the instance is re-established as a context var so that
|
||||
wrap_agent_run_async and wrap_flow_call_llm_async can find the parent via
|
||||
get_current_run_tree().
|
||||
"""
|
||||
root_run = getattr(instance, _SYNC_ROOT_RUN_ATTR, None)
|
||||
if root_run is not None:
|
||||
# sync bridge: re-establish root run as context var in this thread
|
||||
with set_tracing_parent(root_run):
|
||||
async with aclosing(wrapped(*args, **kwargs)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
return
|
||||
|
||||
config = get_tracing_config()
|
||||
trace_name = config.get("name") or TRACE_CHAIN_NAME
|
||||
|
||||
trace_inputs: dict[str, Any] = {}
|
||||
if new_message := kwargs.get("new_message"):
|
||||
if text := _extract_text_from_content(new_message):
|
||||
trace_inputs["input"] = text
|
||||
|
||||
trace_metadata: dict[str, Any] = {
|
||||
"ls_provider": _get_ls_provider(),
|
||||
"ls_integration": "google-adk",
|
||||
"ls_integration_version": _get_package_version("google-adk"),
|
||||
**(config.get("metadata") or {}),
|
||||
}
|
||||
if app_name := getattr(instance, "app_name", None):
|
||||
trace_metadata["app_name"] = app_name
|
||||
if user_id := kwargs.get("user_id"):
|
||||
trace_metadata["user_id"] = user_id
|
||||
if session_id := kwargs.get("session_id"):
|
||||
trace_metadata["session_id"] = session_id
|
||||
|
||||
async def _trace_run_async() -> AsyncGenerator[Any, None]:
|
||||
async with trace(
|
||||
name=trace_name,
|
||||
run_type="chain",
|
||||
inputs=trace_inputs,
|
||||
project_name=config.get("project_name"),
|
||||
tags=config.get("tags"),
|
||||
metadata=trace_metadata,
|
||||
) as run:
|
||||
try:
|
||||
final_output: Optional[str] = None
|
||||
async with aclosing(wrapped(*args, **kwargs)) as agen:
|
||||
async for event in agen:
|
||||
if content := getattr(event, "content", None):
|
||||
if text := _extract_text_from_content(content):
|
||||
final_output = text
|
||||
yield event
|
||||
run.end(outputs={"output": final_output} if final_output else None)
|
||||
except Exception as e:
|
||||
run.end(error=str(e))
|
||||
raise
|
||||
|
||||
async for event in _trace_run_async():
|
||||
yield event
|
||||
|
||||
|
||||
async def wrap_agent_run_async(
|
||||
wrapped: Any, instance: Any, args: Any, kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrap BaseAgent.run_async to create a chain span for each agent invocation."""
|
||||
parent = get_current_run_tree()
|
||||
if not parent:
|
||||
async with aclosing(wrapped(*args, **kwargs)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
return
|
||||
|
||||
ctx = args[0] if args else kwargs.get("parent_context")
|
||||
agent_name = getattr(instance, "name", None) or type(instance).__name__
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
if ctx is not None:
|
||||
if latest := _extract_latest_invocation_text(ctx):
|
||||
inputs["input"] = latest
|
||||
|
||||
async with trace(name=agent_name, run_type="chain", inputs=inputs) as agent_run:
|
||||
try:
|
||||
final_output: Optional[str] = None
|
||||
async with aclosing(wrapped(*args, **kwargs)) as agen:
|
||||
async for event in agen:
|
||||
if content := getattr(event, "content", None):
|
||||
if text := _extract_text_from_content(content):
|
||||
final_output = text
|
||||
yield event
|
||||
agent_run.end(outputs={"output": final_output} if final_output else None)
|
||||
except Exception as e:
|
||||
agent_run.end(error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
async def wrap_tool_run_async(
|
||||
wrapped: Any, instance: Any, args: Any, kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrap BaseTool.run_async (all tool subclasses) to trace tool invocations."""
|
||||
parent = get_current_run_tree()
|
||||
if not parent:
|
||||
return await wrapped(*args, **kwargs)
|
||||
|
||||
tool_name = getattr(instance, "name", None) or type(instance).__name__
|
||||
tool_args = kwargs.get("args") or (args[0] if args else {})
|
||||
inputs = tool_args if isinstance(tool_args, dict) else {"args": tool_args}
|
||||
|
||||
start_time = time.time()
|
||||
tool_run = parent.create_child(
|
||||
name=tool_name,
|
||||
run_type="tool",
|
||||
inputs=inputs,
|
||||
extra={"metadata": {"ls_provider": _get_ls_provider()}},
|
||||
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
tool_run.post()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to post tool run: {e}")
|
||||
|
||||
try:
|
||||
result = await wrapped(*args, **kwargs)
|
||||
if isinstance(result, dict):
|
||||
outputs = result
|
||||
elif isinstance(result, list):
|
||||
outputs = {"content": result}
|
||||
elif result is not None:
|
||||
outputs = {"output": str(result)}
|
||||
else:
|
||||
outputs = {}
|
||||
tool_run.end(outputs=outputs)
|
||||
try:
|
||||
tool_run.patch()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to patch tool run: {e}")
|
||||
return result
|
||||
except Exception as e:
|
||||
tool_run.end(error=str(e))
|
||||
try:
|
||||
tool_run.patch()
|
||||
except Exception as patch_e:
|
||||
logger.debug(f"Failed to patch tool run on error: {patch_e}")
|
||||
raise
|
||||
|
||||
|
||||
def _determine_llm_call_type(llm_request: Any, llm_response: Any) -> str:
|
||||
try:
|
||||
for content in getattr(llm_request, "contents", None) or []:
|
||||
for part in getattr(content, "parts", None) or []:
|
||||
if hasattr(part, "function_response") and part.function_response:
|
||||
return "response_generation"
|
||||
if has_function_calls(llm_response):
|
||||
return "tool_selection"
|
||||
return "direct_response"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
|
||||
async def wrap_flow_call_llm_async(
|
||||
wrapped: Any, instance: Any, args: Any, kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrap BaseLlmFlow._call_llm_async to capture LLM calls with TTFT tracking."""
|
||||
parent = get_current_run_tree()
|
||||
if not parent:
|
||||
async for event in wrapped(*args, **kwargs):
|
||||
yield event
|
||||
return
|
||||
|
||||
llm_request = args[1] if len(args) > 1 else kwargs.get("llm_request")
|
||||
model_name = extract_model_name(llm_request) if llm_request else None
|
||||
messages = convert_llm_request_to_messages(llm_request) if llm_request else None
|
||||
tools = extract_tools_from_llm_request(llm_request) if llm_request else []
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
if messages:
|
||||
inputs["messages"] = messages
|
||||
|
||||
metadata: dict[str, Any] = {"ls_provider": _get_ls_provider()}
|
||||
if model_name:
|
||||
metadata["ls_model_name"] = model_name
|
||||
|
||||
# Build extra dict with invocation_params if tools exist
|
||||
extra: dict[str, Any] = {"metadata": metadata}
|
||||
if tools:
|
||||
extra["invocation_params"] = {"tools": tools}
|
||||
|
||||
start_time = time.time()
|
||||
llm_run = parent.create_child(
|
||||
name=model_name or "google_adk_llm",
|
||||
run_type="llm",
|
||||
inputs=inputs,
|
||||
extra=extra,
|
||||
start_time=datetime.fromtimestamp(start_time, tz=timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
llm_run.post()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to post LLM run: {e}")
|
||||
|
||||
first_token_time: Optional[float] = None
|
||||
last_event = None
|
||||
event_with_content = None
|
||||
|
||||
try:
|
||||
async with aclosing(wrapped(*args, **kwargs)) as agen:
|
||||
async for event in agen:
|
||||
is_partial = getattr(event, "partial", False)
|
||||
|
||||
if first_token_time is None and is_partial:
|
||||
first_token_time = time.time()
|
||||
try:
|
||||
llm_run.add_event(
|
||||
{
|
||||
"name": "new_token",
|
||||
"time": datetime.fromtimestamp(
|
||||
first_token_time, tz=timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add new_token event: {e}")
|
||||
|
||||
last_event = event
|
||||
if hasattr(event, "content") and event.content is not None:
|
||||
event_with_content = event
|
||||
yield event
|
||||
|
||||
outputs: dict[str, Any] = {"role": "assistant"}
|
||||
content_source = event_with_content or last_event
|
||||
|
||||
if (
|
||||
content_source
|
||||
and hasattr(content_source, "content")
|
||||
and content_source.content
|
||||
):
|
||||
parts = getattr(content_source.content, "parts", None) or []
|
||||
text_parts, tool_calls = [], []
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if hasattr(part, "text") and part.text:
|
||||
text_parts.append(str(part.text))
|
||||
elif hasattr(part, "function_call") and part.function_call:
|
||||
fc = part.function_call
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(fc, "name", ""),
|
||||
"arguments": json.dumps(
|
||||
dict(fc.args) if getattr(fc, "args", None) else {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
outputs["content"] = " ".join(text_parts) if text_parts else None
|
||||
if tool_calls:
|
||||
outputs["tool_calls"] = tool_calls
|
||||
|
||||
if last_event:
|
||||
if usage := extract_usage_from_response(last_event):
|
||||
llm_run.extra.setdefault("metadata", {})["usage_metadata"] = usage
|
||||
|
||||
if first_token_time is not None:
|
||||
llm_run.extra.setdefault("metadata", {})["time_to_first_token"] = (
|
||||
first_token_time - start_time
|
||||
)
|
||||
|
||||
if last_event and llm_request:
|
||||
llm_run.extra.setdefault("metadata", {})["llm_call_type"] = (
|
||||
_determine_llm_call_type(llm_request, last_event)
|
||||
)
|
||||
|
||||
llm_run.end(outputs=outputs)
|
||||
try:
|
||||
llm_run.patch()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to patch LLM run: {e}")
|
||||
|
||||
except Exception as e:
|
||||
llm_run.end(error=str(e))
|
||||
try:
|
||||
llm_run.patch()
|
||||
except Exception as patch_e:
|
||||
logger.debug(f"Failed to patch LLM run on error: {patch_e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_traced_session_context(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""Create a trace context for manual session tracing."""
|
||||
config = get_tracing_config()
|
||||
return trace(
|
||||
name=name or config.get("name") or TRACE_CHAIN_NAME,
|
||||
run_type="chain",
|
||||
inputs=inputs or {},
|
||||
project_name=project_name or config.get("project_name"),
|
||||
tags=tags or config.get("tags"),
|
||||
metadata={**(config.get("metadata") or {}), **(metadata or {})},
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Configuration for Google ADK tracing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
_tracing_config: dict[str, Any] = {
|
||||
"name": None,
|
||||
"project_name": None,
|
||||
"metadata": None,
|
||||
"tags": None,
|
||||
}
|
||||
|
||||
|
||||
def set_tracing_config(
|
||||
name: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
global _tracing_config
|
||||
_tracing_config = {
|
||||
"name": name,
|
||||
"project_name": project_name,
|
||||
"metadata": metadata,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
|
||||
def get_tracing_config() -> dict[str, Any]:
|
||||
return _tracing_config.copy()
|
||||
@@ -0,0 +1,200 @@
|
||||
"""Message serialization for Google ADK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_adk_content_to_langsmith(content: Any) -> list[dict[str, Any]]:
|
||||
"""Convert ADK Content/Part objects to serializable format."""
|
||||
if content is None:
|
||||
return []
|
||||
if hasattr(content, "parts"):
|
||||
parts = content.parts
|
||||
elif isinstance(content, list):
|
||||
parts = content
|
||||
else:
|
||||
return [_serialize_part(content)]
|
||||
return [_serialize_part(part) for part in parts if part is not None]
|
||||
|
||||
|
||||
def _serialize_part(part: Any) -> dict[str, Any]:
|
||||
"""Serialize a single Part."""
|
||||
if isinstance(part, dict):
|
||||
return part
|
||||
|
||||
if hasattr(part, "inline_data") and part.inline_data:
|
||||
data = getattr(part.inline_data, "data", None)
|
||||
mime_type = getattr(part.inline_data, "mime_type", "application/octet-stream")
|
||||
if data is not None:
|
||||
encoded = (
|
||||
base64.b64encode(data).decode("utf-8")
|
||||
if isinstance(data, bytes)
|
||||
else str(data)
|
||||
)
|
||||
return {"type": "image", "data": encoded, "mime_type": mime_type}
|
||||
|
||||
if hasattr(part, "file_data") and part.file_data:
|
||||
return {
|
||||
"type": "file",
|
||||
"file_uri": getattr(part.file_data, "file_uri", None),
|
||||
"mime_type": getattr(part.file_data, "mime_type", None),
|
||||
}
|
||||
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
fc = part.function_call
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"name": getattr(fc, "name", "unknown"),
|
||||
"input": dict(getattr(fc, "args", None) or {}),
|
||||
}
|
||||
|
||||
if hasattr(part, "function_response") and part.function_response:
|
||||
fr = part.function_response
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"name": getattr(fr, "name", "unknown"),
|
||||
"content": _safe_serialize(getattr(fr, "response", None)),
|
||||
}
|
||||
|
||||
if hasattr(part, "text") and part.text is not None:
|
||||
return {"type": "text", "text": str(part.text)}
|
||||
|
||||
if hasattr(part, "executable_code") and part.executable_code:
|
||||
code = part.executable_code
|
||||
return {
|
||||
"type": "executable_code",
|
||||
"language": getattr(code, "language", "python"),
|
||||
"code": getattr(code, "code", ""),
|
||||
}
|
||||
|
||||
if hasattr(part, "code_execution_result") and part.code_execution_result:
|
||||
result = part.code_execution_result
|
||||
return {
|
||||
"type": "code_execution_result",
|
||||
"outcome": getattr(result, "outcome", "unknown"),
|
||||
"output": getattr(result, "output", ""),
|
||||
}
|
||||
|
||||
if hasattr(part, "thought") and part.thought is not None:
|
||||
return {"type": "thinking", "thinking": str(part.thought)}
|
||||
|
||||
return _safe_serialize(part)
|
||||
|
||||
|
||||
def _safe_serialize(obj: Any) -> Any:
|
||||
"""Safely serialize an object to JSON-compatible format."""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _safe_serialize(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_safe_serialize(item) for item in obj]
|
||||
if hasattr(obj, "model_dump"):
|
||||
try:
|
||||
return obj.model_dump()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(obj, "__dict__"):
|
||||
try:
|
||||
return {k: _safe_serialize(v) for k, v in obj.__dict__.items()}
|
||||
except Exception:
|
||||
pass
|
||||
return str(obj)
|
||||
|
||||
|
||||
def convert_llm_request_to_messages(llm_request: Any) -> list[dict[str, Any]]:
|
||||
"""Convert LlmRequest to OpenAI-compatible message format."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
# Extract system instruction from config
|
||||
config = getattr(llm_request, "config", None)
|
||||
if config:
|
||||
sys_inst = getattr(config, "system_instruction", None)
|
||||
if sys_inst:
|
||||
messages.append({"role": "system", "content": str(sys_inst)})
|
||||
|
||||
contents = getattr(llm_request, "contents", None)
|
||||
if not contents:
|
||||
return messages
|
||||
|
||||
for content in contents:
|
||||
role = getattr(content, "role", "user")
|
||||
if role == "model":
|
||||
role = "assistant"
|
||||
|
||||
parts = convert_adk_content_to_langsmith(content)
|
||||
text_parts, tool_calls, tool_results = [], [], []
|
||||
|
||||
for part in parts:
|
||||
t = part.get("type")
|
||||
if t == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
elif t == "tool_use":
|
||||
tool_calls.append(part)
|
||||
elif t == "tool_result":
|
||||
tool_results.append(part)
|
||||
else:
|
||||
text_parts.append(str(part))
|
||||
|
||||
if tool_calls and role == "assistant":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": " ".join(text_parts) if text_parts else None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(tc.get("input", {})),
|
||||
},
|
||||
}
|
||||
for i, tc in enumerate(tool_calls)
|
||||
],
|
||||
}
|
||||
)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
c = tr.get("content")
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"name": tr.get("name", ""),
|
||||
"content": (
|
||||
json.dumps(c) if isinstance(c, dict) else str(c or "")
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": " ".join(text_parts) if text_parts else "",
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def has_function_calls(llm_response: Any) -> bool:
|
||||
"""Check if LlmResponse contains function calls."""
|
||||
content = getattr(llm_response, "content", None)
|
||||
if not content:
|
||||
return False
|
||||
parts = convert_adk_content_to_langsmith(content)
|
||||
return any(p.get("type") == "tool_use" for p in parts)
|
||||
|
||||
|
||||
def has_function_response_in_request(llm_request: Any) -> bool:
|
||||
"""Check if LlmRequest contains function responses (tool results)."""
|
||||
for content in getattr(llm_request, "contents", None) or []:
|
||||
parts = convert_adk_content_to_langsmith(content)
|
||||
if any(p.get("type") == "tool_result" for p in parts):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Token usage extraction for Google ADK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def extract_usage_from_response(llm_response: Any) -> dict[str, Any]:
|
||||
"""Extract token usage from LlmResponse."""
|
||||
usage: dict[str, Any] = {}
|
||||
usage_metadata = getattr(llm_response, "usage_metadata", None)
|
||||
if not usage_metadata:
|
||||
return usage
|
||||
|
||||
if (v := getattr(usage_metadata, "prompt_token_count", None)) is not None:
|
||||
usage["input_tokens"] = int(v)
|
||||
if (v := getattr(usage_metadata, "candidates_token_count", None)) is not None:
|
||||
usage["output_tokens"] = int(v)
|
||||
if (v := getattr(usage_metadata, "total_token_count", None)) is not None:
|
||||
usage["total_tokens"] = int(v)
|
||||
if (v := getattr(usage_metadata, "cached_content_token_count", None)) is not None:
|
||||
usage.setdefault("input_token_details", {})["cache_read"] = int(v)
|
||||
if (v := getattr(usage_metadata, "thoughts_token_count", None)) is not None:
|
||||
usage.setdefault("output_token_details", {})["reasoning"] = int(v)
|
||||
|
||||
return usage
|
||||
|
||||
|
||||
def extract_model_name(llm_request: Any) -> Optional[str]:
|
||||
"""Extract the model name from an LlmRequest."""
|
||||
if config := getattr(llm_request, "config", None):
|
||||
if model := getattr(config, "model", None):
|
||||
return str(model)
|
||||
if model := getattr(llm_request, "model", None):
|
||||
return str(model)
|
||||
return None
|
||||
@@ -0,0 +1,8 @@
|
||||
"""LangSmith integration for OpenAI Agents SDK.
|
||||
|
||||
This module provides tracing support for the OpenAI Agents SDK.
|
||||
"""
|
||||
|
||||
from ._openai_agents import OpenAIAgentsTracingProcessor
|
||||
|
||||
__all__ = ["OpenAIAgentsTracingProcessor"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
try:
|
||||
from agents import tracing # type: ignore[import]
|
||||
|
||||
HAVE_AGENTS = True
|
||||
except ImportError:
|
||||
HAVE_AGENTS = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RunTypeT = Literal["tool", "chain", "llm", "retriever", "embedding", "prompt", "parser"]
|
||||
|
||||
if HAVE_AGENTS:
|
||||
|
||||
def parse_io(data: Any, default_key: str = "output") -> dict:
|
||||
"""Parse inputs or outputs into a dictionary format.
|
||||
|
||||
Args:
|
||||
data: The data to parse (can be inputs or outputs)
|
||||
default_key: The default key to use if data is not a dict
|
||||
(`'input'` or `'output'`)
|
||||
|
||||
Returns:
|
||||
Dict: The parsed data as a dictionary
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
if len(data) == 0:
|
||||
return {}
|
||||
# Check if this is a list of output blocks (reasoning, message, etc.)
|
||||
if len(data) > 0 and isinstance(data[0], dict):
|
||||
if "type" in data[0]:
|
||||
return {default_key: data}
|
||||
elif len(data) == 1:
|
||||
return data[0]
|
||||
return {default_key: data}
|
||||
elif isinstance(data, dict):
|
||||
data_ = data
|
||||
elif isinstance(data, str):
|
||||
try:
|
||||
parsed_json = json.loads(data)
|
||||
if isinstance(parsed_json, dict):
|
||||
data_ = parsed_json
|
||||
else:
|
||||
data_ = {default_key: data}
|
||||
except json.JSONDecodeError:
|
||||
data_ = {default_key: data}
|
||||
elif (
|
||||
data is not None
|
||||
and hasattr(data, "model_dump")
|
||||
and callable(data.model_dump)
|
||||
and not isinstance(data, type)
|
||||
):
|
||||
try:
|
||||
data_ = data.model_dump(exclude_none=True, mode="json")
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to use model_dump to serialize {type(data)} to JSON: {e}"
|
||||
)
|
||||
data_ = {default_key: data}
|
||||
else:
|
||||
data_ = {default_key: data}
|
||||
|
||||
return data_
|
||||
|
||||
def get_run_type(span: tracing.Span) -> RunTypeT:
|
||||
span_type = getattr(span.span_data, "type", None)
|
||||
if span_type in ["agent", "handoff", "custom"]:
|
||||
return "chain"
|
||||
elif span_type in ["function", "guardrail"]:
|
||||
return "tool"
|
||||
elif span_type in ["generation", "response"]:
|
||||
return "llm"
|
||||
else:
|
||||
return "chain"
|
||||
|
||||
def get_run_name(span: tracing.Span) -> str:
|
||||
if hasattr(span.span_data, "name") and span.span_data.name:
|
||||
return span.span_data.name
|
||||
span_type = getattr(span.span_data, "type", None)
|
||||
if span_type == "generation":
|
||||
return "Generation"
|
||||
elif span_type == "response":
|
||||
return "Response"
|
||||
elif span_type == "handoff":
|
||||
return "Handoff"
|
||||
else:
|
||||
return "Span"
|
||||
|
||||
def _extract_function_span_data(
|
||||
span_data: tracing.FunctionSpanData,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"inputs": parse_io(span_data.input, "input"),
|
||||
"outputs": parse_io(span_data.output, "output"),
|
||||
}
|
||||
|
||||
def _extract_generation_span_data(
|
||||
span_data: tracing.GenerationSpanData,
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"inputs": parse_io(span_data.input, "input"),
|
||||
"outputs": parse_io(span_data.output, "output"),
|
||||
"invocation_params": {
|
||||
"model": span_data.model,
|
||||
"model_config": span_data.model_config,
|
||||
},
|
||||
}
|
||||
if span_data.usage:
|
||||
from langsmith.wrappers._openai import _create_usage_metadata
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["usage_metadata"] = _create_usage_metadata(span_data.usage)
|
||||
return data
|
||||
|
||||
def _extract_response_span_data(
|
||||
span_data: tracing.ResponseSpanData,
|
||||
) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
if span_data.input is not None:
|
||||
data["inputs"] = {
|
||||
"input": span_data.input,
|
||||
"instructions": (
|
||||
span_data.response.instructions
|
||||
if span_data.response is not None
|
||||
and span_data.response.instructions
|
||||
else ""
|
||||
),
|
||||
}
|
||||
if span_data.response is not None:
|
||||
response = span_data.response.model_dump(exclude_none=True, mode="json")
|
||||
output_data = response.pop("output", [])
|
||||
data["outputs"] = parse_io(output_data, "output")
|
||||
data["invocation_params"] = {
|
||||
k: v
|
||||
for k, v in response.items()
|
||||
if k
|
||||
in (
|
||||
"max_output_tokens",
|
||||
"model",
|
||||
"parallel_tool_calls",
|
||||
"reasoning",
|
||||
"temperature",
|
||||
"text",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_p",
|
||||
"truncation",
|
||||
)
|
||||
}
|
||||
metadata = {
|
||||
k: v
|
||||
for k, v in response.items()
|
||||
if k
|
||||
not in (
|
||||
{"output", "usage", "instructions"}.union(data["invocation_params"])
|
||||
)
|
||||
}
|
||||
metadata.update(
|
||||
{
|
||||
"ls_model_name": data["invocation_params"].get("model"),
|
||||
"ls_max_tokens": data["invocation_params"].get("max_output_tokens"),
|
||||
"ls_temperature": data["invocation_params"].get("temperature"),
|
||||
"ls_model_type": "chat",
|
||||
"ls_provider": "openai",
|
||||
}
|
||||
)
|
||||
if usage := response.pop("usage", None):
|
||||
from langsmith.wrappers._openai import _create_usage_metadata
|
||||
|
||||
metadata["usage_metadata"] = _create_usage_metadata(usage)
|
||||
data["metadata"] = metadata
|
||||
|
||||
return data
|
||||
|
||||
def _extract_agent_span_data(span_data: tracing.AgentSpanData) -> dict[str, Any]:
|
||||
return {
|
||||
"invocation_params": {
|
||||
"tools": span_data.tools,
|
||||
"handoffs": span_data.handoffs,
|
||||
},
|
||||
"metadata": {
|
||||
"output_type": span_data.output_type,
|
||||
},
|
||||
}
|
||||
|
||||
def _extract_handoff_span_data(
|
||||
span_data: tracing.HandoffSpanData,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"inputs": {
|
||||
"from_agent": span_data.from_agent,
|
||||
"to_agent": span_data.to_agent,
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_guardrail_span_data(
|
||||
span_data: tracing.GuardrailSpanData,
|
||||
) -> dict[str, Any]:
|
||||
return {"metadata": {"triggered": span_data.triggered}}
|
||||
|
||||
def _extract_custom_span_data(span_data: tracing.CustomSpanData) -> dict[str, Any]:
|
||||
return {"metadata": span_data.data}
|
||||
|
||||
def extract_span_data(span: tracing.Span) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if isinstance(span.span_data, tracing.FunctionSpanData):
|
||||
data.update(_extract_function_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.GenerationSpanData):
|
||||
data.update(_extract_generation_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.ResponseSpanData):
|
||||
data.update(_extract_response_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.AgentSpanData):
|
||||
data.update(_extract_agent_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.HandoffSpanData):
|
||||
data.update(_extract_handoff_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.GuardrailSpanData):
|
||||
data.update(_extract_guardrail_span_data(span.span_data))
|
||||
elif isinstance(span.span_data, tracing.CustomSpanData):
|
||||
data.update(_extract_custom_span_data(span.span_data))
|
||||
else:
|
||||
return {}
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,403 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
from langsmith import run_trees as rt
|
||||
from langsmith._internal import _context
|
||||
from langsmith.run_helpers import get_current_run_tree
|
||||
|
||||
try:
|
||||
from agents import tracing # type: ignore[import]
|
||||
|
||||
required = (
|
||||
"TracingProcessor",
|
||||
"Trace",
|
||||
"Span",
|
||||
"ResponseSpanData",
|
||||
)
|
||||
if not all(hasattr(tracing, name) for name in required):
|
||||
raise ImportError("The `agents` package is not installed.")
|
||||
|
||||
from langsmith.integrations.openai_agents_sdk import (
|
||||
_openai_agent_utils as agent_utils,
|
||||
)
|
||||
|
||||
HAVE_AGENTS = True
|
||||
except ImportError:
|
||||
HAVE_AGENTS = False
|
||||
|
||||
class OpenAIAgentsTracingProcessor:
|
||||
"""Tracing processor for the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/).
|
||||
|
||||
Traces all intermediate steps of your OpenAI Agent to LangSmith.
|
||||
|
||||
Requirements: Make sure to install `pip install -U langsmith[openai-agents]`.
|
||||
|
||||
Args:
|
||||
client: An instance of `langsmith.client.Client`. If not provided, a default
|
||||
client is created.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from agents import (
|
||||
Agent,
|
||||
FileSearchTool,
|
||||
Runner,
|
||||
WebSearchTool,
|
||||
function_tool,
|
||||
set_trace_processors,
|
||||
)
|
||||
|
||||
from langsmith.wrappers import OpenAIAgentsTracingProcessor
|
||||
|
||||
set_trace_processors([OpenAIAgentsTracingProcessor()])
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny"
|
||||
|
||||
|
||||
haiku_agent = Agent(
|
||||
name="Haiku agent",
|
||||
instructions="Always respond in haiku form",
|
||||
model="o3-mini",
|
||||
tools=[get_weather],
|
||||
)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
tools=[WebSearchTool()],
|
||||
instructions="speak in spanish. use Haiku agent if they ask for a haiku or for the weather",
|
||||
handoffs=[haiku_agent],
|
||||
)
|
||||
|
||||
result = await Runner.run(
|
||||
agent,
|
||||
"write a haiku about the weather today and tell me a recent news story about new york",
|
||||
)
|
||||
print(result.final_output)
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The `agents` package is not installed. "
|
||||
"Please install it with `pip install langsmith[openai-agents]`."
|
||||
)
|
||||
|
||||
|
||||
from langsmith import client as ls_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_package_version(package_name: str) -> str | None:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
if HAVE_AGENTS:
|
||||
|
||||
class OpenAIAgentsTracingProcessor(tracing.TracingProcessor): # type: ignore[no-redef]
|
||||
"""Tracing processor for the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/).
|
||||
|
||||
Traces all intermediate steps of your OpenAI Agent to LangSmith.
|
||||
|
||||
Requirements: Make sure to install `pip install -U langsmith[openai-agents]`.
|
||||
|
||||
Args:
|
||||
client: An instance of `langsmith.client.Client`. If not provided,
|
||||
a default client is created.
|
||||
metadata: Metadata to associate with all traces.
|
||||
tags: Tags to associate with all traces.
|
||||
project_name: LangSmith project to trace to.
|
||||
name: Name of the root trace.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from agents import (
|
||||
Agent,
|
||||
FileSearchTool,
|
||||
Runner,
|
||||
WebSearchTool,
|
||||
function_tool,
|
||||
set_trace_processors,
|
||||
)
|
||||
|
||||
from langsmith.wrappers import OpenAIAgentsTracingProcessor
|
||||
|
||||
set_trace_processors([OpenAIAgentsTracingProcessor()])
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny"
|
||||
|
||||
|
||||
haiku_agent = Agent(
|
||||
name="Haiku agent",
|
||||
instructions="Always respond in haiku form",
|
||||
model="o3-mini",
|
||||
tools=[get_weather],
|
||||
)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
tools=[WebSearchTool()],
|
||||
instructions="speak in spanish. use Haiku agent if they ask for a haiku or for the weather",
|
||||
handoffs=[haiku_agent],
|
||||
)
|
||||
|
||||
result = await Runner.run(
|
||||
agent,
|
||||
"write a haiku about the weather today and tell me a recent news story about new york",
|
||||
)
|
||||
print(result.final_output)
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[ls_client.Client] = None,
|
||||
*,
|
||||
metadata: Optional[dict] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
project_name: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
self.client = client or rt.get_cached_client()
|
||||
self._metadata = metadata
|
||||
self._tags = tags
|
||||
self._project_name = project_name
|
||||
self._name = name
|
||||
self._first_response_inputs: dict = {}
|
||||
self._last_response_outputs: dict = {}
|
||||
|
||||
self._runs: dict[str, rt.RunTree] = {}
|
||||
self._unposted_traces: set[str] = set()
|
||||
self._unposted_spans: set[str] = set()
|
||||
|
||||
def on_trace_start(self, trace: tracing.Trace) -> None:
|
||||
current_run_tree = get_current_run_tree()
|
||||
|
||||
# Determine run name
|
||||
if self._name:
|
||||
run_name = self._name
|
||||
elif trace.name:
|
||||
run_name = trace.name
|
||||
else:
|
||||
run_name = "Agent workflow"
|
||||
|
||||
# Build metadata
|
||||
run_extra = {
|
||||
"metadata": {
|
||||
**(self._metadata or {}),
|
||||
"ls_integration": "openai-agents-sdk",
|
||||
"ls_integration_version": _get_package_version("openai-agents"),
|
||||
}
|
||||
}
|
||||
trace_dict = trace.export() or {}
|
||||
if trace_dict.get("group_id") is not None:
|
||||
run_extra["metadata"]["thread_id"] = trace_dict["group_id"]
|
||||
|
||||
try:
|
||||
if current_run_tree is not None:
|
||||
# Nest under existing trace
|
||||
new_run = current_run_tree.create_child(
|
||||
name=run_name,
|
||||
run_type="chain",
|
||||
inputs={},
|
||||
extra=run_extra,
|
||||
tags=self._tags,
|
||||
)
|
||||
else:
|
||||
# Create new root trace
|
||||
run_kwargs = {
|
||||
"name": run_name,
|
||||
"run_type": "chain",
|
||||
"inputs": {},
|
||||
"extra": run_extra,
|
||||
"tags": self._tags,
|
||||
"client": self.client,
|
||||
}
|
||||
if self._project_name is not None:
|
||||
run_kwargs["project_name"] = self._project_name
|
||||
new_run = rt.RunTree(**run_kwargs) # type: ignore[arg-type]
|
||||
|
||||
# Delay posting until first response/generation span ends
|
||||
# so inputs can be included in the POST.
|
||||
self._unposted_traces.add(trace.trace_id)
|
||||
_context._PARENT_RUN_TREE.set(new_run)
|
||||
self._runs[trace.trace_id] = new_run
|
||||
except Exception as e:
|
||||
logger.exception(f"Error creating trace run: {e}")
|
||||
|
||||
def on_trace_end(self, trace: tracing.Trace) -> None:
|
||||
run = self._runs.pop(trace.trace_id, None)
|
||||
if not run:
|
||||
return
|
||||
|
||||
trace_dict = trace.export() or {}
|
||||
metadata = {**(trace_dict.get("metadata") or {}), **(self._metadata or {})}
|
||||
|
||||
try:
|
||||
# Update run with final inputs/outputs
|
||||
run.outputs = self._last_response_outputs.pop(trace.trace_id, {})
|
||||
|
||||
# Update metadata
|
||||
if "metadata" not in run.extra:
|
||||
run.extra["metadata"] = {}
|
||||
run.extra["metadata"].update(metadata)
|
||||
|
||||
# End and patch
|
||||
run.end()
|
||||
|
||||
if trace.trace_id in self._unposted_traces:
|
||||
# No response/generation spans ended, post now
|
||||
run.inputs = self._first_response_inputs.pop(trace.trace_id, {})
|
||||
self._unposted_traces.discard(trace.trace_id)
|
||||
run.post()
|
||||
else:
|
||||
self._first_response_inputs.pop(trace.trace_id, None)
|
||||
run.patch(exclude_inputs=True)
|
||||
|
||||
_context._PARENT_RUN_TREE.set(run.parent_run)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating trace run: {e}")
|
||||
|
||||
def on_span_start(self, span: tracing.Span) -> None:
|
||||
# Find parent run
|
||||
parent_run = (
|
||||
self._runs.get(span.parent_id)
|
||||
if span.parent_id
|
||||
else self._runs.get(span.trace_id)
|
||||
)
|
||||
|
||||
if parent_run is None:
|
||||
logger.warning(
|
||||
f"No trace info found for span, skipping: {span.span_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Extract span data
|
||||
run_name = agent_utils.get_run_name(span)
|
||||
if isinstance(span.span_data, tracing.ResponseSpanData):
|
||||
parent_name = parent_run.name
|
||||
raw_span_name = getattr(span, "name", None) or getattr(
|
||||
span.span_data, "name", None
|
||||
)
|
||||
span_name = str(raw_span_name) if raw_span_name else run_name
|
||||
if parent_name:
|
||||
run_name = f"{parent_name} {span_name}".strip()
|
||||
else:
|
||||
run_name = span_name
|
||||
|
||||
run_type = agent_utils.get_run_type(span)
|
||||
extracted = agent_utils.extract_span_data(span)
|
||||
|
||||
try:
|
||||
# Create child run
|
||||
child_run = parent_run.create_child(
|
||||
name=run_name,
|
||||
run_type=run_type,
|
||||
inputs=extracted.get("inputs", {}),
|
||||
extra=extracted,
|
||||
start_time=datetime.fromisoformat(span.started_at)
|
||||
if span.started_at
|
||||
else None,
|
||||
)
|
||||
|
||||
# Delay posting for spans whose inputs aren't available at start
|
||||
if isinstance(
|
||||
span.span_data,
|
||||
(
|
||||
tracing.GenerationSpanData,
|
||||
tracing.ResponseSpanData,
|
||||
tracing.FunctionSpanData,
|
||||
),
|
||||
):
|
||||
self._unposted_spans.add(span.span_id)
|
||||
else:
|
||||
child_run.post()
|
||||
self._runs[span.span_id] = child_run
|
||||
except Exception as e:
|
||||
logger.exception(f"Error creating span run: {e}")
|
||||
|
||||
def on_span_end(self, span: tracing.Span) -> None:
|
||||
run = self._runs.pop(span.span_id, None)
|
||||
if not run:
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract outputs and metadata
|
||||
extracted = agent_utils.extract_span_data(span)
|
||||
outputs = extracted.pop("outputs", {})
|
||||
inputs = extracted.pop("inputs", {})
|
||||
|
||||
# Update run
|
||||
run.outputs = outputs
|
||||
if inputs:
|
||||
run.inputs = inputs
|
||||
if error := span.error:
|
||||
run.error = str(error)
|
||||
|
||||
# Add OpenAI metadata
|
||||
if "metadata" not in run.extra:
|
||||
run.extra["metadata"] = {}
|
||||
run.extra["metadata"].update(
|
||||
{
|
||||
"openai_parent_id": span.parent_id,
|
||||
"openai_trace_id": span.trace_id,
|
||||
"openai_span_id": span.span_id,
|
||||
}
|
||||
)
|
||||
if metadata := extracted.get("metadata"):
|
||||
run.extra["metadata"].update(metadata)
|
||||
if invocation_params := extracted.get("invocation_params"):
|
||||
run.extra["invocation_params"] = invocation_params
|
||||
|
||||
if isinstance(span.span_data, tracing.ResponseSpanData):
|
||||
self._first_response_inputs[span.trace_id] = (
|
||||
self._first_response_inputs.get(span.trace_id) or inputs
|
||||
)
|
||||
self._last_response_outputs[span.trace_id] = outputs
|
||||
self._maybe_post_trace(span.trace_id, inputs)
|
||||
elif isinstance(span.span_data, tracing.GenerationSpanData):
|
||||
self._first_response_inputs[span.trace_id] = (
|
||||
self._first_response_inputs.get(span.trace_id) or inputs
|
||||
)
|
||||
self._last_response_outputs[span.trace_id] = outputs
|
||||
self._maybe_post_trace(span.trace_id, inputs)
|
||||
|
||||
if span.ended_at:
|
||||
run.end_time = datetime.fromisoformat(span.ended_at)
|
||||
else:
|
||||
run.end()
|
||||
|
||||
if span.span_id in self._unposted_spans:
|
||||
self._unposted_spans.discard(span.span_id)
|
||||
run.post()
|
||||
else:
|
||||
run.patch(exclude_inputs=True)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating span run: {e}")
|
||||
|
||||
def _maybe_post_trace(self, trace_id: str, inputs: dict) -> None:
|
||||
"""Post the trace if it hasn't been posted yet."""
|
||||
if trace_id in self._unposted_traces:
|
||||
trace_run = self._runs.get(trace_id)
|
||||
if trace_run:
|
||||
trace_run.inputs = inputs
|
||||
trace_run.post()
|
||||
self._unposted_traces.discard(trace_id)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.client.flush()
|
||||
|
||||
def force_flush(self) -> None:
|
||||
self.client.flush()
|
||||
114
venv/Lib/site-packages/langsmith/integrations/otel/__init__.py
Normal file
114
venv/Lib/site-packages/langsmith/integrations/otel/__init__.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""OpenTelemetry integration for LangSmith."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from langsmith import utils as ls_utils
|
||||
|
||||
from .processor import OtelExporter, OtelSpanProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["configure", "OtelSpanProcessor", "OtelExporter"]
|
||||
|
||||
|
||||
def configure(
|
||||
api_key: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
SpanProcessor: Optional[type] = None,
|
||||
) -> bool:
|
||||
"""Configure OpenTelemetry with LangSmith as the `TracerProvider`.
|
||||
|
||||
Initializes OpenTelemetry with LangSmith as the primary and only `TracerProvider`.
|
||||
|
||||
Usage:
|
||||
>>> from langsmith.integrations.otel import configure
|
||||
>>> configure( # doctest: +SKIP
|
||||
... api_key="your-api-key", project_name="your-project"
|
||||
... )
|
||||
|
||||
Using environment variables:
|
||||
>>> # Set LANGSMITH_API_KEY and LANGSMITH_PROJECT
|
||||
>>> configure() # Will use env vars # doctest: +SKIP
|
||||
|
||||
!!! warning
|
||||
|
||||
This function is only for when LangSmith is your ONLY OpenTelemetry source.
|
||||
|
||||
It sets the global TracerProvider, which can only be done once per application.
|
||||
|
||||
This function will fail if OpenTelemetry is already initialized with another
|
||||
`TracerProvider` (you cannot override an existing `TracerProvider`).
|
||||
|
||||
If you already have OpenTelemetry set up with other tools, use `OtelSpanProcessor`
|
||||
directly to add LangSmith to your existing setup:
|
||||
|
||||
!!! example "Adding LangSmith to existing OTEL setup"
|
||||
```python
|
||||
from opentelemetry import trace
|
||||
from langsmith.integrations.otel.processor import OtelSpanProcessor
|
||||
|
||||
# Use your existing provider (already initialized)
|
||||
provider = trace.get_tracer_provider()
|
||||
|
||||
# Add LangSmith processor to existing provider
|
||||
langsmith_processor = OtelSpanProcessor(
|
||||
api_key="your-api-key", project="your-project"
|
||||
)
|
||||
provider.add_span_processor(langsmith_processor)
|
||||
```
|
||||
|
||||
Args:
|
||||
api_key: LangSmith API key. Defaults to `LANGSMITH_API_KEY` env var.
|
||||
project_name: Project name. Defaults to `LANGSMITH_PROJECT` env var.
|
||||
SpanProcessor: Span processor class to use. Defaults to `BatchSpanProcessor`.
|
||||
|
||||
Returns:
|
||||
`True` if configuration succeeded, `False` if `TracerProvider` already exists.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.trace import NoOpTracer, ProxyTracer, ProxyTracerProvider
|
||||
|
||||
existing_provider = cast(TracerProvider, trace.get_tracer_provider())
|
||||
tracer = existing_provider.get_tracer(__name__)
|
||||
|
||||
# Check if OpenTelemetry is in its default uninitialized state
|
||||
# (ProxyTracerProvider with NoOpTracer means no real TracerProvider was set)
|
||||
if (
|
||||
isinstance(existing_provider, ProxyTracerProvider)
|
||||
and hasattr(tracer, "_tracer")
|
||||
and isinstance(
|
||||
cast(
|
||||
ProxyTracer, # type: ignore[attr-defined, name-defined]
|
||||
tracer,
|
||||
)._tracer,
|
||||
NoOpTracer,
|
||||
)
|
||||
):
|
||||
# Safe to set TracerProvider since none exists yet
|
||||
provider = TracerProvider()
|
||||
trace.set_tracer_provider(provider)
|
||||
else:
|
||||
logger.warning(
|
||||
"OpenTelemetry TracerProvider is already set. "
|
||||
"Cannot override existing TracerProvider. Use OtelSpanProcessor "
|
||||
"directly to add LangSmith to your existing provider instead."
|
||||
)
|
||||
return False
|
||||
|
||||
api_key = api_key or ls_utils.get_api_key(None)
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
project_name = project_name or ls_utils.get_tracer_project()
|
||||
|
||||
processor = OtelSpanProcessor(
|
||||
api_key=api_key, project=project_name, SpanProcessor=SpanProcessor
|
||||
)
|
||||
provider.add_span_processor(processor) # type: ignore
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize Otel for LangSmith:", e)
|
||||
return False
|
||||
Binary file not shown.
Binary file not shown.
222
venv/Lib/site-packages/langsmith/integrations/otel/processor.py
Normal file
222
venv/Lib/site-packages/langsmith/integrations/otel/processor.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""OpenTelemetry span processor and exporter for LangSmith."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langsmith import utils as ls_utils
|
||||
|
||||
try:
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
|
||||
OTEL_AVAILABLE = True
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"OpenTelemetry packages are not installed. "
|
||||
"Install optional OpenTelemetry dependencies with: "
|
||||
"pip install langsmith[otel]",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
class OTLPSpanExporter: # type: ignore[no-redef]
|
||||
"""Mock otlp span exporter class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Mock init method."""
|
||||
raise ImportError(
|
||||
"OpenTelemetry packages are not installed. "
|
||||
"Install optional OpenTelemetry dependencies with: "
|
||||
"pip install langsmith[otel]"
|
||||
)
|
||||
|
||||
class BatchSpanProcessor: # type: ignore[no-redef]
|
||||
"""Mock batch span processor class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Mock init method."""
|
||||
raise ImportError(
|
||||
"OpenTelemetry packages are not installed. "
|
||||
"Install optional OpenTelemetry dependencies with: "
|
||||
"pip install langsmith[otel]"
|
||||
)
|
||||
|
||||
class trace:
|
||||
"""Mock trace class."""
|
||||
|
||||
@staticmethod
|
||||
def get_tracer_provider():
|
||||
"""Mock get tracer provider method."""
|
||||
raise ImportError(
|
||||
"OpenTelemetry packages are not installed. "
|
||||
"Install optional OpenTelemetry dependencies with: "
|
||||
"pip install langsmith[otel]"
|
||||
)
|
||||
|
||||
OTEL_AVAILABLE = False
|
||||
|
||||
|
||||
class OtelExporter(OTLPSpanExporter):
|
||||
"""A subclass of `OTLPSpanExporter` configured for LangSmith.
|
||||
|
||||
Environment Variables:
|
||||
|
||||
- `LANGSMITH_API_KEY`: Your LangSmith API key.
|
||||
- `LANGSMITH_ENDPOINT`: Base URL for LangSmith API (defaults to `https://api.smith.langchain.com`).
|
||||
- `LANGSMITH_PROJECT`: Project identifier.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the `OtelExporter`.
|
||||
|
||||
Args:
|
||||
url: OTLP endpoint URL. Defaults to `{LANGSMITH_ENDPOINT}/otel/v1/traces`.
|
||||
api_key: LangSmith API key. Defaults to `LANGSMITH_API_KEY` env var.
|
||||
parent: Parent identifier (e.g., `'project_name:test'`).
|
||||
|
||||
Defaults to `LANGSMITH_PARENT` env var.
|
||||
headers: Additional headers to include in requests.
|
||||
**kwargs: Additional arguments passed to `OTLPSpanExporter`.
|
||||
"""
|
||||
base_url = ls_utils.get_api_url(None)
|
||||
# Ensure base_url ends with / for proper joining
|
||||
if not base_url.endswith("/"):
|
||||
base_url += "/"
|
||||
endpoint = url or urljoin(base_url, "otel/v1/traces")
|
||||
api_key = api_key or ls_utils.get_api_key(None)
|
||||
project = project or ls_utils.get_tracer_project()
|
||||
headers = headers or {}
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is required. Provide it via api_key parameter or "
|
||||
"LANGSMITH_API_KEY environment variable."
|
||||
)
|
||||
|
||||
if not project:
|
||||
project = "default"
|
||||
logging.info(
|
||||
"No project specified, using default. "
|
||||
"Configure with LANGSMITH_PROJECT environment variable or "
|
||||
"project parameter."
|
||||
)
|
||||
|
||||
exporter_headers = {
|
||||
"x-api-key": api_key,
|
||||
**headers,
|
||||
}
|
||||
|
||||
if project:
|
||||
exporter_headers["Langsmith-Project"] = project
|
||||
|
||||
self.project = project
|
||||
|
||||
super().__init__(endpoint=endpoint, headers=exporter_headers, **kwargs)
|
||||
|
||||
|
||||
class OtelSpanProcessor:
|
||||
"""A span processor for adding LangSmith to OpenTelemetry setups.
|
||||
|
||||
This class combines the `OtelExporter` and `BatchSpanProcessor`
|
||||
into a single processor that can be added to any `TracerProvider`.
|
||||
|
||||
Use this when:
|
||||
|
||||
1. You already have OpenTelemetry initialized with other tools
|
||||
2. You want to add LangSmith alongside existing OTEL exporters
|
||||
|
||||
Examples:
|
||||
# Fresh OpenTelemetry setup (LangSmith only):
|
||||
from langsmith.integrations.otel import configure
|
||||
configure(api_key="your-key", project="your-project")
|
||||
|
||||
# Add LangSmith to existing OpenTelemetry setup:
|
||||
from opentelemetry import trace
|
||||
from langsmith.integrations.otel.processor import OtelSpanProcessor
|
||||
|
||||
# Get your existing TracerProvider (already set by other tools)
|
||||
provider = trace.get_tracer_provider()
|
||||
|
||||
# Add LangSmith processor alongside existing processors
|
||||
langsmith_processor = OtelSpanProcessor(
|
||||
project="your-project",
|
||||
)
|
||||
provider.add_span_processor(langsmith_processor)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
SpanProcessor: Optional[type] = None,
|
||||
):
|
||||
"""Initialize the `OtelSpanProcessor`.
|
||||
|
||||
Args:
|
||||
api_key: LangSmith API key. Defaults to `LANGSMITH_API_KEY` env var.
|
||||
project: Project identifier. Defaults to `LANGSMITH_PROJECT` env var.
|
||||
url: Base URL for LangSmith API. Defaults to `LANGSMITH_ENDPOINT` env var
|
||||
or `https://api.smith.langchain.com`.
|
||||
headers: Additional headers to include in requests.
|
||||
SpanProcessor: Optional span processor class. Defaults to
|
||||
`BatchSpanProcessor`.
|
||||
"""
|
||||
# Create the exporter
|
||||
# Convert url to the full endpoint URL that OtelExporter expects
|
||||
exporter_url = None
|
||||
if url:
|
||||
exporter_url = f"{url.rstrip('/')}/otel/v1/traces"
|
||||
|
||||
self._exporter = OtelExporter(
|
||||
url=exporter_url, api_key=api_key, project=project, headers=headers
|
||||
)
|
||||
|
||||
# Create the processor chain
|
||||
if not OTEL_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenTelemetry packages are not installed. "
|
||||
"Install optional OpenTelemetry dependencies with: "
|
||||
"pip install langsmith[otel]"
|
||||
)
|
||||
|
||||
if SpanProcessor is None:
|
||||
SpanProcessor = BatchSpanProcessor
|
||||
|
||||
self._processor = SpanProcessor(self._exporter)
|
||||
|
||||
def on_start(self, span, parent_context=None):
|
||||
"""Forward span start events to the inner processor."""
|
||||
self._processor.on_start(span, parent_context)
|
||||
|
||||
def on_end(self, span):
|
||||
"""Forward span end events to the inner processor."""
|
||||
self._processor.on_end(span)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown processor."""
|
||||
self._processor.shutdown()
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force flush the inner processor."""
|
||||
return self._processor.force_flush(timeout_millis)
|
||||
|
||||
@property
|
||||
def exporter(self):
|
||||
"""The underlying OtelExporter."""
|
||||
return self._exporter
|
||||
|
||||
@property
|
||||
def processor(self):
|
||||
"""The underlying span processor."""
|
||||
return self._processor
|
||||
Reference in New Issue
Block a user