initial commit
This commit is contained in:
9
venv/Lib/site-packages/langchain/agents/__init__.py
Normal file
9
venv/Lib/site-packages/langchain/agents/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
|
||||
__all__ = [
|
||||
"AgentState",
|
||||
"create_agent",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
1817
venv/Lib/site-packages/langchain/agents/factory.py
Normal file
1817
venv/Lib/site-packages/langchain/agents/factory.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,81 @@
|
||||
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
|
||||
|
||||
from langchain.agents.middleware.context_editing import ClearToolUsesEdit, ContextEditingMiddleware
|
||||
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
)
|
||||
from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddleware
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents.middleware.model_retry import ModelRetryMiddleware
|
||||
from langchain.agents.middleware.pii import PIIDetectionError, PIIMiddleware
|
||||
from langchain.agents.middleware.shell_tool import (
|
||||
CodexSandboxExecutionPolicy,
|
||||
DockerExecutionPolicy,
|
||||
HostExecutionPolicy,
|
||||
RedactionRule,
|
||||
ShellToolMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.todo import TodoListMiddleware
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents.middleware.tool_emulator import LLMToolEmulator
|
||||
from langchain.agents.middleware.tool_retry import ToolRetryMiddleware
|
||||
from langchain.agents.middleware.tool_selection import LLMToolSelectorMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ExtendedModelResponse,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
before_model,
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ClearToolUsesEdit",
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"DockerExecutionPolicy",
|
||||
"ExtendedModelResponse",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelCallResult",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"ModelRetryMiddleware",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"RedactionRule",
|
||||
"ShellToolMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"TodoListMiddleware",
|
||||
"ToolCallLimitMiddleware",
|
||||
"ToolCallRequest",
|
||||
"ToolRetryMiddleware",
|
||||
"after_agent",
|
||||
"after_model",
|
||||
"before_agent",
|
||||
"before_model",
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"wrap_model_call",
|
||||
"wrap_tool_call",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
385
venv/Lib/site-packages/langchain/agents/middleware/_execution.py
Normal file
385
venv/Lib/site-packages/langchain/agents/middleware/_execution.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Execution policies for the persistent shell middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import typing
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
try: # pragma: no cover - optional dependency on POSIX platforms
|
||||
import resource
|
||||
|
||||
_HAS_RESOURCE = True
|
||||
except ImportError: # pragma: no cover - non-POSIX systems
|
||||
_HAS_RESOURCE = False
|
||||
|
||||
|
||||
SHELL_TEMP_PREFIX = "langchain-shell-"
|
||||
|
||||
|
||||
def _launch_subprocess(
|
||||
command: Sequence[str],
|
||||
*,
|
||||
env: Mapping[str, str],
|
||||
cwd: Path,
|
||||
preexec_fn: typing.Callable[[], None] | None,
|
||||
start_new_session: bool,
|
||||
) -> subprocess.Popen[str]:
|
||||
return subprocess.Popen( # noqa: S603
|
||||
list(command),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
bufsize=1,
|
||||
env=env,
|
||||
preexec_fn=preexec_fn, # noqa: PLW1509
|
||||
start_new_session=start_new_session,
|
||||
)
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseExecutionPolicy(abc.ABC):
|
||||
"""Configuration contract for persistent shell sessions.
|
||||
|
||||
Concrete subclasses encapsulate how a shell process is launched and constrained.
|
||||
|
||||
Each policy documents its security guarantees and the operating environments in
|
||||
which it is appropriate. Use `HostExecutionPolicy` for trusted, same-host execution;
|
||||
`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is available and you want
|
||||
additional syscall restrictions; and `DockerExecutionPolicy` for container-level
|
||||
isolation using Docker.
|
||||
"""
|
||||
|
||||
command_timeout: float = 30.0
|
||||
startup_timeout: float = 30.0
|
||||
termination_timeout: float = 10.0
|
||||
max_output_lines: int = 100
|
||||
max_output_bytes: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_output_lines <= 0:
|
||||
msg = "max_output_lines must be positive."
|
||||
raise ValueError(msg)
|
||||
|
||||
@abc.abstractmethod
|
||||
def spawn(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
env: Mapping[str, str],
|
||||
command: Sequence[str],
|
||||
) -> subprocess.Popen[str]:
|
||||
"""Launch the persistent shell process."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostExecutionPolicy(BaseExecutionPolicy):
|
||||
"""Run the shell directly on the host process.
|
||||
|
||||
This policy is best suited for trusted or single-tenant environments (CI jobs,
|
||||
developer workstations, pre-sandboxed containers) where the agent must access the
|
||||
host filesystem and tooling without additional isolation. Enforces optional CPU and
|
||||
memory limits to prevent runaway commands but offers **no** filesystem or network
|
||||
sandboxing; commands can modify anything the process user can reach.
|
||||
|
||||
On Linux platforms resource limits are applied with `resource.prlimit` after the
|
||||
shell starts. On macOS, where `prlimit` is unavailable, limits are set in a
|
||||
`preexec_fn` before `exec`. In both cases the shell runs in its own process group
|
||||
so timeouts can terminate the full subtree.
|
||||
"""
|
||||
|
||||
cpu_time_seconds: int | None = None
|
||||
memory_bytes: int | None = None
|
||||
create_process_group: bool = True
|
||||
|
||||
_limits_requested: bool = field(init=False, repr=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.cpu_time_seconds is not None and self.cpu_time_seconds <= 0:
|
||||
msg = "cpu_time_seconds must be positive if provided."
|
||||
raise ValueError(msg)
|
||||
if self.memory_bytes is not None and self.memory_bytes <= 0:
|
||||
msg = "memory_bytes must be positive if provided."
|
||||
raise ValueError(msg)
|
||||
self._limits_requested = any(
|
||||
value is not None for value in (self.cpu_time_seconds, self.memory_bytes)
|
||||
)
|
||||
if self._limits_requested and not _HAS_RESOURCE:
|
||||
msg = (
|
||||
"HostExecutionPolicy cpu/memory limits require the Python 'resource' module. "
|
||||
"Either remove the limits or run on a POSIX platform."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
env: Mapping[str, str],
|
||||
command: Sequence[str],
|
||||
) -> subprocess.Popen[str]:
|
||||
process = _launch_subprocess(
|
||||
list(command),
|
||||
env=env,
|
||||
cwd=workspace,
|
||||
preexec_fn=self._create_preexec_fn(),
|
||||
start_new_session=self.create_process_group,
|
||||
)
|
||||
self._apply_post_spawn_limits(process)
|
||||
return process
|
||||
|
||||
def _create_preexec_fn(self) -> typing.Callable[[], None] | None:
|
||||
if not self._limits_requested or self._can_use_prlimit():
|
||||
return None
|
||||
|
||||
def _configure() -> None: # pragma: no cover - depends on OS
|
||||
if self.cpu_time_seconds is not None:
|
||||
limit = (self.cpu_time_seconds, self.cpu_time_seconds)
|
||||
resource.setrlimit(resource.RLIMIT_CPU, limit)
|
||||
if self.memory_bytes is not None:
|
||||
limit = (self.memory_bytes, self.memory_bytes)
|
||||
if hasattr(resource, "RLIMIT_AS"):
|
||||
resource.setrlimit(resource.RLIMIT_AS, limit)
|
||||
elif hasattr(resource, "RLIMIT_DATA"):
|
||||
resource.setrlimit(resource.RLIMIT_DATA, limit)
|
||||
|
||||
return _configure
|
||||
|
||||
def _apply_post_spawn_limits(self, process: subprocess.Popen[str]) -> None:
|
||||
if not self._limits_requested or not self._can_use_prlimit():
|
||||
return
|
||||
if not _HAS_RESOURCE: # pragma: no cover - defensive
|
||||
return
|
||||
pid = process.pid
|
||||
try:
|
||||
prlimit = typing.cast("typing.Any", resource).prlimit
|
||||
if self.cpu_time_seconds is not None:
|
||||
prlimit(pid, resource.RLIMIT_CPU, (self.cpu_time_seconds, self.cpu_time_seconds))
|
||||
if self.memory_bytes is not None:
|
||||
limit = (self.memory_bytes, self.memory_bytes)
|
||||
if hasattr(resource, "RLIMIT_AS"):
|
||||
prlimit(pid, resource.RLIMIT_AS, limit)
|
||||
elif hasattr(resource, "RLIMIT_DATA"):
|
||||
prlimit(pid, resource.RLIMIT_DATA, limit)
|
||||
except OSError as exc: # pragma: no cover - depends on platform support
|
||||
msg = "Failed to apply resource limits via prlimit."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
@staticmethod
|
||||
def _can_use_prlimit() -> bool:
|
||||
return _HAS_RESOURCE and hasattr(resource, "prlimit") and sys.platform.startswith("linux")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
|
||||
"""Launch the shell through the Codex CLI sandbox.
|
||||
|
||||
Ideal when you have the Codex CLI installed and want the additional syscall and
|
||||
filesystem restrictions provided by Anthropic's Seatbelt (macOS) or Landlock/seccomp
|
||||
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
|
||||
the CLI. If the Codex binary is unavailable or the runtime lacks the required
|
||||
kernel features (e.g., Landlock inside some containers), process startup fails with a
|
||||
`RuntimeError`.
|
||||
|
||||
Configure sandbox behavior via `config_overrides` to align with your Codex CLI
|
||||
profile. This policy does not add its own resource limits; combine it with
|
||||
host-level guards (cgroups, container resource limits) as needed.
|
||||
"""
|
||||
|
||||
binary: str = "codex"
|
||||
platform: typing.Literal["auto", "macos", "linux"] = "auto"
|
||||
config_overrides: Mapping[str, typing.Any] = field(default_factory=dict)
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
env: Mapping[str, str],
|
||||
command: Sequence[str],
|
||||
) -> subprocess.Popen[str]:
|
||||
full_command = self._build_command(command)
|
||||
return _launch_subprocess(
|
||||
full_command,
|
||||
env=env,
|
||||
cwd=workspace,
|
||||
preexec_fn=None,
|
||||
start_new_session=False,
|
||||
)
|
||||
|
||||
def _build_command(self, command: Sequence[str]) -> list[str]:
|
||||
binary = self._resolve_binary()
|
||||
platform_arg = self._determine_platform()
|
||||
full_command: list[str] = [binary, "sandbox", platform_arg]
|
||||
for key, value in sorted(dict(self.config_overrides).items()):
|
||||
full_command.extend(["-c", f"{key}={self._format_override(value)}"])
|
||||
full_command.append("--")
|
||||
full_command.extend(command)
|
||||
return full_command
|
||||
|
||||
def _resolve_binary(self) -> str:
|
||||
path = shutil.which(self.binary)
|
||||
if path is None:
|
||||
msg = (
|
||||
"Codex sandbox policy requires the '%s' CLI to be installed and available on PATH."
|
||||
)
|
||||
raise RuntimeError(msg % self.binary)
|
||||
return path
|
||||
|
||||
def _determine_platform(self) -> str:
|
||||
if self.platform != "auto":
|
||||
return self.platform
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
if sys.platform == "darwin": # type: ignore[unreachable, unused-ignore]
|
||||
return "macos"
|
||||
msg = ( # type: ignore[unreachable, unused-ignore]
|
||||
"Codex sandbox policy could not determine a supported platform; "
|
||||
"set 'platform' explicitly."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@staticmethod
|
||||
def _format_override(value: typing.Any) -> str:
|
||||
try:
|
||||
return json.dumps(value)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DockerExecutionPolicy(BaseExecutionPolicy):
|
||||
"""Run the shell inside a dedicated Docker container.
|
||||
|
||||
Choose this policy when commands originate from untrusted users or you require
|
||||
strong isolation between sessions. By default the workspace is bind-mounted only
|
||||
when it refers to an existing non-temporary directory; ephemeral sessions run
|
||||
without a mount to minimise host exposure. The container's network namespace is
|
||||
disabled by default (`--network none`) and you can enable further hardening via
|
||||
`read_only_rootfs` and `user`.
|
||||
|
||||
The security guarantees depend on your Docker daemon configuration. Run the agent on
|
||||
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and
|
||||
review any additional volumes or capabilities passed through ``extra_run_args``. The
|
||||
default image is `python:3.12-alpine3.19`; supply a custom image if you need
|
||||
preinstalled tooling.
|
||||
"""
|
||||
|
||||
binary: str = "docker"
|
||||
image: str = "python:3.12-alpine3.19"
|
||||
remove_container_on_exit: bool = True
|
||||
network_enabled: bool = False
|
||||
extra_run_args: Sequence[str] | None = None
|
||||
memory_bytes: int | None = None
|
||||
cpu_time_seconds: typing.Any | None = None
|
||||
cpus: str | None = None
|
||||
read_only_rootfs: bool = False
|
||||
user: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.memory_bytes is not None and self.memory_bytes <= 0:
|
||||
msg = "memory_bytes must be positive if provided."
|
||||
raise ValueError(msg)
|
||||
if self.cpu_time_seconds is not None:
|
||||
msg = (
|
||||
"DockerExecutionPolicy does not support cpu_time_seconds; configure CPU limits "
|
||||
"using Docker run options such as '--cpus'."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
if self.cpus is not None and not self.cpus.strip():
|
||||
msg = "cpus must be a non-empty string when provided."
|
||||
raise ValueError(msg)
|
||||
if self.user is not None and not self.user.strip():
|
||||
msg = "user must be a non-empty string when provided."
|
||||
raise ValueError(msg)
|
||||
self.extra_run_args = tuple(self.extra_run_args or ())
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
*,
|
||||
workspace: Path,
|
||||
env: Mapping[str, str],
|
||||
command: Sequence[str],
|
||||
) -> subprocess.Popen[str]:
|
||||
full_command = self._build_command(workspace, env, command)
|
||||
host_env = os.environ.copy()
|
||||
return _launch_subprocess(
|
||||
full_command,
|
||||
env=host_env,
|
||||
cwd=workspace,
|
||||
preexec_fn=None,
|
||||
start_new_session=False,
|
||||
)
|
||||
|
||||
def _build_command(
|
||||
self,
|
||||
workspace: Path,
|
||||
env: Mapping[str, str],
|
||||
command: Sequence[str],
|
||||
) -> list[str]:
|
||||
binary = self._resolve_binary()
|
||||
full_command: list[str] = [binary, "run", "-i"]
|
||||
if self.remove_container_on_exit:
|
||||
full_command.append("--rm")
|
||||
if not self.network_enabled:
|
||||
full_command.extend(["--network", "none"])
|
||||
if self.memory_bytes is not None:
|
||||
full_command.extend(["--memory", str(self.memory_bytes)])
|
||||
if self._should_mount_workspace(workspace):
|
||||
host_path = str(workspace)
|
||||
full_command.extend(["-v", f"{host_path}:{host_path}"])
|
||||
full_command.extend(["-w", host_path])
|
||||
else:
|
||||
full_command.extend(["-w", "/"])
|
||||
if self.read_only_rootfs:
|
||||
full_command.append("--read-only")
|
||||
for key, value in env.items():
|
||||
full_command.extend(["-e", f"{key}={value}"])
|
||||
if self.cpus is not None:
|
||||
full_command.extend(["--cpus", self.cpus])
|
||||
if self.user is not None:
|
||||
full_command.extend(["--user", self.user])
|
||||
if self.extra_run_args:
|
||||
full_command.extend(self.extra_run_args)
|
||||
full_command.append(self.image)
|
||||
full_command.extend(command)
|
||||
return full_command
|
||||
|
||||
@staticmethod
|
||||
def _should_mount_workspace(workspace: Path) -> bool:
|
||||
return not workspace.name.startswith(SHELL_TEMP_PREFIX)
|
||||
|
||||
def _resolve_binary(self) -> str:
|
||||
path = shutil.which(self.binary)
|
||||
if path is None:
|
||||
msg = (
|
||||
"Docker execution policy requires the '%s' CLI to be installed"
|
||||
" and available on PATH."
|
||||
)
|
||||
raise RuntimeError(msg % self.binary)
|
||||
return path
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseExecutionPolicy",
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"DockerExecutionPolicy",
|
||||
"HostExecutionPolicy",
|
||||
]
|
||||
436
venv/Lib/site-packages/langchain/agents/middleware/_redaction.py
Normal file
436
venv/Lib/site-packages/langchain/agents/middleware/_redaction.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Shared redaction utilities for middleware components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import operator
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
RedactionStrategy = Literal["block", "redact", "mask", "hash"]
|
||||
"""Supported strategies for handling detected sensitive values."""
|
||||
|
||||
|
||||
class PIIMatch(TypedDict):
|
||||
"""Represents an individual match of sensitive data."""
|
||||
|
||||
type: str
|
||||
value: str
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class PIIDetectionError(Exception):
|
||||
"""Raised when configured to block on detected sensitive values."""
|
||||
|
||||
def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None:
|
||||
"""Initialize the exception with match context.
|
||||
|
||||
Args:
|
||||
pii_type: Name of the detected sensitive type.
|
||||
matches: All matches that were detected for that type.
|
||||
"""
|
||||
self.pii_type = pii_type
|
||||
self.matches = list(matches)
|
||||
count = len(matches)
|
||||
msg = f"Detected {count} instance(s) of {pii_type} in text content"
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
Detector = Callable[[str], list[PIIMatch]]
|
||||
"""Callable signature for detectors that locate sensitive values."""
|
||||
|
||||
|
||||
def detect_email(content: str) -> list[PIIMatch]:
|
||||
"""Detect email addresses in content.
|
||||
|
||||
Args:
|
||||
content: The text content to scan for email addresses.
|
||||
|
||||
Returns:
|
||||
A list of detected email matches.
|
||||
"""
|
||||
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
||||
return [
|
||||
PIIMatch(
|
||||
type="email",
|
||||
value=match.group(),
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
for match in re.finditer(pattern, content)
|
||||
]
|
||||
|
||||
|
||||
def detect_credit_card(content: str) -> list[PIIMatch]:
|
||||
"""Detect credit card numbers in content using Luhn validation.
|
||||
|
||||
Args:
|
||||
content: The text content to scan for credit card numbers.
|
||||
|
||||
Returns:
|
||||
A list of detected credit card matches.
|
||||
"""
|
||||
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
|
||||
matches = []
|
||||
|
||||
for match in re.finditer(pattern, content):
|
||||
card_number = match.group()
|
||||
if _passes_luhn(card_number):
|
||||
matches.append(
|
||||
PIIMatch(
|
||||
type="credit_card",
|
||||
value=card_number,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def detect_ip(content: str) -> list[PIIMatch]:
|
||||
"""Detect IPv4 or IPv6 addresses in content.
|
||||
|
||||
Args:
|
||||
content: The text content to scan for IP addresses.
|
||||
|
||||
Returns:
|
||||
A list of detected IP address matches.
|
||||
"""
|
||||
matches: list[PIIMatch] = []
|
||||
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
|
||||
|
||||
for match in re.finditer(ipv4_pattern, content):
|
||||
ip_candidate = match.group()
|
||||
try:
|
||||
ipaddress.ip_address(ip_candidate)
|
||||
except ValueError:
|
||||
continue
|
||||
matches.append(
|
||||
PIIMatch(
|
||||
type="ip",
|
||||
value=ip_candidate,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def detect_mac_address(content: str) -> list[PIIMatch]:
|
||||
"""Detect MAC addresses in content.
|
||||
|
||||
Args:
|
||||
content: The text content to scan for MAC addresses.
|
||||
|
||||
Returns:
|
||||
A list of detected MAC address matches.
|
||||
"""
|
||||
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
|
||||
return [
|
||||
PIIMatch(
|
||||
type="mac_address",
|
||||
value=match.group(),
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
for match in re.finditer(pattern, content)
|
||||
]
|
||||
|
||||
|
||||
def detect_url(content: str) -> list[PIIMatch]:
|
||||
"""Detect URLs in content using regex and stdlib validation.
|
||||
|
||||
Args:
|
||||
content: The text content to scan for URLs.
|
||||
|
||||
Returns:
|
||||
A list of detected URL matches.
|
||||
"""
|
||||
matches: list[PIIMatch] = []
|
||||
|
||||
# Pattern 1: URLs with scheme (http:// or https://)
|
||||
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
|
||||
|
||||
for match in re.finditer(scheme_pattern, content):
|
||||
url = match.group()
|
||||
result = urlparse(url)
|
||||
if result.scheme in {"http", "https"} and result.netloc:
|
||||
matches.append(
|
||||
PIIMatch(
|
||||
type="url",
|
||||
value=url,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
|
||||
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
|
||||
# More conservative to avoid false positives
|
||||
bare_pattern = (
|
||||
r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?"
|
||||
r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?"
|
||||
)
|
||||
|
||||
for match in re.finditer(bare_pattern, content):
|
||||
start, end = match.start(), match.end()
|
||||
# Skip if already matched with scheme
|
||||
if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches):
|
||||
continue
|
||||
|
||||
url = match.group()
|
||||
# Only accept if it has a path or starts with www
|
||||
# This reduces false positives like "example.com" in prose
|
||||
if "/" in url or url.startswith("www."):
|
||||
# Add scheme for validation (required for urlparse to work correctly)
|
||||
test_url = f"http://{url}"
|
||||
result = urlparse(test_url)
|
||||
if result.netloc and "." in result.netloc:
|
||||
matches.append(
|
||||
PIIMatch(
|
||||
type="url",
|
||||
value=url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
BUILTIN_DETECTORS: dict[str, Detector] = {
|
||||
"email": detect_email,
|
||||
"credit_card": detect_credit_card,
|
||||
"ip": detect_ip,
|
||||
"mac_address": detect_mac_address,
|
||||
"url": detect_url,
|
||||
}
|
||||
"""Registry of built-in detectors keyed by type name."""
|
||||
|
||||
_CARD_NUMBER_MIN_DIGITS = 13
|
||||
_CARD_NUMBER_MAX_DIGITS = 19
|
||||
|
||||
|
||||
def _passes_luhn(card_number: str) -> bool:
|
||||
"""Validate credit card number using the Luhn checksum."""
|
||||
digits = [int(d) for d in card_number if d.isdigit()]
|
||||
if not _CARD_NUMBER_MIN_DIGITS <= len(digits) <= _CARD_NUMBER_MAX_DIGITS:
|
||||
return False
|
||||
|
||||
checksum = 0
|
||||
for index, digit in enumerate(reversed(digits)):
|
||||
value = digit
|
||||
if index % 2 == 1:
|
||||
value *= 2
|
||||
if value > 9: # noqa: PLR2004
|
||||
value -= 9
|
||||
checksum += value
|
||||
return checksum % 10 == 0
|
||||
|
||||
|
||||
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||
result = content
|
||||
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
||||
replacement = f"[REDACTED_{match['type'].upper()}]"
|
||||
result = result[: match["start"]] + replacement + result[match["end"] :]
|
||||
return result
|
||||
|
||||
|
||||
_UNMASKED_CHAR_NUMBER = 4
|
||||
_IPV4_PARTS_NUMBER = 4
|
||||
|
||||
|
||||
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||
result = content
|
||||
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
||||
value = match["value"]
|
||||
pii_type = match["type"]
|
||||
if pii_type == "email":
|
||||
parts = value.split("@")
|
||||
if len(parts) == 2: # noqa: PLR2004
|
||||
domain_parts = parts[1].split(".")
|
||||
masked = (
|
||||
f"{parts[0]}@****.{domain_parts[-1]}"
|
||||
if len(domain_parts) > 1
|
||||
else f"{parts[0]}@****"
|
||||
)
|
||||
else:
|
||||
masked = "****"
|
||||
elif pii_type == "credit_card":
|
||||
digits_only = "".join(c for c in value if c.isdigit())
|
||||
separator = "-" if "-" in value else " " if " " in value else ""
|
||||
if separator:
|
||||
masked = (
|
||||
f"****{separator}****{separator}****{separator}"
|
||||
f"{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
|
||||
)
|
||||
else:
|
||||
masked = f"************{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
|
||||
elif pii_type == "ip":
|
||||
octets = value.split(".")
|
||||
masked = f"*.*.*.{octets[-1]}" if len(octets) == _IPV4_PARTS_NUMBER else "****"
|
||||
elif pii_type == "mac_address":
|
||||
separator = ":" if ":" in value else "-"
|
||||
masked = (
|
||||
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
|
||||
)
|
||||
elif pii_type == "url":
|
||||
masked = "[MASKED_URL]"
|
||||
else:
|
||||
masked = (
|
||||
f"****{value[-_UNMASKED_CHAR_NUMBER:]}"
|
||||
if len(value) > _UNMASKED_CHAR_NUMBER
|
||||
else "****"
|
||||
)
|
||||
result = result[: match["start"]] + masked + result[match["end"] :]
|
||||
return result
|
||||
|
||||
|
||||
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
||||
result = content
|
||||
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
|
||||
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
|
||||
replacement = f"<{match['type']}_hash:{digest}>"
|
||||
result = result[: match["start"]] + replacement + result[match["end"] :]
|
||||
return result
|
||||
|
||||
|
||||
def apply_strategy(
|
||||
content: str,
|
||||
matches: list[PIIMatch],
|
||||
strategy: RedactionStrategy,
|
||||
) -> str:
|
||||
"""Apply the configured strategy to matches within content.
|
||||
|
||||
Args:
|
||||
content: The content to apply strategy to.
|
||||
matches: List of detected PII matches.
|
||||
strategy: The redaction strategy to apply.
|
||||
|
||||
Returns:
|
||||
The content with the strategy applied.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If the strategy is `'block'` and matches are found.
|
||||
ValueError: If the strategy is unknown.
|
||||
"""
|
||||
if not matches:
|
||||
return content
|
||||
if strategy == "redact":
|
||||
return _apply_redact_strategy(content, matches)
|
||||
if strategy == "mask":
|
||||
return _apply_mask_strategy(content, matches)
|
||||
if strategy == "hash":
|
||||
return _apply_hash_strategy(content, matches)
|
||||
if strategy == "block":
|
||||
raise PIIDetectionError(matches[0]["type"], matches)
|
||||
msg = f"Unknown redaction strategy: {strategy}" # type: ignore[unreachable]
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector:
|
||||
"""Return a callable detector for the given configuration.
|
||||
|
||||
Args:
|
||||
pii_type: The PII type name.
|
||||
detector: Optional custom detector or regex pattern. If `None`, a built-in detector
|
||||
for the given PII type will be used.
|
||||
|
||||
Returns:
|
||||
The resolved detector.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown PII type is specified without a custom detector or regex.
|
||||
"""
|
||||
if detector is None:
|
||||
if pii_type not in BUILTIN_DETECTORS:
|
||||
msg = (
|
||||
f"Unknown PII type: {pii_type}. "
|
||||
f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return BUILTIN_DETECTORS[pii_type]
|
||||
if isinstance(detector, str):
|
||||
pattern = re.compile(detector)
|
||||
|
||||
def regex_detector(content: str) -> list[PIIMatch]:
|
||||
return [
|
||||
PIIMatch(
|
||||
type=pii_type,
|
||||
value=match.group(),
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
for match in pattern.finditer(content)
|
||||
]
|
||||
|
||||
return regex_detector
|
||||
return detector
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RedactionRule:
|
||||
"""Configuration for handling a single PII type."""
|
||||
|
||||
pii_type: str
|
||||
strategy: RedactionStrategy = "redact"
|
||||
detector: Detector | str | None = None
|
||||
|
||||
def resolve(self) -> ResolvedRedactionRule:
|
||||
"""Resolve runtime detector and return an immutable rule.
|
||||
|
||||
Returns:
|
||||
The resolved redaction rule.
|
||||
"""
|
||||
resolved_detector = resolve_detector(self.pii_type, self.detector)
|
||||
return ResolvedRedactionRule(
|
||||
pii_type=self.pii_type,
|
||||
strategy=self.strategy,
|
||||
detector=resolved_detector,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedRedactionRule:
|
||||
"""Resolved redaction rule ready for execution."""
|
||||
|
||||
pii_type: str
|
||||
strategy: RedactionStrategy
|
||||
detector: Detector
|
||||
|
||||
def apply(self, content: str) -> tuple[str, list[PIIMatch]]:
|
||||
"""Apply this rule to content, returning new content and matches.
|
||||
|
||||
Args:
|
||||
content: The text content to scan and redact.
|
||||
|
||||
Returns:
|
||||
A tuple of (updated content, list of detected matches).
|
||||
"""
|
||||
matches = self.detector(content)
|
||||
if not matches:
|
||||
return content, []
|
||||
updated = apply_strategy(content, matches, self.strategy)
|
||||
return updated, matches
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
"PIIMatch",
|
||||
"RedactionRule",
|
||||
"ResolvedRedactionRule",
|
||||
"apply_strategy",
|
||||
"detect_credit_card",
|
||||
"detect_email",
|
||||
"detect_ip",
|
||||
"detect_mac_address",
|
||||
"detect_url",
|
||||
]
|
||||
123
venv/Lib/site-packages/langchain/agents/middleware/_retry.py
Normal file
123
venv/Lib/site-packages/langchain/agents/middleware/_retry.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Shared retry utilities for agent middleware.
|
||||
|
||||
This module contains common constants, utilities, and logic used by both
|
||||
model and tool retry middleware implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
# Type aliases
|
||||
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
|
||||
"""Type for specifying which exceptions to retry on.
|
||||
|
||||
Can be either:
|
||||
- A tuple of exception types to retry on (based on `isinstance` checks)
|
||||
- A callable that takes an exception and returns `True` if it should be retried
|
||||
"""
|
||||
|
||||
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
|
||||
"""Type for specifying failure handling behavior.
|
||||
|
||||
Can be either:
|
||||
- A literal action string (`'error'` or `'continue'`)
|
||||
- `'error'`: Re-raise the exception, stopping agent execution.
|
||||
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
|
||||
For tool retries, a `ToolMessage` with the error details will be injected.
|
||||
For model retries, an `AIMessage` with the error details will be returned.
|
||||
- A callable that takes an exception and returns a string for error message content
|
||||
"""
|
||||
|
||||
|
||||
def validate_retry_params(
|
||||
max_retries: int,
|
||||
initial_delay: float,
|
||||
max_delay: float,
|
||||
backoff_factor: float,
|
||||
) -> None:
|
||||
"""Validate retry parameters.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid (negative values).
|
||||
"""
|
||||
if max_retries < 0:
|
||||
msg = "max_retries must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if initial_delay < 0:
|
||||
msg = "initial_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if max_delay < 0:
|
||||
msg = "max_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if backoff_factor < 0:
|
||||
msg = "backoff_factor must be >= 0"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def should_retry_exception(
|
||||
exc: Exception,
|
||||
retry_on: RetryOn,
|
||||
) -> bool:
|
||||
"""Check if an exception should trigger a retry.
|
||||
|
||||
Args:
|
||||
exc: The exception that occurred.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Returns:
|
||||
`True` if the exception should be retried, `False` otherwise.
|
||||
"""
|
||||
if callable(retry_on):
|
||||
return retry_on(exc)
|
||||
return isinstance(exc, retry_on)
|
||||
|
||||
|
||||
def calculate_delay(
|
||||
retry_number: int,
|
||||
*,
|
||||
backoff_factor: float,
|
||||
initial_delay: float,
|
||||
max_delay: float,
|
||||
jitter: bool,
|
||||
) -> float:
|
||||
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
|
||||
|
||||
Args:
|
||||
retry_number: The retry attempt number (0-indexed).
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter to delay to avoid thundering herd.
|
||||
|
||||
Returns:
|
||||
Delay in seconds before next retry.
|
||||
"""
|
||||
if backoff_factor == 0.0:
|
||||
delay = initial_delay
|
||||
else:
|
||||
delay = initial_delay * (backoff_factor**retry_number)
|
||||
|
||||
# Cap at max_delay
|
||||
delay = min(delay, max_delay)
|
||||
|
||||
if jitter and delay > 0:
|
||||
jitter_amount = delay * 0.25 # ±25% jitter
|
||||
delay += random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
||||
# Ensure delay is not negative after jitter
|
||||
delay = max(0, delay)
|
||||
|
||||
return delay
|
||||
@@ -0,0 +1,298 @@
|
||||
"""Context editing middleware.
|
||||
|
||||
Mirrors Anthropic's context editing capabilities by clearing older tool results once the
|
||||
conversation grows beyond a configurable token threshold.
|
||||
|
||||
The implementation is intentionally model-agnostic so it can be used with any LangChain
|
||||
chat model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
|
||||
|
||||
TokenCounter = Callable[
|
||||
[Sequence[BaseMessage]],
|
||||
int,
|
||||
]
|
||||
|
||||
|
||||
class ContextEdit(Protocol):
|
||||
"""Protocol describing a context editing strategy."""
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Apply an edit to the message list in place."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ClearToolUsesEdit(ContextEdit):
|
||||
"""Configuration for clearing tool outputs when token limits are exceeded."""
|
||||
|
||||
trigger: int = 100_000
|
||||
"""Token count that triggers the edit."""
|
||||
|
||||
clear_at_least: int = 0
|
||||
"""Minimum number of tokens to reclaim when the edit runs."""
|
||||
|
||||
keep: int = 3
|
||||
"""Number of most recent tool results that must be preserved."""
|
||||
|
||||
clear_tool_inputs: bool = False
|
||||
"""Whether to clear the originating tool call parameters on the AI message."""
|
||||
|
||||
exclude_tools: Sequence[str] = ()
|
||||
"""List of tool names to exclude from clearing."""
|
||||
|
||||
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
|
||||
"""Placeholder text inserted for cleared tool outputs."""
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Apply the clear-tool-uses strategy."""
|
||||
tokens = count_tokens(messages)
|
||||
|
||||
if tokens <= self.trigger:
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
]
|
||||
|
||||
if self.keep >= len(candidates):
|
||||
candidates = []
|
||||
elif self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
|
||||
cleared_tokens = 0
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||
continue
|
||||
|
||||
ai_message = next(
|
||||
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), None
|
||||
)
|
||||
|
||||
if ai_message is None:
|
||||
continue
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
call
|
||||
for call in ai_message.tool_calls
|
||||
if call.get("id") == tool_message.tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if tool_call is None:
|
||||
continue
|
||||
|
||||
if (tool_message.name or tool_call["name"]) in excluded_tools:
|
||||
continue
|
||||
|
||||
messages[idx] = tool_message.model_copy(
|
||||
update={
|
||||
"artifact": None,
|
||||
"content": self.placeholder,
|
||||
"response_metadata": {
|
||||
**tool_message.response_metadata,
|
||||
"context_editing": {
|
||||
"cleared": True,
|
||||
"strategy": "clear_tool_uses",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.clear_tool_inputs:
|
||||
messages[messages.index(ai_message)] = self._build_cleared_tool_input_message(
|
||||
ai_message,
|
||||
tool_message.tool_call_id,
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _build_cleared_tool_input_message(
|
||||
message: AIMessage,
|
||||
tool_call_id: str,
|
||||
) -> AIMessage:
|
||||
updated_tool_calls = []
|
||||
cleared_any = False
|
||||
for tool_call in message.tool_calls:
|
||||
updated_call = dict(tool_call)
|
||||
if updated_call.get("id") == tool_call_id:
|
||||
updated_call["args"] = {}
|
||||
cleared_any = True
|
||||
updated_tool_calls.append(updated_call)
|
||||
|
||||
metadata = dict(getattr(message, "response_metadata", {}))
|
||||
context_entry = dict(metadata.get("context_editing", {}))
|
||||
if cleared_any:
|
||||
cleared_ids = set(context_entry.get("cleared_tool_inputs", []))
|
||||
cleared_ids.add(tool_call_id)
|
||||
context_entry["cleared_tool_inputs"] = sorted(cleared_ids)
|
||||
metadata["context_editing"] = context_entry
|
||||
|
||||
return message.model_copy(
|
||||
update={
|
||||
"tool_calls": updated_tool_calls,
|
||||
"response_metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ContextEditingMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Automatically prune tool results to manage context size.
|
||||
|
||||
The middleware applies a sequence of edits when the total input token count exceeds
|
||||
configured thresholds.
|
||||
|
||||
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
|
||||
`clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
|
||||
"""
|
||||
|
||||
edits: list[ContextEdit]
|
||||
token_count_method: Literal["approximate", "model"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
edits: Iterable[ContextEdit] | None = None,
|
||||
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
|
||||
) -> None:
|
||||
"""Initialize an instance of context editing middleware.
|
||||
|
||||
Args:
|
||||
edits: Sequence of edit strategies to apply.
|
||||
|
||||
Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
|
||||
token_count_method: Whether to use approximate token counting
|
||||
(faster, less accurate) or exact counting implemented by the
|
||||
chat model (potentially slower, more accurate).
|
||||
"""
|
||||
super().__init__()
|
||||
self.edits = list(edits or (ClearToolUsesEdit(),))
|
||||
self.token_count_method = token_count_method
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Apply context edits before invoking the model via handler.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The result of invoking the handler with potentially edited messages.
|
||||
"""
|
||||
if not request.messages:
|
||||
return handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request.override(messages=edited_messages))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Apply context edits before invoking the model via handler.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The result of invoking the handler with potentially edited messages.
|
||||
"""
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,387 @@
|
||||
"""File search middleware for Anthropic text editor and memory tools.
|
||||
|
||||
This module provides Glob and Grep search tools that operate on files stored
|
||||
in state or filesystem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import suppress
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
|
||||
|
||||
|
||||
def _expand_include_patterns(pattern: str) -> list[str] | None:
|
||||
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
|
||||
if "}" in pattern and "{" not in pattern:
|
||||
return None
|
||||
|
||||
expanded: list[str] = []
|
||||
|
||||
def _expand(current: str) -> None:
|
||||
start = current.find("{")
|
||||
if start == -1:
|
||||
expanded.append(current)
|
||||
return
|
||||
|
||||
end = current.find("}", start)
|
||||
if end == -1:
|
||||
raise ValueError
|
||||
|
||||
prefix = current[:start]
|
||||
suffix = current[end + 1 :]
|
||||
inner = current[start + 1 : end]
|
||||
if not inner:
|
||||
raise ValueError
|
||||
|
||||
for option in inner.split(","):
|
||||
_expand(prefix + option + suffix)
|
||||
|
||||
try:
|
||||
_expand(pattern)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
def _is_valid_include_pattern(pattern: str) -> bool:
|
||||
"""Validate glob pattern used for include filters."""
|
||||
if not pattern:
|
||||
return False
|
||||
|
||||
if any(char in pattern for char in ("\x00", "\n", "\r")):
|
||||
return False
|
||||
|
||||
expanded = _expand_include_patterns(pattern)
|
||||
if expanded is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
for candidate in expanded:
|
||||
re.compile(fnmatch.translate(candidate))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _match_include_pattern(basename: str, pattern: str) -> bool:
|
||||
"""Return True if the basename matches the include pattern."""
|
||||
expanded = _expand_include_patterns(pattern)
|
||||
if not expanded:
|
||||
return False
|
||||
|
||||
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)
|
||||
|
||||
|
||||
class FilesystemFileSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Provides Glob and Grep search over filesystem files.
|
||||
|
||||
This middleware adds two tools that search through local filesystem:
|
||||
|
||||
- Glob: Fast file pattern matching by file path
|
||||
- Grep: Fast content search using ripgrep or Python fallback
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
FilesystemFileSearchMiddleware,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[], # Add tools as needed
|
||||
middleware=[
|
||||
FilesystemFileSearchMiddleware(root_path="/workspace"),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
root_path: str,
|
||||
use_ripgrep: bool = True,
|
||||
max_file_size_mb: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the search middleware.
|
||||
|
||||
Args:
|
||||
root_path: Root directory to search.
|
||||
use_ripgrep: Whether to use `ripgrep` for search.
|
||||
|
||||
Falls back to Python if `ripgrep` unavailable.
|
||||
max_file_size_mb: Maximum file size to search in MB.
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
self.use_ripgrep = use_ripgrep
|
||||
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
|
||||
|
||||
# Create tool instances as closures that capture self
|
||||
@tool
|
||||
def glob_search(pattern: str, path: str = "/") -> str:
|
||||
"""Fast file pattern matching tool that works with any codebase size.
|
||||
|
||||
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
|
||||
|
||||
Returns matching file paths sorted by modification time.
|
||||
|
||||
Use this tool when you need to find files by name patterns.
|
||||
|
||||
Args:
|
||||
pattern: The glob pattern to match files against.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
|
||||
Returns:
|
||||
Newline-separated list of matching file paths, sorted by modification
|
||||
time (most recently modified first). Returns `'No files found'` if no
|
||||
matches.
|
||||
"""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(path)
|
||||
except ValueError:
|
||||
return "No files found"
|
||||
|
||||
if not base_full.exists() or not base_full.is_dir():
|
||||
return "No files found"
|
||||
|
||||
# Use pathlib glob
|
||||
matching: list[tuple[str, str]] = []
|
||||
for match in base_full.glob(pattern):
|
||||
if match.is_file():
|
||||
# Convert to virtual path
|
||||
virtual_path = "/" + str(match.relative_to(self.root_path))
|
||||
stat = match.stat()
|
||||
modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat()
|
||||
matching.append((virtual_path, modified_at))
|
||||
|
||||
if not matching:
|
||||
return "No files found"
|
||||
|
||||
file_paths = [p for p, _ in matching]
|
||||
return "\n".join(file_paths)
|
||||
|
||||
@tool
|
||||
def grep_search(
|
||||
pattern: str,
|
||||
path: str = "/",
|
||||
include: str | None = None,
|
||||
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
|
||||
) -> str:
|
||||
"""Fast content search tool that works with any codebase size.
|
||||
|
||||
Searches file contents using regular expressions. Supports full regex
|
||||
syntax and filters files by pattern with the include parameter.
|
||||
|
||||
Args:
|
||||
pattern: The regular expression pattern to search for in file contents.
|
||||
path: The directory to search in. If not specified, searches from root.
|
||||
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
|
||||
output_mode: Output format:
|
||||
|
||||
- `'files_with_matches'`: Only file paths containing matches
|
||||
- `'content'`: Matching lines with `file:line:content` format
|
||||
- `'count'`: Count of matches per file
|
||||
|
||||
Returns:
|
||||
Search results formatted according to `output_mode`.
|
||||
Returns `'No matches found'` if no results.
|
||||
"""
|
||||
# Compile regex pattern (for validation)
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
return f"Invalid regex pattern: {e}"
|
||||
|
||||
if include and not _is_valid_include_pattern(include):
|
||||
return "Invalid include pattern"
|
||||
|
||||
# Try ripgrep first if enabled
|
||||
results = None
|
||||
if self.use_ripgrep:
|
||||
with suppress(
|
||||
FileNotFoundError,
|
||||
subprocess.CalledProcessError,
|
||||
subprocess.TimeoutExpired,
|
||||
):
|
||||
results = self._ripgrep_search(pattern, path, include)
|
||||
|
||||
# Python fallback if ripgrep failed or is disabled
|
||||
if results is None:
|
||||
results = self._python_search(pattern, path, include)
|
||||
|
||||
if not results:
|
||||
return "No matches found"
|
||||
|
||||
# Format output based on mode
|
||||
return self._format_grep_results(results, output_mode)
|
||||
|
||||
self.glob_search = glob_search
|
||||
self.grep_search = grep_search
|
||||
self.tools = [glob_search, grep_search]
|
||||
|
||||
def _validate_and_resolve_path(self, path: str) -> Path:
|
||||
"""Validate and resolve a virtual path to filesystem path."""
|
||||
# Normalize path
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for path traversal
|
||||
if ".." in path or "~" in path:
|
||||
msg = "Path traversal not allowed"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert virtual path to filesystem path
|
||||
relative = path.lstrip("/")
|
||||
full_path = (self.root_path / relative).resolve()
|
||||
|
||||
# Ensure path is within root
|
||||
try:
|
||||
full_path.relative_to(self.root_path)
|
||||
except ValueError:
|
||||
msg = f"Path outside root directory: {path}"
|
||||
raise ValueError(msg) from None
|
||||
|
||||
return full_path
|
||||
|
||||
def _ripgrep_search(
|
||||
self, pattern: str, base_path: str, include: str | None
|
||||
) -> dict[str, list[tuple[int, str]]]:
|
||||
"""Search using ripgrep subprocess."""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(base_path)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
if not base_full.exists():
|
||||
return {}
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = ["rg", "--json"]
|
||||
|
||||
if include:
|
||||
# Convert glob pattern to ripgrep glob
|
||||
cmd.extend(["--glob", include])
|
||||
|
||||
cmd.extend(["--", pattern, str(base_full)])
|
||||
|
||||
try:
|
||||
result = subprocess.run( # noqa: S603
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
# Fallback to Python search if ripgrep unavailable or times out
|
||||
return self._python_search(pattern, base_path, include)
|
||||
|
||||
# Parse ripgrep JSON output
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
for line in result.stdout.splitlines():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if data["type"] == "match":
|
||||
path = data["data"]["path"]["text"]
|
||||
# Convert to virtual path
|
||||
virtual_path = "/" + str(Path(path).relative_to(self.root_path))
|
||||
line_num = data["data"]["line_number"]
|
||||
line_text = data["data"]["lines"]["text"].rstrip("\n")
|
||||
|
||||
if virtual_path not in results:
|
||||
results[virtual_path] = []
|
||||
results[virtual_path].append((line_num, line_text))
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _python_search(
|
||||
self, pattern: str, base_path: str, include: str | None
|
||||
) -> dict[str, list[tuple[int, str]]]:
|
||||
"""Search using Python regex (fallback)."""
|
||||
try:
|
||||
base_full = self._validate_and_resolve_path(base_path)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
if not base_full.exists():
|
||||
return {}
|
||||
|
||||
regex = re.compile(pattern)
|
||||
results: dict[str, list[tuple[int, str]]] = {}
|
||||
|
||||
# Walk directory tree
|
||||
for file_path in base_full.rglob("*"):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Check include filter
|
||||
if include and not _match_include_pattern(file_path.name, include):
|
||||
continue
|
||||
|
||||
# Skip files that are too large
|
||||
if file_path.stat().st_size > self.max_file_size_bytes:
|
||||
continue
|
||||
|
||||
try:
|
||||
content = file_path.read_text()
|
||||
except (UnicodeDecodeError, PermissionError):
|
||||
continue
|
||||
|
||||
# Search content
|
||||
for line_num, line in enumerate(content.splitlines(), 1):
|
||||
if regex.search(line):
|
||||
virtual_path = "/" + str(file_path.relative_to(self.root_path))
|
||||
if virtual_path not in results:
|
||||
results[virtual_path] = []
|
||||
results[virtual_path].append((line_num, line))
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _format_grep_results(
|
||||
results: dict[str, list[tuple[int, str]]],
|
||||
output_mode: str,
|
||||
) -> str:
|
||||
"""Format grep results based on output mode."""
|
||||
if output_mode == "files_with_matches":
|
||||
# Just return file paths
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
if output_mode == "content":
|
||||
# Return file:line:content format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
for line_num, line in results[file_path]:
|
||||
lines.append(f"{file_path}:{line_num}:{line}")
|
||||
return "\n".join(lines)
|
||||
|
||||
if output_mode == "count":
|
||||
# Return file:count format
|
||||
lines = []
|
||||
for file_path in sorted(results.keys()):
|
||||
count = len(results[file_path])
|
||||
lines.append(f"{file_path}:{count}")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Default to files_with_matches
|
||||
return "\n".join(sorted(results.keys()))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FilesystemFileSearchMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,387 @@
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
StateT,
|
||||
)
|
||||
|
||||
|
||||
class Action(TypedDict):
|
||||
"""Represents an action with a name and args."""
|
||||
|
||||
name: str
|
||||
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
|
||||
|
||||
args: dict[str, Any]
|
||||
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
||||
|
||||
|
||||
class ActionRequest(TypedDict):
|
||||
"""Represents an action request with a name, args, and description."""
|
||||
|
||||
name: str
|
||||
"""The name of the action being requested."""
|
||||
|
||||
args: dict[str, Any]
|
||||
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
|
||||
|
||||
description: NotRequired[str]
|
||||
"""The description of the action to be reviewed."""
|
||||
|
||||
|
||||
DecisionType = Literal["approve", "edit", "reject"]
|
||||
|
||||
|
||||
class ReviewConfig(TypedDict):
|
||||
"""Policy for reviewing a HITL request."""
|
||||
|
||||
action_name: str
|
||||
"""Name of the action associated with this review configuration."""
|
||||
|
||||
allowed_decisions: list[DecisionType]
|
||||
"""The decisions that are allowed for this request."""
|
||||
|
||||
args_schema: NotRequired[dict[str, Any]]
|
||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||
|
||||
|
||||
class HITLRequest(TypedDict):
|
||||
"""Request for human feedback on a sequence of actions requested by a model."""
|
||||
|
||||
action_requests: list[ActionRequest]
|
||||
"""A list of agent actions for human review."""
|
||||
|
||||
review_configs: list[ReviewConfig]
|
||||
"""Review configuration for all possible actions."""
|
||||
|
||||
|
||||
class ApproveDecision(TypedDict):
|
||||
"""Response when a human approves the action."""
|
||||
|
||||
type: Literal["approve"]
|
||||
"""The type of response when a human approves the action."""
|
||||
|
||||
|
||||
class EditDecision(TypedDict):
|
||||
"""Response when a human edits the action."""
|
||||
|
||||
type: Literal["edit"]
|
||||
"""The type of response when a human edits the action."""
|
||||
|
||||
edited_action: Action
|
||||
"""Edited action for the agent to perform.
|
||||
|
||||
Ex: for a tool call, a human reviewer can edit the tool name and args.
|
||||
"""
|
||||
|
||||
|
||||
class RejectDecision(TypedDict):
|
||||
"""Response when a human rejects the action."""
|
||||
|
||||
type: Literal["reject"]
|
||||
"""The type of response when a human rejects the action."""
|
||||
|
||||
message: NotRequired[str]
|
||||
"""The message sent to the model explaining why the action was rejected."""
|
||||
|
||||
|
||||
Decision = ApproveDecision | EditDecision | RejectDecision
|
||||
|
||||
|
||||
class HITLResponse(TypedDict):
|
||||
"""Response payload for a HITLRequest."""
|
||||
|
||||
decisions: list[Decision]
|
||||
"""The decisions made by the human."""
|
||||
|
||||
|
||||
class _DescriptionFactory(Protocol):
|
||||
"""Callable that generates a description for a tool call."""
|
||||
|
||||
def __call__(
|
||||
self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> str:
|
||||
"""Generate a description for a tool call."""
|
||||
...
|
||||
|
||||
|
||||
class InterruptOnConfig(TypedDict):
|
||||
"""Configuration for an action requiring human in the loop.
|
||||
|
||||
This is the configuration format used in the `HumanInTheLoopMiddleware.__init__`
|
||||
method.
|
||||
"""
|
||||
|
||||
allowed_decisions: list[DecisionType]
|
||||
"""The decisions that are allowed for this action."""
|
||||
|
||||
description: NotRequired[str | _DescriptionFactory]
|
||||
"""The description attached to the request for human input.
|
||||
|
||||
Can be either:
|
||||
|
||||
- A static string describing the approval request
|
||||
- A callable that dynamically generates the description based on agent state,
|
||||
runtime, and tool call information
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Static string description
|
||||
config = ToolConfig(
|
||||
allowed_decisions=["approve", "reject"],
|
||||
description="Please review this tool execution"
|
||||
)
|
||||
|
||||
# Dynamic callable description
|
||||
def format_tool_description(
|
||||
tool_call: ToolCall,
|
||||
state: AgentState,
|
||||
runtime: Runtime[ContextT]
|
||||
) -> str:
|
||||
import json
|
||||
return (
|
||||
f"Tool: {tool_call['name']}\\n"
|
||||
f"Arguments:\\n{json.dumps(tool_call['args'], indent=2)}"
|
||||
)
|
||||
|
||||
config = InterruptOnConfig(
|
||||
allowed_decisions=["approve", "edit", "reject"],
|
||||
description=format_tool_description
|
||||
)
|
||||
```
|
||||
"""
|
||||
args_schema: NotRequired[dict[str, Any]]
|
||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interrupt_on: dict[str, bool | InterruptOnConfig],
|
||||
*,
|
||||
description_prefix: str = "Tool execution requires approval",
|
||||
) -> None:
|
||||
"""Initialize the human in the loop middleware.
|
||||
|
||||
Args:
|
||||
interrupt_on: Mapping of tool name to allowed actions.
|
||||
|
||||
If a tool doesn't have an entry, it's auto-approved by default.
|
||||
|
||||
* `True` indicates all decisions are allowed: approve, edit, and reject.
|
||||
* `False` indicates that the tool is auto-approved.
|
||||
* `InterruptOnConfig` indicates the specific decisions allowed for this
|
||||
tool.
|
||||
|
||||
The `InterruptOnConfig` can include a `description` field (`str` or
|
||||
`Callable`) for custom formatting of the interrupt description.
|
||||
description_prefix: The prefix to use when constructing action requests.
|
||||
|
||||
This is used to provide context about the tool call and the action being
|
||||
requested.
|
||||
|
||||
Not used if a tool has a `description` in its `InterruptOnConfig`.
|
||||
"""
|
||||
super().__init__()
|
||||
resolved_configs: dict[str, InterruptOnConfig] = {}
|
||||
for tool_name, tool_config in interrupt_on.items():
|
||||
if isinstance(tool_config, bool):
|
||||
if tool_config is True:
|
||||
resolved_configs[tool_name] = InterruptOnConfig(
|
||||
allowed_decisions=["approve", "edit", "reject"]
|
||||
)
|
||||
elif tool_config.get("allowed_decisions"):
|
||||
resolved_configs[tool_name] = tool_config
|
||||
self.interrupt_on = resolved_configs
|
||||
self.description_prefix = description_prefix
|
||||
|
||||
def _create_action_and_config(
|
||||
self,
|
||||
tool_call: ToolCall,
|
||||
config: InterruptOnConfig,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> tuple[ActionRequest, ReviewConfig]:
|
||||
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
|
||||
# Generate description using the description field (str or callable)
|
||||
description_value = config.get("description")
|
||||
if callable(description_value):
|
||||
description = description_value(tool_call, state, runtime)
|
||||
elif description_value is not None:
|
||||
description = description_value
|
||||
else:
|
||||
description = f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
||||
|
||||
# Create ActionRequest with description
|
||||
action_request = ActionRequest(
|
||||
name=tool_name,
|
||||
args=tool_args,
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Create ReviewConfig
|
||||
# eventually can get tool information and populate args_schema from there
|
||||
review_config = ReviewConfig(
|
||||
action_name=tool_name,
|
||||
allowed_decisions=config["allowed_decisions"],
|
||||
)
|
||||
|
||||
return action_request, review_config
|
||||
|
||||
@staticmethod
|
||||
def _process_decision(
|
||||
decision: Decision,
|
||||
tool_call: ToolCall,
|
||||
config: InterruptOnConfig,
|
||||
) -> tuple[ToolCall | None, ToolMessage | None]:
|
||||
"""Process a single decision and return the revised tool call and optional tool message."""
|
||||
allowed_decisions = config["allowed_decisions"]
|
||||
|
||||
if decision["type"] == "approve" and "approve" in allowed_decisions:
|
||||
return tool_call, None
|
||||
if decision["type"] == "edit" and "edit" in allowed_decisions:
|
||||
edited_action = decision["edited_action"]
|
||||
return (
|
||||
ToolCall(
|
||||
type="tool_call",
|
||||
name=edited_action["name"],
|
||||
args=edited_action["args"],
|
||||
id=tool_call["id"],
|
||||
),
|
||||
None,
|
||||
)
|
||||
if decision["type"] == "reject" and "reject" in allowed_decisions:
|
||||
# Create a tool message with the human's text response
|
||||
content = decision.get("message") or (
|
||||
f"User rejected the tool call for `{tool_call['name']}` with id {tool_call['id']}"
|
||||
)
|
||||
tool_message = ToolMessage(
|
||||
content=content,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
status="error",
|
||||
)
|
||||
return tool_call, tool_message
|
||||
msg = (
|
||||
f"Unexpected human decision: {decision}. "
|
||||
f"Decision type '{decision.get('type')}' "
|
||||
f"is not allowed for tool '{tool_call['name']}'. "
|
||||
f"Expected one of {allowed_decisions} based on the tool's configuration."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
Updated message with the revised tool calls.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of human decisions does not match the number of
|
||||
interrupted tool calls.
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Create action requests and review configs for tools that need approval
|
||||
action_requests: list[ActionRequest] = []
|
||||
review_configs: list[ReviewConfig] = []
|
||||
interrupt_indices: list[int] = []
|
||||
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
|
||||
action_request, review_config = self._create_action_and_config(
|
||||
tool_call, config, state, runtime
|
||||
)
|
||||
action_requests.append(action_request)
|
||||
review_configs.append(review_config)
|
||||
interrupt_indices.append(idx)
|
||||
|
||||
# If no interrupts needed, return early
|
||||
if not action_requests:
|
||||
return None
|
||||
|
||||
# Create single HITLRequest with all actions and configs
|
||||
hitl_request = HITLRequest(
|
||||
action_requests=action_requests,
|
||||
review_configs=review_configs,
|
||||
)
|
||||
|
||||
# Send interrupt and get response
|
||||
decisions = interrupt(hitl_request)["decisions"]
|
||||
|
||||
# Validate that the number of decisions matches the number of interrupt tool calls
|
||||
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
|
||||
msg = (
|
||||
f"Number of human decisions ({decisions_len}) does not match "
|
||||
f"number of hanging tool calls ({interrupt_count})."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Process decisions and rebuild tool calls in original order
|
||||
revised_tool_calls: list[ToolCall] = []
|
||||
artificial_tool_messages: list[ToolMessage] = []
|
||||
decision_idx = 0
|
||||
|
||||
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
|
||||
if idx in interrupt_indices:
|
||||
# This was an interrupt tool call - process the decision
|
||||
config = self.interrupt_on[tool_call["name"]]
|
||||
decision = decisions[decision_idx]
|
||||
decision_idx += 1
|
||||
|
||||
revised_tool_call, tool_message = self._process_decision(
|
||||
decision, tool_call, config
|
||||
)
|
||||
if revised_tool_call is not None:
|
||||
revised_tool_calls.append(revised_tool_call)
|
||||
if tool_message:
|
||||
artificial_tool_messages.append(tool_message)
|
||||
else:
|
||||
# This was auto-approved - keep original
|
||||
revised_tool_calls.append(tool_call)
|
||||
|
||||
# Update the AI message to only include approved tool calls
|
||||
last_ai_msg.tool_calls = revised_tool_calls
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
Updated message with the revised tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
@@ -0,0 +1,267 @@
|
||||
"""Call tracking middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
PrivateStateAttr,
|
||||
ResponseT,
|
||||
hook_config,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
class ModelCallLimitState(AgentState[ResponseT]):
|
||||
"""State schema for `ModelCallLimitMiddleware`.
|
||||
|
||||
Extends `AgentState` with model call tracking fields.
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to `Any`.
|
||||
"""
|
||||
|
||||
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
||||
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
||||
|
||||
|
||||
def _build_limit_exceeded_message(
|
||||
thread_count: int,
|
||||
run_count: int,
|
||||
thread_limit: int | None,
|
||||
run_limit: int | None,
|
||||
) -> str:
|
||||
"""Build a message indicating which limits were exceeded.
|
||||
|
||||
Args:
|
||||
thread_count: Current thread model call count.
|
||||
run_count: Current run model call count.
|
||||
thread_limit: Thread model call limit (if set).
|
||||
run_limit: Run model call limit (if set).
|
||||
|
||||
Returns:
|
||||
A formatted message describing which limits were exceeded.
|
||||
"""
|
||||
exceeded_limits = []
|
||||
if thread_limit is not None and thread_count >= thread_limit:
|
||||
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
|
||||
if run_limit is not None and run_count >= run_limit:
|
||||
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
|
||||
|
||||
return f"Model call limits exceeded: {', '.join(exceeded_limits)}"
|
||||
|
||||
|
||||
class ModelCallLimitExceededError(Exception):
|
||||
"""Exception raised when model call limits are exceeded.
|
||||
|
||||
This exception is raised when the configured exit behavior is `'error'` and either
|
||||
the thread or run model call limit has been exceeded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
thread_count: int,
|
||||
run_count: int,
|
||||
thread_limit: int | None,
|
||||
run_limit: int | None,
|
||||
) -> None:
|
||||
"""Initialize the exception with call count information.
|
||||
|
||||
Args:
|
||||
thread_count: Current thread model call count.
|
||||
run_count: Current run model call count.
|
||||
thread_limit: Thread model call limit (if set).
|
||||
run_limit: Run model call limit (if set).
|
||||
"""
|
||||
self.thread_count = thread_count
|
||||
self.run_count = run_count
|
||||
self.thread_limit = thread_limit
|
||||
self.run_limit = run_limit
|
||||
|
||||
msg = _build_limit_exceeded_message(thread_count, run_count, thread_limit, run_limit)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ModelCallLimitMiddleware(
|
||||
AgentMiddleware[ModelCallLimitState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""Tracks model call counts and enforces limits.
|
||||
|
||||
This middleware monitors the number of model calls made during agent execution
|
||||
and can terminate the agent when specified limits are reached. It supports
|
||||
both thread-level and run-level call counting with configurable exit behaviors.
|
||||
|
||||
Thread-level: The middleware tracks the number of model calls and persists
|
||||
call count across multiple runs (invocations) of the agent.
|
||||
|
||||
Run-level: The middleware tracks the number of model calls made during a single
|
||||
run (invocation) of the agent.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.call_tracking import ModelCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
# Create middleware with limits
|
||||
call_tracker = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[call_tracker])
|
||||
|
||||
# Agent will automatically jump to end when limits are exceeded
|
||||
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
|
||||
```
|
||||
"""
|
||||
|
||||
state_schema = ModelCallLimitState # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
thread_limit: int | None = None,
|
||||
run_limit: int | None = None,
|
||||
exit_behavior: Literal["end", "error"] = "end",
|
||||
) -> None:
|
||||
"""Initialize the call tracking middleware.
|
||||
|
||||
Args:
|
||||
thread_limit: Maximum number of model calls allowed per thread.
|
||||
|
||||
`None` means no limit.
|
||||
run_limit: Maximum number of model calls allowed per run.
|
||||
|
||||
`None` means no limit.
|
||||
exit_behavior: What to do when limits are exceeded.
|
||||
|
||||
- `'end'`: Jump to the end of the agent execution and
|
||||
inject an artificial AI message indicating that the limit was
|
||||
exceeded.
|
||||
- `'error'`: Raise a `ModelCallLimitExceededError`
|
||||
|
||||
Raises:
|
||||
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if thread_limit is None and run_limit is None:
|
||||
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
||||
raise ValueError(msg)
|
||||
|
||||
if exit_behavior not in {"end", "error"}:
|
||||
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.thread_limit = thread_limit
|
||||
self.run_limit = run_limit
|
||||
self.exit_behavior = exit_behavior
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
@override
|
||||
def before_model(
|
||||
self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
thread_count = state.get("thread_model_call_count", 0)
|
||||
run_count = state.get("run_model_call_count", 0)
|
||||
|
||||
# Check if any limits will be exceeded after the next call
|
||||
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
|
||||
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
|
||||
|
||||
if thread_limit_exceeded or run_limit_exceeded:
|
||||
if self.exit_behavior == "error":
|
||||
raise ModelCallLimitExceededError(
|
||||
thread_count=thread_count,
|
||||
run_count=run_count,
|
||||
thread_limit=self.thread_limit,
|
||||
run_limit=self.run_limit,
|
||||
)
|
||||
if self.exit_behavior == "end":
|
||||
# Create a message indicating the limit was exceeded
|
||||
limit_message = _build_limit_exceeded_message(
|
||||
thread_count, run_count, self.thread_limit, self.run_limit
|
||||
)
|
||||
limit_ai_message = AIMessage(content=limit_message)
|
||||
|
||||
return {"jump_to": "end", "messages": [limit_ai_message]}
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: ModelCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check model call limits before making a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing call counts.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
If limits are exceeded and exit_behavior is `'end'`, returns
|
||||
a `Command` to jump to the end with a limit exceeded message. Otherwise
|
||||
returns `None`.
|
||||
|
||||
Raises:
|
||||
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
@override
|
||||
def after_model(
|
||||
self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ModelCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment model call counts after a model call.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented call counts.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Model fallback middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class ModelFallbackMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Automatic fallback to alternative models on errors.
|
||||
|
||||
Retries failed model calls with alternative models in sequence until
|
||||
success or all models exhausted. Primary model specified in `create_agent`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
fallback = ModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini", # Try first on error
|
||||
"anthropic:claude-sonnet-4-5-20250929", # Then this
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o", # Primary model
|
||||
middleware=[fallback],
|
||||
)
|
||||
|
||||
# If primary fails: tries gpt-4o-mini, then claude-sonnet-4-5-20250929
|
||||
result = await agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
first_model: str | BaseChatModel,
|
||||
*additional_models: str | BaseChatModel,
|
||||
) -> None:
|
||||
"""Initialize model fallback middleware.
|
||||
|
||||
Args:
|
||||
first_model: First fallback model (string name or instance).
|
||||
*additional_models: Additional fallbacks in order.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Initialize all fallback models
|
||||
all_models = (first_model, *additional_models)
|
||||
self.models: list[BaseChatModel] = []
|
||||
for model in all_models:
|
||||
if isinstance(model, str):
|
||||
self.models.append(init_chat_model(model))
|
||||
else:
|
||||
self.models.append(model)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Try fallback models in sequence on errors.
|
||||
|
||||
Args:
|
||||
request: Initial model request.
|
||||
handler: Callback to execute the model.
|
||||
|
||||
Returns:
|
||||
AIMessage from successful model call.
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail, re-raises last exception.
|
||||
"""
|
||||
# Try primary model first
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Try fallback models in sequence on errors (async version).
|
||||
|
||||
Args:
|
||||
request: Initial model request.
|
||||
handler: Async callback to execute the model.
|
||||
|
||||
Returns:
|
||||
AIMessage from successful model call.
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail, re-raises last exception.
|
||||
"""
|
||||
# Try primary model first
|
||||
last_exception: Exception
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
try:
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
raise last_exception
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Model retry middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware._retry import (
|
||||
OnFailure,
|
||||
RetryOn,
|
||||
calculate_delay,
|
||||
should_retry_exception,
|
||||
validate_retry_params,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
class ModelRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Middleware that automatically retries failed model calls with configurable backoff.
|
||||
|
||||
Supports retrying on specific exceptions and exponential backoff.
|
||||
|
||||
Examples:
|
||||
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||||
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelRetryMiddleware
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ModelRetryMiddleware()])
|
||||
```
|
||||
|
||||
!!! example "Retry specific exceptions only"
|
||||
|
||||
```python
|
||||
from anthropic import RateLimitError
|
||||
from openai import APITimeoutError
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(APITimeoutError, RateLimitError),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom exception filtering"
|
||||
|
||||
```python
|
||||
from anthropic import APIStatusError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, APIStatusError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom error handling"
|
||||
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Model temporarily unavailable. Please try again later."
|
||||
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=4,
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Constant backoff (no exponential growth)"
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Raise exception on failure"
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="error", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 2,
|
||||
retry_on: RetryOn = (Exception,),
|
||||
on_failure: OnFailure = "continue",
|
||||
backoff_factor: float = 2.0,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
) -> None:
|
||||
"""Initialize `ModelRetryMiddleware`.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
|
||||
Must be `>= 0`.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Options:
|
||||
|
||||
- `'continue'`: Return an `AIMessage` with error details,
|
||||
allowing the agent to continue with an error response.
|
||||
- `'error'`: Re-raise the exception, stopping agent execution.
|
||||
- **Custom callable:** Function that takes the exception and returns a
|
||||
string for the `AIMessage` content, allowing custom error
|
||||
formatting.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||||
seconds.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_retries < 0` or delays are negative.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Validate parameters
|
||||
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
|
||||
|
||||
self.max_retries = max_retries
|
||||
self.tools = [] # No additional tools registered by this middleware
|
||||
self.retry_on = retry_on
|
||||
self.on_failure = on_failure
|
||||
self.backoff_factor = backoff_factor
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
|
||||
@staticmethod
|
||||
def _format_failure_message(exc: Exception, attempts_made: int) -> AIMessage:
|
||||
"""Format the failure message when retries are exhausted.
|
||||
|
||||
Args:
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
`AIMessage` with formatted error message.
|
||||
"""
|
||||
exc_type = type(exc).__name__
|
||||
exc_msg = str(exc)
|
||||
attempt_word = "attempt" if attempts_made == 1 else "attempts"
|
||||
content = (
|
||||
f"Model call failed after {attempts_made} {attempt_word} with {exc_type}: {exc_msg}"
|
||||
)
|
||||
return AIMessage(content=content)
|
||||
|
||||
def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse[ResponseT]:
|
||||
"""Handle failure when all retries are exhausted.
|
||||
|
||||
Args:
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
`ModelResponse` with error details.
|
||||
|
||||
Raises:
|
||||
Exception: If `on_failure` is `'error'`, re-raises the exception.
|
||||
"""
|
||||
if self.on_failure == "error":
|
||||
raise exc
|
||||
|
||||
if callable(self.on_failure):
|
||||
content = self.on_failure(exc)
|
||||
ai_msg = AIMessage(content=content)
|
||||
else:
|
||||
ai_msg = self._format_failure_message(exc, attempts_made)
|
||||
|
||||
return ModelResponse(result=[ai_msg])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Intercept model execution and retry on failure.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Callable to execute the model (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the retry loop completes without returning. (This should not happen.)
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Intercept and control async model execution with retry logic.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Async callable to execute the model and returns `ModelResponse`.
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the retry loop completes without returning. (This should not happen.)
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
376
venv/Lib/site-packages/langchain/agents/middleware/pii.py
Normal file
376
venv/Lib/site-packages/langchain/agents/middleware/pii.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""PII detection and handling middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.middleware._redaction import (
|
||||
PIIDetectionError,
|
||||
PIIMatch,
|
||||
RedactionRule,
|
||||
ResolvedRedactionRule,
|
||||
apply_strategy,
|
||||
detect_credit_card,
|
||||
detect_email,
|
||||
detect_ip,
|
||||
detect_mac_address,
|
||||
detect_url,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
hook_config,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
class PIIMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Detect and handle Personally Identifiable Information (PII) in conversations.
|
||||
|
||||
This middleware detects common PII types and applies configurable strategies
|
||||
to handle them. It can detect emails, credit cards, IP addresses, MAC addresses, and
|
||||
URLs in both user input and agent output.
|
||||
|
||||
Built-in PII types:
|
||||
|
||||
- `email`: Email addresses
|
||||
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
|
||||
- `ip`: IP addresses (validated with stdlib)
|
||||
- `mac_address`: MAC addresses
|
||||
- `url`: URLs (both `http`/`https` and bare URLs)
|
||||
|
||||
Strategies:
|
||||
|
||||
- `block`: Raise an exception when PII is detected
|
||||
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
|
||||
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
|
||||
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
|
||||
|
||||
Strategy Selection Guide:
|
||||
|
||||
| Strategy | Preserves Identity? | Best For |
|
||||
| -------- | ------------------- | --------------------------------------- |
|
||||
| `block` | N/A | Avoid PII completely |
|
||||
| `redact` | No | General compliance, log sanitization |
|
||||
| `mask` | No | Human readability, customer service UIs |
|
||||
| `hash` | Yes (pseudonymous) | Analytics, debugging |
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware import PIIMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
# Redact all emails in user input
|
||||
agent = create_agent(
|
||||
"openai:gpt-5",
|
||||
middleware=[
|
||||
PIIMiddleware("email", strategy="redact"),
|
||||
],
|
||||
)
|
||||
|
||||
# Use different strategies for different PII types
|
||||
agent = create_agent(
|
||||
"openai:gpt-4o",
|
||||
middleware=[
|
||||
PIIMiddleware("credit_card", strategy="mask"),
|
||||
PIIMiddleware("url", strategy="redact"),
|
||||
PIIMiddleware("ip", strategy="hash"),
|
||||
],
|
||||
)
|
||||
|
||||
# Custom PII type with regex
|
||||
agent = create_agent(
|
||||
"openai:gpt-5",
|
||||
middleware=[
|
||||
PIIMiddleware("api_key", detector=r"sk-[a-zA-Z0-9]{32}", strategy="block"),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# From a typing point of view, the literals are covered by 'str'.
|
||||
# Nonetheless, we escape PYI051 to keep hints and autocompletion for the caller.
|
||||
pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
|
||||
*,
|
||||
strategy: Literal["block", "redact", "mask", "hash"] = "redact",
|
||||
detector: Callable[[str], list[PIIMatch]] | str | None = None,
|
||||
apply_to_input: bool = True,
|
||||
apply_to_output: bool = False,
|
||||
apply_to_tool_results: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the PII detection middleware.
|
||||
|
||||
Args:
|
||||
pii_type: Type of PII to detect.
|
||||
|
||||
Can be a built-in type (`email`, `credit_card`, `ip`, `mac_address`,
|
||||
`url`) or a custom type name.
|
||||
strategy: How to handle detected PII.
|
||||
|
||||
Options:
|
||||
|
||||
* `block`: Raise `PIIDetectionError` when PII is detected
|
||||
* `redact`: Replace with `[REDACTED_TYPE]` placeholders
|
||||
* `mask`: Partially mask PII (show last few characters)
|
||||
* `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
|
||||
|
||||
detector: Custom detector function or regex pattern.
|
||||
|
||||
* If `Callable`: Function that takes content string and returns
|
||||
list of `PIIMatch` objects
|
||||
* If `str`: Regex pattern to match PII
|
||||
* If `None`: Uses built-in detector for the `pii_type`
|
||||
apply_to_input: Whether to check user messages before model call.
|
||||
apply_to_output: Whether to check AI messages after model call.
|
||||
apply_to_tool_results: Whether to check tool result messages after tool execution.
|
||||
|
||||
Raises:
|
||||
ValueError: If `pii_type` is not built-in and no detector is provided.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.apply_to_input = apply_to_input
|
||||
self.apply_to_output = apply_to_output
|
||||
self.apply_to_tool_results = apply_to_tool_results
|
||||
|
||||
self._resolved_rule: ResolvedRedactionRule = RedactionRule(
|
||||
pii_type=pii_type,
|
||||
strategy=strategy,
|
||||
detector=detector,
|
||||
).resolve()
|
||||
self.pii_type = self._resolved_rule.pii_type
|
||||
self.strategy = self._resolved_rule.strategy
|
||||
self.detector = self._resolved_rule.detector
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the middleware."""
|
||||
return f"{self.__class__.__name__}[{self.pii_type}]"
|
||||
|
||||
def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
|
||||
"""Apply the configured redaction rule to the provided content."""
|
||||
matches = self.detector(content)
|
||||
if not matches:
|
||||
return content, []
|
||||
sanitized = apply_strategy(content, matches, self.strategy)
|
||||
return sanitized, matches
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
if not self.apply_to_input and not self.apply_to_tool_results:
|
||||
return None
|
||||
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_messages = list(messages)
|
||||
any_modified = False
|
||||
|
||||
# Check user input if enabled
|
||||
if self.apply_to_input:
|
||||
# Get last user message
|
||||
last_user_msg = None
|
||||
last_user_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], HumanMessage):
|
||||
last_user_msg = messages[i]
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx is not None and last_user_msg and last_user_msg.content:
|
||||
# Detect PII in message content
|
||||
content = str(last_user_msg.content)
|
||||
new_content, matches = self._process_content(content)
|
||||
|
||||
if matches:
|
||||
updated_message: AnyMessage = HumanMessage(
|
||||
content=new_content,
|
||||
id=last_user_msg.id,
|
||||
name=last_user_msg.name,
|
||||
)
|
||||
|
||||
new_messages[last_user_idx] = updated_message
|
||||
any_modified = True
|
||||
|
||||
# Check tool results if enabled
|
||||
if self.apply_to_tool_results:
|
||||
# Find the last AIMessage, then process all `ToolMessage` objects after it
|
||||
last_ai_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], AIMessage):
|
||||
last_ai_idx = i
|
||||
break
|
||||
|
||||
if last_ai_idx is not None:
|
||||
# Get all tool messages after the last AI message
|
||||
for i in range(last_ai_idx + 1, len(messages)):
|
||||
msg = messages[i]
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_msg = msg
|
||||
if not tool_msg.content:
|
||||
continue
|
||||
|
||||
content = str(tool_msg.content)
|
||||
new_content, matches = self._process_content(content)
|
||||
|
||||
if not matches:
|
||||
continue
|
||||
|
||||
# Create updated tool message
|
||||
updated_message = ToolMessage(
|
||||
content=new_content,
|
||||
id=tool_msg.id,
|
||||
name=tool_msg.name,
|
||||
tool_call_id=tool_msg.tool_call_id,
|
||||
)
|
||||
|
||||
new_messages[i] = updated_message
|
||||
any_modified = True
|
||||
|
||||
if any_modified:
|
||||
return {"messages": new_messages}
|
||||
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check user messages and tool results for PII before model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or `None` if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
@override
|
||||
def after_model(
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
if not self.apply_to_output:
|
||||
return None
|
||||
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Get last AI message
|
||||
last_ai_msg = None
|
||||
last_ai_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if isinstance(msg, AIMessage):
|
||||
last_ai_msg = msg
|
||||
last_ai_idx = i
|
||||
break
|
||||
|
||||
if last_ai_idx is None or not last_ai_msg or not last_ai_msg.content:
|
||||
return None
|
||||
|
||||
# Detect PII in message content
|
||||
content = str(last_ai_msg.content)
|
||||
new_content, matches = self._process_content(content)
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Create updated message
|
||||
updated_message = AIMessage(
|
||||
content=new_content,
|
||||
id=last_ai_msg.id,
|
||||
name=last_ai_msg.name,
|
||||
tool_calls=last_ai_msg.tool_calls,
|
||||
)
|
||||
|
||||
# Return updated messages
|
||||
new_messages = list(messages)
|
||||
new_messages[last_ai_idx] = updated_message
|
||||
|
||||
return {"messages": new_messages}
|
||||
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async check AI messages for PII after model invocation.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
Updated state with PII handled according to strategy, or None if no PII
|
||||
detected.
|
||||
|
||||
Raises:
|
||||
PIIDetectionError: If PII is detected and strategy is `'block'`.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PIIDetectionError",
|
||||
"PIIMatch",
|
||||
"PIIMiddleware",
|
||||
"detect_credit_card",
|
||||
"detect_email",
|
||||
"detect_ip",
|
||||
"detect_mac_address",
|
||||
"detect_url",
|
||||
]
|
||||
882
venv/Lib/site-packages/langchain/agents/middleware/shell_tool.py
Normal file
882
venv/Lib/site-packages/langchain/agents/middleware/shell_tool.py
Normal file
@@ -0,0 +1,882 @@
|
||||
"""Middleware that exposes a persistent shell tool to agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.tools.base import ToolException
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain.agents.middleware._execution import (
|
||||
SHELL_TEMP_PREFIX,
|
||||
BaseExecutionPolicy,
|
||||
CodexSandboxExecutionPolicy,
|
||||
DockerExecutionPolicy,
|
||||
HostExecutionPolicy,
|
||||
)
|
||||
from langchain.agents.middleware._redaction import (
|
||||
PIIDetectionError,
|
||||
PIIMatch,
|
||||
RedactionRule,
|
||||
ResolvedRedactionRule,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
PrivateStateAttr,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||||
|
||||
DEFAULT_TOOL_DESCRIPTION = (
|
||||
"Execute a shell command inside a persistent session. Before running a command, "
|
||||
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
|
||||
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
|
||||
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
|
||||
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
|
||||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||||
"running commands will be terminated once their configured timeout elapses."
|
||||
)
|
||||
SHELL_TOOL_NAME = "shell"
|
||||
|
||||
|
||||
def _cleanup_resources(
|
||||
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
|
||||
) -> None:
|
||||
with contextlib.suppress(Exception):
|
||||
session.stop(timeout)
|
||||
if tempdir is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
tempdir.cleanup()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionResources:
|
||||
"""Container for per-run shell resources."""
|
||||
|
||||
session: ShellSession
|
||||
tempdir: tempfile.TemporaryDirectory[str] | None
|
||||
policy: BaseExecutionPolicy
|
||||
finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.finalizer = weakref.finalize(
|
||||
self,
|
||||
_cleanup_resources,
|
||||
self.session,
|
||||
self.tempdir,
|
||||
self.policy.termination_timeout,
|
||||
)
|
||||
|
||||
|
||||
class ShellToolState(AgentState[ResponseT]):
|
||||
"""Agent state extension for tracking shell session resources.
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to `Any`.
|
||||
"""
|
||||
|
||||
shell_session_resources: NotRequired[
|
||||
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandExecutionResult:
|
||||
"""Structured result from command execution."""
|
||||
|
||||
output: str
|
||||
exit_code: int | None
|
||||
timed_out: bool
|
||||
truncated_by_lines: bool
|
||||
truncated_by_bytes: bool
|
||||
total_lines: int
|
||||
total_bytes: int
|
||||
|
||||
|
||||
class ShellSession:
|
||||
"""Persistent shell session that supports sequential command execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
policy: BaseExecutionPolicy,
|
||||
command: tuple[str, ...],
|
||||
environment: Mapping[str, str],
|
||||
) -> None:
|
||||
self._workspace = workspace
|
||||
self._policy = policy
|
||||
self._command = command
|
||||
self._environment = dict(environment)
|
||||
self._process: subprocess.Popen[str] | None = None
|
||||
self._stdin: Any = None
|
||||
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._stdout_thread: threading.Thread | None = None
|
||||
self._stderr_thread: threading.Thread | None = None
|
||||
self._terminated = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the shell subprocess and reader threads.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the shell session pipes cannot be initialized.
|
||||
"""
|
||||
if self._process and self._process.poll() is None:
|
||||
return
|
||||
|
||||
self._process = self._policy.spawn(
|
||||
workspace=self._workspace,
|
||||
env=self._environment,
|
||||
command=self._command,
|
||||
)
|
||||
if (
|
||||
self._process.stdin is None
|
||||
or self._process.stdout is None
|
||||
or self._process.stderr is None
|
||||
):
|
||||
msg = "Failed to initialize shell session pipes."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
self._stdin = self._process.stdin
|
||||
self._terminated = False
|
||||
self._queue = queue.Queue()
|
||||
|
||||
self._stdout_thread = threading.Thread(
|
||||
target=self._enqueue_stream,
|
||||
args=(self._process.stdout, "stdout"),
|
||||
daemon=True,
|
||||
)
|
||||
self._stderr_thread = threading.Thread(
|
||||
target=self._enqueue_stream,
|
||||
args=(self._process.stderr, "stderr"),
|
||||
daemon=True,
|
||||
)
|
||||
self._stdout_thread.start()
|
||||
self._stderr_thread.start()
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart the shell process."""
|
||||
self.stop(self._policy.termination_timeout)
|
||||
self.start()
|
||||
|
||||
def stop(self, timeout: float) -> None:
|
||||
"""Stop the shell subprocess."""
|
||||
if not self._process:
|
||||
return
|
||||
|
||||
if self._process.poll() is None and not self._terminated:
|
||||
try:
|
||||
self._stdin.write("exit\n")
|
||||
self._stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
LOGGER.debug(
|
||||
"Failed to write exit command; terminating shell session.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if self._process.wait(timeout=timeout) is None:
|
||||
self._kill_process()
|
||||
except subprocess.TimeoutExpired:
|
||||
self._kill_process()
|
||||
finally:
|
||||
self._terminated = True
|
||||
with contextlib.suppress(Exception):
|
||||
self._stdin.close()
|
||||
self._process = None
|
||||
|
||||
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
|
||||
"""Execute a command in the persistent shell."""
|
||||
if not self._process or self._process.poll() is not None:
|
||||
msg = "Shell session is not running."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
with self._lock:
|
||||
self._drain_queue()
|
||||
payload = command if command.endswith("\n") else f"{command}\n"
|
||||
try:
|
||||
self._stdin.write(payload)
|
||||
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
||||
self._stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
# The shell exited before we could write the marker command.
|
||||
# This happens when commands like 'exit 1' terminate the shell.
|
||||
return self._collect_output_after_exit(deadline)
|
||||
|
||||
return self._collect_output(marker, deadline, timeout)
|
||||
|
||||
def _collect_output(
|
||||
self,
|
||||
marker: str,
|
||||
deadline: float,
|
||||
timeout: float,
|
||||
) -> CommandExecutionResult:
|
||||
collected: list[str] = []
|
||||
total_lines = 0
|
||||
total_bytes = 0
|
||||
truncated_by_lines = False
|
||||
truncated_by_bytes = False
|
||||
exit_code: int | None = None
|
||||
timed_out = False
|
||||
|
||||
while True:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
timed_out = True
|
||||
break
|
||||
try:
|
||||
source, data = self._queue.get(timeout=remaining)
|
||||
except queue.Empty:
|
||||
timed_out = True
|
||||
break
|
||||
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
if source == "stdout" and data.startswith(marker):
|
||||
_, _, status = data.partition(" ")
|
||||
exit_code = self._safe_int(status.strip())
|
||||
# Drain any remaining stderr that may have arrived concurrently.
|
||||
# The stderr reader thread runs independently, so output might
|
||||
# still be in flight when the stdout marker arrives.
|
||||
self._drain_remaining_stderr(collected, deadline)
|
||||
break
|
||||
|
||||
total_lines += 1
|
||||
encoded = data.encode("utf-8", "replace")
|
||||
total_bytes += len(encoded)
|
||||
|
||||
if total_lines > self._policy.max_output_lines:
|
||||
truncated_by_lines = True
|
||||
continue
|
||||
|
||||
if (
|
||||
self._policy.max_output_bytes is not None
|
||||
and total_bytes > self._policy.max_output_bytes
|
||||
):
|
||||
truncated_by_bytes = True
|
||||
continue
|
||||
|
||||
if source == "stderr":
|
||||
stripped = data.rstrip("\n")
|
||||
collected.append(f"[stderr] {stripped}")
|
||||
if data.endswith("\n"):
|
||||
collected.append("\n")
|
||||
else:
|
||||
collected.append(data)
|
||||
|
||||
if timed_out:
|
||||
LOGGER.warning(
|
||||
"Command timed out after %.2f seconds; restarting shell session.",
|
||||
timeout,
|
||||
)
|
||||
self.restart()
|
||||
return CommandExecutionResult(
|
||||
output="",
|
||||
exit_code=None,
|
||||
timed_out=True,
|
||||
truncated_by_lines=truncated_by_lines,
|
||||
truncated_by_bytes=truncated_by_bytes,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
|
||||
output = "".join(collected)
|
||||
return CommandExecutionResult(
|
||||
output=output,
|
||||
exit_code=exit_code,
|
||||
timed_out=False,
|
||||
truncated_by_lines=truncated_by_lines,
|
||||
truncated_by_bytes=truncated_by_bytes,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
|
||||
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
|
||||
"""Collect output after the shell exited unexpectedly.
|
||||
|
||||
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
|
||||
shell process terminated (e.g., due to an 'exit' command).
|
||||
|
||||
Args:
|
||||
deadline: Absolute time by which collection must complete.
|
||||
|
||||
Returns:
|
||||
`CommandExecutionResult` with collected output and the process exit code.
|
||||
"""
|
||||
collected: list[str] = []
|
||||
total_lines = 0
|
||||
total_bytes = 0
|
||||
truncated_by_lines = False
|
||||
truncated_by_bytes = False
|
||||
|
||||
# Give reader threads a brief moment to enqueue any remaining output.
|
||||
drain_timeout = 0.1
|
||||
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
||||
|
||||
while True:
|
||||
remaining = drain_deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
try:
|
||||
source, data = self._queue.get(timeout=remaining)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
if data is None:
|
||||
# EOF marker from a reader thread; continue draining.
|
||||
continue
|
||||
|
||||
total_lines += 1
|
||||
encoded = data.encode("utf-8", "replace")
|
||||
total_bytes += len(encoded)
|
||||
|
||||
if total_lines > self._policy.max_output_lines:
|
||||
truncated_by_lines = True
|
||||
continue
|
||||
|
||||
if (
|
||||
self._policy.max_output_bytes is not None
|
||||
and total_bytes > self._policy.max_output_bytes
|
||||
):
|
||||
truncated_by_bytes = True
|
||||
continue
|
||||
|
||||
if source == "stderr":
|
||||
stripped = data.rstrip("\n")
|
||||
collected.append(f"[stderr] {stripped}")
|
||||
if data.endswith("\n"):
|
||||
collected.append("\n")
|
||||
else:
|
||||
collected.append(data)
|
||||
|
||||
# Get exit code from the terminated process.
|
||||
exit_code: int | None = None
|
||||
if self._process:
|
||||
exit_code = self._process.poll()
|
||||
|
||||
output = "".join(collected)
|
||||
return CommandExecutionResult(
|
||||
output=output,
|
||||
exit_code=exit_code,
|
||||
timed_out=False,
|
||||
truncated_by_lines=truncated_by_lines,
|
||||
truncated_by_bytes=truncated_by_bytes,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
|
||||
def _kill_process(self) -> None:
|
||||
if not self._process:
|
||||
return
|
||||
|
||||
if hasattr(os, "killpg"):
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
|
||||
else: # pragma: no cover
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
self._process.kill()
|
||||
|
||||
def _enqueue_stream(self, stream: Any, label: str) -> None:
|
||||
for line in iter(stream.readline, ""):
|
||||
self._queue.put((label, line))
|
||||
self._queue.put((label, None))
|
||||
|
||||
def _drain_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _drain_remaining_stderr(
|
||||
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
|
||||
) -> None:
|
||||
"""Drain any stderr output that arrived concurrently with the done marker.
|
||||
|
||||
The stdout and stderr reader threads run independently. When a command writes to
|
||||
stderr just before exiting, the stderr output may still be in transit when the
|
||||
done marker arrives on stdout. This method briefly polls the queue to capture
|
||||
such output.
|
||||
|
||||
Args:
|
||||
collected: The list to append collected stderr lines to.
|
||||
deadline: The original command deadline (used as an upper bound).
|
||||
drain_timeout: Maximum time to wait for additional stderr output.
|
||||
"""
|
||||
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
||||
while True:
|
||||
remaining = drain_deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
try:
|
||||
source, data = self._queue.get(timeout=remaining)
|
||||
except queue.Empty:
|
||||
break
|
||||
if data is None or source != "stderr":
|
||||
continue
|
||||
stripped = data.rstrip("\n")
|
||||
collected.append(f"[stderr] {stripped}")
|
||||
if data.endswith("\n"):
|
||||
collected.append("\n")
|
||||
|
||||
@staticmethod
|
||||
def _safe_int(value: str) -> int | None:
|
||||
with contextlib.suppress(ValueError):
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
class _ShellToolInput(BaseModel):
|
||||
"""Input schema for the persistent shell tool."""
|
||||
|
||||
command: str | None = None
|
||||
"""The shell command to execute."""
|
||||
|
||||
restart: bool | None = None
|
||||
"""Whether to restart the shell session."""
|
||||
|
||||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||||
"""The runtime for the shell tool.
|
||||
|
||||
Included as a workaround at the moment bc args_schema doesn't work with
|
||||
injected ToolRuntime.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self) -> _ShellToolInput:
|
||||
if self.command is None and not self.restart:
|
||||
msg = "Shell tool requires either 'command' or 'restart'."
|
||||
raise ValueError(msg)
|
||||
if self.command is not None and self.restart:
|
||||
msg = "Specify only one of 'command' or 'restart'."
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, ResponseT]):
|
||||
"""Middleware that registers a persistent shell tool for agents.
|
||||
|
||||
The middleware exposes a single long-lived shell session. Use the execution policy
|
||||
to match your deployment's security posture:
|
||||
|
||||
* `HostExecutionPolicy` – full host access; best for trusted environments where the
|
||||
agent already runs inside a container or VM that provides isolation.
|
||||
* `CodexSandboxExecutionPolicy` – reuses the Codex CLI sandbox for additional
|
||||
syscall/filesystem restrictions when the CLI is available.
|
||||
* `DockerExecutionPolicy` – launches a separate Docker container for each agent run,
|
||||
providing harder isolation, optional read-only root filesystems, and user
|
||||
remapping.
|
||||
|
||||
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
|
||||
"""
|
||||
|
||||
state_schema = ShellToolState # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_root: str | Path | None = None,
|
||||
*,
|
||||
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||||
execution_policy: BaseExecutionPolicy | None = None,
|
||||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||||
tool_description: str | None = None,
|
||||
tool_name: str = SHELL_TOOL_NAME,
|
||||
shell_command: Sequence[str] | str | None = None,
|
||||
env: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize an instance of `ShellToolMiddleware`.
|
||||
|
||||
Args:
|
||||
workspace_root: Base directory for the shell session.
|
||||
|
||||
If omitted, a temporary directory is created when the agent starts and
|
||||
removed when it ends.
|
||||
startup_commands: Optional commands executed sequentially after the session
|
||||
starts.
|
||||
shutdown_commands: Optional commands executed before the session shuts down.
|
||||
execution_policy: Execution policy controlling timeouts, output limits, and
|
||||
resource configuration.
|
||||
|
||||
Defaults to `HostExecutionPolicy` for native execution.
|
||||
redaction_rules: Optional redaction rules to sanitize command output before
|
||||
returning it to the model.
|
||||
|
||||
!!! warning
|
||||
Redaction rules are applied post execution and do not prevent
|
||||
exfiltration of secrets or sensitive data when using
|
||||
`HostExecutionPolicy`.
|
||||
|
||||
tool_description: Optional override for the registered shell tool
|
||||
description.
|
||||
tool_name: Name for the registered shell tool.
|
||||
|
||||
Defaults to `"shell"`.
|
||||
shell_command: Optional shell executable (string) or argument sequence used
|
||||
to launch the persistent session.
|
||||
|
||||
Defaults to an implementation-defined bash command.
|
||||
env: Optional environment variables to supply to the shell session.
|
||||
|
||||
Values are coerced to strings before command execution. If omitted, the
|
||||
session inherits the parent process environment.
|
||||
"""
|
||||
super().__init__()
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
self._tool_name = tool_name
|
||||
self._shell_command = self._normalize_shell_command(shell_command)
|
||||
self._environment = self._normalize_env(env)
|
||||
if execution_policy is not None:
|
||||
self._execution_policy = execution_policy
|
||||
else:
|
||||
self._execution_policy = HostExecutionPolicy()
|
||||
rules = redaction_rules or ()
|
||||
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
|
||||
rule.resolve() for rule in rules
|
||||
)
|
||||
self._startup_commands = self._normalize_commands(startup_commands)
|
||||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||||
|
||||
# Create a proper tool that executes directly (no interception needed)
|
||||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||||
|
||||
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||||
def shell_tool(
|
||||
*,
|
||||
runtime: ToolRuntime[None, ShellToolState],
|
||||
command: str | None = None,
|
||||
restart: bool = False,
|
||||
) -> ToolMessage | str:
|
||||
resources = self._get_or_create_resources(runtime.state)
|
||||
return self._run_shell_tool(
|
||||
resources,
|
||||
{"command": command, "restart": restart},
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
self._shell_tool = shell_tool
|
||||
self.tools = [self._shell_tool]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_commands(
|
||||
commands: tuple[str, ...] | list[str] | str | None,
|
||||
) -> tuple[str, ...]:
|
||||
if commands is None:
|
||||
return ()
|
||||
if isinstance(commands, str):
|
||||
return (commands,)
|
||||
return tuple(commands)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_shell_command(
|
||||
shell_command: Sequence[str] | str | None,
|
||||
) -> tuple[str, ...]:
|
||||
if shell_command is None:
|
||||
return ("/bin/bash",)
|
||||
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
|
||||
if not normalized:
|
||||
msg = "Shell command must contain at least one argument."
|
||||
raise ValueError(msg)
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
|
||||
if env is None:
|
||||
return None
|
||||
normalized: dict[str, str] = {}
|
||||
for key, value in env.items():
|
||||
if not isinstance(key, str):
|
||||
msg = "Environment variable names must be strings." # type: ignore[unreachable]
|
||||
raise TypeError(msg)
|
||||
normalized[key] = str(value)
|
||||
return normalized
|
||||
|
||||
@override
|
||||
def before_agent(
|
||||
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Start the shell session and run startup commands.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
Shell session resources to be stored in the agent state.
|
||||
"""
|
||||
resources = self._get_or_create_resources(state)
|
||||
return {"shell_session_resources": resources}
|
||||
|
||||
async def abefore_agent(
|
||||
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async start the shell session and run startup commands.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
Shell session resources to be stored in the agent state.
|
||||
"""
|
||||
return await run_in_executor(None, self.before_agent, state, runtime)
|
||||
|
||||
@override
|
||||
def after_agent(self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]) -> None:
|
||||
"""Run shutdown commands and release resources when an agent completes."""
|
||||
resources = state.get("shell_session_resources")
|
||||
if not isinstance(resources, _SessionResources):
|
||||
# Resources were never created, nothing to clean up
|
||||
return
|
||||
try:
|
||||
self._run_shutdown_commands(resources.session)
|
||||
finally:
|
||||
resources.finalizer()
|
||||
|
||||
async def aafter_agent(
|
||||
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> None:
|
||||
"""Async run shutdown commands and release resources when an agent completes."""
|
||||
return self.after_agent(state, runtime)
|
||||
|
||||
def _get_or_create_resources(self, state: ShellToolState[ResponseT]) -> _SessionResources:
|
||||
"""Get existing resources from state or create new ones if they don't exist.
|
||||
|
||||
This method enables resumability by checking if resources already exist in the state
|
||||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||||
|
||||
Args:
|
||||
state: The agent state which may contain shell session resources.
|
||||
|
||||
Returns:
|
||||
Session resources, either retrieved from state or newly created.
|
||||
"""
|
||||
resources = state.get("shell_session_resources")
|
||||
if isinstance(resources, _SessionResources):
|
||||
return resources
|
||||
|
||||
new_resources = self._create_resources()
|
||||
# Cast needed to make state dict-like for mutation
|
||||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||||
return new_resources
|
||||
|
||||
def _create_resources(self) -> _SessionResources:
|
||||
workspace = self._workspace_root
|
||||
tempdir: tempfile.TemporaryDirectory[str] | None = None
|
||||
if workspace is None:
|
||||
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
|
||||
workspace_path = Path(tempdir.name)
|
||||
else:
|
||||
workspace_path = workspace
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
session = ShellSession(
|
||||
workspace_path,
|
||||
self._execution_policy,
|
||||
self._shell_command,
|
||||
self._environment or {},
|
||||
)
|
||||
try:
|
||||
session.start()
|
||||
LOGGER.info("Started shell session in %s", workspace_path)
|
||||
self._run_startup_commands(session)
|
||||
except BaseException:
|
||||
LOGGER.exception("Starting shell session failed; cleaning up resources.")
|
||||
session.stop(self._execution_policy.termination_timeout)
|
||||
if tempdir is not None:
|
||||
tempdir.cleanup()
|
||||
raise
|
||||
|
||||
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
|
||||
|
||||
def _run_startup_commands(self, session: ShellSession) -> None:
|
||||
if not self._startup_commands:
|
||||
return
|
||||
for command in self._startup_commands:
|
||||
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
||||
if result.timed_out or (result.exit_code not in {0, None}):
|
||||
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def _run_shutdown_commands(self, session: ShellSession) -> None:
|
||||
if not self._shutdown_commands:
|
||||
return
|
||||
for command in self._shutdown_commands:
|
||||
try:
|
||||
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||||
if result.timed_out:
|
||||
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
||||
elif result.exit_code not in {0, None}:
|
||||
LOGGER.warning(
|
||||
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
||||
)
|
||||
except (RuntimeError, ToolException, OSError) as exc:
|
||||
LOGGER.warning(
|
||||
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
|
||||
)
|
||||
|
||||
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
|
||||
"""Apply configured redaction rules to command output."""
|
||||
matches_by_type: dict[str, list[PIIMatch]] = {}
|
||||
updated = content
|
||||
for rule in self._redaction_rules:
|
||||
updated, matches = rule.apply(updated)
|
||||
if matches:
|
||||
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
|
||||
return updated, matches_by_type
|
||||
|
||||
def _run_shell_tool(
|
||||
self,
|
||||
resources: _SessionResources,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
tool_call_id: str | None,
|
||||
) -> Any:
|
||||
session = resources.session
|
||||
|
||||
if payload.get("restart"):
|
||||
LOGGER.info("Restarting shell session on request.")
|
||||
try:
|
||||
session.restart()
|
||||
self._run_startup_commands(session)
|
||||
except BaseException as err:
|
||||
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
|
||||
msg = "Failed to restart shell session."
|
||||
raise ToolException(msg) from err
|
||||
message = "Shell session restarted."
|
||||
return self._format_tool_message(message, tool_call_id, status="success")
|
||||
|
||||
command = payload.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
msg = "Shell tool expects a 'command' string when restart is not requested."
|
||||
raise ToolException(msg)
|
||||
|
||||
LOGGER.info("Executing shell command: %s", command)
|
||||
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||||
|
||||
if result.timed_out:
|
||||
timeout_seconds = self._execution_policy.command_timeout
|
||||
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
|
||||
return self._format_tool_message(
|
||||
message,
|
||||
tool_call_id,
|
||||
status="error",
|
||||
artifact={
|
||||
"timed_out": True,
|
||||
"exit_code": None,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
sanitized_output, matches = self._apply_redactions(result.output)
|
||||
except PIIDetectionError as error:
|
||||
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
|
||||
message = f"Output blocked: detected {error.pii_type}."
|
||||
return self._format_tool_message(
|
||||
message,
|
||||
tool_call_id,
|
||||
status="error",
|
||||
artifact={
|
||||
"timed_out": False,
|
||||
"exit_code": result.exit_code,
|
||||
"matches": {error.pii_type: error.matches},
|
||||
},
|
||||
)
|
||||
|
||||
sanitized_output = sanitized_output or "<no output>"
|
||||
if result.truncated_by_lines:
|
||||
sanitized_output = (
|
||||
f"{sanitized_output.rstrip()}\n\n"
|
||||
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
|
||||
f"(observed {result.total_lines})."
|
||||
)
|
||||
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
|
||||
sanitized_output = (
|
||||
f"{sanitized_output.rstrip()}\n\n"
|
||||
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
|
||||
f"(observed {result.total_bytes})."
|
||||
)
|
||||
|
||||
if result.exit_code not in {0, None}:
|
||||
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
||||
final_status: Literal["success", "error"] = "error"
|
||||
else:
|
||||
final_status = "success"
|
||||
|
||||
artifact = {
|
||||
"timed_out": False,
|
||||
"exit_code": result.exit_code,
|
||||
"truncated_by_lines": result.truncated_by_lines,
|
||||
"truncated_by_bytes": result.truncated_by_bytes,
|
||||
"total_lines": result.total_lines,
|
||||
"total_bytes": result.total_bytes,
|
||||
"redaction_matches": matches,
|
||||
}
|
||||
|
||||
return self._format_tool_message(
|
||||
sanitized_output,
|
||||
tool_call_id,
|
||||
status=final_status,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_call_id: str | None,
|
||||
*,
|
||||
status: Literal["success", "error"],
|
||||
artifact: dict[str, Any] | None = None,
|
||||
) -> ToolMessage | str:
|
||||
artifact = artifact or {}
|
||||
if tool_call_id is None:
|
||||
return content
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=self._tool_name,
|
||||
status=status,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"DockerExecutionPolicy",
|
||||
"HostExecutionPolicy",
|
||||
"RedactionRule",
|
||||
"ShellToolMiddleware",
|
||||
]
|
||||
@@ -0,0 +1,658 @@
|
||||
"""Summarization middleware."""
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.human import HumanMessage
|
||||
from langchain_core.messages.utils import (
|
||||
count_tokens_approximately,
|
||||
get_buffer_string,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import (
|
||||
REMOVE_ALL_MESSAGES,
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
|
||||
from langchain.chat_models import BaseChatModel, init_chat_model
|
||||
|
||||
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
|
||||
|
||||
DEFAULT_SUMMARY_PROMPT = """<role>
|
||||
Context Extraction Assistant
|
||||
</role>
|
||||
|
||||
<primary_objective>
|
||||
Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
|
||||
</primary_objective>
|
||||
|
||||
<objective_information>
|
||||
You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
|
||||
This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to continue working toward your overall goal.
|
||||
</objective_information>
|
||||
|
||||
<instructions>
|
||||
The conversation history below will be replaced with the context you extract in this step.
|
||||
You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
|
||||
|
||||
You should structure your summary using the following sections. Each section acts as a checklist - you must populate it with relevant information or explicitly state "None" if there is nothing to report for that section:
|
||||
|
||||
## SESSION INTENT
|
||||
What is the user's primary goal or request? What overall task are you trying to accomplish? This should be concise but complete enough to understand the purpose of the entire session.
|
||||
|
||||
## SUMMARY
|
||||
Extract and record all of the most important context from the conversation history. Include important choices, conclusions, or strategies determined during this conversation. Include the reasoning behind key decisions. Document any rejected options and why they were not pursued.
|
||||
|
||||
## ARTIFACTS
|
||||
What artifacts, files, or resources were created, modified, or accessed during this conversation? For file modifications, list specific file paths and briefly describe the changes made to each. This section prevents silent loss of artifact information.
|
||||
|
||||
## NEXT STEPS
|
||||
What specific tasks remain to be completed to achieve the session intent? What should you do next?
|
||||
|
||||
</instructions>
|
||||
|
||||
The user will message you with the full message history from which you'll extract context to create a replacement. Carefully read through it all and think deeply about what information is most important to your overall goal and should be saved:
|
||||
|
||||
With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
|
||||
Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
|
||||
|
||||
<messages>
|
||||
Messages to summarize:
|
||||
{messages}
|
||||
</messages>""" # noqa: E501
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Fraction of model's maximum input tokens.
|
||||
|
||||
Example:
|
||||
To specify 50% of the model's max input tokens:
|
||||
|
||||
```python
|
||||
("fraction", 0.5)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Absolute number of tokens.
|
||||
|
||||
Example:
|
||||
To specify 3000 tokens:
|
||||
|
||||
```python
|
||||
("tokens", 3000)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Absolute number of messages.
|
||||
|
||||
Example:
|
||||
To specify 50 messages:
|
||||
|
||||
```python
|
||||
("messages", 50)
|
||||
```
|
||||
"""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Union type for context size specifications.
|
||||
|
||||
Can be either:
|
||||
|
||||
- [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
|
||||
fraction of the model's maximum input tokens.
|
||||
- [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
|
||||
number of tokens.
|
||||
- [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
|
||||
absolute number of messages.
|
||||
|
||||
Depending on use with `trigger` or `keep` parameters, this type indicates either
|
||||
when to trigger summarization or how much context to retain.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# ContextFraction
|
||||
context_size: ContextSize = ("fraction", 0.5)
|
||||
|
||||
# ContextTokens
|
||||
context_size: ContextSize = ("tokens", 3000)
|
||||
|
||||
# ContextMessages
|
||||
context_size: ContextSize = ("messages", 50)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
|
||||
"""Tune parameters of approximate token counter based on model type."""
|
||||
if model._llm_type == "anthropic-chat": # noqa: SLF001
|
||||
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
|
||||
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
|
||||
return partial(
|
||||
count_tokens_approximately, use_usage_metadata_scaling=True, chars_per_token=3.3
|
||||
)
|
||||
return partial(count_tokens_approximately, use_usage_metadata_scaling=True)
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
|
||||
This middleware monitors message token counts and automatically summarizes older
|
||||
messages when a threshold is reached, preserving recent messages and maintaining
|
||||
context continuity by ensuring AI/Tool message pairs remain together.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize summarization middleware.
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
trigger: One or more thresholds that trigger summarization.
|
||||
|
||||
Provide a single
|
||||
[`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
tuple or a list of tuples, in which case summarization runs when any
|
||||
threshold is met.
|
||||
|
||||
!!! example
|
||||
|
||||
```python
|
||||
# Trigger summarization when 50 messages is reached
|
||||
("messages", 50)
|
||||
|
||||
# Trigger summarization when 3000 tokens is reached
|
||||
("tokens", 3000)
|
||||
|
||||
# Trigger summarization either when 80% of model's max input tokens
|
||||
# is reached or when 100 messages is reached (whichever comes first)
|
||||
[("fraction", 0.8), ("messages", 100)]
|
||||
```
|
||||
|
||||
See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
for more details.
|
||||
keep: Context retention policy applied after summarization.
|
||||
|
||||
Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
||||
tuple to specify how much history to preserve.
|
||||
|
||||
Defaults to keeping the most recent `20` messages.
|
||||
|
||||
Does not support multiple values like `trigger`.
|
||||
|
||||
!!! example
|
||||
|
||||
```python
|
||||
# Keep the most recent 20 messages
|
||||
("messages", 20)
|
||||
|
||||
# Keep the most recent 3000 tokens
|
||||
("tokens", 3000)
|
||||
|
||||
# Keep the most recent 30% of the model's max input tokens
|
||||
("fraction", 0.3)
|
||||
```
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
||||
the summarization call.
|
||||
|
||||
Pass `None` to skip trimming entirely.
|
||||
"""
|
||||
# Handle deprecated parameters
|
||||
if "max_tokens_before_summary" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["max_tokens_before_summary"]
|
||||
warnings.warn(
|
||||
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if trigger is None and value is not None:
|
||||
trigger = ("tokens", value)
|
||||
|
||||
if "messages_to_keep" in deprecated_kwargs:
|
||||
value = deprecated_kwargs["messages_to_keep"]
|
||||
warnings.warn(
|
||||
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
|
||||
keep = ("messages", value)
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = model
|
||||
if trigger is None:
|
||||
self.trigger: ContextSize | list[ContextSize] | None = None
|
||||
trigger_conditions: list[ContextSize] = []
|
||||
elif isinstance(trigger, list):
|
||||
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
if token_counter is count_tokens_approximately:
|
||||
self.token_counter = _get_approximate_token_counter(self.model)
|
||||
self._partial_token_counter: TokenCounter = partial( # type: ignore[call-arg]
|
||||
self.token_counter, use_usage_metadata_scaling=False
|
||||
)
|
||||
else:
|
||||
self.token_counter = token_counter
|
||||
self._partial_token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits, "
|
||||
"and is unavailable for the specified model. Please use absolute token "
|
||||
"counts instead, or pass "
|
||||
'`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
|
||||
"with a desired integer value of the model's maximum input tokens."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Process messages before model invocation, potentially triggering summarization.
|
||||
|
||||
Args:
|
||||
state: The agent state.
|
||||
runtime: The runtime environment.
|
||||
|
||||
Returns:
|
||||
An updated state with summarized messages if summarization was performed.
|
||||
"""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = self._create_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
@override
|
||||
async def abefore_model(
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Process messages before model invocation, potentially triggering summarization.
|
||||
|
||||
Args:
|
||||
state: The agent state.
|
||||
runtime: The runtime environment.
|
||||
|
||||
Returns:
|
||||
An updated state with summarized messages if summarization was performed.
|
||||
"""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
||||
total_tokens = self.token_counter(messages)
|
||||
if not self._should_summarize(messages, total_tokens):
|
||||
return None
|
||||
|
||||
cutoff_index = self._determine_cutoff_index(messages)
|
||||
|
||||
if cutoff_index <= 0:
|
||||
return None
|
||||
|
||||
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
||||
|
||||
summary = await self._acreate_summary(messages_to_summarize)
|
||||
new_messages = self._build_new_messages(summary)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages,
|
||||
*preserved_messages,
|
||||
]
|
||||
}
|
||||
|
||||
def _should_summarize_based_on_reported_tokens(
|
||||
self, messages: list[AnyMessage], threshold: float
|
||||
) -> bool:
|
||||
"""Check if reported token usage from last AIMessage exceeds threshold."""
|
||||
last_ai_message = next(
|
||||
(msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ( # noqa: SIM103
|
||||
isinstance(last_ai_message, AIMessage)
|
||||
and last_ai_message.usage_metadata is not None
|
||||
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
|
||||
and reported_tokens >= threshold
|
||||
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
|
||||
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
|
||||
messages, value
|
||||
):
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
|
||||
if self._should_summarize_based_on_reported_tokens(messages, threshold):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
"""Choose cutoff index respecting retention configuration."""
|
||||
kind, value = self.keep
|
||||
if kind in {"tokens", "fraction"}:
|
||||
token_based_cutoff = self._find_token_based_cutoff(messages)
|
||||
if token_based_cutoff is not None:
|
||||
return token_based_cutoff
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
return self._find_safe_cutoff(messages, cast("int", value))
|
||||
|
||||
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
||||
"""Find cutoff index based on target token retention."""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
kind, value = self.keep
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return None
|
||||
target_token_count = int(max_input_tokens * value)
|
||||
elif kind == "tokens":
|
||||
target_token_count = int(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
if target_token_count <= 0:
|
||||
target_token_count = 1
|
||||
|
||||
if self.token_counter(messages) <= target_token_count:
|
||||
return 0
|
||||
|
||||
# Use binary search to identify the earliest message index that keeps the
|
||||
# suffix within the token budget.
|
||||
left, right = 0, len(messages)
|
||||
cutoff_candidate = len(messages)
|
||||
max_iterations = len(messages).bit_length() + 1
|
||||
for _ in range(max_iterations):
|
||||
if left >= right:
|
||||
break
|
||||
|
||||
mid = (left + right) // 2
|
||||
if self._partial_token_counter(messages[mid:]) <= target_token_count:
|
||||
cutoff_candidate = mid
|
||||
right = mid
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
if cutoff_candidate == len(messages):
|
||||
cutoff_candidate = left
|
||||
|
||||
if cutoff_candidate >= len(messages):
|
||||
if len(messages) == 1:
|
||||
return 0
|
||||
cutoff_candidate = len(messages) - 1
|
||||
|
||||
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
|
||||
return self._find_safe_cutoff_point(messages, cutoff_candidate)
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
try:
|
||||
profile = self.model.profile
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
@staticmethod
|
||||
def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
if value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _build_new_messages(summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(
|
||||
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
||||
additional_kwargs={"lc_source": "summarization"},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_message_ids(messages: list[AnyMessage]) -> None:
|
||||
"""Ensure all messages have unique IDs for the add_messages reducer."""
|
||||
for msg in messages:
|
||||
if msg.id is None:
|
||||
msg.id = str(uuid.uuid4())
|
||||
|
||||
@staticmethod
|
||||
def _partition_messages(
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Partition messages into those to summarize and those to preserve."""
|
||||
messages_to_summarize = conversation_messages[:cutoff_index]
|
||||
preserved_messages = conversation_messages[cutoff_index:]
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
||||
|
||||
This is aggressive with summarization - if the target cutoff lands in the
|
||||
middle of tool messages, we advance past all of them (summarizing more).
|
||||
"""
|
||||
if len(messages) <= messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - messages_to_keep
|
||||
return self._find_safe_cutoff_point(messages, target_cutoff)
|
||||
|
||||
@staticmethod
|
||||
def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
|
||||
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
||||
|
||||
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
|
||||
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
|
||||
include it. This ensures tool call requests and responses stay together.
|
||||
|
||||
Falls back to advancing forward past `ToolMessage` objects only if no matching
|
||||
`AIMessage` is found (edge case).
|
||||
"""
|
||||
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
|
||||
return cutoff_index
|
||||
|
||||
# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
|
||||
tool_call_ids: set[str] = set()
|
||||
idx = cutoff_index
|
||||
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
|
||||
tool_msg = cast("ToolMessage", messages[idx])
|
||||
if tool_msg.tool_call_id:
|
||||
tool_call_ids.add(tool_msg.tool_call_id)
|
||||
idx += 1
|
||||
|
||||
# Search backward for AIMessage with matching tool_calls
|
||||
for i in range(cutoff_index - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
|
||||
if tool_call_ids & ai_tool_call_ids:
|
||||
# Found the AIMessage - move cutoff to include it
|
||||
return i
|
||||
|
||||
# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
|
||||
# orphaned tool responses
|
||||
return idx
|
||||
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages.
|
||||
|
||||
Args:
|
||||
messages_to_summarize: Messages to summarize.
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
# Format messages to avoid token inflation from metadata when str() is called on
|
||||
# message objects
|
||||
formatted_messages = get_buffer_string(trimmed_messages)
|
||||
|
||||
try:
|
||||
response = self.model.invoke(
|
||||
self.summary_prompt.format(messages=formatted_messages).rstrip(),
|
||||
config={"metadata": {"lc_source": "summarization"}},
|
||||
)
|
||||
return response.text.strip()
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages.
|
||||
|
||||
Args:
|
||||
messages_to_summarize: Messages to summarize.
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed_messages:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
# Format messages to avoid token inflation from metadata when str() is called on
|
||||
# message objects
|
||||
formatted_messages = get_buffer_string(trimmed_messages)
|
||||
|
||||
try:
|
||||
response = await self.model.ainvoke(
|
||||
self.summary_prompt.format(messages=formatted_messages).rstrip(),
|
||||
config={"metadata": {"lc_source": "summarization"}},
|
||||
)
|
||||
return response.text.strip()
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
"""Trim messages to fit within summary generation limits."""
|
||||
try:
|
||||
if self.trim_tokens_to_summarize is None:
|
||||
return messages
|
||||
return cast(
|
||||
"list[AnyMessage]",
|
||||
trim_messages(
|
||||
messages,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
allow_partial=True,
|
||||
include_system=True,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
|
||||
327
venv/Lib/site-packages/langchain/agents/middleware/todo.py
Normal file
327
venv/Lib/site-packages/langchain/agents/middleware/todo.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""Planning and task management middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict, override
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromInput,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.tools import InjectedToolCallId
|
||||
|
||||
|
||||
class Todo(TypedDict):
|
||||
"""A single todo item with content and status."""
|
||||
|
||||
content: str
|
||||
"""The content/description of the todo item."""
|
||||
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
"""The current status of the todo item."""
|
||||
|
||||
|
||||
class PlanningState(AgentState[ResponseT]):
|
||||
"""State schema for the todo middleware.
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to `Any`.
|
||||
"""
|
||||
|
||||
todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
|
||||
"""List of todo items for tracking task progress."""
|
||||
|
||||
|
||||
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
|
||||
|
||||
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
|
||||
|
||||
## When to Use This Tool
|
||||
Use this tool in these scenarios:
|
||||
|
||||
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
|
||||
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
|
||||
3. User explicitly requests todo list - When the user directly asks you to use the todo list
|
||||
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
|
||||
5. The plan may need future revisions or updates based on results from the first few steps
|
||||
|
||||
## How to Use This Tool
|
||||
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
|
||||
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
|
||||
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
|
||||
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
|
||||
|
||||
## When NOT to Use This Tool
|
||||
It is important to skip using this tool when:
|
||||
1. There is only a single, straightforward task
|
||||
2. The task is trivial and tracking it provides no benefit
|
||||
3. The task can be completed in less than 3 trivial steps
|
||||
4. The task is purely conversational or informational
|
||||
|
||||
## Task States and Management
|
||||
|
||||
1. **Task States**: Use these states to track progress:
|
||||
- pending: Task not yet started
|
||||
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
|
||||
- completed: Task finished successfully
|
||||
|
||||
2. **Task Management**:
|
||||
- Update task status in real-time as you work
|
||||
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
||||
- Complete current tasks before starting new ones
|
||||
- Remove tasks that are no longer relevant from the list entirely
|
||||
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
|
||||
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
|
||||
|
||||
3. **Task Completion Requirements**:
|
||||
- ONLY mark a task as completed when you have FULLY accomplished it
|
||||
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
|
||||
- When blocked, create a new task describing what needs to be resolved
|
||||
- Never mark a task as completed if:
|
||||
- There are unresolved issues or errors
|
||||
- Work is partial or incomplete
|
||||
- You encountered blockers that prevent completion
|
||||
- You couldn't find necessary resources or dependencies
|
||||
- Quality standards haven't been met
|
||||
|
||||
4. **Task Breakdown**:
|
||||
- Create specific, actionable items
|
||||
- Break complex tasks into smaller, manageable steps
|
||||
- Use clear, descriptive task names
|
||||
|
||||
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
|
||||
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all.""" # noqa: E501
|
||||
|
||||
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
|
||||
|
||||
You have access to the `write_todos` tool to help you manage and plan complex objectives.
|
||||
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
|
||||
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
|
||||
|
||||
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
|
||||
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
|
||||
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
|
||||
|
||||
## Important To-Do List Usage Notes to Remember
|
||||
- The `write_todos` tool should never be called multiple times in parallel.
|
||||
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant.""" # noqa: E501
|
||||
|
||||
|
||||
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
||||
def write_todos(
|
||||
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> Command[Any]:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TodoListMiddleware(AgentMiddleware[PlanningState[ResponseT], ContextT, ResponseT]):
|
||||
"""Middleware that provides todo list management capabilities to agents.
|
||||
|
||||
This middleware adds a `write_todos` tool that allows agents to create and manage
|
||||
structured task lists for complex multi-step operations. It's designed to help
|
||||
agents track progress, organize complex tasks, and provide users with visibility
|
||||
into task completion status.
|
||||
|
||||
The middleware automatically injects system prompts that guide the agent on when
|
||||
and how to use the todo functionality effectively. It also enforces that the
|
||||
`write_todos` tool is called at most once per model turn, since the tool replaces
|
||||
the entire todo list and parallel calls would create ambiguity about precedence.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.todo import TodoListMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
|
||||
|
||||
# Agent now has access to write_todos tool and todo state tracking
|
||||
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
||||
|
||||
print(result["todos"]) # Array of todo items with status tracking
|
||||
```
|
||||
"""
|
||||
|
||||
state_schema = PlanningState # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
||||
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
) -> None:
|
||||
"""Initialize the `TodoListMiddleware` with optional custom prompts.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo
|
||||
tool.
|
||||
tool_description: Custom description for the `write_todos` tool.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
self.tool_description = tool_description
|
||||
|
||||
# Dynamically create the write_todos tool with the custom description
|
||||
@tool(description=self.tool_description)
|
||||
def write_todos(
|
||||
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> Command[Any]:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [
|
||||
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
self.tools = [write_todos]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Update the system message to include the todo system prompt.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
"""
|
||||
if request.system_message is not None:
|
||||
new_system_content = [
|
||||
*request.system_message.content_blocks,
|
||||
{"type": "text", "text": f"\n\n{self.system_prompt}"},
|
||||
]
|
||||
else:
|
||||
new_system_content = [{"type": "text", "text": self.system_prompt}]
|
||||
new_system_message = SystemMessage(
|
||||
content=cast("list[str | dict[str, str]]", new_system_content)
|
||||
)
|
||||
return handler(request.override(system_message=new_system_message))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Update the system message to include the todo system prompt.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
"""
|
||||
if request.system_message is not None:
|
||||
new_system_content = [
|
||||
*request.system_message.content_blocks,
|
||||
{"type": "text", "text": f"\n\n{self.system_prompt}"},
|
||||
]
|
||||
else:
|
||||
new_system_content = [{"type": "text", "text": self.system_prompt}]
|
||||
new_system_message = SystemMessage(
|
||||
content=cast("list[str | dict[str, str]]", new_system_content)
|
||||
)
|
||||
return await handler(request.override(system_message=new_system_message))
|
||||
|
||||
@override
|
||||
def after_model(
|
||||
self, state: PlanningState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check for parallel write_todos tool calls and return errors if detected.
|
||||
|
||||
The todo list is designed to be updated at most once per model turn. Since
|
||||
the `write_todos` tool replaces the entire todo list with each call, making
|
||||
multiple parallel calls would create ambiguity about which update should take
|
||||
precedence. This method prevents such conflicts by rejecting any response that
|
||||
contains multiple write_todos tool calls.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages.
|
||||
runtime: The LangGraph runtime instance.
|
||||
|
||||
Returns:
|
||||
A dict containing error ToolMessages for each write_todos call if multiple
|
||||
parallel calls are detected, otherwise None to allow normal execution.
|
||||
"""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Count write_todos tool calls
|
||||
write_todos_calls = [tc for tc in last_ai_msg.tool_calls if tc["name"] == "write_todos"]
|
||||
|
||||
if len(write_todos_calls) > 1:
|
||||
# Create error tool messages for all write_todos calls
|
||||
error_messages = [
|
||||
ToolMessage(
|
||||
content=(
|
||||
"Error: The `write_todos` tool should never be called multiple times "
|
||||
"in parallel. Please call it only once per model invocation to update "
|
||||
"the todo list."
|
||||
),
|
||||
tool_call_id=tc["id"],
|
||||
status="error",
|
||||
)
|
||||
for tc in write_todos_calls
|
||||
]
|
||||
|
||||
# Keep the tool calls in the AI message but return error messages
|
||||
# This follows the same pattern as HumanInTheLoopMiddleware
|
||||
return {"messages": error_messages}
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_model(
|
||||
self, state: PlanningState[ResponseT], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Check for parallel write_todos tool calls and return errors if detected.
|
||||
|
||||
Async version of `after_model`. The todo list is designed to be updated at
|
||||
most once per model turn. Since the `write_todos` tool replaces the entire
|
||||
todo list with each call, making multiple parallel calls would create ambiguity
|
||||
about which update should take precedence. This method prevents such conflicts
|
||||
by rejecting any response that contains multiple write_todos tool calls.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages.
|
||||
runtime: The LangGraph runtime instance.
|
||||
|
||||
Returns:
|
||||
A dict containing error ToolMessages for each write_todos call if multiple
|
||||
parallel calls are detected, otherwise None to allow normal execution.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
@@ -0,0 +1,488 @@
|
||||
"""Tool call limit middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
PrivateStateAttr,
|
||||
ResponseT,
|
||||
hook_config,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
ExitBehavior = Literal["continue", "error", "end"]
|
||||
"""How to handle execution when tool call limits are exceeded.
|
||||
|
||||
- `'continue'`: Block exceeded tools with error messages, let other tools continue
|
||||
(default)
|
||||
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
|
||||
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
"""
|
||||
|
||||
|
||||
class ToolCallLimitState(AgentState[ResponseT]):
|
||||
"""State schema for `ToolCallLimitMiddleware`.
|
||||
|
||||
Extends `AgentState` with tool call tracking fields.
|
||||
|
||||
The count fields are dictionaries mapping tool names to execution counts. This
|
||||
allows multiple middleware instances to track different tools independently. The
|
||||
special key `'__all__'` is used for tracking all tool calls globally.
|
||||
|
||||
Type Parameters:
|
||||
ResponseT: The type of the structured response. Defaults to `Any`.
|
||||
"""
|
||||
|
||||
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
||||
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
|
||||
|
||||
|
||||
def _build_tool_message_content(tool_name: str | None) -> str:
|
||||
"""Build the error message content for `ToolMessage` when limit is exceeded.
|
||||
|
||||
This message is sent to the model, so it should not reference thread/run concepts
|
||||
that the model has no notion of.
|
||||
|
||||
Args:
|
||||
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
||||
|
||||
Returns:
|
||||
A concise message instructing the model not to call the tool again.
|
||||
"""
|
||||
# Always instruct the model not to call again, regardless of which limit was hit
|
||||
if tool_name:
|
||||
return f"Tool call limit exceeded. Do not call '{tool_name}' again."
|
||||
return "Tool call limit exceeded. Do not make additional tool calls."
|
||||
|
||||
|
||||
def _build_final_ai_message_content(
|
||||
thread_count: int,
|
||||
run_count: int,
|
||||
thread_limit: int | None,
|
||||
run_limit: int | None,
|
||||
tool_name: str | None,
|
||||
) -> str:
|
||||
"""Build the final AI message content for `'end'` behavior.
|
||||
|
||||
This message is displayed to the user, so it should include detailed information
|
||||
about which limits were exceeded.
|
||||
|
||||
Args:
|
||||
thread_count: Current thread tool call count.
|
||||
run_count: Current run tool call count.
|
||||
thread_limit: Thread tool call limit (if set).
|
||||
run_limit: Run tool call limit (if set).
|
||||
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
|
||||
|
||||
Returns:
|
||||
A formatted message describing which limits were exceeded.
|
||||
"""
|
||||
tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
|
||||
exceeded_limits = []
|
||||
|
||||
if thread_limit is not None and thread_count > thread_limit:
|
||||
exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
|
||||
if run_limit is not None and run_count > run_limit:
|
||||
exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
|
||||
|
||||
limits_text = " and ".join(exceeded_limits)
|
||||
return f"{tool_desc} call limit reached: {limits_text}."
|
||||
|
||||
|
||||
class ToolCallLimitExceededError(Exception):
|
||||
"""Exception raised when tool call limits are exceeded.
|
||||
|
||||
This exception is raised when the configured exit behavior is `'error'` and either
|
||||
the thread or run tool call limit has been exceeded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
thread_count: int,
|
||||
run_count: int,
|
||||
thread_limit: int | None,
|
||||
run_limit: int | None,
|
||||
tool_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the exception with call count information.
|
||||
|
||||
Args:
|
||||
thread_count: Current thread tool call count.
|
||||
run_count: Current run tool call count.
|
||||
thread_limit: Thread tool call limit (if set).
|
||||
run_limit: Run tool call limit (if set).
|
||||
tool_name: Tool name being limited (if specific tool), or None for all tools.
|
||||
"""
|
||||
self.thread_count = thread_count
|
||||
self.run_count = run_count
|
||||
self.thread_limit = thread_limit
|
||||
self.run_limit = run_limit
|
||||
self.tool_name = tool_name
|
||||
|
||||
msg = _build_final_ai_message_content(
|
||||
thread_count, run_count, thread_limit, run_limit, tool_name
|
||||
)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState[ResponseT], ContextT, ResponseT]):
|
||||
"""Track tool call counts and enforces limits during agent execution.
|
||||
|
||||
This middleware monitors the number of tool calls made and can terminate or
|
||||
restrict execution when limits are exceeded. It supports both thread-level
|
||||
(persistent across runs) and run-level (per invocation) call counting.
|
||||
|
||||
Configuration:
|
||||
- `exit_behavior`: How to handle when limits are exceeded
|
||||
- `'continue'`: Block exceeded tools, let execution continue (default)
|
||||
- `'error'`: Raise an exception
|
||||
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
|
||||
tool call that exceeded the limit (raises `NotImplementedError` if there
|
||||
are other pending tool calls (due to parallel tool calling).
|
||||
|
||||
Examples:
|
||||
!!! example "Continue execution with blocked tools (default)"
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
# Block exceeded tools but let other tools and model continue
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
thread_limit=20,
|
||||
run_limit=10,
|
||||
exit_behavior="continue", # default
|
||||
)
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
!!! example "Stop immediately when limit exceeded"
|
||||
|
||||
```python
|
||||
# End execution immediately with an AI message
|
||||
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
```
|
||||
|
||||
!!! example "Raise exception on limit"
|
||||
|
||||
```python
|
||||
# Strict limit with exception handling
|
||||
limiter = ToolCallLimitMiddleware(
|
||||
tool_name="search", thread_limit=5, exit_behavior="error"
|
||||
)
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[limiter])
|
||||
|
||||
try:
|
||||
result = await agent.invoke({"messages": [HumanMessage("Task")]})
|
||||
except ToolCallLimitExceededError as e:
|
||||
print(f"Search limit exceeded: {e}")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
state_schema = ToolCallLimitState # type: ignore[assignment]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_name: str | None = None,
|
||||
thread_limit: int | None = None,
|
||||
run_limit: int | None = None,
|
||||
exit_behavior: ExitBehavior = "continue",
|
||||
) -> None:
|
||||
"""Initialize the tool call limit middleware.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the specific tool to limit. If `None`, limits apply
|
||||
to all tools.
|
||||
thread_limit: Maximum number of tool calls allowed per thread.
|
||||
`None` means no limit.
|
||||
run_limit: Maximum number of tool calls allowed per run.
|
||||
`None` means no limit.
|
||||
exit_behavior: How to handle when limits are exceeded.
|
||||
|
||||
- `'continue'`: Block exceeded tools with error messages, let other
|
||||
tools continue. Model decides when to end.
|
||||
- `'error'`: Raise a `ToolCallLimitExceededError` exception
|
||||
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
|
||||
for the single tool call that exceeded the limit. Raises
|
||||
`NotImplementedError` if there are multiple parallel tool
|
||||
calls to other tools or multiple pending tool calls.
|
||||
|
||||
Raises:
|
||||
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
|
||||
or if `run_limit` exceeds `thread_limit`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if thread_limit is None and run_limit is None:
|
||||
msg = "At least one limit must be specified (thread_limit or run_limit)"
|
||||
raise ValueError(msg)
|
||||
|
||||
valid_behaviors = ("continue", "error", "end")
|
||||
if exit_behavior not in valid_behaviors:
|
||||
msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
|
||||
msg = (
|
||||
f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
|
||||
"The run limit should be less than or equal to the thread limit."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.tool_name = tool_name
|
||||
self.thread_limit = thread_limit
|
||||
self.run_limit = run_limit
|
||||
self.exit_behavior = exit_behavior
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the middleware instance.
|
||||
|
||||
Includes the tool name if specified to allow multiple instances
|
||||
of this middleware with different tool names.
|
||||
"""
|
||||
base_name = self.__class__.__name__
|
||||
if self.tool_name:
|
||||
return f"{base_name}[{self.tool_name}]"
|
||||
return base_name
|
||||
|
||||
def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
|
||||
"""Check if incrementing the counts would exceed any configured limit.
|
||||
|
||||
Args:
|
||||
thread_count: Current thread call count.
|
||||
run_count: Current run call count.
|
||||
|
||||
Returns:
|
||||
True if either limit would be exceeded by one more call.
|
||||
"""
|
||||
return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
|
||||
self.run_limit is not None and run_count + 1 > self.run_limit
|
||||
)
|
||||
|
||||
def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
|
||||
"""Check if a tool call matches this middleware's tool filter.
|
||||
|
||||
Args:
|
||||
tool_call: The tool call to check.
|
||||
|
||||
Returns:
|
||||
True if this middleware should track this tool call.
|
||||
"""
|
||||
return self.tool_name is None or tool_call["name"] == self.tool_name
|
||||
|
||||
def _separate_tool_calls(
|
||||
self, tool_calls: list[ToolCall], thread_count: int, run_count: int
|
||||
) -> tuple[list[ToolCall], list[ToolCall], int, int]:
|
||||
"""Separate tool calls into allowed and blocked based on limits.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls to evaluate.
|
||||
thread_count: Current thread call count.
|
||||
run_count: Current run call count.
|
||||
|
||||
Returns:
|
||||
Tuple of `(allowed_calls, blocked_calls, final_thread_count,
|
||||
final_run_count)`.
|
||||
"""
|
||||
allowed_calls: list[ToolCall] = []
|
||||
blocked_calls: list[ToolCall] = []
|
||||
temp_thread_count = thread_count
|
||||
temp_run_count = run_count
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not self._matches_tool_filter(tool_call):
|
||||
continue
|
||||
|
||||
if self._would_exceed_limit(temp_thread_count, temp_run_count):
|
||||
blocked_calls.append(tool_call)
|
||||
else:
|
||||
allowed_calls.append(tool_call)
|
||||
temp_thread_count += 1
|
||||
temp_run_count += 1
|
||||
|
||||
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
@override
|
||||
def after_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
# Get the last AIMessage to check for tool calls
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Find the last AIMessage
|
||||
last_ai_message = None
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, AIMessage):
|
||||
last_ai_message = message
|
||||
break
|
||||
|
||||
if not last_ai_message or not last_ai_message.tool_calls:
|
||||
return None
|
||||
|
||||
# Get the count key for this middleware instance
|
||||
count_key = self.tool_name or "__all__"
|
||||
|
||||
# Get current counts
|
||||
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
||||
run_counts = state.get("run_tool_call_count", {}).copy()
|
||||
current_thread_count = thread_counts.get(count_key, 0)
|
||||
current_run_count = run_counts.get(count_key, 0)
|
||||
|
||||
# Separate tool calls into allowed and blocked
|
||||
allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
|
||||
last_ai_message.tool_calls, current_thread_count, current_run_count
|
||||
)
|
||||
|
||||
# Update counts to include only allowed calls for thread count
|
||||
# (blocked calls don't count towards thread-level tracking)
|
||||
# But run count includes blocked calls since they were attempted in this run
|
||||
thread_counts[count_key] = new_thread_count
|
||||
run_counts[count_key] = new_run_count + len(blocked_calls)
|
||||
|
||||
# If no tool calls are blocked, just update counts
|
||||
if not blocked_calls:
|
||||
if allowed_calls:
|
||||
return {
|
||||
"thread_tool_call_count": thread_counts,
|
||||
"run_tool_call_count": run_counts,
|
||||
}
|
||||
return None
|
||||
|
||||
# Get final counts for building messages
|
||||
final_thread_count = thread_counts[count_key]
|
||||
final_run_count = run_counts[count_key]
|
||||
|
||||
# Handle different exit behaviors
|
||||
if self.exit_behavior == "error":
|
||||
# Use hypothetical thread count to show which limit was exceeded
|
||||
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
||||
raise ToolCallLimitExceededError(
|
||||
thread_count=hypothetical_thread_count,
|
||||
run_count=final_run_count,
|
||||
thread_limit=self.thread_limit,
|
||||
run_limit=self.run_limit,
|
||||
tool_name=self.tool_name,
|
||||
)
|
||||
|
||||
# Build tool message content (sent to model - no thread/run details)
|
||||
tool_msg_content = _build_tool_message_content(self.tool_name)
|
||||
|
||||
# Inject artificial error ToolMessages for blocked tool calls
|
||||
artificial_messages: list[ToolMessage | AIMessage] = [
|
||||
ToolMessage(
|
||||
content=tool_msg_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call.get("name"),
|
||||
status="error",
|
||||
)
|
||||
for tool_call in blocked_calls
|
||||
]
|
||||
|
||||
if self.exit_behavior == "end":
|
||||
# Check if there are tool calls to other tools that would continue executing
|
||||
other_tools = [
|
||||
tc
|
||||
for tc in last_ai_message.tool_calls
|
||||
if self.tool_name is not None and tc["name"] != self.tool_name
|
||||
]
|
||||
|
||||
if other_tools:
|
||||
tool_names = ", ".join({tc["name"] for tc in other_tools})
|
||||
msg = (
|
||||
f"Cannot end execution with other tool calls pending. "
|
||||
f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# Build final AI message content (displayed to user - includes thread/run details)
|
||||
# Use hypothetical thread count (what it would have been if call wasn't blocked)
|
||||
# to show which limit was actually exceeded
|
||||
hypothetical_thread_count = final_thread_count + len(blocked_calls)
|
||||
final_msg_content = _build_final_ai_message_content(
|
||||
hypothetical_thread_count,
|
||||
final_run_count,
|
||||
self.thread_limit,
|
||||
self.run_limit,
|
||||
self.tool_name,
|
||||
)
|
||||
artificial_messages.append(AIMessage(content=final_msg_content))
|
||||
|
||||
return {
|
||||
"thread_tool_call_count": thread_counts,
|
||||
"run_tool_call_count": run_counts,
|
||||
"jump_to": "end",
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
# For exit_behavior="continue", return error messages to block exceeded tools
|
||||
return {
|
||||
"thread_tool_call_count": thread_counts,
|
||||
"run_tool_call_count": run_counts,
|
||||
"messages": artificial_messages,
|
||||
}
|
||||
|
||||
@hook_config(can_jump_to=["end"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: ToolCallLimitState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async increment tool call counts after a model call and check limits.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The langgraph runtime.
|
||||
|
||||
Returns:
|
||||
State updates with incremented tool call counts. If limits are exceeded
|
||||
and exit_behavior is `'end'`, also includes a jump to end with a
|
||||
`ToolMessage` and AI message for the single exceeded tool call.
|
||||
|
||||
Raises:
|
||||
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
|
||||
is `'error'`.
|
||||
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
|
||||
and there are multiple tool calls.
|
||||
"""
|
||||
return self.after_model(state, runtime)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Tool emulator middleware for testing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class LLMToolEmulator(AgentMiddleware[AgentState[Any], ContextT], Generic[ContextT]):
|
||||
"""Emulates specified tools using an LLM instead of executing them.
|
||||
|
||||
This middleware allows selective emulation of tools for testing purposes.
|
||||
|
||||
By default (when `tools=None`), all tools are emulated. You can specify which
|
||||
tools to emulate by passing a list of tool names or `BaseTool` instances.
|
||||
|
||||
Examples:
|
||||
!!! example "Emulate all tools (default behavior)"
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolEmulator
|
||||
|
||||
middleware = LLMToolEmulator()
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[get_weather, get_user_location, calculator],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Emulate specific tools by name"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
|
||||
```
|
||||
|
||||
!!! example "Use a custom model for emulation"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(
|
||||
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Emulate specific tools by passing tool instances"
|
||||
|
||||
```python
|
||||
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tools: list[str | BaseTool] | None = None,
|
||||
model: str | BaseChatModel | None = None,
|
||||
) -> None:
|
||||
"""Initialize the tool emulator.
|
||||
|
||||
Args:
|
||||
tools: List of tool names (`str`) or `BaseTool` instances to emulate.
|
||||
|
||||
If `None`, ALL tools will be emulated.
|
||||
|
||||
If empty list, no tools will be emulated.
|
||||
model: Model to use for emulation.
|
||||
|
||||
Defaults to `'anthropic:claude-sonnet-4-5-20250929'`.
|
||||
|
||||
Can be a model identifier string or `BaseChatModel` instance.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Extract tool names from tools
|
||||
# None means emulate all tools
|
||||
self.emulate_all = tools is None
|
||||
self.tools_to_emulate: set[str] = set()
|
||||
|
||||
if not self.emulate_all and tools is not None:
|
||||
for tool in tools:
|
||||
if isinstance(tool, str):
|
||||
self.tools_to_emulate.add(tool)
|
||||
else:
|
||||
# Assume BaseTool with .name attribute
|
||||
self.tools_to_emulate.add(tool.name)
|
||||
|
||||
# Initialize emulator model
|
||||
if model is None:
|
||||
self.model = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=1)
|
||||
elif isinstance(model, BaseChatModel):
|
||||
self.model = model
|
||||
else:
|
||||
self.model = init_chat_model(model, temperature=1)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Emulate tool execution using LLM if tool should be emulated.
|
||||
|
||||
Args:
|
||||
request: Tool call request to potentially emulate.
|
||||
handler: Callback to execute the tool (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
ToolMessage with emulated response if tool should be emulated,
|
||||
otherwise calls handler for normal execution.
|
||||
"""
|
||||
tool_name = request.tool_call["name"]
|
||||
|
||||
# Check if this tool should be emulated
|
||||
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
|
||||
|
||||
if not should_emulate:
|
||||
# Let it execute normally by calling the handler
|
||||
return handler(request)
|
||||
|
||||
# Extract tool information for emulation
|
||||
tool_args = request.tool_call["args"]
|
||||
tool_description = request.tool.description if request.tool else "No description available"
|
||||
|
||||
# Build prompt for emulator LLM
|
||||
prompt = (
|
||||
f"You are emulating a tool call for testing purposes.\n\n"
|
||||
f"Tool: {tool_name}\n"
|
||||
f"Description: {tool_description}\n"
|
||||
f"Arguments: {tool_args}\n\n"
|
||||
f"Generate a realistic response that this tool would return "
|
||||
f"given these arguments.\n"
|
||||
f"Return ONLY the tool's output, no explanation or preamble. "
|
||||
f"Introduce variation into your responses."
|
||||
)
|
||||
|
||||
# Get emulated response from LLM
|
||||
response = self.model.invoke([HumanMessage(prompt)])
|
||||
|
||||
# Short-circuit: return emulated result without executing real tool
|
||||
return ToolMessage(
|
||||
content=response.content,
|
||||
tool_call_id=request.tool_call["id"],
|
||||
name=tool_name,
|
||||
)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Async version of `wrap_tool_call`.
|
||||
|
||||
Emulate tool execution using LLM if tool should be emulated.
|
||||
|
||||
Args:
|
||||
request: Tool call request to potentially emulate.
|
||||
handler: Async callback to execute the tool (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
ToolMessage with emulated response if tool should be emulated,
|
||||
otherwise calls handler for normal execution.
|
||||
"""
|
||||
tool_name = request.tool_call["name"]
|
||||
|
||||
# Check if this tool should be emulated
|
||||
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
|
||||
|
||||
if not should_emulate:
|
||||
# Let it execute normally by calling the handler
|
||||
return await handler(request)
|
||||
|
||||
# Extract tool information for emulation
|
||||
tool_args = request.tool_call["args"]
|
||||
tool_description = request.tool.description if request.tool else "No description available"
|
||||
|
||||
# Build prompt for emulator LLM
|
||||
prompt = (
|
||||
f"You are emulating a tool call for testing purposes.\n\n"
|
||||
f"Tool: {tool_name}\n"
|
||||
f"Description: {tool_description}\n"
|
||||
f"Arguments: {tool_args}\n\n"
|
||||
f"Generate a realistic response that this tool would return "
|
||||
f"given these arguments.\n"
|
||||
f"Return ONLY the tool's output, no explanation or preamble. "
|
||||
f"Introduce variation into your responses."
|
||||
)
|
||||
|
||||
# Get emulated response from LLM (using async invoke)
|
||||
response = await self.model.ainvoke([HumanMessage(prompt)])
|
||||
|
||||
# Short-circuit: return emulated result without executing real tool
|
||||
return ToolMessage(
|
||||
content=response.content,
|
||||
tool_call_id=request.tool_call["id"],
|
||||
name=tool_name,
|
||||
)
|
||||
403
venv/Lib/site-packages/langchain/agents/middleware/tool_retry.py
Normal file
403
venv/Lib/site-packages/langchain/agents/middleware/tool_retry.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Tool retry middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from langchain.agents.middleware._retry import (
|
||||
OnFailure,
|
||||
RetryOn,
|
||||
calculate_delay,
|
||||
should_retry_exception,
|
||||
validate_retry_params,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ToolRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Middleware that automatically retries failed tool calls with configurable backoff.
|
||||
|
||||
Supports retrying on specific exceptions and exponential backoff.
|
||||
|
||||
Examples:
|
||||
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||||
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||||
```
|
||||
|
||||
!!! example "Retry specific exceptions only"
|
||||
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Custom exception filtering"
|
||||
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Apply to specific tools with custom error handling"
|
||||
|
||||
```python
|
||||
def format_error(exc: Exception) -> str:
|
||||
return "Database temporarily unavailable. Please try again later."
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=["search_database"],
|
||||
on_failure=format_error,
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Apply to specific tools using `BaseTool` instances"
|
||||
|
||||
```python
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def search_database(query: str) -> str:
|
||||
'''Search the database.'''
|
||||
return results
|
||||
|
||||
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=4,
|
||||
tools=[search_database], # Pass BaseTool instance
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Constant backoff (no exponential growth)"
|
||||
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Raise exception on failure"
|
||||
|
||||
```python
|
||||
retry = ToolRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="error", # Re-raise exception instead of returning message
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 2,
|
||||
tools: list[BaseTool | str] | None = None,
|
||||
retry_on: RetryOn = (Exception,),
|
||||
on_failure: OnFailure = "continue",
|
||||
backoff_factor: float = 2.0,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
) -> None:
|
||||
"""Initialize `ToolRetryMiddleware`.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
|
||||
Must be `>= 0`.
|
||||
tools: Optional list of tools or tool names to apply retry logic to.
|
||||
|
||||
Can be a list of `BaseTool` instances or tool name strings.
|
||||
|
||||
If `None`, applies to all tools.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Options:
|
||||
|
||||
- `'continue'`: Return a `ToolMessage` with error details,
|
||||
allowing the LLM to handle the failure and potentially recover.
|
||||
- `'error'`: Re-raise the exception, stopping agent execution.
|
||||
- **Custom callable:** Function that takes the exception and returns a
|
||||
string for the `ToolMessage` content, allowing custom error
|
||||
formatting.
|
||||
|
||||
**Deprecated values** (for backwards compatibility):
|
||||
|
||||
- `'return_message'`: Use `'continue'` instead.
|
||||
- `'raise'`: Use `'error'` instead.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||||
seconds.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_retries < 0` or delays are negative.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Validate parameters
|
||||
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
|
||||
|
||||
# Handle backwards compatibility for deprecated on_failure values
|
||||
if on_failure == "raise": # type: ignore[comparison-overlap]
|
||||
msg = ( # type: ignore[unreachable]
|
||||
"on_failure='raise' is deprecated and will be removed in a future version. "
|
||||
"Use on_failure='error' instead."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
on_failure = "error"
|
||||
elif on_failure == "return_message": # type: ignore[comparison-overlap]
|
||||
msg = ( # type: ignore[unreachable]
|
||||
"on_failure='return_message' is deprecated and will be removed "
|
||||
"in a future version. Use on_failure='continue' instead."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
on_failure = "continue"
|
||||
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Extract tool names from BaseTool instances or strings
|
||||
self._tool_filter: list[str] | None
|
||||
if tools is not None:
|
||||
self._tool_filter = [tool.name if not isinstance(tool, str) else tool for tool in tools]
|
||||
else:
|
||||
self._tool_filter = None
|
||||
|
||||
self.tools = [] # No additional tools registered by this middleware
|
||||
self.retry_on = retry_on
|
||||
self.on_failure = on_failure
|
||||
self.backoff_factor = backoff_factor
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
|
||||
def _should_retry_tool(self, tool_name: str) -> bool:
|
||||
"""Check if retry logic should apply to this tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called.
|
||||
|
||||
Returns:
|
||||
`True` if retry logic should apply, `False` otherwise.
|
||||
"""
|
||||
if self._tool_filter is None:
|
||||
return True
|
||||
return tool_name in self._tool_filter
|
||||
|
||||
@staticmethod
|
||||
def _format_failure_message(tool_name: str, exc: Exception, attempts_made: int) -> str:
|
||||
"""Format the failure message when retries are exhausted.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that failed.
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
Formatted error message string.
|
||||
"""
|
||||
exc_type = type(exc).__name__
|
||||
exc_msg = str(exc)
|
||||
attempt_word = "attempt" if attempts_made == 1 else "attempts"
|
||||
return (
|
||||
f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} "
|
||||
f"with {exc_type}: {exc_msg}. Please try again."
|
||||
)
|
||||
|
||||
def _handle_failure(
|
||||
self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
|
||||
) -> ToolMessage:
|
||||
"""Handle failure when all retries are exhausted.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that failed.
|
||||
tool_call_id: ID of the tool call (may be `None`).
|
||||
exc: The exception that caused the failure.
|
||||
attempts_made: Number of attempts actually made.
|
||||
|
||||
Returns:
|
||||
`ToolMessage` with error details.
|
||||
|
||||
Raises:
|
||||
Exception: If `on_failure` is `'error'`, re-raises the exception.
|
||||
"""
|
||||
if self.on_failure == "error":
|
||||
raise exc
|
||||
|
||||
if callable(self.on_failure):
|
||||
content = self.on_failure(exc)
|
||||
else:
|
||||
content = self._format_failure_message(tool_name, exc, attempts_made)
|
||||
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Intercept tool execution and retry on failure.
|
||||
|
||||
Args:
|
||||
request: Tool call request with call dict, `BaseTool`, state, and runtime.
|
||||
handler: Callable to execute the tool (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
`ToolMessage` or `Command` (the final result).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the retry loop completes without returning. This should not happen.
|
||||
"""
|
||||
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||||
|
||||
# Check if retry should apply to this tool
|
||||
if not self._should_retry_tool(tool_name):
|
||||
return handler(request)
|
||||
|
||||
tool_call_id = request.tool_call["id"]
|
||||
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Intercept and control async tool execution with retry logic.
|
||||
|
||||
Args:
|
||||
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
||||
handler: Async callable to execute the tool and returns `ToolMessage` or
|
||||
`Command`.
|
||||
|
||||
Returns:
|
||||
`ToolMessage` or `Command` (the final result).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the retry loop completes without returning. This should not happen.
|
||||
"""
|
||||
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||||
|
||||
# Check if retry should apply to this tool
|
||||
if not self._should_retry_tool(tool_name):
|
||||
return await handler(request)
|
||||
|
||||
tool_call_id = request.tool_call["id"]
|
||||
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
attempts_made = attempt + 1 # attempt is 0-indexed
|
||||
|
||||
# Check if we should retry this exception
|
||||
if not should_retry_exception(exc, self.retry_on):
|
||||
# Exception is not retryable, handle failure immediately
|
||||
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = calculate_delay(
|
||||
attempt,
|
||||
backoff_factor=self.backoff_factor,
|
||||
initial_delay=self.initial_delay,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, handle failure
|
||||
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||||
|
||||
# Unreachable: loop always returns via handler success or _handle_failure
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
@@ -0,0 +1,358 @@
|
||||
"""LLM-based tool selector middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class LLMToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Uses an LLM to select relevant tools before calling the main model.
|
||||
|
||||
When an agent has many tools available, this middleware filters them down
|
||||
to only the most relevant ones for the user's query. This reduces token usage
|
||||
and helps the main model focus on the right tools.
|
||||
|
||||
Examples:
|
||||
!!! example "Limit to 3 tools"
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
|
||||
middleware = LLMToolSelectorMiddleware(max_tools=3)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[tool1, tool2, tool3, tool4, tool5],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
!!! example "Use a smaller model for selection"
|
||||
|
||||
```python
|
||||
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | BaseChatModel | None = None,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the tool selector.
|
||||
|
||||
Args:
|
||||
model: Model to use for selection.
|
||||
|
||||
If not provided, uses the agent's main model.
|
||||
|
||||
Can be a model identifier string or `BaseChatModel` instance.
|
||||
system_prompt: Instructions for the selection model.
|
||||
max_tools: Maximum number of tools to select.
|
||||
|
||||
If the model selects more, only the first `max_tools` will be used.
|
||||
|
||||
If not specified, there is no limit.
|
||||
always_include: Tool names to always include regardless of selection.
|
||||
|
||||
These do not count against the `max_tools` limit.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
|
||||
if isinstance(model, (BaseChatModel, type(None))):
|
||||
self.model: BaseChatModel | None = model
|
||||
else:
|
||||
self.model = init_chat_model(model)
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Filter tools based on LLM selection before invoking the model via handler.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the selection model response is not a dict.
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return handler(request)
|
||||
|
||||
# Create dynamic response model with Literal enum of available tool names
|
||||
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
||||
schema = type_adapter.json_schema()
|
||||
structured_model = selection_request.model.with_structured_output(schema)
|
||||
|
||||
response = structured_model.invoke(
|
||||
[
|
||||
{"role": "system", "content": selection_request.system_message},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
|
||||
# Response should be a dict since we're passing a schema (not a Pydantic model class)
|
||||
if not isinstance(response, dict):
|
||||
msg = f"Expected dict response, got {type(response)}"
|
||||
raise AssertionError(msg) # noqa: TRY004
|
||||
modified_request = self._process_selection_response(
|
||||
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
||||
)
|
||||
return handler(modified_request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
"""Filter tools based on LLM selection before invoking the model via handler.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request and returns
|
||||
`ModelResponse`.
|
||||
|
||||
Returns:
|
||||
The model call result.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the selection model response is not a dict.
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return await handler(request)
|
||||
|
||||
# Create dynamic response model with Literal enum of available tool names
|
||||
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
||||
schema = type_adapter.json_schema()
|
||||
structured_model = selection_request.model.with_structured_output(schema)
|
||||
|
||||
response = await structured_model.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": selection_request.system_message},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
|
||||
# Response should be a dict since we're passing a schema (not a Pydantic model class)
|
||||
if not isinstance(response, dict):
|
||||
msg = f"Expected dict response, got {type(response)}"
|
||||
raise AssertionError(msg) # noqa: TRY004
|
||||
modified_request = self._process_selection_response(
|
||||
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
||||
)
|
||||
return await handler(modified_request)
|
||||
2052
venv/Lib/site-packages/langchain/agents/middleware/types.py
Normal file
2052
venv/Lib/site-packages/langchain/agents/middleware/types.py
Normal file
File diff suppressed because it is too large
Load Diff
462
venv/Lib/site-packages/langchain/agents/structured_output.py
Normal file
462
venv/Lib/site-packages/langchain/agents/structured_output.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""Types for setting agent response formats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from typing_extensions import Self, is_typeddict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
|
||||
SchemaT = TypeVar("SchemaT")
|
||||
|
||||
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
|
||||
|
||||
|
||||
class StructuredOutputError(Exception):
|
||||
"""Base class for structured output errors."""
|
||||
|
||||
ai_message: AIMessage
|
||||
|
||||
|
||||
class MultipleStructuredOutputsError(StructuredOutputError):
|
||||
"""Raised when model returns multiple structured output tool calls when only one is expected."""
|
||||
|
||||
def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
|
||||
"""Initialize `MultipleStructuredOutputsError`.
|
||||
|
||||
Args:
|
||||
tool_names: The names of the tools called for structured output.
|
||||
ai_message: The AI message that contained the invalid multiple tool calls.
|
||||
"""
|
||||
self.tool_names = tool_names
|
||||
self.ai_message = ai_message
|
||||
|
||||
super().__init__(
|
||||
"Model incorrectly returned multiple structured responses "
|
||||
f"({', '.join(tool_names)}) when only one is expected."
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputValidationError(StructuredOutputError):
|
||||
"""Raised when structured output tool call arguments fail to parse according to the schema."""
|
||||
|
||||
def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
|
||||
"""Initialize `StructuredOutputValidationError`.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool that failed.
|
||||
source: The exception that occurred.
|
||||
ai_message: The AI message that contained the invalid structured output.
|
||||
"""
|
||||
self.tool_name = tool_name
|
||||
self.source = source
|
||||
self.ai_message = ai_message
|
||||
super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
|
||||
|
||||
|
||||
def _parse_with_schema(
|
||||
schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any]
|
||||
) -> Any:
|
||||
"""Parse data using for any supported schema type.
|
||||
|
||||
Args:
|
||||
schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`)
|
||||
schema_kind: One of `'pydantic'`, `'dataclass'`, `'typeddict'`, or
|
||||
`'json_schema'`
|
||||
data: The data to parse
|
||||
|
||||
Returns:
|
||||
The parsed instance according to the schema type
|
||||
|
||||
Raises:
|
||||
ValueError: If parsing fails
|
||||
"""
|
||||
if schema_kind == "json_schema":
|
||||
return data
|
||||
try:
|
||||
adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
|
||||
return adapter.validate_python(data)
|
||||
except Exception as e:
|
||||
schema_name = getattr(schema, "__name__", str(schema))
|
||||
msg = f"Failed to parse data to {schema_name}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class _SchemaSpec(Generic[SchemaT]):
|
||||
"""Describes a structured output schema."""
|
||||
|
||||
schema: type[SchemaT] | dict[str, Any]
|
||||
"""The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
|
||||
or JSON schema dict.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""Name of the schema, used for tool calling.
|
||||
|
||||
If not provided, the name will be the class name for models/dataclasses/TypedDicts,
|
||||
or the `title` field for JSON schemas.
|
||||
|
||||
Falls back to a generated name if unavailable.
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Custom description of the schema.
|
||||
|
||||
If not provided, will use the model's docstring.
|
||||
"""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""The kind of schema."""
|
||||
|
||||
json_schema: dict[str, Any]
|
||||
"""JSON schema associated with the schema."""
|
||||
|
||||
strict: bool | None = None
|
||||
"""Whether to enforce strict validation of the schema."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT] | dict[str, Any],
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> None:
|
||||
"""Initialize `SchemaSpec` with schema and optional parameters.
|
||||
|
||||
Args:
|
||||
schema: Schema to describe.
|
||||
name: Optional name for the schema.
|
||||
description: Optional description for the schema.
|
||||
strict: Whether to enforce strict validation of the schema.
|
||||
|
||||
Raises:
|
||||
ValueError: If the schema type is unsupported.
|
||||
"""
|
||||
self.schema = schema
|
||||
|
||||
if name:
|
||||
self.name = name
|
||||
elif isinstance(schema, dict):
|
||||
self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
|
||||
else:
|
||||
self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
|
||||
|
||||
self.description = description or (
|
||||
schema.get("description", "")
|
||||
if isinstance(schema, dict)
|
||||
else getattr(schema, "__doc__", None) or ""
|
||||
)
|
||||
|
||||
self.strict = strict
|
||||
|
||||
if isinstance(schema, dict):
|
||||
self.schema_kind = "json_schema"
|
||||
self.json_schema = schema
|
||||
elif isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
self.schema_kind = "pydantic"
|
||||
self.json_schema = schema.model_json_schema()
|
||||
elif is_dataclass(schema):
|
||||
self.schema_kind = "dataclass"
|
||||
self.json_schema = TypeAdapter(schema).json_schema()
|
||||
elif is_typeddict(schema):
|
||||
self.schema_kind = "typeddict"
|
||||
self.json_schema = TypeAdapter(schema).json_schema()
|
||||
else:
|
||||
msg = (
|
||||
f"Unsupported schema type: {type(schema)}. "
|
||||
f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class ToolStrategy(Generic[SchemaT]):
|
||||
"""Use a tool calling strategy for model responses."""
|
||||
|
||||
schema: type[SchemaT] | UnionType | dict[str, Any]
|
||||
"""Schema for the tool calls."""
|
||||
|
||||
schema_specs: list[_SchemaSpec[Any]]
|
||||
"""Schema specs for the tool calls."""
|
||||
|
||||
tool_message_content: str | None
|
||||
"""The content of the tool message to be returned when the model calls
|
||||
an artificial structured output tool.
|
||||
"""
|
||||
|
||||
handle_errors: (
|
||||
bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
|
||||
)
|
||||
"""Error handling strategy for structured output via `ToolStrategy`.
|
||||
|
||||
- `True`: Catch all errors with default error template
|
||||
- `str`: Catch all errors with this custom message
|
||||
- `type[Exception]`: Only catch this exception type with default message
|
||||
- `tuple[type[Exception], ...]`: Only catch these exception types with default
|
||||
message
|
||||
- `Callable[[Exception], str]`: Custom function that returns error message
|
||||
- `False`: No retry, let exceptions propagate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT] | UnionType | dict[str, Any],
|
||||
*,
|
||||
tool_message_content: str | None = None,
|
||||
handle_errors: bool
|
||||
| str
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...]
|
||||
| Callable[[Exception], str] = True,
|
||||
) -> None:
|
||||
"""Initialize `ToolStrategy`.
|
||||
|
||||
Initialize `ToolStrategy` with schemas, tool message content, and error handling
|
||||
strategy.
|
||||
"""
|
||||
self.schema = schema
|
||||
self.tool_message_content = tool_message_content
|
||||
self.handle_errors = handle_errors
|
||||
|
||||
def _iter_variants(schema: Any) -> Iterable[Any]:
|
||||
"""Yield leaf variants from Union and JSON Schema oneOf."""
|
||||
if get_origin(schema) in {UnionType, Union}:
|
||||
for arg in get_args(schema):
|
||||
yield from _iter_variants(arg)
|
||||
return
|
||||
|
||||
if isinstance(schema, dict) and "oneOf" in schema:
|
||||
for sub in schema.get("oneOf", []):
|
||||
yield from _iter_variants(sub)
|
||||
return
|
||||
|
||||
yield schema
|
||||
|
||||
self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class ProviderStrategy(Generic[SchemaT]):
|
||||
"""Use the model provider's native structured output method."""
|
||||
|
||||
schema: type[SchemaT] | dict[str, Any]
|
||||
"""Schema for native mode."""
|
||||
|
||||
schema_spec: _SchemaSpec[SchemaT]
|
||||
"""Schema spec for native mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT] | dict[str, Any],
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
) -> None:
|
||||
"""Initialize `ProviderStrategy` with schema.
|
||||
|
||||
Args:
|
||||
schema: Schema to enforce via the provider's native structured output.
|
||||
strict: Whether to request strict provider-side schema enforcement.
|
||||
"""
|
||||
self.schema = schema
|
||||
self.schema_spec = _SchemaSpec(schema, strict=strict)
|
||||
|
||||
def to_model_kwargs(self) -> dict[str, Any]:
|
||||
"""Convert to kwargs to bind to a model to force structured output.
|
||||
|
||||
Returns:
|
||||
The kwargs to bind to a model.
|
||||
"""
|
||||
# OpenAI:
|
||||
# - see https://platform.openai.com/docs/guides/structured-outputs
|
||||
json_schema: dict[str, Any] = {
|
||||
"name": self.schema_spec.name,
|
||||
"schema": self.schema_spec.json_schema,
|
||||
}
|
||||
if self.schema_spec.strict:
|
||||
json_schema["strict"] = True
|
||||
|
||||
response_format: dict[str, Any] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": json_schema,
|
||||
}
|
||||
return {"response_format": response_format}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputToolBinding(Generic[SchemaT]):
|
||||
"""Information for tracking structured output tool metadata.
|
||||
|
||||
This contains all necessary information to handle structured responses generated via
|
||||
tool calls, including the original schema, its type classification, and the
|
||||
corresponding tool implementation used by the tools strategy.
|
||||
"""
|
||||
|
||||
schema: type[SchemaT] | dict[str, Any]
|
||||
"""The original schema provided for structured output (Pydantic model, dataclass,
|
||||
TypedDict, or JSON schema dict).
|
||||
"""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""Classification of the schema type for proper response construction."""
|
||||
|
||||
tool: BaseTool
|
||||
"""LangChain tool instance created from the schema for model binding."""
|
||||
|
||||
@classmethod
|
||||
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
|
||||
"""Create an `OutputToolBinding` instance from a `SchemaSpec`.
|
||||
|
||||
Args:
|
||||
schema_spec: The `SchemaSpec` to convert
|
||||
|
||||
Returns:
|
||||
An `OutputToolBinding` instance with the appropriate tool created
|
||||
"""
|
||||
return cls(
|
||||
schema=schema_spec.schema,
|
||||
schema_kind=schema_spec.schema_kind,
|
||||
tool=StructuredTool(
|
||||
args_schema=schema_spec.json_schema,
|
||||
name=schema_spec.name,
|
||||
description=schema_spec.description,
|
||||
),
|
||||
)
|
||||
|
||||
def parse(self, tool_args: dict[str, Any]) -> SchemaT:
|
||||
"""Parse tool arguments according to the schema.
|
||||
|
||||
Args:
|
||||
tool_args: The arguments from the tool call
|
||||
|
||||
Returns:
|
||||
The parsed response according to the schema type
|
||||
|
||||
Raises:
|
||||
ValueError: If parsing fails
|
||||
"""
|
||||
return _parse_with_schema(self.schema, self.schema_kind, tool_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderStrategyBinding(Generic[SchemaT]):
|
||||
"""Information for tracking native structured output metadata.
|
||||
|
||||
This contains all necessary information to handle structured responses generated via
|
||||
native provider output, including the original schema, its type classification, and
|
||||
parsing logic for provider-enforced JSON.
|
||||
"""
|
||||
|
||||
schema: type[SchemaT] | dict[str, Any]
|
||||
"""The original schema provided for structured output (Pydantic model, `dataclass`,
|
||||
`TypedDict`, or JSON schema dict).
|
||||
"""
|
||||
|
||||
schema_kind: SchemaKind
|
||||
"""Classification of the schema type for proper response construction."""
|
||||
|
||||
@classmethod
|
||||
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
|
||||
"""Create a `ProviderStrategyBinding` instance from a `SchemaSpec`.
|
||||
|
||||
Args:
|
||||
schema_spec: The `SchemaSpec` to convert
|
||||
|
||||
Returns:
|
||||
A `ProviderStrategyBinding` instance for parsing native structured output
|
||||
"""
|
||||
return cls(
|
||||
schema=schema_spec.schema,
|
||||
schema_kind=schema_spec.schema_kind,
|
||||
)
|
||||
|
||||
def parse(self, response: AIMessage) -> SchemaT:
|
||||
"""Parse `AIMessage` content according to the schema.
|
||||
|
||||
Args:
|
||||
response: The `AIMessage` containing the structured output
|
||||
|
||||
Returns:
|
||||
The parsed response according to the schema
|
||||
|
||||
Raises:
|
||||
ValueError: If text extraction, JSON parsing or schema validation fails
|
||||
"""
|
||||
# Extract text content from AIMessage and parse as JSON
|
||||
raw_text = self._extract_text_content_from_message(response)
|
||||
|
||||
try:
|
||||
data = json.loads(raw_text)
|
||||
except Exception as e:
|
||||
schema_name = getattr(self.schema, "__name__", "response_format")
|
||||
msg = (
|
||||
f"Native structured output expected valid JSON for {schema_name}, "
|
||||
f"but parsing failed: {e}."
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# Parse according to schema
|
||||
return _parse_with_schema(self.schema, self.schema_kind, data)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content_from_message(message: AIMessage) -> str:
|
||||
"""Extract text content from an `AIMessage`.
|
||||
|
||||
Args:
|
||||
message: The AI message to extract text from
|
||||
|
||||
Returns:
|
||||
The extracted text content
|
||||
"""
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
parts: list[str] = []
|
||||
for c in content:
|
||||
if isinstance(c, dict):
|
||||
if c.get("type") == "text" and "text" in c:
|
||||
parts.append(str(c["text"]))
|
||||
elif "content" in c and isinstance(c["content"], str):
|
||||
parts.append(c["content"])
|
||||
else:
|
||||
parts.append(str(c))
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class AutoStrategy(Generic[SchemaT]):
|
||||
"""Automatically select the best strategy for structured output."""
|
||||
|
||||
schema: type[SchemaT] | dict[str, Any]
|
||||
"""Schema for automatic mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT] | dict[str, Any],
|
||||
) -> None:
|
||||
"""Initialize `AutoStrategy` with schema."""
|
||||
self.schema = schema
|
||||
|
||||
|
||||
ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]
|
||||
"""Union type for all supported response format strategies."""
|
||||
Reference in New Issue
Block a user