initial commit
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
216
venv/Lib/site-packages/langchain_classic/agents/mrkl/base.py
Normal file
216
venv/Lib/site-packages/langchain_classic/agents/mrkl/base.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_core.tools.render import render_text_description
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain_classic.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
from langchain_classic.agents.agent_types import AgentType
|
||||
from langchain_classic.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain_classic.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain_classic.agents.utils import validate_tools_single_input
|
||||
from langchain_classic.chains import LLMChain
|
||||
|
||||
|
||||
class ChainConfig(NamedTuple):
|
||||
"""Configuration for a chain to use in MRKL system.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action.
|
||||
action: Action function to call.
|
||||
action_description: Description of the action.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action: Callable
|
||||
action_description: str
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.0",
|
||||
message=AGENT_DEPRECATION_WARNING,
|
||||
removal="1.0",
|
||||
)
|
||||
class ZeroShotAgent(Agent):
|
||||
"""Agent for the MRKL chain.
|
||||
|
||||
Args:
|
||||
output_parser: Output parser for the agent.
|
||||
"""
|
||||
|
||||
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return MRKLOutputParser()
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
return AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the observation with.
|
||||
|
||||
Returns:
|
||||
"Observation: "
|
||||
"""
|
||||
return "Observation: "
|
||||
|
||||
@property
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the llm call with.
|
||||
|
||||
Returns:
|
||||
"Thought: "
|
||||
"""
|
||||
return "Thought:"
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: list[str] | None = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
suffix: String to put after the list of tools.
|
||||
format_instructions: Instructions on how to use the tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
tool_strings = render_text_description(list(tools))
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = f"{prefix}\n\n{tool_strings}\n\n{format_instructions}\n\n{suffix}"
|
||||
if input_variables:
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
return PromptTemplate.from_template(template)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: BaseCallbackManager | None = None,
|
||||
output_parser: AgentOutputParser | None = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use as the agent LLM.
|
||||
tools: The tools to use.
|
||||
callback_manager: The callback manager to use.
|
||||
output_parser: The output parser to use.
|
||||
prefix: The prefix to use.
|
||||
suffix: The suffix to use.
|
||||
format_instructions: The format instructions to use.
|
||||
input_variables: The input variables to use.
|
||||
kwargs: Additional parameters to pass to the agent.
|
||||
"""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser()
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
if len(tools) == 0:
|
||||
msg = (
|
||||
f"Got no tools for {cls.__name__}. At least one tool must be provided."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
for tool in tools:
|
||||
if tool.description is None:
|
||||
msg = ( # type: ignore[unreachable]
|
||||
f"Got a tool {tool.name} without a description. For this agent, "
|
||||
f"a description must always be provided."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
super()._validate_tools(tools)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"0.1.0",
|
||||
message=AGENT_DEPRECATION_WARNING,
|
||||
removal="1.0",
|
||||
)
|
||||
class MRKLChain(AgentExecutor):
|
||||
"""Chain that implements the MRKL system."""
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chains: list[ChainConfig],
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""User-friendly way to initialize the MRKL chain.
|
||||
|
||||
This is intended to be an easy way to get up and running with the
|
||||
MRKL chain.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use as the agent LLM.
|
||||
chains: The chains the MRKL system has access to.
|
||||
**kwargs: parameters to be passed to initialization.
|
||||
|
||||
Returns:
|
||||
An initialized MRKL chain.
|
||||
"""
|
||||
tools = [
|
||||
Tool(
|
||||
name=c.action_name,
|
||||
func=c.action,
|
||||
description=c.action_description,
|
||||
)
|
||||
for c in chains
|
||||
]
|
||||
agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
|
||||
return cls(agent=agent, tools=tools, **kwargs)
|
||||
@@ -0,0 +1,99 @@
|
||||
import re
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
from langchain_classic.agents.agent import AgentOutputParser
|
||||
from langchain_classic.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action:' after 'Thought:"
|
||||
)
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action Input:' after 'Action:'"
|
||||
)
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||
)
|
||||
|
||||
|
||||
class MRKLOutputParser(AgentOutputParser):
|
||||
"""MRKL Output parser for the chat agent."""
|
||||
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS
|
||||
"""Default formatting instructions"""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Returns formatting instructions for the given output parser."""
|
||||
return self.format_instructions
|
||||
|
||||
def parse(self, text: str) -> AgentAction | AgentFinish:
|
||||
"""Parse the output from the agent into an AgentAction or AgentFinish object.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
|
||||
Returns:
|
||||
An AgentAction or AgentFinish object.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the output could not be parsed.
|
||||
"""
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = r"Action\s*\d*\s*:[\s]*(.*?)Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if action_match and includes_answer:
|
||||
if text.find(FINAL_ANSWER_ACTION) < text.find(action_match.group(0)):
|
||||
# if final answer is before the hallucination, return final answer
|
||||
start_index = text.find(FINAL_ANSWER_ACTION) + len(FINAL_ANSWER_ACTION)
|
||||
end_index = text.find("\n\n", start_index)
|
||||
return AgentFinish(
|
||||
{"output": text[start_index:end_index].strip()},
|
||||
text[:end_index],
|
||||
)
|
||||
msg = f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
if action_match:
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
# ensure if its a well formed SQL query we don't remove any trailing " chars
|
||||
if tool_input.startswith("SELECT ") is False:
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
return AgentAction(action, tool_input, text)
|
||||
|
||||
if includes_answer:
|
||||
return AgentFinish(
|
||||
{"output": text.rsplit(FINAL_ANSWER_ACTION, maxsplit=1)[-1].strip()},
|
||||
text,
|
||||
)
|
||||
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
raise OutputParserException(
|
||||
msg,
|
||||
observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
if not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)",
|
||||
text,
|
||||
re.DOTALL,
|
||||
):
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
raise OutputParserException(
|
||||
msg,
|
||||
observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
msg = f"Could not parse LLM output: `{text}`"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "mrkl"
|
||||
@@ -0,0 +1,15 @@
|
||||
PREFIX = """Answer the following questions as best you can. You have access to the following tools:""" # noqa: E501
|
||||
FORMAT_INSTRUCTIONS = """Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question"""
|
||||
SUFFIX = """Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought:{agent_scratchpad}"""
|
||||
Reference in New Issue
Block a user