initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""Main entrypoint into LangChain."""
__version__ = "1.2.10"

View File

@@ -0,0 +1,9 @@
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentState
__all__ = [
"AgentState",
"create_agent",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
from langchain.agents.middleware.context_editing import ClearToolUsesEdit, ContextEditingMiddleware
from langchain.agents.middleware.file_search import FilesystemFileSearchMiddleware
from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
)
from langchain.agents.middleware.model_call_limit import ModelCallLimitMiddleware
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents.middleware.model_retry import ModelRetryMiddleware
from langchain.agents.middleware.pii import PIIDetectionError, PIIMiddleware
from langchain.agents.middleware.shell_tool import (
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
HostExecutionPolicy,
RedactionRule,
ShellToolMiddleware,
)
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.todo import TodoListMiddleware
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents.middleware.tool_emulator import LLMToolEmulator
from langchain.agents.middleware.tool_retry import ToolRetryMiddleware
from langchain.agents.middleware.tool_selection import LLMToolSelectorMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ExtendedModelResponse,
ModelCallResult,
ModelRequest,
ModelResponse,
ToolCallRequest,
after_agent,
after_model,
before_agent,
before_model,
dynamic_prompt,
hook_config,
wrap_model_call,
wrap_tool_call,
)
__all__ = [
"AgentMiddleware",
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
"ContextEditingMiddleware",
"DockerExecutionPolicy",
"ExtendedModelResponse",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",
"HumanInTheLoopMiddleware",
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"ModelCallLimitMiddleware",
"ModelCallResult",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"ModelRetryMiddleware",
"PIIDetectionError",
"PIIMiddleware",
"RedactionRule",
"ShellToolMiddleware",
"SummarizationMiddleware",
"TodoListMiddleware",
"ToolCallLimitMiddleware",
"ToolCallRequest",
"ToolRetryMiddleware",
"after_agent",
"after_model",
"before_agent",
"before_model",
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]

View File

@@ -0,0 +1,385 @@
"""Execution policies for the persistent shell middleware."""
from __future__ import annotations
import abc
import json
import os
import shutil
import subprocess
import sys
import typing
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
try: # pragma: no cover - optional dependency on POSIX platforms
import resource
_HAS_RESOURCE = True
except ImportError: # pragma: no cover - non-POSIX systems
_HAS_RESOURCE = False
SHELL_TEMP_PREFIX = "langchain-shell-"
def _launch_subprocess(
command: Sequence[str],
*,
env: Mapping[str, str],
cwd: Path,
preexec_fn: typing.Callable[[], None] | None,
start_new_session: bool,
) -> subprocess.Popen[str]:
return subprocess.Popen( # noqa: S603
list(command),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
text=True,
encoding="utf-8",
errors="replace",
bufsize=1,
env=env,
preexec_fn=preexec_fn, # noqa: PLW1509
start_new_session=start_new_session,
)
if typing.TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from pathlib import Path
@dataclass
class BaseExecutionPolicy(abc.ABC):
"""Configuration contract for persistent shell sessions.
Concrete subclasses encapsulate how a shell process is launched and constrained.
Each policy documents its security guarantees and the operating environments in
which it is appropriate. Use `HostExecutionPolicy` for trusted, same-host execution;
`CodexSandboxExecutionPolicy` when the Codex CLI sandbox is available and you want
additional syscall restrictions; and `DockerExecutionPolicy` for container-level
isolation using Docker.
"""
command_timeout: float = 30.0
startup_timeout: float = 30.0
termination_timeout: float = 10.0
max_output_lines: int = 100
max_output_bytes: int | None = None
def __post_init__(self) -> None:
if self.max_output_lines <= 0:
msg = "max_output_lines must be positive."
raise ValueError(msg)
@abc.abstractmethod
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
"""Launch the persistent shell process."""
@dataclass
class HostExecutionPolicy(BaseExecutionPolicy):
"""Run the shell directly on the host process.
This policy is best suited for trusted or single-tenant environments (CI jobs,
developer workstations, pre-sandboxed containers) where the agent must access the
host filesystem and tooling without additional isolation. Enforces optional CPU and
memory limits to prevent runaway commands but offers **no** filesystem or network
sandboxing; commands can modify anything the process user can reach.
On Linux platforms resource limits are applied with `resource.prlimit` after the
shell starts. On macOS, where `prlimit` is unavailable, limits are set in a
`preexec_fn` before `exec`. In both cases the shell runs in its own process group
so timeouts can terminate the full subtree.
"""
cpu_time_seconds: int | None = None
memory_bytes: int | None = None
create_process_group: bool = True
_limits_requested: bool = field(init=False, repr=False, default=False)
def __post_init__(self) -> None:
super().__post_init__()
if self.cpu_time_seconds is not None and self.cpu_time_seconds <= 0:
msg = "cpu_time_seconds must be positive if provided."
raise ValueError(msg)
if self.memory_bytes is not None and self.memory_bytes <= 0:
msg = "memory_bytes must be positive if provided."
raise ValueError(msg)
self._limits_requested = any(
value is not None for value in (self.cpu_time_seconds, self.memory_bytes)
)
if self._limits_requested and not _HAS_RESOURCE:
msg = (
"HostExecutionPolicy cpu/memory limits require the Python 'resource' module. "
"Either remove the limits or run on a POSIX platform."
)
raise RuntimeError(msg)
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
process = _launch_subprocess(
list(command),
env=env,
cwd=workspace,
preexec_fn=self._create_preexec_fn(),
start_new_session=self.create_process_group,
)
self._apply_post_spawn_limits(process)
return process
def _create_preexec_fn(self) -> typing.Callable[[], None] | None:
if not self._limits_requested or self._can_use_prlimit():
return None
def _configure() -> None: # pragma: no cover - depends on OS
if self.cpu_time_seconds is not None:
limit = (self.cpu_time_seconds, self.cpu_time_seconds)
resource.setrlimit(resource.RLIMIT_CPU, limit)
if self.memory_bytes is not None:
limit = (self.memory_bytes, self.memory_bytes)
if hasattr(resource, "RLIMIT_AS"):
resource.setrlimit(resource.RLIMIT_AS, limit)
elif hasattr(resource, "RLIMIT_DATA"):
resource.setrlimit(resource.RLIMIT_DATA, limit)
return _configure
def _apply_post_spawn_limits(self, process: subprocess.Popen[str]) -> None:
if not self._limits_requested or not self._can_use_prlimit():
return
if not _HAS_RESOURCE: # pragma: no cover - defensive
return
pid = process.pid
try:
prlimit = typing.cast("typing.Any", resource).prlimit
if self.cpu_time_seconds is not None:
prlimit(pid, resource.RLIMIT_CPU, (self.cpu_time_seconds, self.cpu_time_seconds))
if self.memory_bytes is not None:
limit = (self.memory_bytes, self.memory_bytes)
if hasattr(resource, "RLIMIT_AS"):
prlimit(pid, resource.RLIMIT_AS, limit)
elif hasattr(resource, "RLIMIT_DATA"):
prlimit(pid, resource.RLIMIT_DATA, limit)
except OSError as exc: # pragma: no cover - depends on platform support
msg = "Failed to apply resource limits via prlimit."
raise RuntimeError(msg) from exc
@staticmethod
def _can_use_prlimit() -> bool:
return _HAS_RESOURCE and hasattr(resource, "prlimit") and sys.platform.startswith("linux")
@dataclass
class CodexSandboxExecutionPolicy(BaseExecutionPolicy):
"""Launch the shell through the Codex CLI sandbox.
Ideal when you have the Codex CLI installed and want the additional syscall and
filesystem restrictions provided by Anthropic's Seatbelt (macOS) or Landlock/seccomp
(Linux) profiles. Commands still run on the host, but within the sandbox requested by
the CLI. If the Codex binary is unavailable or the runtime lacks the required
kernel features (e.g., Landlock inside some containers), process startup fails with a
`RuntimeError`.
Configure sandbox behavior via `config_overrides` to align with your Codex CLI
profile. This policy does not add its own resource limits; combine it with
host-level guards (cgroups, container resource limits) as needed.
"""
binary: str = "codex"
platform: typing.Literal["auto", "macos", "linux"] = "auto"
config_overrides: Mapping[str, typing.Any] = field(default_factory=dict)
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
full_command = self._build_command(command)
return _launch_subprocess(
full_command,
env=env,
cwd=workspace,
preexec_fn=None,
start_new_session=False,
)
def _build_command(self, command: Sequence[str]) -> list[str]:
binary = self._resolve_binary()
platform_arg = self._determine_platform()
full_command: list[str] = [binary, "sandbox", platform_arg]
for key, value in sorted(dict(self.config_overrides).items()):
full_command.extend(["-c", f"{key}={self._format_override(value)}"])
full_command.append("--")
full_command.extend(command)
return full_command
def _resolve_binary(self) -> str:
path = shutil.which(self.binary)
if path is None:
msg = (
"Codex sandbox policy requires the '%s' CLI to be installed and available on PATH."
)
raise RuntimeError(msg % self.binary)
return path
def _determine_platform(self) -> str:
if self.platform != "auto":
return self.platform
if sys.platform.startswith("linux"):
return "linux"
if sys.platform == "darwin": # type: ignore[unreachable, unused-ignore]
return "macos"
msg = ( # type: ignore[unreachable, unused-ignore]
"Codex sandbox policy could not determine a supported platform; "
"set 'platform' explicitly."
)
raise RuntimeError(msg)
@staticmethod
def _format_override(value: typing.Any) -> str:
try:
return json.dumps(value)
except TypeError:
return str(value)
@dataclass
class DockerExecutionPolicy(BaseExecutionPolicy):
"""Run the shell inside a dedicated Docker container.
Choose this policy when commands originate from untrusted users or you require
strong isolation between sessions. By default the workspace is bind-mounted only
when it refers to an existing non-temporary directory; ephemeral sessions run
without a mount to minimise host exposure. The container's network namespace is
disabled by default (`--network none`) and you can enable further hardening via
`read_only_rootfs` and `user`.
The security guarantees depend on your Docker daemon configuration. Run the agent on
a host where Docker is locked down (rootless mode, AppArmor/SELinux, etc.) and
review any additional volumes or capabilities passed through ``extra_run_args``. The
default image is `python:3.12-alpine3.19`; supply a custom image if you need
preinstalled tooling.
"""
binary: str = "docker"
image: str = "python:3.12-alpine3.19"
remove_container_on_exit: bool = True
network_enabled: bool = False
extra_run_args: Sequence[str] | None = None
memory_bytes: int | None = None
cpu_time_seconds: typing.Any | None = None
cpus: str | None = None
read_only_rootfs: bool = False
user: str | None = None
def __post_init__(self) -> None:
super().__post_init__()
if self.memory_bytes is not None and self.memory_bytes <= 0:
msg = "memory_bytes must be positive if provided."
raise ValueError(msg)
if self.cpu_time_seconds is not None:
msg = (
"DockerExecutionPolicy does not support cpu_time_seconds; configure CPU limits "
"using Docker run options such as '--cpus'."
)
raise RuntimeError(msg)
if self.cpus is not None and not self.cpus.strip():
msg = "cpus must be a non-empty string when provided."
raise ValueError(msg)
if self.user is not None and not self.user.strip():
msg = "user must be a non-empty string when provided."
raise ValueError(msg)
self.extra_run_args = tuple(self.extra_run_args or ())
def spawn(
self,
*,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> subprocess.Popen[str]:
full_command = self._build_command(workspace, env, command)
host_env = os.environ.copy()
return _launch_subprocess(
full_command,
env=host_env,
cwd=workspace,
preexec_fn=None,
start_new_session=False,
)
def _build_command(
self,
workspace: Path,
env: Mapping[str, str],
command: Sequence[str],
) -> list[str]:
binary = self._resolve_binary()
full_command: list[str] = [binary, "run", "-i"]
if self.remove_container_on_exit:
full_command.append("--rm")
if not self.network_enabled:
full_command.extend(["--network", "none"])
if self.memory_bytes is not None:
full_command.extend(["--memory", str(self.memory_bytes)])
if self._should_mount_workspace(workspace):
host_path = str(workspace)
full_command.extend(["-v", f"{host_path}:{host_path}"])
full_command.extend(["-w", host_path])
else:
full_command.extend(["-w", "/"])
if self.read_only_rootfs:
full_command.append("--read-only")
for key, value in env.items():
full_command.extend(["-e", f"{key}={value}"])
if self.cpus is not None:
full_command.extend(["--cpus", self.cpus])
if self.user is not None:
full_command.extend(["--user", self.user])
if self.extra_run_args:
full_command.extend(self.extra_run_args)
full_command.append(self.image)
full_command.extend(command)
return full_command
@staticmethod
def _should_mount_workspace(workspace: Path) -> bool:
return not workspace.name.startswith(SHELL_TEMP_PREFIX)
def _resolve_binary(self) -> str:
path = shutil.which(self.binary)
if path is None:
msg = (
"Docker execution policy requires the '%s' CLI to be installed"
" and available on PATH."
)
raise RuntimeError(msg % self.binary)
return path
__all__ = [
"BaseExecutionPolicy",
"CodexSandboxExecutionPolicy",
"DockerExecutionPolicy",
"HostExecutionPolicy",
]

View File

@@ -0,0 +1,436 @@
"""Shared redaction utilities for middleware components."""
from __future__ import annotations
import hashlib
import ipaddress
import operator
import re
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse
from typing_extensions import TypedDict
RedactionStrategy = Literal["block", "redact", "mask", "hash"]
"""Supported strategies for handling detected sensitive values."""
class PIIMatch(TypedDict):
"""Represents an individual match of sensitive data."""
type: str
value: str
start: int
end: int
class PIIDetectionError(Exception):
"""Raised when configured to block on detected sensitive values."""
def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None:
"""Initialize the exception with match context.
Args:
pii_type: Name of the detected sensitive type.
matches: All matches that were detected for that type.
"""
self.pii_type = pii_type
self.matches = list(matches)
count = len(matches)
msg = f"Detected {count} instance(s) of {pii_type} in text content"
super().__init__(msg)
Detector = Callable[[str], list[PIIMatch]]
"""Callable signature for detectors that locate sensitive values."""
def detect_email(content: str) -> list[PIIMatch]:
"""Detect email addresses in content.
Args:
content: The text content to scan for email addresses.
Returns:
A list of detected email matches.
"""
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
return [
PIIMatch(
type="email",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_credit_card(content: str) -> list[PIIMatch]:
"""Detect credit card numbers in content using Luhn validation.
Args:
content: The text content to scan for credit card numbers.
Returns:
A list of detected credit card matches.
"""
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
matches = []
for match in re.finditer(pattern, content):
card_number = match.group()
if _passes_luhn(card_number):
matches.append(
PIIMatch(
type="credit_card",
value=card_number,
start=match.start(),
end=match.end(),
)
)
return matches
def detect_ip(content: str) -> list[PIIMatch]:
"""Detect IPv4 or IPv6 addresses in content.
Args:
content: The text content to scan for IP addresses.
Returns:
A list of detected IP address matches.
"""
matches: list[PIIMatch] = []
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
for match in re.finditer(ipv4_pattern, content):
ip_candidate = match.group()
try:
ipaddress.ip_address(ip_candidate)
except ValueError:
continue
matches.append(
PIIMatch(
type="ip",
value=ip_candidate,
start=match.start(),
end=match.end(),
)
)
return matches
def detect_mac_address(content: str) -> list[PIIMatch]:
"""Detect MAC addresses in content.
Args:
content: The text content to scan for MAC addresses.
Returns:
A list of detected MAC address matches.
"""
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
return [
PIIMatch(
type="mac_address",
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in re.finditer(pattern, content)
]
def detect_url(content: str) -> list[PIIMatch]:
"""Detect URLs in content using regex and stdlib validation.
Args:
content: The text content to scan for URLs.
Returns:
A list of detected URL matches.
"""
matches: list[PIIMatch] = []
# Pattern 1: URLs with scheme (http:// or https://)
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
for match in re.finditer(scheme_pattern, content):
url = match.group()
result = urlparse(url)
if result.scheme in {"http", "https"} and result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=match.start(),
end=match.end(),
)
)
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
# More conservative to avoid false positives
bare_pattern = (
r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?"
r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?"
)
for match in re.finditer(bare_pattern, content):
start, end = match.start(), match.end()
# Skip if already matched with scheme
if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches):
continue
url = match.group()
# Only accept if it has a path or starts with www
# This reduces false positives like "example.com" in prose
if "/" in url or url.startswith("www."):
# Add scheme for validation (required for urlparse to work correctly)
test_url = f"http://{url}"
result = urlparse(test_url)
if result.netloc and "." in result.netloc:
matches.append(
PIIMatch(
type="url",
value=url,
start=start,
end=end,
)
)
return matches
BUILTIN_DETECTORS: dict[str, Detector] = {
"email": detect_email,
"credit_card": detect_credit_card,
"ip": detect_ip,
"mac_address": detect_mac_address,
"url": detect_url,
}
"""Registry of built-in detectors keyed by type name."""
_CARD_NUMBER_MIN_DIGITS = 13
_CARD_NUMBER_MAX_DIGITS = 19
def _passes_luhn(card_number: str) -> bool:
"""Validate credit card number using the Luhn checksum."""
digits = [int(d) for d in card_number if d.isdigit()]
if not _CARD_NUMBER_MIN_DIGITS <= len(digits) <= _CARD_NUMBER_MAX_DIGITS:
return False
checksum = 0
for index, digit in enumerate(reversed(digits)):
value = digit
if index % 2 == 1:
value *= 2
if value > 9: # noqa: PLR2004
value -= 9
checksum += value
return checksum % 10 == 0
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
replacement = f"[REDACTED_{match['type'].upper()}]"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
_UNMASKED_CHAR_NUMBER = 4
_IPV4_PARTS_NUMBER = 4
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
value = match["value"]
pii_type = match["type"]
if pii_type == "email":
parts = value.split("@")
if len(parts) == 2: # noqa: PLR2004
domain_parts = parts[1].split(".")
masked = (
f"{parts[0]}@****.{domain_parts[-1]}"
if len(domain_parts) > 1
else f"{parts[0]}@****"
)
else:
masked = "****"
elif pii_type == "credit_card":
digits_only = "".join(c for c in value if c.isdigit())
separator = "-" if "-" in value else " " if " " in value else ""
if separator:
masked = (
f"****{separator}****{separator}****{separator}"
f"{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
)
else:
masked = f"************{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
elif pii_type == "ip":
octets = value.split(".")
masked = f"*.*.*.{octets[-1]}" if len(octets) == _IPV4_PARTS_NUMBER else "****"
elif pii_type == "mac_address":
separator = ":" if ":" in value else "-"
masked = (
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
)
elif pii_type == "url":
masked = "[MASKED_URL]"
else:
masked = (
f"****{value[-_UNMASKED_CHAR_NUMBER:]}"
if len(value) > _UNMASKED_CHAR_NUMBER
else "****"
)
result = result[: match["start"]] + masked + result[match["end"] :]
return result
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=operator.itemgetter("start"), reverse=True):
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
replacement = f"<{match['type']}_hash:{digest}>"
result = result[: match["start"]] + replacement + result[match["end"] :]
return result
def apply_strategy(
content: str,
matches: list[PIIMatch],
strategy: RedactionStrategy,
) -> str:
"""Apply the configured strategy to matches within content.
Args:
content: The content to apply strategy to.
matches: List of detected PII matches.
strategy: The redaction strategy to apply.
Returns:
The content with the strategy applied.
Raises:
PIIDetectionError: If the strategy is `'block'` and matches are found.
ValueError: If the strategy is unknown.
"""
if not matches:
return content
if strategy == "redact":
return _apply_redact_strategy(content, matches)
if strategy == "mask":
return _apply_mask_strategy(content, matches)
if strategy == "hash":
return _apply_hash_strategy(content, matches)
if strategy == "block":
raise PIIDetectionError(matches[0]["type"], matches)
msg = f"Unknown redaction strategy: {strategy}" # type: ignore[unreachable]
raise ValueError(msg)
def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector:
"""Return a callable detector for the given configuration.
Args:
pii_type: The PII type name.
detector: Optional custom detector or regex pattern. If `None`, a built-in detector
for the given PII type will be used.
Returns:
The resolved detector.
Raises:
ValueError: If an unknown PII type is specified without a custom detector or regex.
"""
if detector is None:
if pii_type not in BUILTIN_DETECTORS:
msg = (
f"Unknown PII type: {pii_type}. "
f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector."
)
raise ValueError(msg)
return BUILTIN_DETECTORS[pii_type]
if isinstance(detector, str):
pattern = re.compile(detector)
def regex_detector(content: str) -> list[PIIMatch]:
return [
PIIMatch(
type=pii_type,
value=match.group(),
start=match.start(),
end=match.end(),
)
for match in pattern.finditer(content)
]
return regex_detector
return detector
@dataclass(frozen=True)
class RedactionRule:
"""Configuration for handling a single PII type."""
pii_type: str
strategy: RedactionStrategy = "redact"
detector: Detector | str | None = None
def resolve(self) -> ResolvedRedactionRule:
"""Resolve runtime detector and return an immutable rule.
Returns:
The resolved redaction rule.
"""
resolved_detector = resolve_detector(self.pii_type, self.detector)
return ResolvedRedactionRule(
pii_type=self.pii_type,
strategy=self.strategy,
detector=resolved_detector,
)
@dataclass(frozen=True)
class ResolvedRedactionRule:
"""Resolved redaction rule ready for execution."""
pii_type: str
strategy: RedactionStrategy
detector: Detector
def apply(self, content: str) -> tuple[str, list[PIIMatch]]:
"""Apply this rule to content, returning new content and matches.
Args:
content: The text content to scan and redact.
Returns:
A tuple of (updated content, list of detected matches).
"""
matches = self.detector(content)
if not matches:
return content, []
updated = apply_strategy(content, matches, self.strategy)
return updated, matches
__all__ = [
"PIIDetectionError",
"PIIMatch",
"RedactionRule",
"ResolvedRedactionRule",
"apply_strategy",
"detect_credit_card",
"detect_email",
"detect_ip",
"detect_mac_address",
"detect_url",
]

View File

@@ -0,0 +1,123 @@
"""Shared retry utilities for agent middleware.
This module contains common constants, utilities, and logic used by both
model and tool retry middleware implementations.
"""
from __future__ import annotations
import random
from collections.abc import Callable
from typing import Literal
# Type aliases
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
"""Type for specifying which exceptions to retry on.
Can be either:
- A tuple of exception types to retry on (based on `isinstance` checks)
- A callable that takes an exception and returns `True` if it should be retried
"""
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
"""Type for specifying failure handling behavior.
Can be either:
- A literal action string (`'error'` or `'continue'`)
- `'error'`: Re-raise the exception, stopping agent execution.
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
For tool retries, a `ToolMessage` with the error details will be injected.
For model retries, an `AIMessage` with the error details will be returned.
- A callable that takes an exception and returns a string for error message content
"""
def validate_retry_params(
max_retries: int,
initial_delay: float,
max_delay: float,
backoff_factor: float,
) -> None:
"""Validate retry parameters.
Args:
max_retries: Maximum number of retry attempts.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
backoff_factor: Multiplier for exponential backoff.
Raises:
ValueError: If any parameter is invalid (negative values).
"""
if max_retries < 0:
msg = "max_retries must be >= 0"
raise ValueError(msg)
if initial_delay < 0:
msg = "initial_delay must be >= 0"
raise ValueError(msg)
if max_delay < 0:
msg = "max_delay must be >= 0"
raise ValueError(msg)
if backoff_factor < 0:
msg = "backoff_factor must be >= 0"
raise ValueError(msg)
def should_retry_exception(
exc: Exception,
retry_on: RetryOn,
) -> bool:
"""Check if an exception should trigger a retry.
Args:
exc: The exception that occurred.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Returns:
`True` if the exception should be retried, `False` otherwise.
"""
if callable(retry_on):
return retry_on(exc)
return isinstance(exc, retry_on)
def calculate_delay(
retry_number: int,
*,
backoff_factor: float,
initial_delay: float,
max_delay: float,
jitter: bool,
) -> float:
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
Args:
retry_number: The retry attempt number (0-indexed).
backoff_factor: Multiplier for exponential backoff.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter to delay to avoid thundering herd.
Returns:
Delay in seconds before next retry.
"""
if backoff_factor == 0.0:
delay = initial_delay
else:
delay = initial_delay * (backoff_factor**retry_number)
# Cap at max_delay
delay = min(delay, max_delay)
if jitter and delay > 0:
jitter_amount = delay * 0.25 # ±25% jitter
delay += random.uniform(-jitter_amount, jitter_amount) # noqa: S311
# Ensure delay is not negative after jitter
delay = max(0, delay)
return delay

