initial commit
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""Shell tool."""
|
||||
|
||||
from langchain_community.tools.shell.tool import ShellTool
|
||||
|
||||
__all__ = ["ShellTool"]
|
||||
Binary file not shown.
Binary file not shown.
103
venv/Lib/site-packages/langchain_community/tools/shell/tool.py
Normal file
103
venv/Lib/site-packages/langchain_community/tools/shell/tool.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import logging
|
||||
import platform
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShellInput(BaseModel):
|
||||
"""Commands for the Bash Shell tool."""
|
||||
|
||||
commands: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="List of shell commands to run. Deserialized using json.loads",
|
||||
)
|
||||
"""List of shell commands to run."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_commands(cls, values: dict) -> Any:
|
||||
"""Validate commands."""
|
||||
# TODO: Add real validators
|
||||
commands = values.get("commands")
|
||||
if not isinstance(commands, list):
|
||||
values["commands"] = [commands]
|
||||
# Warn that the bash tool is not safe
|
||||
warnings.warn(
|
||||
"The shell tool has no safeguards by default. Use at your own risk."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
def _get_default_bash_process() -> Any:
|
||||
"""Get default bash process."""
|
||||
try:
|
||||
from langchain_experimental.llm_bash.bash import BashProcess
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"BashProcess has been moved to langchain experimental."
|
||||
"To use this tool, install langchain-experimental "
|
||||
"with `pip install langchain-experimental`."
|
||||
)
|
||||
return BashProcess(return_err_output=True)
|
||||
|
||||
|
||||
def _get_platform() -> str:
|
||||
"""Get platform."""
|
||||
system = platform.system()
|
||||
if system == "Darwin":
|
||||
return "MacOS"
|
||||
return system
|
||||
|
||||
|
||||
class ShellTool(BaseTool):
|
||||
"""Tool to run shell commands."""
|
||||
|
||||
process: Any = Field(default_factory=_get_default_bash_process)
|
||||
"""Bash process to run commands."""
|
||||
|
||||
name: str = "terminal"
|
||||
"""Name of tool."""
|
||||
|
||||
description: str = f"Run shell commands on this {_get_platform()} machine."
|
||||
"""Description of tool."""
|
||||
|
||||
args_schema: Type[BaseModel] = ShellInput
|
||||
"""Schema for input arguments."""
|
||||
|
||||
ask_human_input: bool = False
|
||||
"""
|
||||
If True, prompts the user for confirmation (y/n) before executing
|
||||
a command generated by the language model in the bash shell.
|
||||
"""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
commands: Union[str, List[str]],
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Run commands and return final output."""
|
||||
|
||||
print(f"Executing command:\n {commands}") # noqa: T201
|
||||
|
||||
try:
|
||||
if self.ask_human_input:
|
||||
user_input = input("Proceed with command execution? (y/n): ").lower()
|
||||
if user_input == "y":
|
||||
return self.process.run(commands)
|
||||
else:
|
||||
logger.info("Invalid input. User aborted command execution.")
|
||||
return None # type: ignore[return-value]
|
||||
else:
|
||||
return self.process.run(commands)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during command execution: {e}")
|
||||
return None # type: ignore[return-value]
|
||||
Reference in New Issue
Block a user