initial commit
This commit is contained in:
923
venv/Lib/site-packages/langsmith/utils.py
Normal file
923
venv/Lib/site-packages/langsmith/utils.py
Normal file
@@ -0,0 +1,923 @@
|
||||
"""Generic utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import contextvars
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from urllib import parse as urllib_parse
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from typing_extensions import ParamSpec
|
||||
from urllib3.util import Retry # type: ignore[import-untyped]
|
||||
|
||||
from langsmith import schemas as ls_schemas
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LangSmithError(Exception):
|
||||
"""An error occurred while communicating with the LangSmith API."""
|
||||
|
||||
|
||||
class LangSmithAPIError(LangSmithError):
|
||||
"""Internal server error while communicating with LangSmith."""
|
||||
|
||||
|
||||
class LangSmithRequestTimeout(LangSmithError):
|
||||
"""Client took too long to send request body."""
|
||||
|
||||
|
||||
class LangSmithUserError(LangSmithError):
|
||||
"""User error caused an exception when communicating with LangSmith."""
|
||||
|
||||
|
||||
class LangSmithRateLimitError(LangSmithError):
|
||||
"""You have exceeded the rate limit for the LangSmith API."""
|
||||
|
||||
|
||||
class LangSmithAuthError(LangSmithError):
|
||||
"""Couldn't authenticate with the LangSmith API."""
|
||||
|
||||
|
||||
class LangSmithNotFoundError(LangSmithError):
|
||||
"""Couldn't find the requested resource."""
|
||||
|
||||
|
||||
class LangSmithConflictError(LangSmithError):
|
||||
"""The resource already exists."""
|
||||
|
||||
|
||||
class LangSmithConnectionError(LangSmithError):
|
||||
"""Couldn't connect to the LangSmith API."""
|
||||
|
||||
|
||||
class LangSmithExceptionGroup(LangSmithError):
|
||||
"""Port of ExceptionGroup for Py < 3.11."""
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, exceptions: Sequence[Exception], **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exceptions = exceptions
|
||||
|
||||
|
||||
def get_invalid_prompt_identifier_msg(identifier: str) -> str:
|
||||
"""Get the error message for an invalid prompt identifier.
|
||||
|
||||
Used consistently across the codebase when parsing prompt identifiers fails.
|
||||
|
||||
Args:
|
||||
identifier: The invalid identifier that was provided.
|
||||
|
||||
Returns:
|
||||
A formatted error message explaining the valid formats.
|
||||
"""
|
||||
return (
|
||||
f'Invalid prompt identifier format: "{identifier}". '
|
||||
f"Expected one of:\n"
|
||||
f' - "prompt-name" (for private prompts)\n'
|
||||
f' - "owner/prompt-name" (for prompts with explicit owner)\n'
|
||||
f' - "prompt-name:commit-hash" (with commit reference)\n'
|
||||
f' - "owner/prompt-name:commit-hash" (with owner and commit)'
|
||||
)
|
||||
|
||||
|
||||
## Warning classes
|
||||
|
||||
|
||||
class LangSmithWarning(UserWarning):
|
||||
"""Base class for warnings."""
|
||||
|
||||
|
||||
class LangSmithMissingAPIKeyWarning(LangSmithWarning):
|
||||
"""Warning for missing API key."""
|
||||
|
||||
|
||||
def tracing_is_enabled(ctx: Optional[dict] = None) -> Union[bool, Literal["local"]]:
|
||||
"""Return True if tracing is enabled."""
|
||||
# Access global fallbacks via context module to avoid stale references.
|
||||
import langsmith._internal._context as _context
|
||||
from langsmith.run_helpers import get_current_run_tree, get_tracing_context
|
||||
|
||||
tc = ctx or get_tracing_context()
|
||||
# You can manually override the environment using context vars.
|
||||
# Check that first.
|
||||
# Doing this before checking the run tree lets us
|
||||
# disable a branch within a trace.
|
||||
if tc["enabled"] is not None:
|
||||
return tc["enabled"]
|
||||
# Next check if we're mid-trace
|
||||
if get_current_run_tree():
|
||||
return True
|
||||
# If a global fallback was configured, use it next.
|
||||
if _context._GLOBAL_TRACING_ENABLED is not None:
|
||||
return _context._GLOBAL_TRACING_ENABLED
|
||||
# Finally, check the global environment
|
||||
var_result = get_env_var("TRACING_V2", default=get_env_var("TRACING", default=""))
|
||||
return var_result == "true"
|
||||
|
||||
|
||||
def test_tracking_is_disabled() -> bool:
|
||||
"""Return True if testing is enabled."""
|
||||
return get_env_var("TEST_TRACKING", default="") == "false"
|
||||
|
||||
|
||||
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Validate exactly one arg in each group is not None."""
|
||||
counts = [
|
||||
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
||||
for arg_group in arg_groups
|
||||
]
|
||||
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
||||
if invalid_groups:
|
||||
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
||||
raise ValueError(
|
||||
"Exactly one argument in each of the following"
|
||||
" groups must be defined:"
|
||||
f" {', '.join(invalid_group_names)}"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_for_status_with_text(
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
) -> None:
|
||||
"""Raise an error with the response text."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
raise requests.HTTPError(str(e), response.text) from e # type: ignore[call-arg]
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise httpx.HTTPStatusError(
|
||||
f"{str(e)}: {response.text}",
|
||||
request=response.request, # type: ignore[arg-type]
|
||||
response=response, # type: ignore[arg-type]
|
||||
) from e
|
||||
|
||||
|
||||
def get_enum_value(enu: Union[enum.Enum, str]) -> str:
|
||||
"""Get the value of a string enum."""
|
||||
if isinstance(enu, enum.Enum):
|
||||
return enu.value
|
||||
return enu
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def log_once(level: int, message: str) -> None:
|
||||
"""Log a message at the specified level, but only once."""
|
||||
_LOGGER.log(level, message)
|
||||
|
||||
|
||||
def _get_message_type(message: Mapping[str, Any]) -> str:
|
||||
if not message:
|
||||
raise ValueError("Message is empty.")
|
||||
if "lc" in message:
|
||||
if "id" not in message:
|
||||
raise ValueError(
|
||||
f"Unexpected format for serialized message: {message}"
|
||||
" Message does not have an id."
|
||||
)
|
||||
return message["id"][-1].replace("Message", "").lower()
|
||||
else:
|
||||
if "type" not in message:
|
||||
raise ValueError(
|
||||
f"Unexpected format for stored message: {message}"
|
||||
" Message does not have a type."
|
||||
)
|
||||
return message["type"]
|
||||
|
||||
|
||||
def _get_message_fields(message: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
if not message:
|
||||
raise ValueError("Message is empty.")
|
||||
if "lc" in message:
|
||||
if "kwargs" not in message:
|
||||
raise ValueError(
|
||||
f"Unexpected format for serialized message: {message}"
|
||||
" Message does not have kwargs."
|
||||
)
|
||||
return message["kwargs"]
|
||||
else:
|
||||
if "data" not in message:
|
||||
raise ValueError(
|
||||
f"Unexpected format for stored message: {message}"
|
||||
" Message does not have data."
|
||||
)
|
||||
return message["data"]
|
||||
|
||||
|
||||
def _convert_message(message: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Extract message from a message object."""
|
||||
message_type = _get_message_type(message)
|
||||
message_data = _get_message_fields(message)
|
||||
return {"type": message_type, "data": message_data}
|
||||
|
||||
|
||||
def get_messages_from_inputs(inputs: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract messages from the given inputs dictionary.
|
||||
|
||||
Args:
|
||||
inputs: The inputs dictionary.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries representing the extracted messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If no message(s) are found in the inputs dictionary.
|
||||
"""
|
||||
if "messages" in inputs:
|
||||
return [_convert_message(message) for message in inputs["messages"]]
|
||||
if "message" in inputs:
|
||||
return [_convert_message(inputs["message"])]
|
||||
raise ValueError(f"Could not find message(s) in run with inputs {inputs}.")
|
||||
|
||||
|
||||
def get_message_generation_from_outputs(outputs: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Retrieve the message generation from the given outputs.
|
||||
|
||||
Args:
|
||||
outputs: The outputs dictionary.
|
||||
|
||||
Returns:
|
||||
The message generation.
|
||||
|
||||
Raises:
|
||||
ValueError: If no generations are found or if multiple generations are present.
|
||||
"""
|
||||
if "generations" not in outputs:
|
||||
raise ValueError(f"No generations found in in run with output: {outputs}.")
|
||||
generations = outputs["generations"]
|
||||
if len(generations) != 1:
|
||||
raise ValueError(
|
||||
"Chat examples expect exactly one generation."
|
||||
f" Found {len(generations)} generations: {generations}."
|
||||
)
|
||||
first_generation = generations[0]
|
||||
if "message" not in first_generation:
|
||||
raise ValueError(
|
||||
f"Unexpected format for generation: {first_generation}."
|
||||
" Generation does not have a message."
|
||||
)
|
||||
return _convert_message(first_generation["message"])
|
||||
|
||||
|
||||
def get_prompt_from_inputs(inputs: Mapping[str, Any]) -> str:
|
||||
"""Retrieve the prompt from the given inputs.
|
||||
|
||||
Args:
|
||||
inputs: The inputs dictionary.
|
||||
|
||||
Returns:
|
||||
str: The prompt.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prompt is not found or if multiple prompts are present.
|
||||
"""
|
||||
if "prompt" in inputs:
|
||||
return inputs["prompt"]
|
||||
if "prompts" in inputs:
|
||||
prompts = inputs["prompts"]
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
raise ValueError(
|
||||
f"Multiple prompts in run with inputs {inputs}."
|
||||
" Please create example manually."
|
||||
)
|
||||
raise ValueError(f"Could not find prompt in run with inputs {inputs}.")
|
||||
|
||||
|
||||
def get_llm_generation_from_outputs(outputs: Mapping[str, Any]) -> str:
|
||||
"""Get the LLM generation from the outputs."""
|
||||
if "generations" not in outputs:
|
||||
raise ValueError(f"No generations found in in run with output: {outputs}.")
|
||||
generations = outputs["generations"]
|
||||
if len(generations) != 1:
|
||||
raise ValueError(f"Multiple generations in run: {generations}")
|
||||
first_generation = generations[0]
|
||||
if "text" not in first_generation:
|
||||
raise ValueError(f"No text in generation: {first_generation}")
|
||||
return first_generation["text"]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_docker_compose_command() -> list[str]:
|
||||
"""Get the correct docker compose command for this system."""
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["docker", "compose", "--version"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return ["docker", "compose"]
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["docker-compose", "--version"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return ["docker-compose"]
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
raise ValueError(
|
||||
"Neither 'docker compose' nor 'docker-compose'"
|
||||
" commands are available. Please install the Docker"
|
||||
" server following the instructions for your operating"
|
||||
" system at https://docs.docker.com/engine/install/"
|
||||
)
|
||||
|
||||
|
||||
def convert_langchain_message(message: ls_schemas.BaseMessageLike) -> dict:
|
||||
"""Convert a LangChain message to an example."""
|
||||
converted: dict[str, Any] = {
|
||||
"type": message.type,
|
||||
"data": {"content": message.content},
|
||||
}
|
||||
# Check for presence of keys in additional_kwargs
|
||||
if message.additional_kwargs and len(message.additional_kwargs) > 0:
|
||||
converted["data"]["additional_kwargs"] = {**message.additional_kwargs}
|
||||
return converted
|
||||
|
||||
|
||||
def is_base_message_like(obj: object) -> bool:
|
||||
"""Check if the given object is similar to `BaseMessage`.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the object is similar to `BaseMessage`, `False` otherwise.
|
||||
"""
|
||||
return all(
|
||||
[
|
||||
isinstance(getattr(obj, "content", None), str),
|
||||
isinstance(getattr(obj, "additional_kwargs", None), dict),
|
||||
hasattr(obj, "type") and isinstance(getattr(obj, "type"), str),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_env_var_truish(value: Optional[str]) -> bool:
|
||||
"""Check if the given environment variable is truish."""
|
||||
if value is None:
|
||||
return False
|
||||
return is_truish(get_env_var(value))
|
||||
|
||||
|
||||
@overload
|
||||
def get_env_var(
|
||||
name: str,
|
||||
default: str,
|
||||
*,
|
||||
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||||
) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_env_var(
|
||||
name: str,
|
||||
default: None = None,
|
||||
*,
|
||||
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||||
) -> Optional[str]: ...
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=100)
|
||||
def get_env_var(
|
||||
name: str,
|
||||
default: Optional[str] = None,
|
||||
*,
|
||||
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||||
) -> Optional[str]:
|
||||
"""Retrieve an environment variable from a list of namespaces.
|
||||
|
||||
Args:
|
||||
name: The name of the environment variable.
|
||||
default: The default value to return if the environment variable is not found.
|
||||
namespaces: A tuple of namespaces to search for the environment variable.
|
||||
|
||||
Defaults to `('LANGSMITH', 'LANGCHAINs')`.
|
||||
|
||||
Returns:
|
||||
The value of the environment variable if found, otherwise the default value.
|
||||
"""
|
||||
names = [f"{namespace}_{name}" for namespace in namespaces]
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value is not None and value.strip() != "":
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_tracer_project(return_default_value=True) -> Optional[str]:
|
||||
"""Get the project name for a LangSmith tracer."""
|
||||
return os.environ.get(
|
||||
# Hosted LangServe projects get precedence over all other defaults.
|
||||
# This is to make sure that we always use the associated project
|
||||
# for a hosted langserve deployment even if the customer sets some
|
||||
# other project name in their environment.
|
||||
"HOSTED_LANGSERVE_PROJECT_NAME",
|
||||
get_env_var(
|
||||
"PROJECT",
|
||||
# This is the legacy name for a LANGCHAIN_PROJECT, so it
|
||||
# has lower precedence than LANGCHAIN_PROJECT
|
||||
default=get_env_var(
|
||||
"SESSION", default="default" if return_default_value else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FilterPoolFullWarning(logging.Filter):
|
||||
"""Filter `urllib3` warnings logged when the connection pool isn't reused."""
|
||||
|
||||
def __init__(self, name: str = "", host: str = "") -> None:
|
||||
"""Initialize the `FilterPoolFullWarning` filter.
|
||||
|
||||
Args:
|
||||
name: The name of the filter. Defaults to `""`.
|
||||
host: The host to filter. Defaults to `""`.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._host = host
|
||||
|
||||
def filter(self, record) -> bool:
|
||||
"""urllib3.connectionpool:Connection pool is full, discarding connection: ..."""
|
||||
msg = record.getMessage()
|
||||
if "Connection pool is full, discarding connection" not in msg:
|
||||
return True
|
||||
return self._host not in msg
|
||||
|
||||
|
||||
class FilterLangSmithRetry(logging.Filter):
|
||||
"""Filter for retries from this lib."""
|
||||
|
||||
def filter(self, record) -> bool:
|
||||
"""Filter retries from this library."""
|
||||
# We re-raise/log manually.
|
||||
msg = record.getMessage()
|
||||
return "LangSmithRetry" not in msg
|
||||
|
||||
|
||||
class LangSmithRetry(Retry):
|
||||
"""Wrapper to filter logs with this name."""
|
||||
|
||||
|
||||
_FILTER_LOCK = threading.RLock()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def filter_logs(
|
||||
logger: logging.Logger, filters: Sequence[logging.Filter]
|
||||
) -> Generator[None, None, None]:
|
||||
"""Temporarily adds specified filters to a logger.
|
||||
|
||||
Parameters:
|
||||
- logger: The logger to which the filters will be added.
|
||||
- filters: A sequence of `logging.Filter` objects to be temporarily added
|
||||
to the logger.
|
||||
"""
|
||||
with _FILTER_LOCK:
|
||||
for filter in filters:
|
||||
logger.addFilter(filter)
|
||||
# Not actually perfectly thread-safe, but it's only log filters
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with _FILTER_LOCK:
|
||||
for filter in filters:
|
||||
try:
|
||||
logger.removeFilter(filter)
|
||||
except BaseException:
|
||||
_LOGGER.warning("Failed to remove filter")
|
||||
|
||||
|
||||
def get_cache_dir(cache: Optional[str]) -> Optional[str]:
|
||||
"""Get the testing cache directory.
|
||||
|
||||
Args:
|
||||
cache: The cache path.
|
||||
|
||||
Returns:
|
||||
The cache path if provided, otherwise the value from the `LANGSMITH_TEST_CACHE`
|
||||
environment variable.
|
||||
"""
|
||||
if cache is not None:
|
||||
return cache
|
||||
return get_env_var("TEST_CACHE", default=None)
|
||||
|
||||
|
||||
def filter_request_headers(
|
||||
request: Any,
|
||||
*,
|
||||
ignore_hosts: Optional[Sequence[str]] = None,
|
||||
allow_hosts: Optional[Sequence[str]] = None,
|
||||
) -> Any:
|
||||
"""Filter request headers based on `ignore_hosts` and `allow_hosts`."""
|
||||
# Legacy behavior
|
||||
if ignore_hosts and any(request.url.startswith(host) for host in ignore_hosts):
|
||||
return None
|
||||
|
||||
if allow_hosts:
|
||||
try:
|
||||
parsed_url = urllib_parse.urlparse(request.url)
|
||||
except Exception:
|
||||
# If URL parsing fails, don't cache to be safe
|
||||
return None
|
||||
request_host = parsed_url.hostname or ""
|
||||
# Check if request matches any allowed host
|
||||
host_matches = any(
|
||||
# Handle both full URLs (https://api.openai.com)
|
||||
# and hostnames (api.openai.com)
|
||||
(
|
||||
request.url.startswith(host)
|
||||
if host.startswith(("http://", "https://"))
|
||||
else request_host == host or request_host.endswith(f".{host}")
|
||||
)
|
||||
for host in allow_hosts
|
||||
)
|
||||
if not host_matches:
|
||||
return None
|
||||
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_cache(
|
||||
path: Union[str, pathlib.Path],
|
||||
ignore_hosts: Optional[Sequence[str]] = None,
|
||||
allow_hosts: Optional[Sequence[str]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Use a cache for requests."""
|
||||
try:
|
||||
import vcr # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"vcrpy is required to use caching. Install with:"
|
||||
'pip install -U "langsmith[vcr]"'
|
||||
)
|
||||
# Fix concurrency issue in vcrpy's patching
|
||||
from langsmith._internal import _patch as patch_urllib3
|
||||
|
||||
patch_urllib3.patch_urllib3()
|
||||
|
||||
cache_dir, cache_file = os.path.split(path)
|
||||
|
||||
ls_vcr = vcr.VCR(
|
||||
serializer=(
|
||||
"yaml"
|
||||
if cache_file.endswith(".yaml") or cache_file.endswith(".yml")
|
||||
else "json"
|
||||
),
|
||||
cassette_library_dir=cache_dir,
|
||||
# Replay previous requests, record new ones
|
||||
# TODO: Support other modes
|
||||
record_mode="new_episodes",
|
||||
match_on=["uri", "method", "path", "body"],
|
||||
filter_headers=["authorization", "Set-Cookie"],
|
||||
before_record_request=lambda request: filter_request_headers(
|
||||
request, ignore_hosts=ignore_hosts, allow_hosts=allow_hosts
|
||||
),
|
||||
)
|
||||
with ls_vcr.use_cassette(cache_file):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_optional_cache(
|
||||
path: Optional[Union[str, pathlib.Path]],
|
||||
ignore_hosts: Optional[Sequence[str]] = None,
|
||||
allow_hosts: Optional[Sequence[str]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Use a cache for requests."""
|
||||
if path is not None:
|
||||
with with_cache(path, ignore_hosts, allow_hosts):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _format_exc() -> str:
|
||||
# Used internally to format exceptions without cluttering the traceback
|
||||
tb_lines = traceback.format_exception(*sys.exc_info())
|
||||
filtered_lines = [line for line in tb_lines if "langsmith/" not in line]
|
||||
return "".join(filtered_lines)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _middle_copy(
|
||||
val: T, memo: dict[int, Any], max_depth: int = 4, _depth: int = 0
|
||||
) -> T:
|
||||
cls = type(val)
|
||||
|
||||
copier = getattr(cls, "__deepcopy__", None)
|
||||
if copier is not None:
|
||||
try:
|
||||
return copier(memo)
|
||||
except BaseException:
|
||||
pass
|
||||
if _depth >= max_depth:
|
||||
return val
|
||||
if isinstance(val, dict):
|
||||
return { # type: ignore[return-value]
|
||||
_middle_copy(k, memo, max_depth, _depth + 1): _middle_copy(
|
||||
v, memo, max_depth, _depth + 1
|
||||
)
|
||||
for k, v in val.items()
|
||||
}
|
||||
if isinstance(val, list):
|
||||
return [_middle_copy(item, memo, max_depth, _depth + 1) for item in val] # type: ignore[return-value]
|
||||
if isinstance(val, tuple):
|
||||
return tuple(_middle_copy(item, memo, max_depth, _depth + 1) for item in val) # type: ignore[return-value]
|
||||
if isinstance(val, set):
|
||||
return {_middle_copy(item, memo, max_depth, _depth + 1) for item in val} # type: ignore[return-value]
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def deepish_copy(val: T) -> T:
|
||||
"""Deep copy a value with a compromise for uncopyable objects.
|
||||
|
||||
Args:
|
||||
val: The value to be deep copied.
|
||||
|
||||
Returns:
|
||||
The deep copied value.
|
||||
"""
|
||||
memo: dict[int, Any] = {}
|
||||
try:
|
||||
return copy.deepcopy(val, memo)
|
||||
except BaseException as e:
|
||||
# Generators, locks, etc. cannot be copied
|
||||
# and raise a TypeError (mentioning pickling, since the dunder methods)
|
||||
# are re-used for copying. We'll try to do a compromise and copy
|
||||
# what we can
|
||||
_LOGGER.debug("Failed to deepcopy input: %s", repr(e))
|
||||
return _middle_copy(val, memo)
|
||||
|
||||
|
||||
def is_version_greater_or_equal(current_version: str, target_version: str) -> bool:
|
||||
"""Check if the current version is greater or equal to the target version."""
|
||||
from packaging import version
|
||||
|
||||
current = version.parse(current_version)
|
||||
target = version.parse(target_version)
|
||||
return current >= target
|
||||
|
||||
|
||||
def parse_prompt_identifier(identifier: str) -> tuple[str, str, str]:
|
||||
"""Parse a string in the format of `owner/name:hash`, `name:hash`, `owner/name`, or `name`.
|
||||
|
||||
Args:
|
||||
identifier: The prompt identifier to parse.
|
||||
|
||||
Returns:
|
||||
A tuple containing `(owner, name, hash)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the identifier doesn't match the expected formats.
|
||||
""" # noqa: E501
|
||||
if (
|
||||
not identifier
|
||||
or identifier.count("/") > 1
|
||||
or identifier.startswith("/")
|
||||
or identifier.endswith("/")
|
||||
):
|
||||
raise ValueError(get_invalid_prompt_identifier_msg(identifier))
|
||||
|
||||
parts = identifier.split(":", 1)
|
||||
owner_name = parts[0]
|
||||
commit = parts[1] if len(parts) > 1 else "latest"
|
||||
|
||||
if "/" in owner_name:
|
||||
owner, name = owner_name.split("/", 1)
|
||||
if not owner or not name:
|
||||
raise ValueError(get_invalid_prompt_identifier_msg(identifier))
|
||||
return owner, name, commit
|
||||
else:
|
||||
if not owner_name:
|
||||
raise ValueError(get_invalid_prompt_identifier_msg(identifier))
|
||||
return "-", owner_name, commit
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||||
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||||
|
||||
def submit( # type: ignore[override]
|
||||
self,
|
||||
func: Callable[P, T],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Future[T]:
|
||||
"""Submit a function to the executor.
|
||||
|
||||
Args:
|
||||
func (Callable[..., T]): The function to submit.
|
||||
*args (Any): The positional arguments to the function.
|
||||
**kwargs (Any): The keyword arguments to the function.
|
||||
|
||||
Returns:
|
||||
Future[T]: The future for the function.
|
||||
"""
|
||||
return super().submit(
|
||||
cast(
|
||||
Callable[..., T],
|
||||
functools.partial(
|
||||
contextvars.copy_context().run, func, *args, **kwargs
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def map(
|
||||
self,
|
||||
fn: Callable[..., T],
|
||||
*iterables: Iterable[Any],
|
||||
timeout: Optional[float] = None,
|
||||
chunksize: int = 1,
|
||||
) -> Iterator[T]:
|
||||
"""Return an iterator equivalent to stdlib map.
|
||||
|
||||
Each function will receive its own copy of the context from the parent thread.
|
||||
|
||||
Args:
|
||||
fn: A callable that will take as many arguments as there are
|
||||
passed iterables.
|
||||
timeout: The maximum number of seconds to wait. If None, then there
|
||||
is no limit on the wait time.
|
||||
chunksize: The size of the chunks the iterable will be broken into
|
||||
before being passed to a child process. This argument is only
|
||||
used by ProcessPoolExecutor; it is ignored by
|
||||
ThreadPoolExecutor.
|
||||
|
||||
Returns:
|
||||
An iterator equivalent to: map(func, *iterables) but the calls may
|
||||
be evaluated out-of-order.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the entire result iterator could not be generated
|
||||
before the given timeout.
|
||||
Exception: If fn(*args) raises for any values.
|
||||
"""
|
||||
contexts = [contextvars.copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||||
|
||||
def _wrapped_fn(*args: Any) -> T:
|
||||
return contexts.pop().run(fn, *args)
|
||||
|
||||
return super().map(
|
||||
_wrapped_fn,
|
||||
*iterables,
|
||||
timeout=timeout,
|
||||
chunksize=chunksize,
|
||||
)
|
||||
|
||||
|
||||
def get_api_url(api_url: Optional[str]) -> str:
|
||||
"""Get the LangSmith API URL from the environment or the given value."""
|
||||
_api_url = api_url or cast(
|
||||
str,
|
||||
get_env_var(
|
||||
"ENDPOINT",
|
||||
default="https://api.smith.langchain.com",
|
||||
),
|
||||
)
|
||||
if not _api_url.strip():
|
||||
raise LangSmithUserError("LangSmith API URL cannot be empty")
|
||||
return _api_url.strip().strip('"').strip("'").rstrip("/")
|
||||
|
||||
|
||||
def get_api_key(api_key: Optional[str]) -> Optional[str]:
|
||||
"""Get the API key from the environment or the given value."""
|
||||
api_key_ = api_key if api_key is not None else get_env_var("API_KEY", default=None)
|
||||
if api_key_ is None or not api_key_.strip():
|
||||
return None
|
||||
return api_key_.strip().strip('"').strip("'")
|
||||
|
||||
|
||||
def get_workspace_id(workspace_id: Optional[str]) -> Optional[str]:
|
||||
"""Get workspace ID."""
|
||||
workspace_id_ = (
|
||||
workspace_id
|
||||
if workspace_id is not None
|
||||
else get_env_var("WORKSPACE_ID", default=None)
|
||||
)
|
||||
if workspace_id_ is None or not workspace_id_.strip():
|
||||
return None
|
||||
return workspace_id_.strip().strip('"').strip("'")
|
||||
|
||||
|
||||
def _is_localhost(url: str) -> bool:
|
||||
"""Check if the URL is localhost.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
The URL to check.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
bool
|
||||
True if the URL is localhost, False otherwise.
|
||||
"""
|
||||
try:
|
||||
netloc = urllib_parse.urlsplit(url).netloc.split(":")[0]
|
||||
ip = socket.gethostbyname(netloc)
|
||||
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
||||
except socket.gaierror:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=2)
|
||||
def get_host_url(web_url: Optional[str], api_url: str):
|
||||
"""Get the host URL based on the web URL or API URL."""
|
||||
if web_url:
|
||||
return web_url
|
||||
parsed_url = urllib_parse.urlparse(api_url)
|
||||
if _is_localhost(api_url):
|
||||
link = "http://localhost"
|
||||
elif str(parsed_url.path).endswith("/api"):
|
||||
new_path = str(parsed_url.path).rsplit("/api", 1)[0]
|
||||
link = urllib_parse.urlunparse(parsed_url._replace(path=new_path))
|
||||
elif str(parsed_url.path).endswith("/api/v1"):
|
||||
new_path = str(parsed_url.path).rsplit("/api/v1", 1)[0]
|
||||
link = urllib_parse.urlunparse(parsed_url._replace(path=new_path))
|
||||
elif str(parsed_url.netloc).startswith("eu."):
|
||||
link = "https://eu.smith.langchain.com"
|
||||
elif str(parsed_url.netloc).startswith("dev."):
|
||||
link = "https://dev.smith.langchain.com"
|
||||
elif str(parsed_url.netloc).startswith("beta."):
|
||||
link = "https://beta.smith.langchain.com"
|
||||
else:
|
||||
link = "https://smith.langchain.com"
|
||||
return link
|
||||
|
||||
|
||||
def _get_function_name(fn: Callable, depth: int = 0) -> str:
|
||||
if depth > 2 or not callable(fn):
|
||||
return str(fn)
|
||||
|
||||
if hasattr(fn, "__name__"):
|
||||
return fn.__name__
|
||||
|
||||
if isinstance(fn, functools.partial):
|
||||
return _get_function_name(fn.func, depth + 1)
|
||||
|
||||
if hasattr(fn, "__call__"):
|
||||
if hasattr(fn, "__class__") and hasattr(fn.__class__, "__name__"):
|
||||
return fn.__class__.__name__
|
||||
return _get_function_name(fn.__call__, depth + 1)
|
||||
|
||||
return str(fn)
|
||||
|
||||
|
||||
def is_truish(val: Any) -> bool:
|
||||
"""Check if the value is truish.
|
||||
|
||||
Args:
|
||||
val (Any): The value to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the value is truish, False otherwise.
|
||||
"""
|
||||
if isinstance(val, str):
|
||||
return val.lower() == "true" or val == "1"
|
||||
return bool(val)
|
||||
Reference in New Issue
Block a user