View File

@@ -0,0 +1,298 @@
"""Context editing middleware.
Mirrors Anthropic's context editing capabilities by clearing older tool results once the
conversation grows beyond a configurable token threshold.
The implementation is intentionally model-agnostic so it can be used with any LangChain
chat model.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Literal
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from typing_extensions import Protocol
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
TokenCounter = Callable[
[Sequence[BaseMessage]],
int,
]
class ContextEdit(Protocol):
"""Protocol describing a context editing strategy."""
def apply(
self,
messages: list[AnyMessage],
*,
count_tokens: TokenCounter,
) -> None:
"""Apply an edit to the message list in place."""
...
@dataclass(slots=True)
class ClearToolUsesEdit(ContextEdit):
"""Configuration for clearing tool outputs when token limits are exceeded."""
trigger: int = 100_000
"""Token count that triggers the edit."""
clear_at_least: int = 0
"""Minimum number of tokens to reclaim when the edit runs."""
keep: int = 3
"""Number of most recent tool results that must be preserved."""
clear_tool_inputs: bool = False
"""Whether to clear the originating tool call parameters on the AI message."""
exclude_tools: Sequence[str] = ()
"""List of tool names to exclude from clearing."""
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
"""Placeholder text inserted for cleared tool outputs."""
def apply(
self,
messages: list[AnyMessage],
*,
count_tokens: TokenCounter,
) -> None:
"""Apply the clear-tool-uses strategy."""
tokens = count_tokens(messages)
if tokens <= self.trigger:
return
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
candidates = []
elif self.keep:
candidates = candidates[: -self.keep]
cleared_tokens = 0
excluded_tools = set(self.exclude_tools)
for idx, tool_message in candidates:
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
continue
ai_message = next(
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), None
)
if ai_message is None:
continue
tool_call = next(
(
call
for call in ai_message.tool_calls
if call.get("id") == tool_message.tool_call_id
),
None,
)
if tool_call is None:
continue
if (tool_message.name or tool_call["name"]) in excluded_tools:
continue
messages[idx] = tool_message.model_copy(
update={
"artifact": None,
"content": self.placeholder,
"response_metadata": {
**tool_message.response_metadata,
"context_editing": {
"cleared": True,
"strategy": "clear_tool_uses",
},
},
}
)
if self.clear_tool_inputs:
messages[messages.index(ai_message)] = self._build_cleared_tool_input_message(
ai_message,
tool_message.tool_call_id,
)
if self.clear_at_least > 0:
new_token_count = count_tokens(messages)
cleared_tokens = max(0, tokens - new_token_count)
if cleared_tokens >= self.clear_at_least:
break
return
@staticmethod
def _build_cleared_tool_input_message(
message: AIMessage,
tool_call_id: str,
) -> AIMessage:
updated_tool_calls = []
cleared_any = False
for tool_call in message.tool_calls:
updated_call = dict(tool_call)
if updated_call.get("id") == tool_call_id:
updated_call["args"] = {}
cleared_any = True
updated_tool_calls.append(updated_call)
metadata = dict(getattr(message, "response_metadata", {}))
context_entry = dict(metadata.get("context_editing", {}))
if cleared_any:
cleared_ids = set(context_entry.get("cleared_tool_inputs", []))
cleared_ids.add(tool_call_id)
context_entry["cleared_tool_inputs"] = sorted(cleared_ids)
metadata["context_editing"] = context_entry
return message.model_copy(
update={
"tool_calls": updated_tool_calls,
"response_metadata": metadata,
}
)
class ContextEditingMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Automatically prune tool results to manage context size.
The middleware applies a sequence of edits when the total input token count exceeds
configured thresholds.
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
`clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
"""
edits: list[ContextEdit]
token_count_method: Literal["approximate", "model"]
def __init__(
self,
*,
edits: Iterable[ContextEdit] | None = None,
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
) -> None:
"""Initialize an instance of context editing middleware.
Args:
edits: Sequence of edit strategies to apply.
Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
token_count_method: Whether to use approximate token counting
(faster, less accurate) or exact counting implemented by the
chat model (potentially slower, more accurate).
"""
super().__init__()
self.edits = list(edits or (ClearToolUsesEdit(),))
self.token_count_method = token_count_method
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Apply context edits before invoking the model via handler.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The result of invoking the handler with potentially edited messages.
"""
if not request.messages:
return handler(request)
if self.token_count_method == "approximate": # noqa: S105
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = [request.system_message] if request.system_message else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(edited_messages, count_tokens=count_tokens)
return handler(request.override(messages=edited_messages))
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Apply context edits before invoking the model via handler.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The result of invoking the handler with potentially edited messages.
"""
if not request.messages:
return await handler(request)
if self.token_count_method == "approximate": # noqa: S105
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = [request.system_message] if request.system_message else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(edited_messages, count_tokens=count_tokens)
return await handler(request.override(messages=edited_messages))
__all__ = [
"ClearToolUsesEdit",
"ContextEditingMiddleware",
]

View File

