681 lines
23 KiB
Python
681 lines
23 KiB
Python
"""Data models for the sandbox client."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator, Iterator
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
from langsmith.sandbox._exceptions import (
|
|
SandboxConnectionError,
|
|
SandboxOperationError,
|
|
SandboxServerReloadError,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from langsmith.sandbox._async_sandbox import AsyncSandbox
|
|
from langsmith.sandbox._sandbox import Sandbox
|
|
from langsmith.sandbox._ws_execute import (
|
|
_AsyncWSStreamControl,
|
|
_WSStreamControl,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ExecutionResult:
|
|
"""Result of executing a command in a sandbox."""
|
|
|
|
stdout: str
|
|
stderr: str
|
|
exit_code: int
|
|
|
|
@property
|
|
def success(self) -> bool:
|
|
"""Return True if the command exited with code 0."""
|
|
return self.exit_code == 0
|
|
|
|
|
|
@dataclass
|
|
class ResourceSpec:
|
|
"""Resource specification for a sandbox."""
|
|
|
|
cpu: str = "500m"
|
|
memory: str = "512Mi"
|
|
storage: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class Volume:
|
|
"""Represents a persistent volume.
|
|
|
|
Volumes are persistent storage that can be mounted in sandboxes.
|
|
|
|
Attributes:
|
|
id: Unique identifier (UUID). Remains constant even if name changes.
|
|
May be None for resources created before ID support was added.
|
|
name: Display name (can be updated).
|
|
"""
|
|
|
|
name: str
|
|
size: str
|
|
storage_class: str
|
|
id: Optional[str] = None
|
|
created_at: Optional[str] = None
|
|
updated_at: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> Volume:
|
|
"""Create a Volume from API response dict."""
|
|
return cls(
|
|
name=data.get("name", ""),
|
|
size=data.get("size", "unknown"),
|
|
storage_class=data.get("storage_class", "default"),
|
|
id=data.get("id"),
|
|
created_at=data.get("created_at"),
|
|
updated_at=data.get("updated_at"),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class VolumeMountSpec:
|
|
"""Specification for mounting a volume in a sandbox template."""
|
|
|
|
volume_name: str
|
|
mount_path: str
|
|
|
|
|
|
@dataclass
|
|
class SandboxTemplate:
|
|
"""Represents a SandboxTemplate.
|
|
|
|
Templates define the image, resource limits, and volume mounts for sandboxes.
|
|
All other container details are handled by the server with secure defaults.
|
|
|
|
Attributes:
|
|
id: Unique identifier (UUID). Remains constant even if name changes.
|
|
May be None for resources created before ID support was added.
|
|
name: Display name (can be updated).
|
|
"""
|
|
|
|
name: str
|
|
image: str
|
|
resources: ResourceSpec
|
|
volume_mounts: list[VolumeMountSpec] = field(default_factory=list)
|
|
id: Optional[str] = None
|
|
created_at: Optional[str] = None
|
|
updated_at: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> SandboxTemplate:
|
|
"""Create a SandboxTemplate from API response dict."""
|
|
resources_data = data.get("resources", {})
|
|
volume_mounts_data = data.get("volume_mounts", [])
|
|
return cls(
|
|
name=data.get("name", ""),
|
|
image=data.get("image", "unknown"),
|
|
resources=ResourceSpec(
|
|
cpu=resources_data.get("cpu", "500m"),
|
|
memory=resources_data.get("memory", "512Mi"),
|
|
storage=resources_data.get("storage"),
|
|
),
|
|
volume_mounts=[
|
|
VolumeMountSpec(
|
|
volume_name=vm.get("volume_name", ""),
|
|
mount_path=vm.get("mount_path", ""),
|
|
)
|
|
for vm in volume_mounts_data
|
|
],
|
|
id=data.get("id"),
|
|
created_at=data.get("created_at"),
|
|
updated_at=data.get("updated_at"),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ResourceStatus:
|
|
"""Lightweight provisioning status for any async-created resource.
|
|
|
|
Attributes:
|
|
status: Resource lifecycle status. One of "provisioning", "ready", "failed".
|
|
status_message: Human-readable details when status is "failed", None otherwise.
|
|
"""
|
|
|
|
status: str
|
|
status_message: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> ResourceStatus:
|
|
"""Create a ResourceStatus from API response dict."""
|
|
return cls(
|
|
status=data.get("status", "provisioning"),
|
|
status_message=data.get("status_message"),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Pool:
|
|
"""Represents a Sandbox Pool for pre-provisioned sandboxes.
|
|
|
|
Pools pre-provision sandboxes from a template for faster startup.
|
|
Instead of waiting for a new sandbox to be created, sandboxes can
|
|
be served from a pre-warmed pool.
|
|
|
|
Note: Templates with volume mounts cannot be used in pools.
|
|
|
|
Attributes:
|
|
id: Unique identifier (UUID). Remains constant even if name changes.
|
|
May be None for resources created before ID support was added.
|
|
name: Display name (can be updated).
|
|
"""
|
|
|
|
name: str
|
|
template_name: str
|
|
replicas: int # Desired replicas
|
|
id: Optional[str] = None
|
|
created_at: Optional[str] = None
|
|
updated_at: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> Pool:
|
|
"""Create a Pool from API response dict."""
|
|
return cls(
|
|
name=data.get("name", ""),
|
|
template_name=data.get("template_name", ""),
|
|
replicas=data.get("replicas", 0),
|
|
id=data.get("id"),
|
|
created_at=data.get("created_at"),
|
|
updated_at=data.get("updated_at"),
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# WebSocket Command Execution Models
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class OutputChunk:
|
|
"""A single chunk of streaming output from command execution.
|
|
|
|
Attributes:
|
|
stream: Either "stdout" or "stderr".
|
|
data: The text content of this chunk (valid UTF-8, server handles
|
|
boundary splitting).
|
|
offset: Byte offset within the stream. Used internally for
|
|
reconnection; users typically don't need this.
|
|
"""
|
|
|
|
stream: str
|
|
data: str
|
|
offset: int
|
|
|
|
|
|
class CommandHandle:
|
|
"""Handle to a running command with streaming output and auto-reconnect.
|
|
|
|
Iterable, yielding OutputChunk objects (stdout and stderr interleaved
|
|
in arrival order). Access .result after iteration to get the full
|
|
ExecutionResult.
|
|
|
|
Auto-reconnect behavior:
|
|
- Server hot-reload (1001 Going Away): reconnect immediately
|
|
- Network error / unexpected close: reconnect with exponential backoff
|
|
- User called kill(): do NOT reconnect (propagate error)
|
|
|
|
The auto-reconnect is transparent -- the iterator reconnects and
|
|
continues yielding chunks without any user intervention. If all
|
|
reconnect attempts are exhausted, SandboxConnectionError is raised.
|
|
|
|
Construction modes (controlled by ``command_id``):
|
|
- **New execution** (``command_id=""``, the default): the constructor
|
|
eagerly reads the server's ``"started"`` message to populate
|
|
``command_id`` and ``pid`` before returning.
|
|
- **Reconnection** (``command_id`` set): skips the started-message
|
|
read, since reconnect streams don't emit one.
|
|
|
|
Example:
|
|
handle = sandbox.run("make build", timeout=600, wait=False)
|
|
|
|
for chunk in handle: # auto-reconnects on transient errors
|
|
print(chunk.data, end="")
|
|
|
|
result = handle.result
|
|
print(f"Exit code: {result.exit_code}")
|
|
"""
|
|
|
|
MAX_AUTO_RECONNECTS = 5
|
|
_BACKOFF_BASE = 0.5 # seconds
|
|
_BACKOFF_MAX = 8.0 # seconds
|
|
|
|
def __init__(
|
|
self,
|
|
message_stream: Iterator[dict],
|
|
control: Optional[_WSStreamControl],
|
|
sandbox: Sandbox,
|
|
*,
|
|
command_id: str = "",
|
|
stdout_offset: int = 0,
|
|
stderr_offset: int = 0,
|
|
) -> None:
|
|
self._stream = message_stream
|
|
self._control = control
|
|
self._sandbox = sandbox
|
|
self._command_id: Optional[str] = None
|
|
self._pid: Optional[int] = None
|
|
self._result: Optional[ExecutionResult] = None
|
|
self._stdout_parts: list[str] = []
|
|
self._stderr_parts: list[str] = []
|
|
self._exhausted = False
|
|
self._last_stdout_offset = stdout_offset
|
|
self._last_stderr_offset = stderr_offset
|
|
|
|
# New executions (command_id=""): eager_start reads "started" message.
|
|
# Reconnections (command_id set): skip eager_start since reconnect
|
|
# streams don't send a "started" message.
|
|
if command_id:
|
|
self._command_id = command_id
|
|
else:
|
|
self._consume_started()
|
|
|
|
def _consume_started(self) -> None:
|
|
"""Eagerly read the 'started' message to populate command_id and pid.
|
|
|
|
Blocks briefly until the server sends the started message (arrives
|
|
near-instantly after connection). After this call, command_id and
|
|
pid are available, and the WebSocket is bound to the control object
|
|
(so kill() works).
|
|
"""
|
|
try:
|
|
first_msg = next(self._stream)
|
|
except StopIteration:
|
|
raise SandboxOperationError(
|
|
"Command stream ended before 'started' message",
|
|
operation="command",
|
|
)
|
|
if first_msg.get("type") != "started":
|
|
raise SandboxOperationError(
|
|
f"Expected 'started' message, got '{first_msg.get('type')}'",
|
|
operation="command",
|
|
)
|
|
self._command_id = first_msg.get("command_id")
|
|
self._pid = first_msg.get("pid")
|
|
|
|
@property
|
|
def command_id(self) -> Optional[str]:
|
|
"""The server-assigned command ID. Available after construction."""
|
|
return self._command_id
|
|
|
|
@property
|
|
def pid(self) -> Optional[int]:
|
|
"""The process ID on the sandbox. Available after construction."""
|
|
return self._pid
|
|
|
|
@property
|
|
def result(self) -> ExecutionResult:
|
|
"""The final execution result. Blocks until the command completes.
|
|
|
|
Drains the remaining stream if not already exhausted, then returns
|
|
the ExecutionResult with aggregated stdout, stderr, and exit_code.
|
|
"""
|
|
if self._result is None:
|
|
for _ in self:
|
|
pass
|
|
if self._result is None:
|
|
raise SandboxOperationError(
|
|
"Command stream ended without exit message",
|
|
operation="command",
|
|
)
|
|
return self._result
|
|
|
|
def _iter_stream(self) -> Iterator[OutputChunk]:
|
|
"""Iterate over output chunks from the current stream (no reconnect)."""
|
|
if self._exhausted:
|
|
return
|
|
for msg in self._stream:
|
|
msg_type = msg.get("type")
|
|
if msg_type in ("stdout", "stderr"):
|
|
chunk = OutputChunk(
|
|
stream=msg_type,
|
|
data=msg["data"],
|
|
offset=msg.get("offset", 0),
|
|
)
|
|
if msg_type == "stdout":
|
|
self._stdout_parts.append(msg["data"])
|
|
else:
|
|
self._stderr_parts.append(msg["data"])
|
|
yield chunk
|
|
elif msg_type == "exit":
|
|
self._result = ExecutionResult(
|
|
stdout="".join(self._stdout_parts),
|
|
stderr="".join(self._stderr_parts),
|
|
exit_code=msg["exit_code"],
|
|
)
|
|
self._exhausted = True
|
|
return
|
|
self._exhausted = True
|
|
|
|
def __iter__(self) -> Iterator[OutputChunk]:
|
|
"""Iterate over output chunks, auto-reconnecting on transient errors.
|
|
|
|
Reconnect strategy:
|
|
- 1001 Going Away (hot-reload): immediate reconnect, no delay
|
|
- Other SandboxConnectionError: exponential backoff (0.5s, 1s, 2s...)
|
|
- After kill(): no reconnect, error propagates
|
|
"""
|
|
import time
|
|
|
|
reconnect_attempts = 0
|
|
while True:
|
|
try:
|
|
for chunk in self._iter_stream():
|
|
reconnect_attempts = 0 # Reset on successful data
|
|
if chunk.stream == "stdout":
|
|
self._last_stdout_offset = chunk.offset + len(
|
|
chunk.data.encode("utf-8")
|
|
)
|
|
else:
|
|
self._last_stderr_offset = chunk.offset + len(
|
|
chunk.data.encode("utf-8")
|
|
)
|
|
yield chunk
|
|
return # Stream ended normally (exit message received)
|
|
|
|
except SandboxConnectionError as e:
|
|
if self._control and self._control.killed:
|
|
raise
|
|
|
|
reconnect_attempts += 1
|
|
if reconnect_attempts > self.MAX_AUTO_RECONNECTS:
|
|
raise SandboxConnectionError(
|
|
f"Lost connection {reconnect_attempts} times in "
|
|
f"succession, giving up"
|
|
) from e
|
|
|
|
is_hot_reload = isinstance(e, SandboxServerReloadError)
|
|
if not is_hot_reload:
|
|
delay = min(
|
|
self._BACKOFF_BASE * (2 ** (reconnect_attempts - 1)),
|
|
self._BACKOFF_MAX,
|
|
)
|
|
time.sleep(delay)
|
|
|
|
assert self._command_id is not None
|
|
new_handle = self._sandbox.reconnect(
|
|
self._command_id,
|
|
stdout_offset=self._last_stdout_offset,
|
|
stderr_offset=self._last_stderr_offset,
|
|
)
|
|
self._stream = new_handle._stream
|
|
self._control = new_handle._control
|
|
self._exhausted = False
|
|
|
|
def kill(self) -> None:
|
|
"""Send a kill signal to the running command (SIGKILL).
|
|
|
|
The server kills the entire process group. The stream will
|
|
subsequently yield an exit message with a non-zero exit code.
|
|
|
|
Has no effect if the command has already exited or the
|
|
WebSocket connection is closed.
|
|
"""
|
|
if self._control:
|
|
self._control.send_kill()
|
|
|
|
def send_input(self, data: str) -> None:
|
|
"""Write data to the command's stdin.
|
|
|
|
Args:
|
|
data: String data to write to stdin.
|
|
|
|
Has no effect if the command has already exited or the
|
|
WebSocket connection is closed.
|
|
"""
|
|
if self._control:
|
|
self._control.send_input(data)
|
|
|
|
@property
|
|
def last_stdout_offset(self) -> int:
|
|
"""Last known stdout byte offset (for manual reconnection)."""
|
|
return self._last_stdout_offset
|
|
|
|
@property
|
|
def last_stderr_offset(self) -> int:
|
|
"""Last known stderr byte offset (for manual reconnection)."""
|
|
return self._last_stderr_offset
|
|
|
|
def reconnect(self) -> CommandHandle:
|
|
"""Reconnect to this command from the last known offsets.
|
|
|
|
Returns a new handle that resumes output from where this one
|
|
left off. Any output produced while disconnected is replayed
|
|
from the server's ring buffer.
|
|
|
|
Returns:
|
|
A new CommandHandle.
|
|
|
|
Raises:
|
|
SandboxOperationError: If command_id is not found or
|
|
session expired.
|
|
SandboxConnectionError: If connection to sandbox fails.
|
|
"""
|
|
assert self._command_id is not None
|
|
return self._sandbox.reconnect(
|
|
self._command_id,
|
|
stdout_offset=self._last_stdout_offset,
|
|
stderr_offset=self._last_stderr_offset,
|
|
)
|
|
|
|
|
|
class AsyncCommandHandle:
|
|
"""Async handle to a running command with streaming output and auto-reconnect.
|
|
|
|
Async iterable, yielding OutputChunk objects (stdout and stderr interleaved
|
|
in arrival order). Access .result after iteration to get the full
|
|
ExecutionResult.
|
|
|
|
Auto-reconnect behavior:
|
|
- Server hot-reload (1001 Going Away): reconnect immediately
|
|
- Network error / unexpected close: reconnect with exponential backoff
|
|
- User called kill(): do NOT reconnect (propagate error)
|
|
|
|
Construction modes (controlled by ``command_id``):
|
|
- **New execution** (``command_id=""``, the default): call
|
|
``await handle._ensure_started()`` after construction to read the
|
|
server's ``"started"`` message and populate ``command_id`` / ``pid``.
|
|
- **Reconnection** (``command_id`` set): skips the started-message
|
|
read, since reconnect streams don't emit one.
|
|
|
|
Example:
|
|
handle = await sandbox.run("make build", timeout=600, wait=False)
|
|
|
|
async for chunk in handle: # auto-reconnects on transient errors
|
|
print(chunk.data, end="")
|
|
|
|
result = await handle.result
|
|
print(f"Exit code: {result.exit_code}")
|
|
"""
|
|
|
|
MAX_AUTO_RECONNECTS = 5
|
|
_BACKOFF_BASE = 0.5 # seconds
|
|
_BACKOFF_MAX = 8.0 # seconds
|
|
|
|
def __init__(
|
|
self,
|
|
message_stream: AsyncIterator[dict],
|
|
control: Optional[_AsyncWSStreamControl],
|
|
sandbox: AsyncSandbox,
|
|
*,
|
|
command_id: str = "",
|
|
stdout_offset: int = 0,
|
|
stderr_offset: int = 0,
|
|
) -> None:
|
|
self._stream = message_stream
|
|
self._control = control
|
|
self._sandbox = sandbox
|
|
self._command_id: Optional[str] = None
|
|
self._pid: Optional[int] = None
|
|
self._result: Optional[ExecutionResult] = None
|
|
self._stdout_parts: list[str] = []
|
|
self._stderr_parts: list[str] = []
|
|
self._exhausted = False
|
|
self._last_stdout_offset = stdout_offset
|
|
self._last_stderr_offset = stderr_offset
|
|
|
|
# New executions (command_id=""): _ensure_started reads "started".
|
|
# Reconnections (command_id set): skip since reconnect streams
|
|
# don't send a "started" message.
|
|
if command_id:
|
|
self._command_id = command_id
|
|
self._started = True
|
|
else:
|
|
self._started = False
|
|
|
|
async def _ensure_started(self) -> None:
|
|
"""Read the 'started' message to populate command_id and pid."""
|
|
if self._started:
|
|
return
|
|
try:
|
|
first_msg = await self._stream.__anext__()
|
|
except StopAsyncIteration:
|
|
raise SandboxOperationError(
|
|
"Command stream ended before 'started' message",
|
|
operation="command",
|
|
)
|
|
if first_msg.get("type") != "started":
|
|
raise SandboxOperationError(
|
|
f"Expected 'started' message, got '{first_msg.get('type')}'",
|
|
operation="command",
|
|
)
|
|
self._command_id = first_msg.get("command_id")
|
|
self._pid = first_msg.get("pid")
|
|
self._started = True
|
|
|
|
@property
|
|
def command_id(self) -> Optional[str]:
|
|
"""The server-assigned command ID. Available after _ensure_started."""
|
|
return self._command_id
|
|
|
|
@property
|
|
def pid(self) -> Optional[int]:
|
|
"""The process ID on the sandbox. Available after _ensure_started."""
|
|
return self._pid
|
|
|
|
@property
|
|
async def result(self) -> ExecutionResult:
|
|
"""The final execution result. Awaitable."""
|
|
if self._result is None:
|
|
async for _ in self:
|
|
pass
|
|
if self._result is None:
|
|
raise SandboxOperationError(
|
|
"Command stream ended without exit message",
|
|
operation="command",
|
|
)
|
|
return self._result
|
|
|
|
async def _aiter_stream(self) -> AsyncIterator[OutputChunk]:
|
|
"""Iterate over output chunks from the current stream (no reconnect)."""
|
|
await self._ensure_started()
|
|
if self._exhausted:
|
|
return
|
|
async for msg in self._stream:
|
|
msg_type = msg.get("type")
|
|
if msg_type in ("stdout", "stderr"):
|
|
chunk = OutputChunk(
|
|
stream=msg_type,
|
|
data=msg["data"],
|
|
offset=msg.get("offset", 0),
|
|
)
|
|
if msg_type == "stdout":
|
|
self._stdout_parts.append(msg["data"])
|
|
else:
|
|
self._stderr_parts.append(msg["data"])
|
|
yield chunk
|
|
elif msg_type == "exit":
|
|
self._result = ExecutionResult(
|
|
stdout="".join(self._stdout_parts),
|
|
stderr="".join(self._stderr_parts),
|
|
exit_code=msg["exit_code"],
|
|
)
|
|
self._exhausted = True
|
|
return
|
|
self._exhausted = True
|
|
|
|
async def __aiter__(self) -> AsyncIterator[OutputChunk]:
|
|
"""Async iterate with auto-reconnect on transient errors."""
|
|
import asyncio
|
|
|
|
reconnect_attempts = 0
|
|
while True:
|
|
try:
|
|
async for chunk in self._aiter_stream():
|
|
reconnect_attempts = 0
|
|
if chunk.stream == "stdout":
|
|
self._last_stdout_offset = chunk.offset + len(
|
|
chunk.data.encode("utf-8")
|
|
)
|
|
else:
|
|
self._last_stderr_offset = chunk.offset + len(
|
|
chunk.data.encode("utf-8")
|
|
)
|
|
yield chunk
|
|
return # Stream ended normally
|
|
|
|
except SandboxConnectionError as e:
|
|
if self._control and self._control.killed:
|
|
raise
|
|
|
|
reconnect_attempts += 1
|
|
if reconnect_attempts > self.MAX_AUTO_RECONNECTS:
|
|
raise SandboxConnectionError(
|
|
f"Lost connection {reconnect_attempts} times "
|
|
f"in succession, giving up"
|
|
) from e
|
|
|
|
is_hot_reload = isinstance(e, SandboxServerReloadError)
|
|
if not is_hot_reload:
|
|
delay = min(
|
|
self._BACKOFF_BASE * (2 ** (reconnect_attempts - 1)),
|
|
self._BACKOFF_MAX,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
assert self._command_id is not None
|
|
new_handle = await self._sandbox.reconnect(
|
|
self._command_id,
|
|
stdout_offset=self._last_stdout_offset,
|
|
stderr_offset=self._last_stderr_offset,
|
|
)
|
|
self._stream = new_handle._stream
|
|
self._control = new_handle._control
|
|
self._exhausted = False
|
|
|
|
async def kill(self) -> None:
|
|
"""Send a kill signal to the running command."""
|
|
if self._control:
|
|
await self._control.send_kill()
|
|
|
|
async def send_input(self, data: str) -> None:
|
|
"""Write data to the command's stdin."""
|
|
if self._control:
|
|
await self._control.send_input(data)
|
|
|
|
@property
|
|
def last_stdout_offset(self) -> int:
|
|
"""Last known stdout byte offset (for manual reconnection)."""
|
|
return self._last_stdout_offset
|
|
|
|
@property
|
|
def last_stderr_offset(self) -> int:
|
|
"""Last known stderr byte offset (for manual reconnection)."""
|
|
return self._last_stderr_offset
|
|
|
|
async def reconnect(self) -> AsyncCommandHandle:
|
|
"""Reconnect to this command from the last known offsets."""
|
|
assert self._command_id is not None
|
|
return await self._sandbox.reconnect(
|
|
self._command_id,
|
|
stdout_offset=self._last_stdout_offset,
|
|
stderr_offset=self._last_stderr_offset,
|
|
)
|