@@ -0,0 +1,387 @@
"""File search middleware for Anthropic text editor and memory tools.
This module provides Glob and Grep search tools that operate on files stored
in state or filesystem.
"""
from __future__ import annotations
import fnmatch
import json
import re
import subprocess
from contextlib import suppress
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
from langchain_core.tools import tool
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
def _expand_include_patterns(pattern: str) -> list[str] | None:
"""Expand brace patterns like `*.{py,pyi}` into a list of globs."""
if "}" in pattern and "{" not in pattern:
return None
expanded: list[str] = []
def _expand(current: str) -> None:
start = current.find("{")
if start == -1:
expanded.append(current)
return
end = current.find("}", start)
if end == -1:
raise ValueError
prefix = current[:start]
suffix = current[end + 1 :]
inner = current[start + 1 : end]
if not inner:
raise ValueError
for option in inner.split(","):
_expand(prefix + option + suffix)
try:
_expand(pattern)
except ValueError:
return None
return expanded
def _is_valid_include_pattern(pattern: str) -> bool:
"""Validate glob pattern used for include filters."""
if not pattern:
return False
if any(char in pattern for char in ("\x00", "\n", "\r")):
return False
expanded = _expand_include_patterns(pattern)
if expanded is None:
return False
try:
for candidate in expanded:
re.compile(fnmatch.translate(candidate))
except re.error:
return False
return True
def _match_include_pattern(basename: str, pattern: str) -> bool:
"""Return True if the basename matches the include pattern."""
expanded = _expand_include_patterns(pattern)
if not expanded:
return False
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)
class FilesystemFileSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Provides Glob and Grep search over filesystem files.
This middleware adds two tools that search through local filesystem:
- Glob: Fast file pattern matching by file path
- Grep: Fast content search using ripgrep or Python fallback
Example:
```python
from langchain.agents import create_agent
from langchain.agents.middleware import (
FilesystemFileSearchMiddleware,
)
agent = create_agent(
model=model,
tools=[], # Add tools as needed
middleware=[
FilesystemFileSearchMiddleware(root_path="/workspace"),
],
)
```
"""
def __init__(
self,
*,
root_path: str,
use_ripgrep: bool = True,
max_file_size_mb: int = 10,
) -> None:
"""Initialize the search middleware.
Args:
root_path: Root directory to search.
use_ripgrep: Whether to use `ripgrep` for search.
Falls back to Python if `ripgrep` unavailable.
max_file_size_mb: Maximum file size to search in MB.
"""
self.root_path = Path(root_path).resolve()
self.use_ripgrep = use_ripgrep
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
# Create tool instances as closures that capture self
@tool
def glob_search(pattern: str, path: str = "/") -> str:
"""Fast file pattern matching tool that works with any codebase size.
Supports glob patterns like `**/*.js` or `src/**/*.ts`.
Returns matching file paths sorted by modification time.
Use this tool when you need to find files by name patterns.
Args:
pattern: The glob pattern to match files against.
path: The directory to search in. If not specified, searches from root.
Returns:
Newline-separated list of matching file paths, sorted by modification
time (most recently modified first). Returns `'No files found'` if no
matches.
"""
try:
base_full = self._validate_and_resolve_path(path)
except ValueError:
return "No files found"
if not base_full.exists() or not base_full.is_dir():
return "No files found"
# Use pathlib glob
matching: list[tuple[str, str]] = []
for match in base_full.glob(pattern):
if match.is_file():
# Convert to virtual path
virtual_path = "/" + str(match.relative_to(self.root_path))
stat = match.stat()
modified_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat()
matching.append((virtual_path, modified_at))
if not matching:
return "No files found"
file_paths = [p for p, _ in matching]
return "\n".join(file_paths)
@tool
def grep_search(
pattern: str,
path: str = "/",
include: str | None = None,
output_mode: Literal["files_with_matches", "content", "count"] = "files_with_matches",
) -> str:
"""Fast content search tool that works with any codebase size.
Searches file contents using regular expressions. Supports full regex
syntax and filters files by pattern with the include parameter.
Args:
pattern: The regular expression pattern to search for in file contents.
path: The directory to search in. If not specified, searches from root.
include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
output_mode: Output format:
- `'files_with_matches'`: Only file paths containing matches
- `'content'`: Matching lines with `file:line:content` format
- `'count'`: Count of matches per file
Returns:
Search results formatted according to `output_mode`.
Returns `'No matches found'` if no results.
"""
# Compile regex pattern (for validation)
try:
re.compile(pattern)
except re.error as e:
return f"Invalid regex pattern: {e}"
if include and not _is_valid_include_pattern(include):
return "Invalid include pattern"
# Try ripgrep first if enabled
results = None
if self.use_ripgrep:
with suppress(
FileNotFoundError,
subprocess.CalledProcessError,
subprocess.TimeoutExpired,
):
results = self._ripgrep_search(pattern, path, include)
# Python fallback if ripgrep failed or is disabled
if results is None:
results = self._python_search(pattern, path, include)
if not results:
return "No matches found"
# Format output based on mode
return self._format_grep_results(results, output_mode)
self.glob_search = glob_search
self.grep_search = grep_search
self.tools = [glob_search, grep_search]
def _validate_and_resolve_path(self, path: str) -> Path:
"""Validate and resolve a virtual path to filesystem path."""
# Normalize path
if not path.startswith("/"):
path = "/" + path
# Check for path traversal
if ".." in path or "~" in path:
msg = "Path traversal not allowed"
raise ValueError(msg)
# Convert virtual path to filesystem path
relative = path.lstrip("/")
full_path = (self.root_path / relative).resolve()
# Ensure path is within root
try:
full_path.relative_to(self.root_path)
except ValueError:
msg = f"Path outside root directory: {path}"
raise ValueError(msg) from None
return full_path
def _ripgrep_search(
self, pattern: str, base_path: str, include: str | None
) -> dict[str, list[tuple[int, str]]]:
"""Search using ripgrep subprocess."""
try:
base_full = self._validate_and_resolve_path(base_path)
except ValueError:
return {}
if not base_full.exists():
return {}
# Build ripgrep command
cmd = ["rg", "--json"]
if include:
# Convert glob pattern to ripgrep glob
cmd.extend(["--glob", include])
cmd.extend(["--", pattern, str(base_full)])
try:
result = subprocess.run( # noqa: S603
cmd,
capture_output=True,
text=True,
timeout=30,
check=False,
)
except (subprocess.TimeoutExpired, FileNotFoundError):
# Fallback to Python search if ripgrep unavailable or times out
return self._python_search(pattern, base_path, include)
# Parse ripgrep JSON output
results: dict[str, list[tuple[int, str]]] = {}
for line in result.stdout.splitlines():
try:
data = json.loads(line)
if data["type"] == "match":
path = data["data"]["path"]["text"]
# Convert to virtual path
virtual_path = "/" + str(Path(path).relative_to(self.root_path))
line_num = data["data"]["line_number"]
line_text = data["data"]["lines"]["text"].rstrip("\n")
if virtual_path not in results:
results[virtual_path] = []
results[virtual_path].append((line_num, line_text))
except (json.JSONDecodeError, KeyError):
continue
return results
def _python_search(
self, pattern: str, base_path: str, include: str | None
) -> dict[str, list[tuple[int, str]]]:
"""Search using Python regex (fallback)."""
try:
base_full = self._validate_and_resolve_path(base_path)
except ValueError:
return {}
if not base_full.exists():
return {}
regex = re.compile(pattern)
results: dict[str, list[tuple[int, str]]] = {}
# Walk directory tree
for file_path in base_full.rglob("*"):
if not file_path.is_file():
continue
# Check include filter
if include and not _match_include_pattern(file_path.name, include):
continue
# Skip files that are too large
if file_path.stat().st_size > self.max_file_size_bytes:
continue
try:
content = file_path.read_text()
except (UnicodeDecodeError, PermissionError):
continue
# Search content
for line_num, line in enumerate(content.splitlines(), 1):
if regex.search(line):
virtual_path = "/" + str(file_path.relative_to(self.root_path))
if virtual_path not in results:
results[virtual_path] = []
results[virtual_path].append((line_num, line))
return results
@staticmethod
def _format_grep_results(
results: dict[str, list[tuple[int, str]]],
output_mode: str,
) -> str:
"""Format grep results based on output mode."""
if output_mode == "files_with_matches":
# Just return file paths
return "\n".join(sorted(results.keys()))
if output_mode == "content":
# Return file:line:content format
lines = []
for file_path in sorted(results.keys()):
for line_num, line in results[file_path]:
lines.append(f"{file_path}:{line_num}:{line}")
return "\n".join(lines)
if output_mode == "count":
# Return file:count format
lines = []
for file_path in sorted(results.keys()):
count = len(results[file_path])
lines.append(f"{file_path}:{count}")
return "\n".join(lines)
# Default to files_with_matches
return "\n".join(sorted(results.keys()))
__all__ = [
"FilesystemFileSearchMiddleware",
]

View File

@@ -0,0 +1,387 @@
"""Human in the loop middleware."""
from typing import Any, Literal, Protocol
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
StateT,
)
class Action(TypedDict):
"""Represents an action with a name and args."""
name: str
"""The type or name of action being requested (e.g., `'add_numbers'`)."""
args: dict[str, Any]
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
class ActionRequest(TypedDict):
"""Represents an action request with a name, args, and description."""
name: str
"""The name of the action being requested."""
args: dict[str, Any]
"""Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
description: NotRequired[str]
"""The description of the action to be reviewed."""
DecisionType = Literal["approve", "edit", "reject"]
class ReviewConfig(TypedDict):
"""Policy for reviewing a HITL request."""
action_name: str
"""Name of the action associated with this review configuration."""
allowed_decisions: list[DecisionType]
"""The decisions that are allowed for this request."""
args_schema: NotRequired[dict[str, Any]]
"""JSON schema for the args associated with the action, if edits are allowed."""
class HITLRequest(TypedDict):
"""Request for human feedback on a sequence of actions requested by a model."""
action_requests: list[ActionRequest]
"""A list of agent actions for human review."""
review_configs: list[ReviewConfig]
"""Review configuration for all possible actions."""
class ApproveDecision(TypedDict):
"""Response when a human approves the action."""
type: Literal["approve"]
"""The type of response when a human approves the action."""
class EditDecision(TypedDict):
"""Response when a human edits the action."""
type: Literal["edit"]
"""The type of response when a human edits the action."""
edited_action: Action
"""Edited action for the agent to perform.
Ex: for a tool call, a human reviewer can edit the tool name and args.
"""
class RejectDecision(TypedDict):
"""Response when a human rejects the action."""
type: Literal["reject"]
"""The type of response when a human rejects the action."""
message: NotRequired[str]
"""The message sent to the model explaining why the action was rejected."""
Decision = ApproveDecision | EditDecision | RejectDecision
class HITLResponse(TypedDict):
"""Response payload for a HITLRequest."""
decisions: list[Decision]
"""The decisions made by the human."""
class _DescriptionFactory(Protocol):
"""Callable that generates a description for a tool call."""
def __call__(
self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
) -> str:
"""Generate a description for a tool call."""
...
class InterruptOnConfig(TypedDict):
"""Configuration for an action requiring human in the loop.
This is the configuration format used in the `HumanInTheLoopMiddleware.__init__`
method.
"""
allowed_decisions: list[DecisionType]
"""The decisions that are allowed for this action."""
description: NotRequired[str | _DescriptionFactory]
"""The description attached to the request for human input.
Can be either:
- A static string describing the approval request
- A callable that dynamically generates the description based on agent state,
runtime, and tool call information
Example:
```python
# Static string description
config = ToolConfig(
allowed_decisions=["approve", "reject"],
description="Please review this tool execution"
)
# Dynamic callable description
def format_tool_description(
tool_call: ToolCall,
state: AgentState,
runtime: Runtime[ContextT]
) -> str:
import json
return (
f"Tool: {tool_call['name']}\\n"
f"Arguments:\\n{json.dumps(tool_call['args'], indent=2)}"
)
config = InterruptOnConfig(
allowed_decisions=["approve", "edit", "reject"],
description=format_tool_description
)
```
"""
args_schema: NotRequired[dict[str, Any]]
"""JSON schema for the args associated with the action, if edits are allowed."""
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]):
"""Human in the loop middleware."""
def __init__(
self,
interrupt_on: dict[str, bool | InterruptOnConfig],
*,
description_prefix: str = "Tool execution requires approval",
) -> None:
"""Initialize the human in the loop middleware.
Args:
interrupt_on: Mapping of tool name to allowed actions.
If a tool doesn't have an entry, it's auto-approved by default.
* `True` indicates all decisions are allowed: approve, edit, and reject.
* `False` indicates that the tool is auto-approved.
* `InterruptOnConfig` indicates the specific decisions allowed for this
tool.
The `InterruptOnConfig` can include a `description` field (`str` or
`Callable`) for custom formatting of the interrupt description.
description_prefix: The prefix to use when constructing action requests.
This is used to provide context about the tool call and the action being
requested.
Not used if a tool has a `description` in its `InterruptOnConfig`.
"""
super().__init__()
resolved_configs: dict[str, InterruptOnConfig] = {}
for tool_name, tool_config in interrupt_on.items():
if isinstance(tool_config, bool):
if tool_config is True:
resolved_configs[tool_name] = InterruptOnConfig(
allowed_decisions=["approve", "edit", "reject"]
)
elif tool_config.get("allowed_decisions"):
resolved_configs[tool_name] = tool_config
self.interrupt_on = resolved_configs
self.description_prefix = description_prefix
def _create_action_and_config(
self,
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> tuple[ActionRequest, ReviewConfig]:
"""Create an ActionRequest and ReviewConfig for a tool call."""
tool_name = tool_call["name"]
tool_args = tool_call["args"]
# Generate description using the description field (str or callable)
description_value = config.get("description")
if callable(description_value):
description = description_value(tool_call, state, runtime)
elif description_value is not None:
description = description_value
else:
description = f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
# Create ActionRequest with description
action_request = ActionRequest(
name=tool_name,
args=tool_args,
description=description,
)
# Create ReviewConfig
# eventually can get tool information and populate args_schema from there
review_config = ReviewConfig(
action_name=tool_name,
allowed_decisions=config["allowed_decisions"],
)
return action_request, review_config
@staticmethod
def _process_decision(
decision: Decision,
tool_call: ToolCall,
config: InterruptOnConfig,
) -> tuple[ToolCall | None, ToolMessage | None]:
"""Process a single decision and return the revised tool call and optional tool message."""
allowed_decisions = config["allowed_decisions"]
if decision["type"] == "approve" and "approve" in allowed_decisions:
return tool_call, None
if decision["type"] == "edit" and "edit" in allowed_decisions:
edited_action = decision["edited_action"]
return (
ToolCall(
type="tool_call",
name=edited_action["name"],
args=edited_action["args"],
id=tool_call["id"],
),
None,
)
if decision["type"] == "reject" and "reject" in allowed_decisions:
# Create a tool message with the human's text response
content = decision.get("message") or (
f"User rejected the tool call for `{tool_call['name']}` with id {tool_call['id']}"
)
tool_message = ToolMessage(
content=content,
name=tool_call["name"],
tool_call_id=tool_call["id"],
status="error",
)
return tool_call, tool_message
msg = (
f"Unexpected human decision: {decision}. "
f"Decision type '{decision.get('type')}' "
f"is not allowed for tool '{tool_call['name']}'. "
f"Expected one of {allowed_decisions} based on the tool's configuration."
)
raise ValueError(msg)
def after_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
Updated message with the revised tool calls.
Raises:
ValueError: If the number of human decisions does not match the number of
interrupted tool calls.
"""
messages = state["messages"]
if not messages:
return None
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
if not last_ai_msg or not last_ai_msg.tool_calls:
return None
# Create action requests and review configs for tools that need approval
action_requests: list[ActionRequest] = []
review_configs: list[ReviewConfig] = []
interrupt_indices: list[int] = []
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if (config := self.interrupt_on.get(tool_call["name"])) is not None:
action_request, review_config = self._create_action_and_config(
tool_call, config, state, runtime
)
action_requests.append(action_request)
review_configs.append(review_config)
interrupt_indices.append(idx)
# If no interrupts needed, return early
if not action_requests:
return None
# Create single HITLRequest with all actions and configs
hitl_request = HITLRequest(
action_requests=action_requests,
review_configs=review_configs,
)
# Send interrupt and get response
decisions = interrupt(hitl_request)["decisions"]
# Validate that the number of decisions matches the number of interrupt tool calls
if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
msg = (
f"Number of human decisions ({decisions_len}) does not match "
f"number of hanging tool calls ({interrupt_count})."
)
raise ValueError(msg)
# Process decisions and rebuild tool calls in original order
revised_tool_calls: list[ToolCall] = []
artificial_tool_messages: list[ToolMessage] = []
decision_idx = 0
for idx, tool_call in enumerate(last_ai_msg.tool_calls):
if idx in interrupt_indices:
# This was an interrupt tool call - process the decision
config = self.interrupt_on[tool_call["name"]]
decision = decisions[decision_idx]
decision_idx += 1
revised_tool_call, tool_message = self._process_decision(
decision, tool_call, config
)
if revised_tool_call is not None:
revised_tool_calls.append(revised_tool_call)
if tool_message:
artificial_tool_messages.append(tool_message)
else:
# This was auto-approved - keep original
revised_tool_calls.append(tool_call)
# Update the AI message to only include approved tool calls
last_ai_msg.tool_calls = revised_tool_calls
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
Updated message with the revised tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -0,0 +1,267 @@
"""Call tracking middleware for agents."""
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import AIMessage
from langgraph.channels.untracked_value import UntrackedValue
from typing_extensions import NotRequired, override
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
PrivateStateAttr,
ResponseT,
hook_config,
)
if TYPE_CHECKING:
from langgraph.runtime import Runtime
class ModelCallLimitState(AgentState[ResponseT]):
"""State schema for `ModelCallLimitMiddleware`.
Extends `AgentState` with model call tracking fields.
Type Parameters:
ResponseT: The type of the structured response. Defaults to `Any`.
"""
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
def _build_limit_exceeded_message(
thread_count: int,
run_count: int,
thread_limit: int | None,
run_limit: int | None,
) -> str:
"""Build a message indicating which limits were exceeded.
Args:
thread_count: Current thread model call count.
run_count: Current run model call count.
thread_limit: Thread model call limit (if set).
run_limit: Run model call limit (if set).
Returns:
A formatted message describing which limits were exceeded.
"""
exceeded_limits = []
if thread_limit is not None and thread_count >= thread_limit:
exceeded_limits.append(f"thread limit ({thread_count}/{thread_limit})")
if run_limit is not None and run_count >= run_limit:
exceeded_limits.append(f"run limit ({run_count}/{run_limit})")
return f"Model call limits exceeded: {', '.join(exceeded_limits)}"
class ModelCallLimitExceededError(Exception):
"""Exception raised when model call limits are exceeded.
This exception is raised when the configured exit behavior is `'error'` and either
the thread or run model call limit has been exceeded.
"""
def __init__(
self,
thread_count: int,
run_count: int,
thread_limit: int | None,
run_limit: int | None,
) -> None:
"""Initialize the exception with call count information.
Args:
thread_count: Current thread model call count.
run_count: Current run model call count.
thread_limit: Thread model call limit (if set).
run_limit: Run model call limit (if set).
"""
self.thread_count = thread_count
self.run_count = run_count
self.thread_limit = thread_limit
self.run_limit = run_limit
msg = _build_limit_exceeded_message(thread_count, run_count, thread_limit, run_limit)
super().__init__(msg)
class ModelCallLimitMiddleware(
AgentMiddleware[ModelCallLimitState[ResponseT], ContextT, ResponseT]
):
"""Tracks model call counts and enforces limits.
This middleware monitors the number of model calls made during agent execution
and can terminate the agent when specified limits are reached. It supports
both thread-level and run-level call counting with configurable exit behaviors.
Thread-level: The middleware tracks the number of model calls and persists
call count across multiple runs (invocations) of the agent.
Run-level: The middleware tracks the number of model calls made during a single
run (invocation) of the agent.
Example:
```python
from langchain.agents.middleware.call_tracking import ModelCallLimitMiddleware
from langchain.agents import create_agent
# Create middleware with limits
call_tracker = ModelCallLimitMiddleware(thread_limit=10, run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[call_tracker])
# Agent will automatically jump to end when limits are exceeded
result = await agent.invoke({"messages": [HumanMessage("Help me with a task")]})
```
"""
state_schema = ModelCallLimitState # type: ignore[assignment]
def __init__(
self,
*,
thread_limit: int | None = None,
run_limit: int | None = None,
exit_behavior: Literal["end", "error"] = "end",
) -> None:
"""Initialize the call tracking middleware.
Args:
thread_limit: Maximum number of model calls allowed per thread.
`None` means no limit.
run_limit: Maximum number of model calls allowed per run.
`None` means no limit.
exit_behavior: What to do when limits are exceeded.
- `'end'`: Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was
exceeded.
- `'error'`: Raise a `ModelCallLimitExceededError`
Raises:
ValueError: If both limits are `None` or if `exit_behavior` is invalid.
"""
super().__init__()
if thread_limit is None and run_limit is None:
msg = "At least one limit must be specified (thread_limit or run_limit)"
raise ValueError(msg)
if exit_behavior not in {"end", "error"}:
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
raise ValueError(msg)
self.thread_limit = thread_limit
self.run_limit = run_limit
self.exit_behavior = exit_behavior
@hook_config(can_jump_to=["end"])
@override
def before_model(
self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
thread_count = state.get("thread_model_call_count", 0)
run_count = state.get("run_model_call_count", 0)
# Check if any limits will be exceeded after the next call
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
run_limit_exceeded = self.run_limit is not None and run_count >= self.run_limit
if thread_limit_exceeded or run_limit_exceeded:
if self.exit_behavior == "error":
raise ModelCallLimitExceededError(
thread_count=thread_count,
run_count=run_count,
thread_limit=self.thread_limit,
run_limit=self.run_limit,
)
if self.exit_behavior == "end":
# Create a message indicating the limit was exceeded
limit_message = _build_limit_exceeded_message(
thread_count, run_count, self.thread_limit, self.run_limit
)
limit_ai_message = AIMessage(content=limit_message)
return {"jump_to": "end", "messages": [limit_ai_message]}
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: ModelCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Args:
state: The current agent state containing call counts.
runtime: The langgraph runtime.
Returns:
If limits are exceeded and exit_behavior is `'end'`, returns
a `Command` to jump to the end with a limit exceeded message. Otherwise
returns `None`.
Raises:
ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
"""
return self.before_model(state, runtime)
@override
def after_model(
self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
}
async def aafter_model(
self,
state: ModelCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented call counts.
"""
return self.after_model(state, runtime)

View File

@@ -0,0 +1,138 @@
"""Model fallback middleware for agents."""
from __future__ import annotations
from typing import TYPE_CHECKING
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
class ModelFallbackMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Automatic fallback to alternative models on errors.
Retries failed model calls with alternative models in sequence until
success or all models exhausted. Primary model specified in `create_agent`.
Example:
```python
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
from langchain.agents import create_agent
fallback = ModelFallbackMiddleware(
"openai:gpt-4o-mini", # Try first on error
"anthropic:claude-sonnet-4-5-20250929", # Then this
)
agent = create_agent(
model="openai:gpt-4o", # Primary model
middleware=[fallback],
)
# If primary fails: tries gpt-4o-mini, then claude-sonnet-4-5-20250929
result = await agent.invoke({"messages": [HumanMessage("Hello")]})
```
"""
def __init__(
self,
first_model: str | BaseChatModel,
*additional_models: str | BaseChatModel,
) -> None:
"""Initialize model fallback middleware.
Args:
first_model: First fallback model (string name or instance).
*additional_models: Additional fallbacks in order.
"""
super().__init__()
# Initialize all fallback models
all_models = (first_model, *additional_models)
self.models: list[BaseChatModel] = []
for model in all_models:
if isinstance(model, str):
self.models.append(init_chat_model(model))
else:
self.models.append(model)
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Try fallback models in sequence on errors.
Args:
request: Initial model request.
handler: Callback to execute the model.
Returns:
AIMessage from successful model call.
Raises:
Exception: If all models fail, re-raises last exception.
"""
# Try primary model first
last_exception: Exception
try:
return handler(request)
except Exception as e:
last_exception = e
# Try fallback models
for fallback_model in self.models:
try:
return handler(request.override(model=fallback_model))
except Exception as e:
last_exception = e
continue
raise last_exception
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Try fallback models in sequence on errors (async version).
Args:
request: Initial model request.
handler: Async callback to execute the model.
Returns:
AIMessage from successful model call.
Raises:
Exception: If all models fail, re-raises last exception.
"""
# Try primary model first
last_exception: Exception
try:
return await handler(request)
except Exception as e:
last_exception = e
# Try fallback models
for fallback_model in self.models:
try:
return await handler(request.override(model=fallback_model))
except Exception as e:
last_exception = e
continue
raise last_exception

View File

@@ -0,0 +1,312 @@
"""Model retry middleware for agents."""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING
from langchain_core.messages import AIMessage
from langchain.agents.middleware._retry import (
OnFailure,
RetryOn,
calculate_delay,
should_retry_exception,
validate_retry_params,
)
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
class ModelRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Middleware that automatically retries failed model calls with configurable backoff.
Supports retrying on specific exceptions and exponential backoff.
Examples:
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ModelRetryMiddleware
agent = create_agent(model, tools=[search_tool], middleware=[ModelRetryMiddleware()])
```
!!! example "Retry specific exceptions only"
```python
from anthropic import RateLimitError
from openai import APITimeoutError
retry = ModelRetryMiddleware(
max_retries=4,
retry_on=(APITimeoutError, RateLimitError),
backoff_factor=1.5,
)
```
!!! example "Custom exception filtering"
```python
from anthropic import APIStatusError
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, APIStatusError):
return 500 <= exc.status_code < 600
return False
retry = ModelRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
!!! example "Custom error handling"
```python
def format_error(exc: Exception) -> str:
return "Model temporarily unavailable. Please try again later."
retry = ModelRetryMiddleware(
max_retries=4,
on_failure=format_error,
)
```
!!! example "Constant backoff (no exponential growth)"
```python
retry = ModelRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
!!! example "Raise exception on failure"
```python
retry = ModelRetryMiddleware(
max_retries=2,
on_failure="error", # Re-raise exception instead of returning message
)
```
"""
def __init__(
self,
*,
max_retries: int = 2,
retry_on: RetryOn = (Exception,),
on_failure: OnFailure = "continue",
backoff_factor: float = 2.0,
initial_delay: float = 1.0,
max_delay: float = 60.0,
jitter: bool = True,
) -> None:
"""Initialize `ModelRetryMiddleware`.
Args:
max_retries: Maximum number of retry attempts after the initial call.
Must be `>= 0`.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Default is to retry on all exceptions.
on_failure: Behavior when all retries are exhausted.
Options:
- `'continue'`: Return an `AIMessage` with error details,
allowing the agent to continue with an error response.
- `'error'`: Re-raise the exception, stopping agent execution.
- **Custom callable:** Function that takes the exception and returns a
string for the `AIMessage` content, allowing custom error
formatting.
backoff_factor: Multiplier for exponential backoff.
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
seconds.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
Raises:
ValueError: If `max_retries < 0` or delays are negative.
"""
super().__init__()
# Validate parameters
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
self.max_retries = max_retries
self.tools = [] # No additional tools registered by this middleware
self.retry_on = retry_on
self.on_failure = on_failure
self.backoff_factor = backoff_factor
self.initial_delay = initial_delay
self.max_delay = max_delay
self.jitter = jitter
@staticmethod
def _format_failure_message(exc: Exception, attempts_made: int) -> AIMessage:
"""Format the failure message when retries are exhausted.
Args:
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
`AIMessage` with formatted error message.
"""
exc_type = type(exc).__name__
exc_msg = str(exc)
attempt_word = "attempt" if attempts_made == 1 else "attempts"
content = (
f"Model call failed after {attempts_made} {attempt_word} with {exc_type}: {exc_msg}"
)
return AIMessage(content=content)
def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse[ResponseT]:
"""Handle failure when all retries are exhausted.
Args:
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
`ModelResponse` with error details.
Raises:
Exception: If `on_failure` is `'error'`, re-raises the exception.
"""
if self.on_failure == "error":
raise exc
if callable(self.on_failure):
content = self.on_failure(exc)
ai_msg = AIMessage(content=content)
else:
ai_msg = self._format_failure_message(exc, attempts_made)
return ModelResponse(result=[ai_msg])
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Intercept model execution and retry on failure.
Args:
request: Model request with model, messages, state, and runtime.
handler: Callable to execute the model (can be called multiple times).
Returns:
`ModelResponse` or `AIMessage` (the final result).
Raises:
RuntimeError: If the retry loop completes without returning. (This should not happen.)
"""
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
time.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Intercept and control async model execution with retry logic.
Args:
request: Model request with model, messages, state, and runtime.
handler: Async callable to execute the model and returns `ModelResponse`.
Returns:
`ModelResponse` or `AIMessage` (the final result).
Raises:
RuntimeError: If the retry loop completes without returning. (This should not happen.)
"""
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
await asyncio.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)

View File

@@ -0,0 +1,376 @@
"""PII detection and handling middleware for agents."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from typing_extensions import override
from langchain.agents.middleware._redaction import (
PIIDetectionError,
PIIMatch,
RedactionRule,
ResolvedRedactionRule,
apply_strategy,
detect_credit_card,
detect_email,
detect_ip,
detect_mac_address,
detect_url,
)
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ResponseT,
hook_config,
)
if TYPE_CHECKING:
from collections.abc import Callable
from langgraph.runtime import Runtime
class PIIMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Detect and handle Personally Identifiable Information (PII) in conversations.
This middleware detects common PII types and applies configurable strategies
to handle them. It can detect emails, credit cards, IP addresses, MAC addresses, and
URLs in both user input and agent output.
Built-in PII types:
- `email`: Email addresses
- `credit_card`: Credit card numbers (validated with Luhn algorithm)
- `ip`: IP addresses (validated with stdlib)
- `mac_address`: MAC addresses
- `url`: URLs (both `http`/`https` and bare URLs)
Strategies:
- `block`: Raise an exception when PII is detected
- `redact`: Replace PII with `[REDACTED_TYPE]` placeholders
- `mask`: Partially mask PII (e.g., `****-****-****-1234` for credit card)
- `hash`: Replace PII with deterministic hash (e.g., `<email_hash:a1b2c3d4>`)
Strategy Selection Guide:
| Strategy | Preserves Identity? | Best For |
| -------- | ------------------- | --------------------------------------- |
| `block` | N/A | Avoid PII completely |
| `redact` | No | General compliance, log sanitization |
| `mask` | No | Human readability, customer service UIs |
| `hash` | Yes (pseudonymous) | Analytics, debugging |
Example:
```python
from langchain.agents.middleware import PIIMiddleware
from langchain.agents import create_agent
# Redact all emails in user input
agent = create_agent(
"openai:gpt-5",
middleware=[
PIIMiddleware("email", strategy="redact"),
],
)
# Use different strategies for different PII types
agent = create_agent(
"openai:gpt-4o",
middleware=[
PIIMiddleware("credit_card", strategy="mask"),
PIIMiddleware("url", strategy="redact"),
PIIMiddleware("ip", strategy="hash"),
],
)
# Custom PII type with regex
agent = create_agent(
"openai:gpt-5",
middleware=[
PIIMiddleware("api_key", detector=r"sk-[a-zA-Z0-9]{32}", strategy="block"),
],
)
```
"""
def __init__(
self,
# From a typing point of view, the literals are covered by 'str'.
# Nonetheless, we escape PYI051 to keep hints and autocompletion for the caller.
pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
*,
strategy: Literal["block", "redact", "mask", "hash"] = "redact",
detector: Callable[[str], list[PIIMatch]] | str | None = None,
apply_to_input: bool = True,
apply_to_output: bool = False,
apply_to_tool_results: bool = False,
) -> None:
"""Initialize the PII detection middleware.
Args:
pii_type: Type of PII to detect.
Can be a built-in type (`email`, `credit_card`, `ip`, `mac_address`,
`url`) or a custom type name.
strategy: How to handle detected PII.
Options:
* `block`: Raise `PIIDetectionError` when PII is detected
* `redact`: Replace with `[REDACTED_TYPE]` placeholders
* `mask`: Partially mask PII (show last few characters)
* `hash`: Replace with deterministic hash (format: `<type_hash:digest>`)
detector: Custom detector function or regex pattern.
* If `Callable`: Function that takes content string and returns
list of `PIIMatch` objects
* If `str`: Regex pattern to match PII
* If `None`: Uses built-in detector for the `pii_type`
apply_to_input: Whether to check user messages before model call.
apply_to_output: Whether to check AI messages after model call.
apply_to_tool_results: Whether to check tool result messages after tool execution.
Raises:
ValueError: If `pii_type` is not built-in and no detector is provided.
"""
super().__init__()
self.apply_to_input = apply_to_input
self.apply_to_output = apply_to_output
self.apply_to_tool_results = apply_to_tool_results
self._resolved_rule: ResolvedRedactionRule = RedactionRule(
pii_type=pii_type,
strategy=strategy,
detector=detector,
).resolve()
self.pii_type = self._resolved_rule.pii_type
self.strategy = self._resolved_rule.strategy
self.detector = self._resolved_rule.detector
@property
def name(self) -> str:
"""Name of the middleware."""
return f"{self.__class__.__name__}[{self.pii_type}]"
def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
"""Apply the configured redaction rule to the provided content."""
matches = self.detector(content)
if not matches:
return content, []
sanitized = apply_strategy(content, matches, self.strategy)
return sanitized, matches
@hook_config(can_jump_to=["end"])
@override
def before_model(
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
if not self.apply_to_input and not self.apply_to_tool_results:
return None
messages = state["messages"]
if not messages:
return None
new_messages = list(messages)
any_modified = False
# Check user input if enabled
if self.apply_to_input:
# Get last user message
last_user_msg = None
last_user_idx = None
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], HumanMessage):
last_user_msg = messages[i]
last_user_idx = i
break
if last_user_idx is not None and last_user_msg and last_user_msg.content:
# Detect PII in message content
content = str(last_user_msg.content)
new_content, matches = self._process_content(content)
if matches:
updated_message: AnyMessage = HumanMessage(
content=new_content,
id=last_user_msg.id,
name=last_user_msg.name,
)
new_messages[last_user_idx] = updated_message
any_modified = True
# Check tool results if enabled
if self.apply_to_tool_results:
# Find the last AIMessage, then process all `ToolMessage` objects after it
last_ai_idx = None
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], AIMessage):
last_ai_idx = i
break
if last_ai_idx is not None:
# Get all tool messages after the last AI message
for i in range(last_ai_idx + 1, len(messages)):
msg = messages[i]
if isinstance(msg, ToolMessage):
tool_msg = msg
if not tool_msg.content:
continue
content = str(tool_msg.content)
new_content, matches = self._process_content(content)
if not matches:
continue
# Create updated tool message
updated_message = ToolMessage(
content=new_content,
id=tool_msg.id,
name=tool_msg.name,
tool_call_id=tool_msg.tool_call_id,
)
new_messages[i] = updated_message
any_modified = True
if any_modified:
return {"messages": new_messages}
return None
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or `None` if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.before_model(state, runtime)
@override
def after_model(
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
if not self.apply_to_output:
return None
messages = state["messages"]
if not messages:
return None
# Get last AI message
last_ai_msg = None
last_ai_idx = None
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage):
last_ai_msg = msg
last_ai_idx = i
break
if last_ai_idx is None or not last_ai_msg or not last_ai_msg.content:
return None
# Detect PII in message content
content = str(last_ai_msg.content)
new_content, matches = self._process_content(content)
if not matches:
return None
# Create updated message
updated_message = AIMessage(
content=new_content,
id=last_ai_msg.id,
name=last_ai_msg.name,
tool_calls=last_ai_msg.tool_calls,
)
# Return updated messages
new_messages = list(messages)
new_messages[last_ai_idx] = updated_message
return {"messages": new_messages}
async def aafter_model(
self,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
Updated state with PII handled according to strategy, or None if no PII
detected.
Raises:
PIIDetectionError: If PII is detected and strategy is `'block'`.
"""
return self.after_model(state, runtime)
__all__ = [
"PIIDetectionError",
"PIIMatch",
"PIIMiddleware",
"detect_credit_card",
"detect_email",
"detect_ip",
"detect_mac_address",
"detect_url",
]

View File

@@ -0,0 +1,882 @@
"""Middleware that exposes a persistent shell tool to agents."""
from __future__ import annotations
import contextlib
import logging
import os
import queue
import signal
import subprocess
import tempfile
import threading
import time
import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from langchain_core.messages import ToolMessage
from langchain_core.runnables import run_in_executor
from langchain_core.tools.base import ToolException
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired, override
from langchain.agents.middleware._execution import (
SHELL_TEMP_PREFIX,
BaseExecutionPolicy,
CodexSandboxExecutionPolicy,
DockerExecutionPolicy,
HostExecutionPolicy,
)
from langchain.agents.middleware._redaction import (
PIIDetectionError,
PIIMatch,
RedactionRule,
ResolvedRedactionRule,
)
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
PrivateStateAttr,
ResponseT,
)
from langchain.tools import ToolRuntime, tool
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from langgraph.runtime import Runtime
LOGGER = logging.getLogger(__name__)
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
DEFAULT_TOOL_DESCRIPTION = (
"Execute a shell command inside a persistent session. Before running a command, "
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
"session remains stable. Outputs may be truncated when they become very large, and long "
"running commands will be terminated once their configured timeout elapses."
)
SHELL_TOOL_NAME = "shell"
def _cleanup_resources(
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
) -> None:
with contextlib.suppress(Exception):
session.stop(timeout)
if tempdir is not None:
with contextlib.suppress(Exception):
tempdir.cleanup()
@dataclass
class _SessionResources:
"""Container for per-run shell resources."""
session: ShellSession
tempdir: tempfile.TemporaryDirectory[str] | None
policy: BaseExecutionPolicy
finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
def __post_init__(self) -> None:
self.finalizer = weakref.finalize(
self,
_cleanup_resources,
self.session,
self.tempdir,
self.policy.termination_timeout,
)
class ShellToolState(AgentState[ResponseT]):
"""Agent state extension for tracking shell session resources.
Type Parameters:
ResponseT: The type of the structured response. Defaults to `Any`.
"""
shell_session_resources: NotRequired[
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
]
@dataclass(frozen=True)
class CommandExecutionResult:
"""Structured result from command execution."""
output: str
exit_code: int | None
timed_out: bool
truncated_by_lines: bool
truncated_by_bytes: bool
total_lines: int
total_bytes: int
class ShellSession:
"""Persistent shell session that supports sequential command execution."""
def __init__(
self,
workspace: Path,
policy: BaseExecutionPolicy,
command: tuple[str, ...],
environment: Mapping[str, str],
) -> None:
self._workspace = workspace
self._policy = policy
self._command = command
self._environment = dict(environment)
self._process: subprocess.Popen[str] | None = None
self._stdin: Any = None
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
self._lock = threading.Lock()
self._stdout_thread: threading.Thread | None = None
self._stderr_thread: threading.Thread | None = None
self._terminated = False
def start(self) -> None:
"""Start the shell subprocess and reader threads.
Raises:
RuntimeError: If the shell session pipes cannot be initialized.
"""
if self._process and self._process.poll() is None:
return
self._process = self._policy.spawn(
workspace=self._workspace,
env=self._environment,
command=self._command,
)
if (
self._process.stdin is None
or self._process.stdout is None
or self._process.stderr is None
):
msg = "Failed to initialize shell session pipes."
raise RuntimeError(msg)
self._stdin = self._process.stdin
self._terminated = False
self._queue = queue.Queue()
self._stdout_thread = threading.Thread(
target=self._enqueue_stream,
args=(self._process.stdout, "stdout"),
daemon=True,
)
self._stderr_thread = threading.Thread(
target=self._enqueue_stream,
args=(self._process.stderr, "stderr"),
daemon=True,
)
self._stdout_thread.start()
self._stderr_thread.start()
def restart(self) -> None:
"""Restart the shell process."""
self.stop(self._policy.termination_timeout)
self.start()
def stop(self, timeout: float) -> None:
"""Stop the shell subprocess."""
if not self._process:
return
if self._process.poll() is None and not self._terminated:
try:
self._stdin.write("exit\n")
self._stdin.flush()
except (BrokenPipeError, OSError):
LOGGER.debug(
"Failed to write exit command; terminating shell session.",
exc_info=True,
)
try:
if self._process.wait(timeout=timeout) is None:
self._kill_process()
except subprocess.TimeoutExpired:
self._kill_process()
finally:
self._terminated = True
with contextlib.suppress(Exception):
self._stdin.close()
self._process = None
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
"""Execute a command in the persistent shell."""
if not self._process or self._process.poll() is not None:
msg = "Shell session is not running."
raise RuntimeError(msg)
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
deadline = time.monotonic() + timeout
with self._lock:
self._drain_queue()
payload = command if command.endswith("\n") else f"{command}\n"
try:
self._stdin.write(payload)
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
self._stdin.flush()
except (BrokenPipeError, OSError):
# The shell exited before we could write the marker command.
# This happens when commands like 'exit 1' terminate the shell.
return self._collect_output_after_exit(deadline)
return self._collect_output(marker, deadline, timeout)
def _collect_output(
self,
marker: str,
deadline: float,
timeout: float,
) -> CommandExecutionResult:
collected: list[str] = []
total_lines = 0
total_bytes = 0
truncated_by_lines = False
truncated_by_bytes = False
exit_code: int | None = None
timed_out = False
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
timed_out = True
break
try:
source, data = self._queue.get(timeout=remaining)
except queue.Empty:
timed_out = True
break
if data is None:
continue
if source == "stdout" and data.startswith(marker):
_, _, status = data.partition(" ")
exit_code = self._safe_int(status.strip())
# Drain any remaining stderr that may have arrived concurrently.
# The stderr reader thread runs independently, so output might
# still be in flight when the stdout marker arrives.
self._drain_remaining_stderr(collected, deadline)
break
total_lines += 1
encoded = data.encode("utf-8", "replace")
total_bytes += len(encoded)
if total_lines > self._policy.max_output_lines:
truncated_by_lines = True
continue
if (
self._policy.max_output_bytes is not None
and total_bytes > self._policy.max_output_bytes
):
truncated_by_bytes = True
continue
if source == "stderr":
stripped = data.rstrip("\n")
collected.append(f"[stderr] {stripped}")
if data.endswith("\n"):
collected.append("\n")
else:
collected.append(data)
if timed_out:
LOGGER.warning(
"Command timed out after %.2f seconds; restarting shell session.",
timeout,
)
self.restart()
return CommandExecutionResult(
output="",
exit_code=None,
timed_out=True,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
output = "".join(collected)
return CommandExecutionResult(
output=output,
exit_code=exit_code,
timed_out=False,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
"""Collect output after the shell exited unexpectedly.
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
shell process terminated (e.g., due to an 'exit' command).
Args:
deadline: Absolute time by which collection must complete.
Returns:
`CommandExecutionResult` with collected output and the process exit code.
"""
collected: list[str] = []
total_lines = 0
total_bytes = 0
truncated_by_lines = False
truncated_by_bytes = False
# Give reader threads a brief moment to enqueue any remaining output.
drain_timeout = 0.1
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
while True:
remaining = drain_deadline - time.monotonic()
if remaining <= 0:
break
try:
source, data = self._queue.get(timeout=remaining)
except queue.Empty:
break
if data is None:
# EOF marker from a reader thread; continue draining.
continue
total_lines += 1
encoded = data.encode("utf-8", "replace")
total_bytes += len(encoded)
if total_lines > self._policy.max_output_lines:
truncated_by_lines = True
continue
if (
self._policy.max_output_bytes is not None
and total_bytes > self._policy.max_output_bytes
):
truncated_by_bytes = True
continue
if source == "stderr":
stripped = data.rstrip("\n")
collected.append(f"[stderr] {stripped}")
if data.endswith("\n"):
collected.append("\n")
else:
collected.append(data)
# Get exit code from the terminated process.
exit_code: int | None = None
if self._process:
exit_code = self._process.poll()
output = "".join(collected)
return CommandExecutionResult(
output=output,
exit_code=exit_code,
timed_out=False,
truncated_by_lines=truncated_by_lines,
truncated_by_bytes=truncated_by_bytes,
total_lines=total_lines,
total_bytes=total_bytes,
)
def _kill_process(self) -> None:
if not self._process:
return
if hasattr(os, "killpg"):
with contextlib.suppress(ProcessLookupError):
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
else: # pragma: no cover
with contextlib.suppress(ProcessLookupError):
self._process.kill()
def _enqueue_stream(self, stream: Any, label: str) -> None:
for line in iter(stream.readline, ""):
self._queue.put((label, line))
self._queue.put((label, None))
def _drain_queue(self) -> None:
while True:
try:
self._queue.get_nowait()
except queue.Empty:
break
def _drain_remaining_stderr(
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
) -> None:
"""Drain any stderr output that arrived concurrently with the done marker.
The stdout and stderr reader threads run independently. When a command writes to
stderr just before exiting, the stderr output may still be in transit when the
done marker arrives on stdout. This method briefly polls the queue to capture
such output.
Args:
collected: The list to append collected stderr lines to.
deadline: The original command deadline (used as an upper bound).
drain_timeout: Maximum time to wait for additional stderr output.
"""
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
while True:
remaining = drain_deadline - time.monotonic()
if remaining <= 0:
break
try:
source, data = self._queue.get(timeout=remaining)
except queue.Empty:
break
if data is None or source != "stderr":
continue
stripped = data.rstrip("\n")
collected.append(f"[stderr] {stripped}")
if data.endswith("\n"):
collected.append("\n")
@staticmethod
def _safe_int(value: str) -> int | None:
with contextlib.suppress(ValueError):
return int(value)
return None
class _ShellToolInput(BaseModel):
"""Input schema for the persistent shell tool."""
command: str | None = None
"""The shell command to execute."""
restart: bool | None = None
"""Whether to restart the shell session."""
runtime: Annotated[Any, SkipJsonSchema()] = None
"""The runtime for the shell tool.
Included as a workaround at the moment bc args_schema doesn't work with
injected ToolRuntime.
"""
@model_validator(mode="after")
def validate_payload(self) -> _ShellToolInput:
if self.command is None and not self.restart:
msg = "Shell tool requires either 'command' or 'restart'."
raise ValueError(msg)
if self.command is not None and self.restart:
msg = "Specify only one of 'command' or 'restart'."
raise ValueError(msg)
return self
class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, ResponseT]):
"""Middleware that registers a persistent shell tool for agents.
The middleware exposes a single long-lived shell session. Use the execution policy
to match your deployment's security posture:
* `HostExecutionPolicy` full host access; best for trusted environments where the
agent already runs inside a container or VM that provides isolation.
* `CodexSandboxExecutionPolicy` reuses the Codex CLI sandbox for additional
syscall/filesystem restrictions when the CLI is available.
* `DockerExecutionPolicy` launches a separate Docker container for each agent run,
providing harder isolation, optional read-only root filesystems, and user
remapping.
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
"""
state_schema = ShellToolState # type: ignore[assignment]
def __init__(
self,
workspace_root: str | Path | None = None,
*,
startup_commands: tuple[str, ...] | list[str] | str | None = None,
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
execution_policy: BaseExecutionPolicy | None = None,
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
tool_description: str | None = None,
tool_name: str = SHELL_TOOL_NAME,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
) -> None:
"""Initialize an instance of `ShellToolMiddleware`.
Args:
workspace_root: Base directory for the shell session.
If omitted, a temporary directory is created when the agent starts and
removed when it ends.
startup_commands: Optional commands executed sequentially after the session
starts.
shutdown_commands: Optional commands executed before the session shuts down.
execution_policy: Execution policy controlling timeouts, output limits, and
resource configuration.
Defaults to `HostExecutionPolicy` for native execution.
redaction_rules: Optional redaction rules to sanitize command output before
returning it to the model.
!!! warning
Redaction rules are applied post execution and do not prevent
exfiltration of secrets or sensitive data when using
`HostExecutionPolicy`.
tool_description: Optional override for the registered shell tool
description.
tool_name: Name for the registered shell tool.
Defaults to `"shell"`.
shell_command: Optional shell executable (string) or argument sequence used
to launch the persistent session.
Defaults to an implementation-defined bash command.
env: Optional environment variables to supply to the shell session.
Values are coerced to strings before command execution. If omitted, the
session inherits the parent process environment.
"""
super().__init__()
self._workspace_root = Path(workspace_root) if workspace_root else None
self._tool_name = tool_name
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
if execution_policy is not None:
self._execution_policy = execution_policy
else:
self._execution_policy = HostExecutionPolicy()
rules = redaction_rules or ()
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
rule.resolve() for rule in rules
)
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)
# Create a proper tool that executes directly (no interception needed)
description = tool_description or DEFAULT_TOOL_DESCRIPTION
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
def shell_tool(
*,
runtime: ToolRuntime[None, ShellToolState],
command: str | None = None,
restart: bool = False,
) -> ToolMessage | str:
resources = self._get_or_create_resources(runtime.state)
return self._run_shell_tool(
resources,
{"command": command, "restart": restart},
tool_call_id=runtime.tool_call_id,
)
self._shell_tool = shell_tool
self.tools = [self._shell_tool]
@staticmethod
def _normalize_commands(
commands: tuple[str, ...] | list[str] | str | None,
) -> tuple[str, ...]:
if commands is None:
return ()
if isinstance(commands, str):
return (commands,)
return tuple(commands)
@staticmethod
def _normalize_shell_command(
shell_command: Sequence[str] | str | None,
) -> tuple[str, ...]:
if shell_command is None:
return ("/bin/bash",)
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
if not normalized:
msg = "Shell command must contain at least one argument."
raise ValueError(msg)
return normalized
@staticmethod
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
if env is None:
return None
normalized: dict[str, str] = {}
for key, value in env.items():
if not isinstance(key, str):
msg = "Environment variable names must be strings." # type: ignore[unreachable]
raise TypeError(msg)
normalized[key] = str(value)
return normalized
@override
def before_agent(
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Start the shell session and run startup commands.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
Shell session resources to be stored in the agent state.
"""
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
async def abefore_agent(
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async start the shell session and run startup commands.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
Shell session resources to be stored in the agent state.
"""
return await run_in_executor(None, self.before_agent, state, runtime)
@override
def after_agent(self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]) -> None:
"""Run shutdown commands and release resources when an agent completes."""
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
# Resources were never created, nothing to clean up
return
try:
self._run_shutdown_commands(resources.session)
finally:
resources.finalizer()
async def aafter_agent(
self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]
) -> None:
"""Async run shutdown commands and release resources when an agent completes."""
return self.after_agent(state, runtime)
def _get_or_create_resources(self, state: ShellToolState[ResponseT]) -> _SessionResources:
"""Get existing resources from state or create new ones if they don't exist.
This method enables resumability by checking if resources already exist in the state
(e.g., after an interrupt), and only creating new resources if they're not present.
Args:
state: The agent state which may contain shell session resources.
Returns:
Session resources, either retrieved from state or newly created.
"""
resources = state.get("shell_session_resources")
if isinstance(resources, _SessionResources):
return resources
new_resources = self._create_resources()
# Cast needed to make state dict-like for mutation
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
return new_resources
def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
tempdir: tempfile.TemporaryDirectory[str] | None = None
if workspace is None:
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
workspace_path = Path(tempdir.name)
else:
workspace_path = workspace
workspace_path.mkdir(parents=True, exist_ok=True)
session = ShellSession(
workspace_path,
self._execution_policy,
self._shell_command,
self._environment or {},
)
try:
session.start()
LOGGER.info("Started shell session in %s", workspace_path)
self._run_startup_commands(session)
except BaseException:
LOGGER.exception("Starting shell session failed; cleaning up resources.")
session.stop(self._execution_policy.termination_timeout)
if tempdir is not None:
tempdir.cleanup()
raise
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
def _run_startup_commands(self, session: ShellSession) -> None:
if not self._startup_commands:
return
for command in self._startup_commands:
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
if result.timed_out or (result.exit_code not in {0, None}):
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
raise RuntimeError(msg)
def _run_shutdown_commands(self, session: ShellSession) -> None:
if not self._shutdown_commands:
return
for command in self._shutdown_commands:
try:
result = session.execute(command, timeout=self._execution_policy.command_timeout)
if result.timed_out:
LOGGER.warning("Shutdown command '%s' timed out.", command)
elif result.exit_code not in {0, None}:
LOGGER.warning(
"Shutdown command '%s' exited with %s.", command, result.exit_code
)
except (RuntimeError, ToolException, OSError) as exc:
LOGGER.warning(
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
)
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
"""Apply configured redaction rules to command output."""
matches_by_type: dict[str, list[PIIMatch]] = {}
updated = content
for rule in self._redaction_rules:
updated, matches = rule.apply(updated)
if matches:
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
return updated, matches_by_type
def _run_shell_tool(
self,
resources: _SessionResources,
payload: dict[str, Any],
*,
tool_call_id: str | None,
) -> Any:
session = resources.session
if payload.get("restart"):
LOGGER.info("Restarting shell session on request.")
try:
session.restart()
self._run_startup_commands(session)
except BaseException as err:
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
msg = "Failed to restart shell session."
raise ToolException(msg) from err
message = "Shell session restarted."
return self._format_tool_message(message, tool_call_id, status="success")
command = payload.get("command")
if not command or not isinstance(command, str):
msg = "Shell tool expects a 'command' string when restart is not requested."
raise ToolException(msg)
LOGGER.info("Executing shell command: %s", command)
result = session.execute(command, timeout=self._execution_policy.command_timeout)
if result.timed_out:
timeout_seconds = self._execution_policy.command_timeout
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
return self._format_tool_message(
message,
tool_call_id,
status="error",
artifact={
"timed_out": True,
"exit_code": None,
},
)
try:
sanitized_output, matches = self._apply_redactions(result.output)
except PIIDetectionError as error:
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
message = f"Output blocked: detected {error.pii_type}."
return self._format_tool_message(
message,
tool_call_id,
status="error",
artifact={
"timed_out": False,
"exit_code": result.exit_code,
"matches": {error.pii_type: error.matches},
},
)
sanitized_output = sanitized_output or "<no output>"
if result.truncated_by_lines:
sanitized_output = (
f"{sanitized_output.rstrip()}\n\n"
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
f"(observed {result.total_lines})."
)
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
sanitized_output = (
f"{sanitized_output.rstrip()}\n\n"
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
f"(observed {result.total_bytes})."
)
if result.exit_code not in {0, None}:
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
final_status: Literal["success", "error"] = "error"
else:
final_status = "success"
artifact = {
"timed_out": False,
"exit_code": result.exit_code,
"truncated_by_lines": result.truncated_by_lines,
"truncated_by_bytes": result.truncated_by_bytes,
"total_lines": result.total_lines,
"total_bytes": result.total_bytes,
"redaction_matches": matches,
}
return self._format_tool_message(
sanitized_output,
tool_call_id,
status=final_status,
artifact=artifact,
)
def _format_tool_message(
self,
content: str,
tool_call_id: str | None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage | str:
artifact = artifact or {}
if tool_call_id is None:
return content
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=self._tool_name,
status=status,
artifact=artifact,
)
__all__ = [
"CodexSandboxExecutionPolicy",
"DockerExecutionPolicy",
"HostExecutionPolicy",
"RedactionRule",
"ShellToolMiddleware",
]

View File

@@ -0,0 +1,658 @@
"""Summarization middleware."""
import uuid
import warnings
from collections.abc import Callable, Iterable, Mapping
from functools import partial
from typing import Any, Literal, cast
from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
ToolMessage,
)
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.utils import (
count_tokens_approximately,
get_buffer_string,
trim_messages,
)
from langgraph.graph.message import (
REMOVE_ALL_MESSAGES,
)
from langgraph.runtime import Runtime
from typing_extensions import override
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
from langchain.chat_models import BaseChatModel, init_chat_model
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
DEFAULT_SUMMARY_PROMPT = """<role>
Context Extraction Assistant
</role>
<primary_objective>
Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
</primary_objective>
<objective_information>
You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to continue working toward your overall goal.
</objective_information>
<instructions>
The conversation history below will be replaced with the context you extract in this step.
You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
You should structure your summary using the following sections. Each section acts as a checklist - you must populate it with relevant information or explicitly state "None" if there is nothing to report for that section:
## SESSION INTENT
What is the user's primary goal or request? What overall task are you trying to accomplish? This should be concise but complete enough to understand the purpose of the entire session.
## SUMMARY
Extract and record all of the most important context from the conversation history. Include important choices, conclusions, or strategies determined during this conversation. Include the reasoning behind key decisions. Document any rejected options and why they were not pursued.
## ARTIFACTS
What artifacts, files, or resources were created, modified, or accessed during this conversation? For file modifications, list specific file paths and briefly describe the changes made to each. This section prevents silent loss of artifact information.
## NEXT STEPS
What specific tasks remain to be completed to achieve the session intent? What should you do next?
</instructions>
The user will message you with the full message history from which you'll extract context to create a replacement. Carefully read through it all and think deeply about what information is most important to your overall goal and should be saved:
With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
<messages>
Messages to summarize:
{messages}
</messages>""" # noqa: E501
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
ContextFraction = tuple[Literal["fraction"], float]
"""Fraction of model's maximum input tokens.
Example:
To specify 50% of the model's max input tokens:
```python
("fraction", 0.5)
```
"""
ContextTokens = tuple[Literal["tokens"], int]
"""Absolute number of tokens.
Example:
To specify 3000 tokens:
```python
("tokens", 3000)
```
"""
ContextMessages = tuple[Literal["messages"], int]
"""Absolute number of messages.
Example:
To specify 50 messages:
```python
("messages", 50)
```
"""
ContextSize = ContextFraction | ContextTokens | ContextMessages
"""Union type for context size specifications.
Can be either:
- [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
fraction of the model's maximum input tokens.
- [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
number of tokens.
- [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
absolute number of messages.
Depending on use with `trigger` or `keep` parameters, this type indicates either
when to trigger summarization or how much context to retain.
Example:
```python
# ContextFraction
context_size: ContextSize = ("fraction", 0.5)
# ContextTokens
context_size: ContextSize = ("tokens", 3000)
# ContextMessages
context_size: ContextSize = ("messages", 50)
```
"""
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
"""Tune parameters of approximate token counter based on model type."""
if model._llm_type == "anthropic-chat": # noqa: SLF001
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
return partial(
count_tokens_approximately, use_usage_metadata_scaling=True, chars_per_token=3.3
)
return partial(count_tokens_approximately, use_usage_metadata_scaling=True)
class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Summarizes conversation history when token limits are approached.
This middleware monitors message token counts and automatically summarizes older
messages when a threshold is reached, preserving recent messages and maintaining
context continuity by ensuring AI/Tool message pairs remain together.
"""
def __init__(
self,
model: str | BaseChatModel,
*,
trigger: ContextSize | list[ContextSize] | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
**deprecated_kwargs: Any,
) -> None:
"""Initialize summarization middleware.
Args:
model: The language model to use for generating summaries.
trigger: One or more thresholds that trigger summarization.
Provide a single
[`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
tuple or a list of tuples, in which case summarization runs when any
threshold is met.
!!! example
```python
# Trigger summarization when 50 messages is reached
("messages", 50)
# Trigger summarization when 3000 tokens is reached
("tokens", 3000)
# Trigger summarization either when 80% of model's max input tokens
# is reached or when 100 messages is reached (whichever comes first)
[("fraction", 0.8), ("messages", 100)]
```
See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
for more details.
keep: Context retention policy applied after summarization.
Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
tuple to specify how much history to preserve.
Defaults to keeping the most recent `20` messages.
Does not support multiple values like `trigger`.
!!! example
```python
# Keep the most recent 20 messages
("messages", 20)
# Keep the most recent 3000 tokens
("tokens", 3000)
# Keep the most recent 30% of the model's max input tokens
("fraction", 0.3)
```
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
the summarization call.
Pass `None` to skip trimming entirely.
"""
# Handle deprecated parameters
if "max_tokens_before_summary" in deprecated_kwargs:
value = deprecated_kwargs["max_tokens_before_summary"]
warnings.warn(
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if trigger is None and value is not None:
trigger = ("tokens", value)
if "messages_to_keep" in deprecated_kwargs:
value = deprecated_kwargs["messages_to_keep"]
warnings.warn(
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
DeprecationWarning,
stacklevel=2,
)
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
keep = ("messages", value)
super().__init__()
if isinstance(model, str):
model = init_chat_model(model)
self.model = model
if trigger is None:
self.trigger: ContextSize | list[ContextSize] | None = None
trigger_conditions: list[ContextSize] = []
elif isinstance(trigger, list):
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
self.trigger = validated_list
trigger_conditions = validated_list
else:
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep")
if token_counter is count_tokens_approximately:
self.token_counter = _get_approximate_token_counter(self.model)
self._partial_token_counter: TokenCounter = partial( # type: ignore[call-arg]
self.token_counter, use_usage_metadata_scaling=False
)
else:
self.token_counter = token_counter
self._partial_token_counter = token_counter
self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
msg = (
"Model profile information is required to use fractional token limits, "
"and is unavailable for the specified model. Please use absolute token "
"counts instead, or pass "
'`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
"with a desired integer value of the model's maximum input tokens."
)
raise ValueError(msg)
@override
def before_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization.
Args:
state: The agent state.
runtime: The runtime environment.
Returns:
An updated state with summarized messages if summarization was performed.
"""
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
@override
async def abefore_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization.
Args:
state: The agent state.
runtime: The runtime environment.
Returns:
An updated state with summarized messages if summarization was performed.
"""
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if not self._should_summarize(messages, total_tokens):
return None
cutoff_index = self._determine_cutoff_index(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _should_summarize_based_on_reported_tokens(
self, messages: list[AnyMessage], threshold: float
) -> bool:
"""Check if reported token usage from last AIMessage exceeds threshold."""
last_ai_message = next(
(msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
None,
)
if ( # noqa: SIM103
isinstance(last_ai_message, AIMessage)
and last_ai_message.usage_metadata is not None
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
and reported_tokens >= threshold
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
):
return True
return False
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
if not self._trigger_conditions:
return False
for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
messages, value
):
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
return True
if self._should_summarize_based_on_reported_tokens(messages, threshold):
return True
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
kind, value = self.keep
if kind in {"tokens", "fraction"}:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
return self._find_safe_cutoff(messages, cast("int", value))
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention."""
if not messages:
return 0
kind, value = self.keep
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return None
target_token_count = int(max_input_tokens * value)
elif kind == "tokens":
target_token_count = int(value)
else:
return None
if target_token_count <= 0:
target_token_count = 1
if self.token_counter(messages) <= target_token_count:
return 0
# Use binary search to identify the earliest message index that keeps the
# suffix within the token budget.
left, right = 0, len(messages)
cutoff_candidate = len(messages)
max_iterations = len(messages).bit_length() + 1
for _ in range(max_iterations):
if left >= right:
break
mid = (left + right) // 2
if self._partial_token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
left = mid + 1
if cutoff_candidate == len(messages):
cutoff_candidate = left
if cutoff_candidate >= len(messages):
if len(messages) == 1:
return 0
cutoff_candidate = len(messages) - 1
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
return self._find_safe_cutoff_point(messages, cutoff_candidate)
def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
try:
profile = self.model.profile
except AttributeError:
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
if not isinstance(max_input_tokens, int):
return None
return max_input_tokens
@staticmethod
def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
if value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context
@staticmethod
def _build_new_messages(summary: str) -> list[HumanMessage]:
return [
HumanMessage(
content=f"Here is a summary of the conversation to date:\n\n{summary}",
additional_kwargs={"lc_source": "summarization"},
)
]
@staticmethod
def _ensure_message_ids(messages: list[AnyMessage]) -> None:
"""Ensure all messages have unique IDs for the add_messages reducer."""
for msg in messages:
if msg.id is None:
msg.id = str(uuid.uuid4())
@staticmethod
def _partition_messages(
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition messages into those to summarize and those to preserve."""
messages_to_summarize = conversation_messages[:cutoff_index]
preserved_messages = conversation_messages[cutoff_index:]
return messages_to_summarize, preserved_messages
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns `0` if no safe cutoff is found.
This is aggressive with summarization - if the target cutoff lands in the
middle of tool messages, we advance past all of them (summarizing more).
"""
if len(messages) <= messages_to_keep:
return 0
target_cutoff = len(messages) - messages_to_keep
return self._find_safe_cutoff_point(messages, target_cutoff)
@staticmethod
def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
include it. This ensures tool call requests and responses stay together.
Falls back to advancing forward past `ToolMessage` objects only if no matching
`AIMessage` is found (edge case).
"""
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
return cutoff_index
# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
tool_call_ids: set[str] = set()
idx = cutoff_index
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
tool_msg = cast("ToolMessage", messages[idx])
if tool_msg.tool_call_id:
tool_call_ids.add(tool_msg.tool_call_id)
idx += 1
# Search backward for AIMessage with matching tool_calls
for i in range(cutoff_index - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
if tool_call_ids & ai_tool_call_ids:
# Found the AIMessage - move cutoff to include it
return i
# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
# orphaned tool responses
return idx
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages.
Args:
messages_to_summarize: Messages to summarize.
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
# Format messages to avoid token inflation from metadata when str() is called on
# message objects
formatted_messages = get_buffer_string(trimmed_messages)
try:
response = self.model.invoke(
self.summary_prompt.format(messages=formatted_messages).rstrip(),
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages.
Args:
messages_to_summarize: Messages to summarize.
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
# Format messages to avoid token inflation from metadata when str() is called on
# message objects
formatted_messages = get_buffer_string(trimmed_messages)
try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=formatted_messages).rstrip(),
config={"metadata": {"lc_source": "summarization"}},
)
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
if self.trim_tokens_to_summarize is None:
return messages
return cast(
"list[AnyMessage]",
trim_messages(
messages,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
),
)
except Exception:
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]

View File

@@ -0,0 +1,327 @@
"""Planning and task management middleware for agents."""
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langgraph.runtime import Runtime
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict, override
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
OmitFromInput,
ResponseT,
)
from langchain.tools import InjectedToolCallId
class Todo(TypedDict):
"""A single todo item with content and status."""
content: str
"""The content/description of the todo item."""
status: Literal["pending", "in_progress", "completed"]
"""The current status of the todo item."""
class PlanningState(AgentState[ResponseT]):
"""State schema for the todo middleware.
Type Parameters:
ResponseT: The type of the structured response. Defaults to `Any`.
"""
todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
"""List of todo items for tracking task progress."""
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
## When to Use This Tool
Use this tool in these scenarios:
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
3. User explicitly requests todo list - When the user directly asks you to use the todo list
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
5. The plan may need future revisions or updates based on results from the first few steps
## How to Use This Tool
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
## When NOT to Use This Tool
It is important to skip using this tool when:
1. There is only a single, straightforward task
2. The task is trivial and tracking it provides no benefit
3. The task can be completed in less than 3 trivial steps
4. The task is purely conversational or informational
## Task States and Management
1. **Task States**: Use these states to track progress:
- pending: Task not yet started
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
- completed: Task finished successfully
2. **Task Management**:
- Update task status in real-time as you work
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
- Complete current tasks before starting new ones
- Remove tasks that are no longer relevant from the list entirely
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
3. **Task Completion Requirements**:
- ONLY mark a task as completed when you have FULLY accomplished it
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
- When blocked, create a new task describing what needs to be resolved
- Never mark a task as completed if:
- There are unresolved issues or errors
- Work is partial or incomplete
- You encountered blockers that prevent completion
- You couldn't find necessary resources or dependencies
- Quality standards haven't been met
4. **Task Breakdown**:
- Create specific, actionable items
- Break complex tasks into smaller, manageable steps
- Use clear, descriptive task names
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all.""" # noqa: E501
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
You have access to the `write_todos` tool to help you manage and plan complex objectives.
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
## Important To-Do List Usage Notes to Remember
- The `write_todos` tool should never be called multiple times in parallel.
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant.""" # noqa: E501
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command[Any]:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
}
)
class TodoListMiddleware(AgentMiddleware[PlanningState[ResponseT], ContextT, ResponseT]):
"""Middleware that provides todo list management capabilities to agents.
This middleware adds a `write_todos` tool that allows agents to create and manage
structured task lists for complex multi-step operations. It's designed to help
agents track progress, organize complex tasks, and provide users with visibility
into task completion status.
The middleware automatically injects system prompts that guide the agent on when
and how to use the todo functionality effectively. It also enforces that the
`write_todos` tool is called at most once per model turn, since the tool replaces
the entire todo list and parallel calls would create ambiguity about precedence.
Example:
```python
from langchain.agents.middleware.todo import TodoListMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
# Agent now has access to write_todos tool and todo state tracking
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
print(result["todos"]) # Array of todo items with status tracking
```
"""
state_schema = PlanningState # type: ignore[assignment]
def __init__(
self,
*,
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
) -> None:
"""Initialize the `TodoListMiddleware` with optional custom prompts.
Args:
system_prompt: Custom system prompt to guide the agent on using the todo
tool.
tool_description: Custom description for the `write_todos` tool.
"""
super().__init__()
self.system_prompt = system_prompt
self.tool_description = tool_description
# Dynamically create the write_todos tool with the custom description
@tool(description=self.tool_description)
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command[Any]:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
],
}
)
self.tools = [write_todos]
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Update the system message to include the todo system prompt.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The model call result.
"""
if request.system_message is not None:
new_system_content = [
*request.system_message.content_blocks,
{"type": "text", "text": f"\n\n{self.system_prompt}"},
]
else:
new_system_content = [{"type": "text", "text": self.system_prompt}]
new_system_message = SystemMessage(
content=cast("list[str | dict[str, str]]", new_system_content)
)
return handler(request.override(system_message=new_system_message))
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Update the system message to include the todo system prompt.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The model call result.
"""
if request.system_message is not None:
new_system_content = [
*request.system_message.content_blocks,
{"type": "text", "text": f"\n\n{self.system_prompt}"},
]
else:
new_system_content = [{"type": "text", "text": self.system_prompt}]
new_system_message = SystemMessage(
content=cast("list[str | dict[str, str]]", new_system_content)
)
return await handler(request.override(system_message=new_system_message))
@override
def after_model(
self, state: PlanningState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Check for parallel write_todos tool calls and return errors if detected.
The todo list is designed to be updated at most once per model turn. Since
the `write_todos` tool replaces the entire todo list with each call, making
multiple parallel calls would create ambiguity about which update should take
precedence. This method prevents such conflicts by rejecting any response that
contains multiple write_todos tool calls.
Args:
state: The current agent state containing messages.
runtime: The LangGraph runtime instance.
Returns:
A dict containing error ToolMessages for each write_todos call if multiple
parallel calls are detected, otherwise None to allow normal execution.
"""
messages = state["messages"]
if not messages:
return None
last_ai_msg = next((msg for msg in reversed(messages) if isinstance(msg, AIMessage)), None)
if not last_ai_msg or not last_ai_msg.tool_calls:
return None
# Count write_todos tool calls
write_todos_calls = [tc for tc in last_ai_msg.tool_calls if tc["name"] == "write_todos"]
if len(write_todos_calls) > 1:
# Create error tool messages for all write_todos calls
error_messages = [
ToolMessage(
content=(
"Error: The `write_todos` tool should never be called multiple times "
"in parallel. Please call it only once per model invocation to update "
"the todo list."
),
tool_call_id=tc["id"],
status="error",
)
for tc in write_todos_calls
]
# Keep the tool calls in the AI message but return error messages
# This follows the same pattern as HumanInTheLoopMiddleware
return {"messages": error_messages}
return None
@override
async def aafter_model(
self, state: PlanningState[ResponseT], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Check for parallel write_todos tool calls and return errors if detected.
Async version of `after_model`. The todo list is designed to be updated at
most once per model turn. Since the `write_todos` tool replaces the entire
todo list with each call, making multiple parallel calls would create ambiguity
about which update should take precedence. This method prevents such conflicts
by rejecting any response that contains multiple write_todos tool calls.
Args:
state: The current agent state containing messages.
runtime: The LangGraph runtime instance.
Returns:
A dict containing error ToolMessages for each write_todos call if multiple
parallel calls are detected, otherwise None to allow normal execution.
"""
return self.after_model(state, runtime)

View File

@@ -0,0 +1,488 @@
"""Tool call limit middleware for agents."""
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.typing import ContextT
from typing_extensions import NotRequired, override
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
PrivateStateAttr,
ResponseT,
hook_config,
)
if TYPE_CHECKING:
from langgraph.runtime import Runtime
ExitBehavior = Literal["continue", "error", "end"]
"""How to handle execution when tool call limits are exceeded.
- `'continue'`: Block exceeded tools with error messages, let other tools continue
(default)
- `'error'`: Raise a `ToolCallLimitExceededError` exception
- `'end'`: Stop execution immediately, injecting a `ToolMessage` and an `AIMessage` for
the single tool call that exceeded the limit. Raises `NotImplementedError` if there
are other pending tool calls (due to parallel tool calling).
"""
class ToolCallLimitState(AgentState[ResponseT]):
"""State schema for `ToolCallLimitMiddleware`.
Extends `AgentState` with tool call tracking fields.
The count fields are dictionaries mapping tool names to execution counts. This
allows multiple middleware instances to track different tools independently. The
special key `'__all__'` is used for tracking all tool calls globally.
Type Parameters:
ResponseT: The type of the structured response. Defaults to `Any`.
"""
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
def _build_tool_message_content(tool_name: str | None) -> str:
"""Build the error message content for `ToolMessage` when limit is exceeded.
This message is sent to the model, so it should not reference thread/run concepts
that the model has no notion of.
Args:
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
Returns:
A concise message instructing the model not to call the tool again.
"""
# Always instruct the model not to call again, regardless of which limit was hit
if tool_name:
return f"Tool call limit exceeded. Do not call '{tool_name}' again."
return "Tool call limit exceeded. Do not make additional tool calls."
def _build_final_ai_message_content(
thread_count: int,
run_count: int,
thread_limit: int | None,
run_limit: int | None,
tool_name: str | None,
) -> str:
"""Build the final AI message content for `'end'` behavior.
This message is displayed to the user, so it should include detailed information
about which limits were exceeded.
Args:
thread_count: Current thread tool call count.
run_count: Current run tool call count.
thread_limit: Thread tool call limit (if set).
run_limit: Run tool call limit (if set).
tool_name: Tool name being limited (if specific tool), or `None` for all tools.
Returns:
A formatted message describing which limits were exceeded.
"""
tool_desc = f"'{tool_name}' tool" if tool_name else "Tool"
exceeded_limits = []
if thread_limit is not None and thread_count > thread_limit:
exceeded_limits.append(f"thread limit exceeded ({thread_count}/{thread_limit} calls)")
if run_limit is not None and run_count > run_limit:
exceeded_limits.append(f"run limit exceeded ({run_count}/{run_limit} calls)")
limits_text = " and ".join(exceeded_limits)
return f"{tool_desc} call limit reached: {limits_text}."
class ToolCallLimitExceededError(Exception):
"""Exception raised when tool call limits are exceeded.
This exception is raised when the configured exit behavior is `'error'` and either
the thread or run tool call limit has been exceeded.
"""
def __init__(
self,
thread_count: int,
run_count: int,
thread_limit: int | None,
run_limit: int | None,
tool_name: str | None = None,
) -> None:
"""Initialize the exception with call count information.
Args:
thread_count: Current thread tool call count.
run_count: Current run tool call count.
thread_limit: Thread tool call limit (if set).
run_limit: Run tool call limit (if set).
tool_name: Tool name being limited (if specific tool), or None for all tools.
"""
self.thread_count = thread_count
self.run_count = run_count
self.thread_limit = thread_limit
self.run_limit = run_limit
self.tool_name = tool_name
msg = _build_final_ai_message_content(
thread_count, run_count, thread_limit, run_limit, tool_name
)
super().__init__(msg)
class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState[ResponseT], ContextT, ResponseT]):
"""Track tool call counts and enforces limits during agent execution.
This middleware monitors the number of tool calls made and can terminate or
restrict execution when limits are exceeded. It supports both thread-level
(persistent across runs) and run-level (per invocation) call counting.
Configuration:
- `exit_behavior`: How to handle when limits are exceeded
- `'continue'`: Block exceeded tools, let execution continue (default)
- `'error'`: Raise an exception
- `'end'`: Stop immediately with a `ToolMessage` + AI message for the single
tool call that exceeded the limit (raises `NotImplementedError` if there
are other pending tool calls (due to parallel tool calling).
Examples:
!!! example "Continue execution with blocked tools (default)"
```python
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent
# Block exceeded tools but let other tools and model continue
limiter = ToolCallLimitMiddleware(
thread_limit=20,
run_limit=10,
exit_behavior="continue", # default
)
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
!!! example "Stop immediately when limit exceeded"
```python
# End execution immediately with an AI message
limiter = ToolCallLimitMiddleware(run_limit=5, exit_behavior="end")
agent = create_agent("openai:gpt-4o", middleware=[limiter])
```
!!! example "Raise exception on limit"
```python
# Strict limit with exception handling
limiter = ToolCallLimitMiddleware(
tool_name="search", thread_limit=5, exit_behavior="error"
)
agent = create_agent("openai:gpt-4o", middleware=[limiter])
try:
result = await agent.invoke({"messages": [HumanMessage("Task")]})
except ToolCallLimitExceededError as e:
print(f"Search limit exceeded: {e}")
```
"""
state_schema = ToolCallLimitState # type: ignore[assignment]
def __init__(
self,
*,
tool_name: str | None = None,
thread_limit: int | None = None,
run_limit: int | None = None,
exit_behavior: ExitBehavior = "continue",
) -> None:
"""Initialize the tool call limit middleware.
Args:
tool_name: Name of the specific tool to limit. If `None`, limits apply
to all tools.
thread_limit: Maximum number of tool calls allowed per thread.
`None` means no limit.
run_limit: Maximum number of tool calls allowed per run.
`None` means no limit.
exit_behavior: How to handle when limits are exceeded.
- `'continue'`: Block exceeded tools with error messages, let other
tools continue. Model decides when to end.
- `'error'`: Raise a `ToolCallLimitExceededError` exception
- `'end'`: Stop execution immediately with a `ToolMessage` + AI message
for the single tool call that exceeded the limit. Raises
`NotImplementedError` if there are multiple parallel tool
calls to other tools or multiple pending tool calls.
Raises:
ValueError: If both limits are `None`, if `exit_behavior` is invalid,
or if `run_limit` exceeds `thread_limit`.
"""
super().__init__()
if thread_limit is None and run_limit is None:
msg = "At least one limit must be specified (thread_limit or run_limit)"
raise ValueError(msg)
valid_behaviors = ("continue", "error", "end")
if exit_behavior not in valid_behaviors:
msg = f"Invalid exit_behavior: {exit_behavior!r}. Must be one of {valid_behaviors}"
raise ValueError(msg)
if thread_limit is not None and run_limit is not None and run_limit > thread_limit:
msg = (
f"run_limit ({run_limit}) cannot exceed thread_limit ({thread_limit}). "
"The run limit should be less than or equal to the thread limit."
)
raise ValueError(msg)
self.tool_name = tool_name
self.thread_limit = thread_limit
self.run_limit = run_limit
self.exit_behavior = exit_behavior
@property
def name(self) -> str:
"""The name of the middleware instance.
Includes the tool name if specified to allow multiple instances
of this middleware with different tool names.
"""
base_name = self.__class__.__name__
if self.tool_name:
return f"{base_name}[{self.tool_name}]"
return base_name
def _would_exceed_limit(self, thread_count: int, run_count: int) -> bool:
"""Check if incrementing the counts would exceed any configured limit.
Args:
thread_count: Current thread call count.
run_count: Current run call count.
Returns:
True if either limit would be exceeded by one more call.
"""
return (self.thread_limit is not None and thread_count + 1 > self.thread_limit) or (
self.run_limit is not None and run_count + 1 > self.run_limit
)
def _matches_tool_filter(self, tool_call: ToolCall) -> bool:
"""Check if a tool call matches this middleware's tool filter.
Args:
tool_call: The tool call to check.
Returns:
True if this middleware should track this tool call.
"""
return self.tool_name is None or tool_call["name"] == self.tool_name
def _separate_tool_calls(
self, tool_calls: list[ToolCall], thread_count: int, run_count: int
) -> tuple[list[ToolCall], list[ToolCall], int, int]:
"""Separate tool calls into allowed and blocked based on limits.
Args:
tool_calls: List of tool calls to evaluate.
thread_count: Current thread call count.
run_count: Current run call count.
Returns:
Tuple of `(allowed_calls, blocked_calls, final_thread_count,
final_run_count)`.
"""
allowed_calls: list[ToolCall] = []
blocked_calls: list[ToolCall] = []
temp_thread_count = thread_count
temp_run_count = run_count
for tool_call in tool_calls:
if not self._matches_tool_filter(tool_call):
continue
if self._would_exceed_limit(temp_thread_count, temp_run_count):
blocked_calls.append(tool_call)
else:
allowed_calls.append(tool_call)
temp_thread_count += 1
temp_run_count += 1
return allowed_calls, blocked_calls, temp_thread_count, temp_run_count
@hook_config(can_jump_to=["end"])
@override
def after_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
# Get the last AIMessage to check for tool calls
messages = state.get("messages", [])
if not messages:
return None
# Find the last AIMessage
last_ai_message = None
for message in reversed(messages):
if isinstance(message, AIMessage):
last_ai_message = message
break
if not last_ai_message or not last_ai_message.tool_calls:
return None
# Get the count key for this middleware instance
count_key = self.tool_name or "__all__"
# Get current counts
thread_counts = state.get("thread_tool_call_count", {}).copy()
run_counts = state.get("run_tool_call_count", {}).copy()
current_thread_count = thread_counts.get(count_key, 0)
current_run_count = run_counts.get(count_key, 0)
# Separate tool calls into allowed and blocked
allowed_calls, blocked_calls, new_thread_count, new_run_count = self._separate_tool_calls(
last_ai_message.tool_calls, current_thread_count, current_run_count
)
# Update counts to include only allowed calls for thread count
# (blocked calls don't count towards thread-level tracking)
# But run count includes blocked calls since they were attempted in this run
thread_counts[count_key] = new_thread_count
run_counts[count_key] = new_run_count + len(blocked_calls)
# If no tool calls are blocked, just update counts
if not blocked_calls:
if allowed_calls:
return {
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
}
return None
# Get final counts for building messages
final_thread_count = thread_counts[count_key]
final_run_count = run_counts[count_key]
# Handle different exit behaviors
if self.exit_behavior == "error":
# Use hypothetical thread count to show which limit was exceeded
hypothetical_thread_count = final_thread_count + len(blocked_calls)
raise ToolCallLimitExceededError(
thread_count=hypothetical_thread_count,
run_count=final_run_count,
thread_limit=self.thread_limit,
run_limit=self.run_limit,
tool_name=self.tool_name,
)
# Build tool message content (sent to model - no thread/run details)
tool_msg_content = _build_tool_message_content(self.tool_name)
# Inject artificial error ToolMessages for blocked tool calls
artificial_messages: list[ToolMessage | AIMessage] = [
ToolMessage(
content=tool_msg_content,
tool_call_id=tool_call["id"],
name=tool_call.get("name"),
status="error",
)
for tool_call in blocked_calls
]
if self.exit_behavior == "end":
# Check if there are tool calls to other tools that would continue executing
other_tools = [
tc
for tc in last_ai_message.tool_calls
if self.tool_name is not None and tc["name"] != self.tool_name
]
if other_tools:
tool_names = ", ".join({tc["name"] for tc in other_tools})
msg = (
f"Cannot end execution with other tool calls pending. "
f"Found calls to: {tool_names}. Use 'continue' or 'error' behavior instead."
)
raise NotImplementedError(msg)
# Build final AI message content (displayed to user - includes thread/run details)
# Use hypothetical thread count (what it would have been if call wasn't blocked)
# to show which limit was actually exceeded
hypothetical_thread_count = final_thread_count + len(blocked_calls)
final_msg_content = _build_final_ai_message_content(
hypothetical_thread_count,
final_run_count,
self.thread_limit,
self.run_limit,
self.tool_name,
)
artificial_messages.append(AIMessage(content=final_msg_content))
return {
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
"jump_to": "end",
"messages": artificial_messages,
}
# For exit_behavior="continue", return error messages to block exceeded tools
return {
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
"messages": artificial_messages,
}
@hook_config(can_jump_to=["end"])
async def aafter_model(
self,
state: ToolCallLimitState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment tool call counts after a model call and check limits.
Args:
state: The current agent state.
runtime: The langgraph runtime.
Returns:
State updates with incremented tool call counts. If limits are exceeded
and exit_behavior is `'end'`, also includes a jump to end with a
`ToolMessage` and AI message for the single exceeded tool call.
Raises:
ToolCallLimitExceededError: If limits are exceeded and `exit_behavior`
is `'error'`.
NotImplementedError: If limits are exceeded, `exit_behavior` is `'end'`,
and there are multiple tool calls.
"""
return self.after_model(state, runtime)

View File

@@ -0,0 +1,209 @@
"""Tool emulator middleware for testing."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, ToolMessage
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT
from langchain.chat_models.base import init_chat_model
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
from langchain.tools import BaseTool
class LLMToolEmulator(AgentMiddleware[AgentState[Any], ContextT], Generic[ContextT]):
"""Emulates specified tools using an LLM instead of executing them.
This middleware allows selective emulation of tools for testing purposes.
By default (when `tools=None`), all tools are emulated. You can specify which
tools to emulate by passing a list of tool names or `BaseTool` instances.
Examples:
!!! example "Emulate all tools (default behavior)"
```python
from langchain.agents.middleware import LLMToolEmulator
middleware = LLMToolEmulator()
agent = create_agent(
model="openai:gpt-4o",
tools=[get_weather, get_user_location, calculator],
middleware=[middleware],
)
```
!!! example "Emulate specific tools by name"
```python
middleware = LLMToolEmulator(tools=["get_weather", "get_user_location"])
```
!!! example "Use a custom model for emulation"
```python
middleware = LLMToolEmulator(
tools=["get_weather"], model="anthropic:claude-sonnet-4-5-20250929"
)
```
!!! example "Emulate specific tools by passing tool instances"
```python
middleware = LLMToolEmulator(tools=[get_weather, get_user_location])
```
"""
def __init__(
self,
*,
tools: list[str | BaseTool] | None = None,
model: str | BaseChatModel | None = None,
) -> None:
"""Initialize the tool emulator.
Args:
tools: List of tool names (`str`) or `BaseTool` instances to emulate.
If `None`, ALL tools will be emulated.
If empty list, no tools will be emulated.
model: Model to use for emulation.
Defaults to `'anthropic:claude-sonnet-4-5-20250929'`.
Can be a model identifier string or `BaseChatModel` instance.
"""
super().__init__()
# Extract tool names from tools
# None means emulate all tools
self.emulate_all = tools is None
self.tools_to_emulate: set[str] = set()
if not self.emulate_all and tools is not None:
for tool in tools:
if isinstance(tool, str):
self.tools_to_emulate.add(tool)
else:
# Assume BaseTool with .name attribute
self.tools_to_emulate.add(tool.name)
# Initialize emulator model
if model is None:
self.model = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=1)
elif isinstance(model, BaseChatModel):
self.model = model
else:
self.model = init_chat_model(model, temperature=1)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
"""Emulate tool execution using LLM if tool should be emulated.
Args:
request: Tool call request to potentially emulate.
handler: Callback to execute the tool (can be called multiple times).
Returns:
ToolMessage with emulated response if tool should be emulated,
otherwise calls handler for normal execution.
"""
tool_name = request.tool_call["name"]
# Check if this tool should be emulated
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
if not should_emulate:
# Let it execute normally by calling the handler
return handler(request)
# Extract tool information for emulation
tool_args = request.tool_call["args"]
tool_description = request.tool.description if request.tool else "No description available"
# Build prompt for emulator LLM
prompt = (
f"You are emulating a tool call for testing purposes.\n\n"
f"Tool: {tool_name}\n"
f"Description: {tool_description}\n"
f"Arguments: {tool_args}\n\n"
f"Generate a realistic response that this tool would return "
f"given these arguments.\n"
f"Return ONLY the tool's output, no explanation or preamble. "
f"Introduce variation into your responses."
)
# Get emulated response from LLM
response = self.model.invoke([HumanMessage(prompt)])
# Short-circuit: return emulated result without executing real tool
return ToolMessage(
content=response.content,
tool_call_id=request.tool_call["id"],
name=tool_name,
)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
"""Async version of `wrap_tool_call`.
Emulate tool execution using LLM if tool should be emulated.
Args:
request: Tool call request to potentially emulate.
handler: Async callback to execute the tool (can be called multiple times).
Returns:
ToolMessage with emulated response if tool should be emulated,
otherwise calls handler for normal execution.
"""
tool_name = request.tool_call["name"]
# Check if this tool should be emulated
should_emulate = self.emulate_all or tool_name in self.tools_to_emulate
if not should_emulate:
# Let it execute normally by calling the handler
return await handler(request)
# Extract tool information for emulation
tool_args = request.tool_call["args"]
tool_description = request.tool.description if request.tool else "No description available"
# Build prompt for emulator LLM
prompt = (
f"You are emulating a tool call for testing purposes.\n\n"
f"Tool: {tool_name}\n"
f"Description: {tool_description}\n"
f"Arguments: {tool_args}\n\n"
f"Generate a realistic response that this tool would return "
f"given these arguments.\n"
f"Return ONLY the tool's output, no explanation or preamble. "
f"Introduce variation into your responses."
)
# Get emulated response from LLM (using async invoke)
response = await self.model.ainvoke([HumanMessage(prompt)])
# Short-circuit: return emulated result without executing real tool
return ToolMessage(
content=response.content,
tool_call_id=request.tool_call["id"],
name=tool_name,
)

View File

@@ -0,0 +1,403 @@
"""Tool retry middleware for agents."""
from __future__ import annotations
import asyncio
import time
import warnings
from typing import TYPE_CHECKING, Any
from langchain_core.messages import ToolMessage
from langchain.agents.middleware._retry import (
OnFailure,
RetryOn,
calculate_delay,
should_retry_exception,
validate_retry_params,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langgraph.types import Command
from langchain.agents.middleware.types import ToolCallRequest
from langchain.tools import BaseTool
class ToolRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Middleware that automatically retries failed tool calls with configurable backoff.
Supports retrying on specific exceptions and exponential backoff.
Examples:
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ToolRetryMiddleware
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
```
!!! example "Retry specific exceptions only"
```python
from requests.exceptions import RequestException, Timeout
retry = ToolRetryMiddleware(
max_retries=4,
retry_on=(RequestException, Timeout),
backoff_factor=1.5,
)
```
!!! example "Custom exception filtering"
```python
from requests.exceptions import HTTPError
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, HTTPError):
return 500 <= exc.status_code < 600
return False
retry = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
!!! example "Apply to specific tools with custom error handling"
```python
def format_error(exc: Exception) -> str:
return "Database temporarily unavailable. Please try again later."
retry = ToolRetryMiddleware(
max_retries=4,
tools=["search_database"],
on_failure=format_error,
)
```
!!! example "Apply to specific tools using `BaseTool` instances"
```python
from langchain_core.tools import tool
@tool
def search_database(query: str) -> str:
'''Search the database.'''
return results
retry = ToolRetryMiddleware(
max_retries=4,
tools=[search_database], # Pass BaseTool instance
)
```
!!! example "Constant backoff (no exponential growth)"
```python
retry = ToolRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
!!! example "Raise exception on failure"
```python
retry = ToolRetryMiddleware(
max_retries=2,
on_failure="error", # Re-raise exception instead of returning message
)
```
"""
def __init__(
self,
*,
max_retries: int = 2,
tools: list[BaseTool | str] | None = None,
retry_on: RetryOn = (Exception,),
on_failure: OnFailure = "continue",
backoff_factor: float = 2.0,
initial_delay: float = 1.0,
max_delay: float = 60.0,
jitter: bool = True,
) -> None:
"""Initialize `ToolRetryMiddleware`.
Args:
max_retries: Maximum number of retry attempts after the initial call.
Must be `>= 0`.
tools: Optional list of tools or tool names to apply retry logic to.
Can be a list of `BaseTool` instances or tool name strings.
If `None`, applies to all tools.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Default is to retry on all exceptions.
on_failure: Behavior when all retries are exhausted.
Options:
- `'continue'`: Return a `ToolMessage` with error details,
allowing the LLM to handle the failure and potentially recover.
- `'error'`: Re-raise the exception, stopping agent execution.
- **Custom callable:** Function that takes the exception and returns a
string for the `ToolMessage` content, allowing custom error
formatting.
**Deprecated values** (for backwards compatibility):
- `'return_message'`: Use `'continue'` instead.
- `'raise'`: Use `'error'` instead.
backoff_factor: Multiplier for exponential backoff.
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
seconds.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
Raises:
ValueError: If `max_retries < 0` or delays are negative.
"""
super().__init__()
# Validate parameters
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
# Handle backwards compatibility for deprecated on_failure values
if on_failure == "raise": # type: ignore[comparison-overlap]
msg = ( # type: ignore[unreachable]
"on_failure='raise' is deprecated and will be removed in a future version. "
"Use on_failure='error' instead."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
on_failure = "error"
elif on_failure == "return_message": # type: ignore[comparison-overlap]
msg = ( # type: ignore[unreachable]
"on_failure='return_message' is deprecated and will be removed "
"in a future version. Use on_failure='continue' instead."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
on_failure = "continue"
self.max_retries = max_retries
# Extract tool names from BaseTool instances or strings
self._tool_filter: list[str] | None
if tools is not None:
self._tool_filter = [tool.name if not isinstance(tool, str) else tool for tool in tools]
else:
self._tool_filter = None
self.tools = [] # No additional tools registered by this middleware
self.retry_on = retry_on
self.on_failure = on_failure
self.backoff_factor = backoff_factor
self.initial_delay = initial_delay
self.max_delay = max_delay
self.jitter = jitter
def _should_retry_tool(self, tool_name: str) -> bool:
"""Check if retry logic should apply to this tool.
Args:
tool_name: Name of the tool being called.
Returns:
`True` if retry logic should apply, `False` otherwise.
"""
if self._tool_filter is None:
return True
return tool_name in self._tool_filter
@staticmethod
def _format_failure_message(tool_name: str, exc: Exception, attempts_made: int) -> str:
"""Format the failure message when retries are exhausted.
Args:
tool_name: Name of the tool that failed.
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
Formatted error message string.
"""
exc_type = type(exc).__name__
exc_msg = str(exc)
attempt_word = "attempt" if attempts_made == 1 else "attempts"
return (
f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} "
f"with {exc_type}: {exc_msg}. Please try again."
)
def _handle_failure(
self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
) -> ToolMessage:
"""Handle failure when all retries are exhausted.
Args:
tool_name: Name of the tool that failed.
tool_call_id: ID of the tool call (may be `None`).
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
`ToolMessage` with error details.
Raises:
Exception: If `on_failure` is `'error'`, re-raises the exception.
"""
if self.on_failure == "error":
raise exc
if callable(self.on_failure):
content = self.on_failure(exc)
else:
content = self._format_failure_message(tool_name, exc, attempts_made)
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
"""Intercept tool execution and retry on failure.
Args:
request: Tool call request with call dict, `BaseTool`, state, and runtime.
handler: Callable to execute the tool (can be called multiple times).
Returns:
`ToolMessage` or `Command` (the final result).
Raises:
RuntimeError: If the retry loop completes without returning. This should not happen.
"""
tool_name = request.tool.name if request.tool else request.tool_call["name"]
# Check if retry should apply to this tool
if not self._should_retry_tool(tool_name):
return handler(request)
tool_call_id = request.tool_call["id"]
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
time.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
"""Intercept and control async tool execution with retry logic.
Args:
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
handler: Async callable to execute the tool and returns `ToolMessage` or
`Command`.
Returns:
`ToolMessage` or `Command` (the final result).
Raises:
RuntimeError: If the retry loop completes without returning. This should not happen.
"""
tool_name = request.tool.name if request.tool else request.tool_call["name"]
# Check if retry should apply to this tool
if not self._should_retry_tool(tool_name):
return await handler(request)
tool_call_id = request.tool_call["id"]
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
await asyncio.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)

View File

@@ -0,0 +1,358 @@
"""LLM-based tool selector middleware."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain.chat_models.base import init_chat_model
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain.tools import BaseTool
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_PROMPT = (
"Your goal is to select the most relevant tools for answering the user's query."
)
@dataclass
class _SelectionRequest:
"""Prepared inputs for tool selection."""
available_tools: list[BaseTool]
system_message: str
last_user_message: HumanMessage
model: BaseChatModel
valid_tool_names: list[str]
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
"""Create a structured output schema for tool selection.
Args:
tools: Available tools to include in the schema.
Returns:
`TypeAdapter` for a schema where each tool name is a `Literal` with its
description.
Raises:
AssertionError: If `tools` is empty.
"""
if not tools:
msg = "Invalid usage: tools must be non-empty"
raise AssertionError(msg)
# Create a Union of Annotated Literal types for each tool name with description
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
literals = [
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
]
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
description = "Tools to use. Place the most relevant tools first."
class ToolSelectionResponse(TypedDict):
"""Use to select relevant tools."""
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
return TypeAdapter(ToolSelectionResponse)
def _render_tool_list(tools: list[BaseTool]) -> str:
"""Format tools as markdown list.
Args:
tools: Tools to format.
Returns:
Markdown string with each tool on a new line.
"""
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
class LLMToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Uses an LLM to select relevant tools before calling the main model.
When an agent has many tools available, this middleware filters them down
to only the most relevant ones for the user's query. This reduces token usage
and helps the main model focus on the right tools.
Examples:
!!! example "Limit to 3 tools"
```python
from langchain.agents.middleware import LLMToolSelectorMiddleware
middleware = LLMToolSelectorMiddleware(max_tools=3)
agent = create_agent(
model="openai:gpt-4o",
tools=[tool1, tool2, tool3, tool4, tool5],
middleware=[middleware],
)
```
!!! example "Use a smaller model for selection"
```python
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
```
"""
def __init__(
self,
*,
model: str | BaseChatModel | None = None,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_tools: int | None = None,
always_include: list[str] | None = None,
) -> None:
"""Initialize the tool selector.
Args:
model: Model to use for selection.
If not provided, uses the agent's main model.
Can be a model identifier string or `BaseChatModel` instance.
system_prompt: Instructions for the selection model.
max_tools: Maximum number of tools to select.
If the model selects more, only the first `max_tools` will be used.
If not specified, there is no limit.
always_include: Tool names to always include regardless of selection.
These do not count against the `max_tools` limit.
"""
super().__init__()
self.system_prompt = system_prompt
self.max_tools = max_tools
self.always_include = always_include or []
if isinstance(model, (BaseChatModel, type(None))):
self.model: BaseChatModel | None = model
else:
self.model = init_chat_model(model)
def _prepare_selection_request(
self, request: ModelRequest[ContextT]
) -> _SelectionRequest | None:
"""Prepare inputs for tool selection.
Args:
request: the model request.
Returns:
`SelectionRequest` with prepared inputs, or `None` if no selection is
needed.
Raises:
ValueError: If tools in `always_include` are not found in the request.
AssertionError: If no user message is found in the request messages.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
return None
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
# Validate that always_include tools exist
if self.always_include:
available_tool_names = {tool.name for tool in base_tools}
missing_tools = [
name for name in self.always_include if name not in available_tool_names
]
if missing_tools:
msg = (
f"Tools in always_include not found in request: {missing_tools}. "
f"Available tools: {sorted(available_tool_names)}"
)
raise ValueError(msg)
# Separate tools that are always included from those available for selection
available_tools = [tool for tool in base_tools if tool.name not in self.always_include]
# If no tools available for selection, return None
if not available_tools:
return None
system_message = self.system_prompt
# If there's a max_tools limit, append instructions to the system prompt
if self.max_tools is not None:
system_message += (
f"\nIMPORTANT: List the tool names in order of relevance, "
f"with the most relevant first. "
f"If you exceed the maximum number of tools, "
f"only the first {self.max_tools} will be used."
)
# Get the last user message from the conversation history
last_user_message: HumanMessage
for message in reversed(request.messages):
if isinstance(message, HumanMessage):
last_user_message = message
break
else:
msg = "No user message found in request messages"
raise AssertionError(msg)
model = self.model or request.model
valid_tool_names = [tool.name for tool in available_tools]
return _SelectionRequest(
available_tools=available_tools,
system_message=system_message,
last_user_message=last_user_message,
model=model,
valid_tool_names=valid_tool_names,
)
def _process_selection_response(
self,
response: dict[str, Any],
available_tools: list[BaseTool],
valid_tool_names: list[str],
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT]:
"""Process the selection response and return filtered `ModelRequest`."""
selected_tool_names: list[str] = []
invalid_tool_selections = []
for tool_name in response["tools"]:
if tool_name not in valid_tool_names:
invalid_tool_selections.append(tool_name)
continue
# Only add if not already selected and within max_tools limit
if tool_name not in selected_tool_names and (
self.max_tools is None or len(selected_tool_names) < self.max_tools
):
selected_tool_names.append(tool_name)
if invalid_tool_selections:
msg = f"Model selected invalid tools: {invalid_tool_selections}"
raise ValueError(msg)
# Filter tools based on selection and append always-included tools
selected_tools: list[BaseTool] = [
tool for tool in available_tools if tool.name in selected_tool_names
]
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
if not isinstance(tool, dict) and tool.name in self.always_include
]
selected_tools.extend(always_included_tools)
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
return request.override(tools=[*selected_tools, *provider_tools])
def wrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Filter tools based on LLM selection before invoking the model via handler.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The model call result.
Raises:
AssertionError: If the selection model response is not a dict.
"""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return handler(request)
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
schema = type_adapter.json_schema()
structured_model = selection_request.model.with_structured_output(schema)
response = structured_model.invoke(
[
{"role": "system", "content": selection_request.system_message},
selection_request.last_user_message,
]
)
# Response should be a dict since we're passing a schema (not a Pydantic model class)
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg) # noqa: TRY004
modified_request = self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)
return handler(modified_request)
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
"""Filter tools based on LLM selection before invoking the model via handler.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request and returns
`ModelResponse`.
Returns:
The model call result.
Raises:
AssertionError: If the selection model response is not a dict.
"""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return await handler(request)
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
schema = type_adapter.json_schema()
structured_model = selection_request.model.with_structured_output(schema)
response = await structured_model.ainvoke(
[
{"role": "system", "content": selection_request.system_message},
selection_request.last_user_message,
]
)
# Response should be a dict since we're passing a schema (not a Pydantic model class)
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg) # noqa: TRY004
modified_request = self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)
return await handler(modified_request)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,462 @@
"""Types for setting agent response formats."""
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass, is_dataclass
from types import UnionType
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
TypeVar,
Union,
get_args,
get_origin,
)
from langchain_core.tools import BaseTool, StructuredTool
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self, is_typeddict
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from langchain_core.messages import AIMessage
# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
SchemaT = TypeVar("SchemaT")
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
class StructuredOutputError(Exception):
"""Base class for structured output errors."""
ai_message: AIMessage
class MultipleStructuredOutputsError(StructuredOutputError):
"""Raised when model returns multiple structured output tool calls when only one is expected."""
def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
"""Initialize `MultipleStructuredOutputsError`.
Args:
tool_names: The names of the tools called for structured output.
ai_message: The AI message that contained the invalid multiple tool calls.
"""
self.tool_names = tool_names
self.ai_message = ai_message
super().__init__(
"Model incorrectly returned multiple structured responses "
f"({', '.join(tool_names)}) when only one is expected."
)
class StructuredOutputValidationError(StructuredOutputError):
"""Raised when structured output tool call arguments fail to parse according to the schema."""
def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
"""Initialize `StructuredOutputValidationError`.
Args:
tool_name: The name of the tool that failed.
source: The exception that occurred.
ai_message: The AI message that contained the invalid structured output.
"""
self.tool_name = tool_name
self.source = source
self.ai_message = ai_message
super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
def _parse_with_schema(
schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any]
) -> Any:
"""Parse data using for any supported schema type.
Args:
schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`)
schema_kind: One of `'pydantic'`, `'dataclass'`, `'typeddict'`, or
`'json_schema'`
data: The data to parse
Returns:
The parsed instance according to the schema type
Raises:
ValueError: If parsing fails
"""
if schema_kind == "json_schema":
return data
try:
adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
return adapter.validate_python(data)
except Exception as e:
schema_name = getattr(schema, "__name__", str(schema))
msg = f"Failed to parse data to {schema_name}: {e}"
raise ValueError(msg) from e
@dataclass(init=False)
class _SchemaSpec(Generic[SchemaT]):
"""Describes a structured output schema."""
schema: type[SchemaT] | dict[str, Any]
"""The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
or JSON schema dict.
"""
name: str
"""Name of the schema, used for tool calling.
If not provided, the name will be the class name for models/dataclasses/TypedDicts,
or the `title` field for JSON schemas.
Falls back to a generated name if unavailable.
"""
description: str
"""Custom description of the schema.
If not provided, will use the model's docstring.
"""
schema_kind: SchemaKind
"""The kind of schema."""
json_schema: dict[str, Any]
"""JSON schema associated with the schema."""
strict: bool | None = None
"""Whether to enforce strict validation of the schema."""
def __init__(
self,
schema: type[SchemaT] | dict[str, Any],
*,
name: str | None = None,
description: str | None = None,
strict: bool | None = None,
) -> None:
"""Initialize `SchemaSpec` with schema and optional parameters.
Args:
schema: Schema to describe.
name: Optional name for the schema.
description: Optional description for the schema.
strict: Whether to enforce strict validation of the schema.
Raises:
ValueError: If the schema type is unsupported.
"""
self.schema = schema
if name:
self.name = name
elif isinstance(schema, dict):
self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
else:
self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
self.description = description or (
schema.get("description", "")
if isinstance(schema, dict)
else getattr(schema, "__doc__", None) or ""
)
self.strict = strict
if isinstance(schema, dict):
self.schema_kind = "json_schema"
self.json_schema = schema
elif isinstance(schema, type) and issubclass(schema, BaseModel):
self.schema_kind = "pydantic"
self.json_schema = schema.model_json_schema()
elif is_dataclass(schema):
self.schema_kind = "dataclass"
self.json_schema = TypeAdapter(schema).json_schema()
elif is_typeddict(schema):
self.schema_kind = "typeddict"
self.json_schema = TypeAdapter(schema).json_schema()
else:
msg = (
f"Unsupported schema type: {type(schema)}. "
f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
)
raise ValueError(msg)
@dataclass(init=False)
class ToolStrategy(Generic[SchemaT]):
"""Use a tool calling strategy for model responses."""
schema: type[SchemaT] | UnionType | dict[str, Any]
"""Schema for the tool calls."""
schema_specs: list[_SchemaSpec[Any]]
"""Schema specs for the tool calls."""
tool_message_content: str | None
"""The content of the tool message to be returned when the model calls
an artificial structured output tool.
"""
handle_errors: (
bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
)
"""Error handling strategy for structured output via `ToolStrategy`.
- `True`: Catch all errors with default error template
- `str`: Catch all errors with this custom message
- `type[Exception]`: Only catch this exception type with default message
- `tuple[type[Exception], ...]`: Only catch these exception types with default
message
- `Callable[[Exception], str]`: Custom function that returns error message
- `False`: No retry, let exceptions propagate
"""
def __init__(
self,
schema: type[SchemaT] | UnionType | dict[str, Any],
*,
tool_message_content: str | None = None,
handle_errors: bool
| str
| type[Exception]
| tuple[type[Exception], ...]
| Callable[[Exception], str] = True,
) -> None:
"""Initialize `ToolStrategy`.
Initialize `ToolStrategy` with schemas, tool message content, and error handling
strategy.
"""
self.schema = schema
self.tool_message_content = tool_message_content
self.handle_errors = handle_errors
def _iter_variants(schema: Any) -> Iterable[Any]:
"""Yield leaf variants from Union and JSON Schema oneOf."""
if get_origin(schema) in {UnionType, Union}:
for arg in get_args(schema):
yield from _iter_variants(arg)
return
if isinstance(schema, dict) and "oneOf" in schema:
for sub in schema.get("oneOf", []):
yield from _iter_variants(sub)
return
yield schema
self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
@dataclass(init=False)
class ProviderStrategy(Generic[SchemaT]):
"""Use the model provider's native structured output method."""
schema: type[SchemaT] | dict[str, Any]
"""Schema for native mode."""
schema_spec: _SchemaSpec[SchemaT]
"""Schema spec for native mode."""
def __init__(
self,
schema: type[SchemaT] | dict[str, Any],
*,
strict: bool | None = None,
) -> None:
"""Initialize `ProviderStrategy` with schema.
Args:
schema: Schema to enforce via the provider's native structured output.
strict: Whether to request strict provider-side schema enforcement.
"""
self.schema = schema
self.schema_spec = _SchemaSpec(schema, strict=strict)
def to_model_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs to bind to a model to force structured output.
Returns:
The kwargs to bind to a model.
"""
# OpenAI:
# - see https://platform.openai.com/docs/guides/structured-outputs
json_schema: dict[str, Any] = {
"name": self.schema_spec.name,
"schema": self.schema_spec.json_schema,
}
if self.schema_spec.strict:
json_schema["strict"] = True
response_format: dict[str, Any] = {
"type": "json_schema",
"json_schema": json_schema,
}
return {"response_format": response_format}
@dataclass
class OutputToolBinding(Generic[SchemaT]):
"""Information for tracking structured output tool metadata.
This contains all necessary information to handle structured responses generated via
tool calls, including the original schema, its type classification, and the
corresponding tool implementation used by the tools strategy.
"""
schema: type[SchemaT] | dict[str, Any]
"""The original schema provided for structured output (Pydantic model, dataclass,
TypedDict, or JSON schema dict).
"""
schema_kind: SchemaKind
"""Classification of the schema type for proper response construction."""
tool: BaseTool
"""LangChain tool instance created from the schema for model binding."""
@classmethod
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
"""Create an `OutputToolBinding` instance from a `SchemaSpec`.
Args:
schema_spec: The `SchemaSpec` to convert
Returns:
An `OutputToolBinding` instance with the appropriate tool created
"""
return cls(
schema=schema_spec.schema,
schema_kind=schema_spec.schema_kind,
tool=StructuredTool(
args_schema=schema_spec.json_schema,
name=schema_spec.name,
description=schema_spec.description,
),
)
def parse(self, tool_args: dict[str, Any]) -> SchemaT:
"""Parse tool arguments according to the schema.
Args:
tool_args: The arguments from the tool call
Returns:
The parsed response according to the schema type
Raises:
ValueError: If parsing fails
"""
return _parse_with_schema(self.schema, self.schema_kind, tool_args)
@dataclass
class ProviderStrategyBinding(Generic[SchemaT]):
"""Information for tracking native structured output metadata.
This contains all necessary information to handle structured responses generated via
native provider output, including the original schema, its type classification, and
parsing logic for provider-enforced JSON.
"""
schema: type[SchemaT] | dict[str, Any]
"""The original schema provided for structured output (Pydantic model, `dataclass`,
`TypedDict`, or JSON schema dict).
"""
schema_kind: SchemaKind
"""Classification of the schema type for proper response construction."""
@classmethod
def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
"""Create a `ProviderStrategyBinding` instance from a `SchemaSpec`.
Args:
schema_spec: The `SchemaSpec` to convert
Returns:
A `ProviderStrategyBinding` instance for parsing native structured output
"""
return cls(
schema=schema_spec.schema,
schema_kind=schema_spec.schema_kind,
)
def parse(self, response: AIMessage) -> SchemaT:
"""Parse `AIMessage` content according to the schema.
Args:
response: The `AIMessage` containing the structured output
Returns:
The parsed response according to the schema
Raises:
ValueError: If text extraction, JSON parsing or schema validation fails
"""
# Extract text content from AIMessage and parse as JSON
raw_text = self._extract_text_content_from_message(response)
try:
data = json.loads(raw_text)
except Exception as e:
schema_name = getattr(self.schema, "__name__", "response_format")
msg = (
f"Native structured output expected valid JSON for {schema_name}, "
f"but parsing failed: {e}."
)
raise ValueError(msg) from e
# Parse according to schema
return _parse_with_schema(self.schema, self.schema_kind, data)
@staticmethod
def _extract_text_content_from_message(message: AIMessage) -> str:
"""Extract text content from an `AIMessage`.
Args:
message: The AI message to extract text from
Returns:
The extracted text content
"""
content = message.content
if isinstance(content, str):
return content
parts: list[str] = []
for c in content:
if isinstance(c, dict):
if c.get("type") == "text" and "text" in c:
parts.append(str(c["text"]))
elif "content" in c and isinstance(c["content"], str):
parts.append(c["content"])
else:
parts.append(str(c))
return "".join(parts)
class AutoStrategy(Generic[SchemaT]):
"""Automatically select the best strategy for structured output."""
schema: type[SchemaT] | dict[str, Any]
"""Schema for automatic mode."""
def __init__(
self,
schema: type[SchemaT] | dict[str, Any],
) -> None:
"""Initialize `AutoStrategy` with schema."""
self.schema = schema
ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]
"""Union type for all supported response format strategies."""

View File

@@ -0,0 +1,7 @@
"""Entrypoint to using [chat models](https://docs.langchain.com/oss/python/langchain/models) in LangChain.""" # noqa: E501
from langchain_core.language_models import BaseChatModel
from langchain.chat_models.base import init_chat_model
__all__ = ["BaseChatModel", "init_chat_model"]

View File

@@ -0,0 +1,994 @@
"""Factory functions for chat models."""
from __future__ import annotations
import functools
import importlib
import warnings
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
cast,
overload,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import override
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from types import ModuleType
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from pydantic import BaseModel
def _call(cls: type[BaseChatModel], **kwargs: Any) -> BaseChatModel:
# TODO: replace with operator.call when lower bounding to Python 3.11
return cls(**kwargs)
_BUILTIN_PROVIDERS: dict[str, tuple[str, str, Callable[..., BaseChatModel]]] = {
"anthropic": ("langchain_anthropic", "ChatAnthropic", _call),
"azure_ai": ("langchain_azure_ai.chat_models", "AzureAIChatCompletionsModel", _call),
"azure_openai": ("langchain_openai", "AzureChatOpenAI", _call),
"bedrock": ("langchain_aws", "ChatBedrock", _call),
"bedrock_converse": ("langchain_aws", "ChatBedrockConverse", _call),
"cohere": ("langchain_cohere", "ChatCohere", _call),
"deepseek": ("langchain_deepseek", "ChatDeepSeek", _call),
"fireworks": ("langchain_fireworks", "ChatFireworks", _call),
"google_anthropic_vertex": (
"langchain_google_vertexai.model_garden",
"ChatAnthropicVertex",
_call,
),
"google_genai": ("langchain_google_genai", "ChatGoogleGenerativeAI", _call),
"google_vertexai": ("langchain_google_vertexai", "ChatVertexAI", _call),
"groq": ("langchain_groq", "ChatGroq", _call),
"huggingface": (
"langchain_huggingface",
"ChatHuggingFace",
lambda cls, model, **kwargs: cls.from_model_id(model_id=model, **kwargs),
),
"ibm": (
"langchain_ibm",
"ChatWatsonx",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "ChatMistralAI", _call),
"nvidia": ("langchain_nvidia_ai_endpoints", "ChatNVIDIA", _call),
"ollama": ("langchain_ollama", "ChatOllama", _call),
"openai": ("langchain_openai", "ChatOpenAI", _call),
"perplexity": ("langchain_perplexity", "ChatPerplexity", _call),
"together": ("langchain_together", "ChatTogether", _call),
"upstage": ("langchain_upstage", "ChatUpstage", _call),
"xai": ("langchain_xai", "ChatXAI", _call),
}
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the chat model class.
This may be a submodule (e.g., `'langchain_azure_ai.chat_models'`) if the class is
not exported from the package root.
- `class_name`: The name of the chat model class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
!!! note
This dict is not exhaustive of all providers supported by LangChain, but is
meant to cover the most popular ones and serve as a template for adding more
providers in the future. If a provider is not in this dict, it can still be
used with `init_chat_model` as long as its integration package is installed,
but the provider key will not be inferred from the model name and must be
specified explicitly via the `model_provider` parameter.
Refer to the LangChain [integration documentation](https://docs.langchain.com/oss/python/integrations/providers/overview)
for a full list of supported providers and their corresponding packages.
"""
def _import_module(module: str, class_name: str) -> ModuleType:
"""Import a module by name.
Args:
module: The fully qualified module name to import (e.g., `'langchain_openai'`).
class_name: The name of the class being imported, used for error messages.
Returns:
The imported module.
Raises:
ImportError: If the module cannot be imported, with a message suggesting
the pip package to install.
"""
try:
return importlib.import_module(module)
except ImportError as e:
# Extract package name from module path (e.g., "langchain_azure_ai.chat_models"
# becomes "langchain-azure-ai")
pkg = module.split(".", maxsplit=1)[0].replace("_", "-")
msg = (
f"Initializing {class_name} requires the {pkg} package. Please install it "
f"with `pip install {pkg}`"
)
raise ImportError(msg) from e
@functools.lru_cache(maxsize=len(_BUILTIN_PROVIDERS))
def _get_chat_model_creator(
provider: str,
) -> Callable[..., BaseChatModel]:
"""Return a factory function that creates a chat model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'anthropic'`).
Must be a key in `_BUILTIN_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns a `BaseChatModel` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_BUILTIN_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _BUILTIN_PROVIDERS:
supported = ", ".join(_BUILTIN_PROVIDERS.keys())
msg = f"Unsupported {provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
pkg, class_name, creator_func = _BUILTIN_PROVIDERS[provider]
try:
module = _import_module(pkg, class_name)
except ImportError as e:
if provider != "ollama":
raise
# For backwards compatibility
try:
module = _import_module("langchain_community.chat_models", class_name)
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
raise e from None
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
@overload
def init_chat_model(
model: str,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel: ...
@overload
def init_chat_model(
model: None = None,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel | _ConfigurableModel:
"""Initialize a chat model from any supported provider using a unified interface.
**Two main use cases:**
1. **Fixed model** specify the model upfront and get a ready-to-use chat model.
2. **Configurable model** choose to specify parameters (including model name) at
runtime via `config`. Makes it easy to switch between models/providers without
changing your code
!!! note "Installation requirements"
Requires the integration package for the chosen model provider to be installed.
See the `model_provider` parameter below for specific package names
(e.g., `pip install langchain-openai`).
Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
for supported model parameters to use as `**kwargs`.
Args:
model: The model name, optionally prefixed with provider (e.g., `'openai:gpt-4o'`).
Prefer exact model IDs from provider docs over aliases for reliable behavior
(e.g., dated versions like `'...-20250514'` instead of `'...-latest'`).
Will attempt to infer `model_provider` from model if not specified.
The following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
- `solar...` -> `upstage`
model_provider: The model provider if not specified as part of the model arg
(see above).
Supported `model_provider` values and the corresponding integration package
are:
- `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `anthropic` -> [`langchain-anthropic`](https://docs.langchain.com/oss/python/integrations/providers/anthropic)
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `azure_ai` -> [`langchain-azure-ai`](https://docs.langchain.com/oss/python/integrations/providers/microsoft)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `bedrock_converse` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `fireworks` -> [`langchain-fireworks`](https://docs.langchain.com/oss/python/integrations/providers/fireworks)
- `together` -> [`langchain-together`](https://docs.langchain.com/oss/python/integrations/providers/together)
- `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `groq` -> [`langchain-groq`](https://docs.langchain.com/oss/python/integrations/providers/groq)
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
- `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
configurable_fields: Which model parameters are configurable at runtime:
- `None`: No configurable fields (i.e., a fixed model).
- `'any'`: All fields are configurable. **See security note below.**
- `list[str] | Tuple[str, ...]`: Specified fields are configurable.
Fields are assumed to have `config_prefix` stripped if a `config_prefix` is
specified.
If `model` is specified, then defaults to `None`.
If `model` is not specified, then defaults to `("model", "model_provider")`.
!!! warning "Security note"
Setting `configurable_fields="any"` means fields like `api_key`,
`base_url`, etc., can be altered at runtime, potentially redirecting
model requests to a different service/user.
Make sure that if you're accepting untrusted configurations that you
enumerate the `configurable_fields=(...)` explicitly.
config_prefix: Optional prefix for configuration keys.
Useful when you have multiple configurable models in the same application.
If `'config_prefix'` is a non-empty string then `model` will be configurable
at runtime via the `config["configurable"]["{config_prefix}_{param}"]` keys.
See examples below.
If `'config_prefix'` is an empty string then model will be configurable via
`config["configurable"]["{param}"]`.
**kwargs: Additional model-specific keyword args to pass to the underlying
chat model's `__init__` method. Common parameters include:
- `temperature`: Model temperature for controlling randomness.
- `max_tokens`: Maximum number of output tokens.
- `timeout`: Maximum time (in seconds) to wait for a response.
- `max_retries`: Maximum number of retry attempts for failed requests.
- `base_url`: Custom API endpoint URL.
- `rate_limiter`: A
[`BaseRateLimiter`][langchain_core.rate_limiters.BaseRateLimiter]
instance to control request rate.
Refer to the specific model provider's
[integration reference](https://reference.langchain.com/python/integrations/)
for all available parameters.
Returns:
A `BaseChatModel` corresponding to the `model_name` and `model_provider`
specified if configurability is inferred to be `False`. If configurable, a
chat model emulator that initializes the underlying model at runtime once a
config is passed in.
Raises:
ValueError: If `model_provider` cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
???+ example "Initialize a non-configurable model"
```python
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain.chat_models import init_chat_model
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
claude_sonnet = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=0)
gemini_2-5_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
o3_mini.invoke("what's your name")
claude_sonnet.invoke("what's your name")
gemini_2-5_flash.invoke("what's your name")
```
??? example "Partially configurable model with no default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
# (We don't need to specify configurable=True if a model isn't specified.)
configurable_model = init_chat_model(temperature=0)
configurable_model.invoke("what's your name", config={"configurable": {"model": "gpt-4o"}})
# Use GPT-4o to generate the response
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
```
??? example "Fully configurable model with a default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"openai:gpt-4o",
configurable_fields="any", # This allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0,
)
configurable_model_with_default.invoke("what's your name")
# GPT-4o response with temperature 0 (as set in default)
configurable_model_with_default.invoke(
"what's your name",
config={
"configurable": {
"foo_model": "anthropic:claude-sonnet-4-5-20250929",
"foo_temperature": 0.6,
}
},
)
# Override default to use Sonnet 4.5 with temperature 0.6 to generate response
```
??? example "Bind tools to a configurable model"
You can call any chat model declarative methods on a configurable model in the
same way that you would with a normal model:
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
configurable_model = init_chat_model(
"gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
)
configurable_model_with_tools = configurable_model.bind_tools(
[
GetWeather,
GetPopulation,
]
)
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
# Use GPT-4o
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
# Use Sonnet 4.5
```
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable.",
stacklevel=2,
)
if not configurable_fields:
return _init_chat_model_helper(
cast("str", model),
model_provider=model_provider,
**kwargs,
)
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModel(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
def _init_chat_model_helper(
model: str,
*,
model_provider: str | None = None,
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
creator_func = _get_chat_model_creator(model_provider)
return creator_func(model=model, **kwargs)
def _attempt_infer_model_provider(model_name: str) -> str | None:
"""Attempt to infer model provider from model name.
Args:
model_name: The name of the model to infer provider for.
Returns:
The inferred provider name, or `None` if no provider could be inferred.
"""
model_lower = model_name.lower()
# OpenAI models (including newer models and aliases)
if any(
model_lower.startswith(pre)
for pre in (
"gpt-",
"o1",
"o3",
"chatgpt",
"text-davinci",
)
):
return "openai"
# Anthropic models
if model_lower.startswith("claude"):
return "anthropic"
# Cohere models
if model_lower.startswith("command"):
return "cohere"
# Fireworks models
if model_lower.startswith("accounts/fireworks"):
return "fireworks"
# Google models
if model_lower.startswith("gemini"):
return "google_vertexai"
# AWS Bedrock models
if model_lower.startswith(("amazon.", "anthropic.", "meta.")):
return "bedrock"
# Mistral models
if model_lower.startswith(("mistral", "mixtral")):
return "mistralai"
# DeepSeek models
if model_lower.startswith("deepseek"):
return "deepseek"
# xAI models
if model_lower.startswith("grok"):
return "xai"
# Perplexity models
if model_lower.startswith("sonar"):
return "perplexity"
# Upstage models
if model_lower.startswith("solar"):
return "upstage"
return None
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
"""Parse model name and provider, inferring provider if necessary."""
# Handle provider:model format
if (
not model_provider
and ":" in model
and model.split(":", maxsplit=1)[0] in _BUILTIN_PROVIDERS
):
model_provider = model.split(":", maxsplit=1)[0]
model = ":".join(model.split(":")[1:])
# Attempt to infer provider if not specified
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
# Enhanced error message with suggestions
supported_list = ", ".join(sorted(_BUILTIN_PROVIDERS))
msg = (
f"Unable to infer model provider for {model=}. "
f"Please specify 'model_provider' directly.\n\n"
f"Supported providers: {supported_list}\n\n"
f"For help with specific providers, see: "
f"https://docs.langchain.com/oss/python/integrations/providers"
)
raise ValueError(msg)
# Normalize provider name
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _remove_prefix(s: str, prefix: str) -> str:
return s.removeprefix(prefix)
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: dict[str, Any] | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple[Any, ...], dict[str, Any]]] = (),
) -> None:
self._default_config: dict[str, Any] = default_config or {}
self._configurable_fields: Literal["any"] | list[str] = (
"any" if configurable_fields == "any" else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = (
list(
queued_declarative_operations,
)
)
def __getattr__(self, name: str) -> Any:
if name in _DECLARATIVE_METHODS:
# Declarative operations that cannot be applied until after an actual model
# object is instantiated. So instead of returning the actual operation,
# we record the operation and its arguments in a queue. This queue is
# then applied in order whenever we actually instantiate the model (in
# self._model()).
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
queued_declarative_operations = list(
self._queued_declarative_operations,
)
queued_declarative_operations.append((name, args, kwargs))
return _ConfigurableModel(
default_config=dict(self._default_config),
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
return queue
if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
msg += "."
raise AttributeError(msg)
def _model(self, config: RunnableConfig | None = None) -> Runnable[Any, Any]:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: RunnableConfig | None) -> dict[str, Any]:
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
return model_params
def with_config(
self,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> _ConfigurableModel:
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
# Ensure config is not None after creation
config = ensure_config(config)
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
k: v
for k, v in config.get("configurable", {}).items()
if _remove_prefix(k, self._config_prefix) not in model_params
}
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
(
"with_config",
(),
{"config": remaining_config},
),
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
@property
@override
def InputType(self) -> TypeAlias:
"""Get the input type for this `Runnable`."""
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
@override
def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@override
async def ainvoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@override
def stream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@override
async def astream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
def batch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return self._model(config).batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return super().batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return await self._model(config).abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return await super().abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[tuple[int, Any | Exception]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
yield from super().batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
async for x in super().abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
@override
def transform(
self,
input: Iterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@override
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@overload
@override
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@overload
@override
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@override
async def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
input,
config=config,
diff=diff,
with_streamed_output_list=with_streamed_output_list,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
@override
async def astream_events(
self,
input: Any,
config: RunnableConfig | None = None,
*,
version: Literal["v1", "v2"] = "v2",
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(
input,
config=config,
version=version,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self,
schema: dict[str, Any] | type[BaseModel],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]:
return self.__getattr__("with_structured_output")(schema, **kwargs)

View File

@@ -0,0 +1,18 @@
"""Embeddings models.
!!! warning "Modules moved"
With the release of `langchain 1.0.0`, several embeddings modules were moved to
`langchain-classic`, such as `CacheBackedEmbeddings` and all community
embeddings. See [list](https://github.com/langchain-ai/langchain/blob/bdf1cd383ce36dc18381a3bf3fb0a579337a32b5/libs/langchain/langchain/embeddings/__init__.py)
of moved modules to inform your migration.
"""
from langchain_core.embeddings import Embeddings
from langchain.embeddings.base import init_embeddings
__all__ = [
"Embeddings",
"init_embeddings",
]

View File

@@ -0,0 +1,273 @@
"""Factory functions for embeddings."""
import functools
import importlib
from collections.abc import Callable
from typing import Any
from langchain_core.embeddings import Embeddings
def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
return cls(**kwargs)
_BUILTIN_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
"azure_openai": ("langchain_openai", "AzureOpenAIEmbeddings", _call),
"bedrock": (
"langchain_aws",
"BedrockEmbeddings",
lambda cls, model, **kwargs: cls(model_id=model, **kwargs),
),
"cohere": ("langchain_cohere", "CohereEmbeddings", _call),
"google_genai": ("langchain_google_genai", "GoogleGenerativeAIEmbeddings", _call),
"google_vertexai": ("langchain_google_vertexai", "VertexAIEmbeddings", _call),
"huggingface": (
"langchain_huggingface",
"HuggingFaceEmbeddings",
lambda cls, model, **kwargs: cls(model_name=model, **kwargs),
),
"mistralai": ("langchain_mistralai", "MistralAIEmbeddings", _call),
"ollama": ("langchain_ollama", "OllamaEmbeddings", _call),
"openai": ("langchain_openai", "OpenAIEmbeddings", _call),
}
"""Registry mapping provider names to their import configuration.
Each entry maps a provider key to a tuple of:
- `module_path`: The Python module path containing the embeddings class.
- `class_name`: The name of the embeddings class to import.
- `creator_func`: A callable that instantiates the class with provided kwargs.
!!! note
This dict is not exhaustive of all providers supported by LangChain, but is
meant to cover the most popular ones and serve as a template for adding more
providers in the future. If a provider is not in this dict, it can still be
used with `init_chat_model` as long as its integration package is installed,
but the provider key will not be inferred from the model name and must be
specified explicitly via the `model_provider` parameter.
Refer to the LangChain [integration documentation](https://docs.langchain.com/oss/python/integrations/providers/overview)
for a full list of supported providers and their corresponding packages.
"""
@functools.lru_cache(maxsize=len(_BUILTIN_PROVIDERS))
def _get_embeddings_class_creator(provider: str) -> Callable[..., Embeddings]:
"""Return a factory function that creates an embeddings model for the given provider.
This function is cached to avoid repeated module imports.
Args:
provider: The name of the model provider (e.g., `'openai'`, `'cohere'`).
Must be a key in `_BUILTIN_PROVIDERS`.
Returns:
A callable that accepts model kwargs and returns an `Embeddings` instance for
the specified provider.
Raises:
ValueError: If the provider is not in `_BUILTIN_PROVIDERS`.
ImportError: If the provider's integration package is not installed.
"""
if provider not in _BUILTIN_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
module_name, class_name, creator_func = _BUILTIN_PROVIDERS[provider]
try:
module = importlib.import_module(module_name)
except ImportError as e:
pkg = module_name.replace("_", "-")
msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
raise ImportError(msg) from e
cls = getattr(module, class_name)
return functools.partial(creator_func, cls=cls)
def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg[0].replace('_', '-')}" for p, pkg in _BUILTIN_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
is one of the supported providers.
Args:
model_name: A model string in the format 'provider:model-name'
Returns:
A tuple of (provider, model_name)
Example:
```python
_parse_model_string("openai:text-embedding-3-small")
# Returns: ("openai", "text-embedding-3-small")
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
```
Raises:
ValueError: If the model string is not in the correct format or
the provider is unsupported
"""
if ":" not in model_name:
msg = (
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
f"Example valid model strings:\n"
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {_BUILTIN_PROVIDERS.keys()}"
)
raise ValueError(msg)
provider, model = model_name.split(":", 1)
provider = provider.lower().strip()
model = model.strip()
if provider not in _BUILTIN_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
if not model:
msg = "Model name cannot be empty"
raise ValueError(msg)
return provider, model
def _infer_model_and_provider(
model: str,
*,
provider: str | None = None,
) -> tuple[str, str]:
if not model.strip():
msg = "Model name cannot be empty"
raise ValueError(msg)
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
model_name = model
if not provider:
msg = (
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: "
f"{_BUILTIN_PROVIDERS.keys()}"
)
raise ValueError(msg)
if provider not in _BUILTIN_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
return provider, model_name
def init_embeddings(
model: str,
*,
provider: str | None = None,
**kwargs: Any,
) -> Embeddings:
"""Initialize an embedding model from a model name and optional provider.
!!! note
Requires the integration package for the chosen model provider to be installed.
See the `model_provider` parameter below for specific package names
(e.g., `pip install langchain-openai`).
Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
for supported model parameters to use as `**kwargs`.
Args:
model: The name of the model, e.g. `'openai:text-embedding-3-small'`.
You can also specify model and model provider in a single argument using
`'{model_provider}:{model}'` format, e.g. `'openai:text-embedding-3-small'`.
provider: The model provider if not specified as part of the model arg
(see above).
Supported `provider` values and the corresponding integration package
are:
- `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
**kwargs: Additional model-specific parameters passed to the embedding model.
These vary by provider. Refer to the specific model provider's
[integration reference](https://reference.langchain.com/python/integrations/)
for all available parameters.
Returns:
An `Embeddings` instance that can generate embeddings for text.
Raises:
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
???+ example
```python
# pip install langchain langchain-openai
# Using a model string
model = init_embeddings("openai:text-embedding-3-small")
model.embed_query("Hello, world!")
# Using explicit provider
model = init_embeddings(model="text-embedding-3-small", provider="openai")
model.embed_documents(["Hello, world!", "Goodbye, world!"])
# With additional parameters
model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...")
```
!!! version-added "Added in `langchain` 0.3.9"
"""
if not model:
providers = _BUILTIN_PROVIDERS.keys()
msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider)
return _get_embeddings_class_creator(provider)(model=model_name, **kwargs)
__all__ = [
"Embeddings", # This one is for backwards compatibility
"init_embeddings",
]

View File

@@ -0,0 +1,73 @@
"""Message and message content types.
Includes message types for different roles (e.g., human, AI, system), as well as types
for message content blocks (e.g., text, image, audio) and tool calls.
"""
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Annotation,
AnyMessage,
AudioContentBlock,
Citation,
ContentBlock,
DataContentBlock,
FileContentBlock,
HumanMessage,
ImageContentBlock,
InputTokenDetails,
InvalidToolCall,
MessageLikeRepresentation,
NonStandardAnnotation,
NonStandardContentBlock,
OutputTokenDetails,
PlainTextContentBlock,
ReasoningContentBlock,
RemoveMessage,
ServerToolCall,
ServerToolCallChunk,
ServerToolResult,
SystemMessage,
TextContentBlock,
ToolCall,
ToolCallChunk,
ToolMessage,
UsageMetadata,
VideoContentBlock,
trim_messages,
)
__all__ = [
"AIMessage",
"AIMessageChunk",
"Annotation",
"AnyMessage",
"AudioContentBlock",
"Citation",
"ContentBlock",
"DataContentBlock",
"FileContentBlock",
"HumanMessage",
"ImageContentBlock",
"InputTokenDetails",
"InvalidToolCall",
"MessageLikeRepresentation",
"NonStandardAnnotation",
"NonStandardContentBlock",
"OutputTokenDetails",
"PlainTextContentBlock",
"ReasoningContentBlock",
"RemoveMessage",
"ServerToolCall",
"ServerToolCallChunk",
"ServerToolResult",
"SystemMessage",
"TextContentBlock",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"UsageMetadata",
"VideoContentBlock",
"trim_messages",
]

View File

@@ -0,0 +1,13 @@
"""Base abstraction and in-memory implementation of rate limiters.
These rate limiters can be used to limit the rate of requests to an API.
The rate limiters can be used together with `BaseChatModel`.
"""
from langchain_core.rate_limiters import BaseRateLimiter, InMemoryRateLimiter
__all__ = [
"BaseRateLimiter",
"InMemoryRateLimiter",
]

View File

@@ -0,0 +1,22 @@
"""Tools."""
from langchain_core.tools import (
BaseTool,
InjectedToolArg,
InjectedToolCallId,
ToolException,
tool,
)
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime
__all__ = [
"BaseTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"ToolRuntime",
"tool",
]

View File

@@ -0,0 +1,20 @@
"""Utils file included for backwards compat imports."""
from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime
from langgraph.prebuilt.tool_node import (
ToolCallRequest,
ToolCallWithContext,
ToolCallWrapper,
)
from langgraph.prebuilt.tool_node import (
ToolNode as _ToolNode, # noqa: F401
)
__all__ = [
"InjectedState",
"InjectedStore",
"ToolCallRequest",
"ToolCallWithContext",
"ToolCallWrapper",
"ToolRuntime",
]