initial commit
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
from langchain_classic.chains.router.base import MultiRouteChain, RouterChain
|
||||
from langchain_classic.chains.router.llm_router import LLMRouterChain
|
||||
from langchain_classic.chains.router.multi_prompt import MultiPromptChain
|
||||
from langchain_classic.chains.router.multi_retrieval_qa import MultiRetrievalQAChain
|
||||
|
||||
__all__ = [
|
||||
"LLMRouterChain",
|
||||
"MultiPromptChain",
|
||||
"MultiRetrievalQAChain",
|
||||
"MultiRouteChain",
|
||||
"RouterChain",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
147
venv/Lib/site-packages/langchain_classic/chains/router/base.py
Normal file
147
venv/Lib/site-packages/langchain_classic/chains/router/base.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Base classes for chain routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
"""A route to a destination chain."""
|
||||
|
||||
destination: str | None
|
||||
next_inputs: dict[str, Any]
|
||||
|
||||
|
||||
class RouterChain(Chain, ABC):
|
||||
"""Chain that outputs the name of a destination chain and the inputs to it."""
|
||||
|
||||
@property
|
||||
@override
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||
"""Route inputs to a destination chain.
|
||||
|
||||
Args:
|
||||
inputs: inputs to the chain
|
||||
callbacks: callbacks to use for the chain
|
||||
|
||||
Returns:
|
||||
a Route object
|
||||
"""
|
||||
result = self(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
async def aroute(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
callbacks: Callbacks = None,
|
||||
) -> Route:
|
||||
"""Route inputs to a destination chain.
|
||||
|
||||
Args:
|
||||
inputs: inputs to the chain
|
||||
callbacks: callbacks to use for the chain
|
||||
|
||||
Returns:
|
||||
a Route object
|
||||
"""
|
||||
result = await self.acall(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
||||
class MultiRouteChain(Chain):
|
||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||
|
||||
router_chain: RouterChain
|
||||
"""Chain that routes inputs to destination chains."""
|
||||
destination_chains: Mapping[str, Chain]
|
||||
"""Chains that return final answer to inputs."""
|
||||
default_chain: Chain
|
||||
"""Default chain to use when none of the destination chains are suitable."""
|
||||
silent_errors: bool = False
|
||||
"""If `True`, use default_chain when an invalid destination name is provided."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the router chain prompt expects."""
|
||||
return self.router_chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Will always return text key."""
|
||||
return []
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
route = self.router_chain.route(inputs, callbacks=callbacks)
|
||||
|
||||
_run_manager.on_text(
|
||||
str(route.destination) + ": " + str(route.next_inputs),
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if not route.destination:
|
||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||
if route.destination in self.destination_chains:
|
||||
return self.destination_chains[route.destination](
|
||||
route.next_inputs,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if self.silent_errors:
|
||||
return self.default_chain(route.next_inputs, callbacks=callbacks)
|
||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
|
||||
|
||||
await _run_manager.on_text(
|
||||
str(route.destination) + ": " + str(route.next_inputs),
|
||||
verbose=self.verbose,
|
||||
)
|
||||
if not route.destination:
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if route.destination in self.destination_chains:
|
||||
return await self.destination_chains[route.destination].acall(
|
||||
route.next_inputs,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if self.silent_errors:
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
msg = f"Received invalid destination chain name '{route.destination}'"
|
||||
raise ValueError(msg)
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.router.base import RouterChain
|
||||
|
||||
|
||||
class EmbeddingRouterChain(RouterChain):
|
||||
"""Chain that uses embeddings to route between options."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
routing_keys: list[str] = ["query"]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects."""
|
||||
return self.routing_keys
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||
results = self.vectorstore.similarity_search(_input, k=1)
|
||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||
results = await self.vectorstore.asimilarity_search(_input, k=1)
|
||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||
|
||||
@classmethod
|
||||
def from_names_and_descriptions(
|
||||
cls,
|
||||
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingRouterChain:
|
||||
"""Convenience constructor."""
|
||||
documents = []
|
||||
for name, descriptions in names_and_descriptions:
|
||||
documents.extend(
|
||||
[
|
||||
Document(page_content=description, metadata={"name": name})
|
||||
for description in descriptions
|
||||
]
|
||||
)
|
||||
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
|
||||
return cls(vectorstore=vectorstore, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def afrom_names_and_descriptions(
|
||||
cls,
|
||||
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingRouterChain:
|
||||
"""Convenience constructor."""
|
||||
documents = []
|
||||
documents.extend(
|
||||
[
|
||||
Document(page_content=description, metadata={"name": name})
|
||||
for name, descriptions in names_and_descriptions
|
||||
for description in descriptions
|
||||
]
|
||||
)
|
||||
vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings)
|
||||
return cls(vectorstore=vectorstore, **kwargs)
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Base classes for LLM-powered router chains."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.utils.json import parse_and_check_json_markdown
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_classic.chains import LLMChain
|
||||
from langchain_classic.chains.router.base import RouterChain
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.12",
|
||||
removal="1.0",
|
||||
message=(
|
||||
"Use RunnableLambda to select from multiple prompt templates. See example "
|
||||
"in API reference: "
|
||||
"https://api.python.langchain.com/en/latest/chains/langchain.chains.router.llm_router.LLMRouterChain.html"
|
||||
),
|
||||
)
|
||||
class LLMRouterChain(RouterChain):
|
||||
"""A router chain that uses an LLM chain to perform routing.
|
||||
|
||||
This class is deprecated. See below for a replacement, which offers several
|
||||
benefits, including streaming and batch support.
|
||||
|
||||
Below is an example implementation:
|
||||
|
||||
```python
|
||||
from operator import itemgetter
|
||||
from typing import Literal
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
prompt_1 = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an expert on animals."),
|
||||
("human", "{query}"),
|
||||
]
|
||||
)
|
||||
prompt_2 = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an expert on vegetables."),
|
||||
("human", "{query}"),
|
||||
]
|
||||
)
|
||||
|
||||
chain_1 = prompt_1 | model | StrOutputParser()
|
||||
chain_2 = prompt_2 | model | StrOutputParser()
|
||||
|
||||
route_system = "Route the user's query to either the animal "
|
||||
"or vegetable expert."
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", route_system),
|
||||
("human", "{query}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class RouteQuery(TypedDict):
|
||||
\"\"\"Route query to destination.\"\"\"
|
||||
destination: Literal["animal", "vegetable"]
|
||||
|
||||
|
||||
route_chain = (
|
||||
route_prompt
|
||||
| model.with_structured_output(RouteQuery)
|
||||
| itemgetter("destination")
|
||||
)
|
||||
|
||||
chain = {
|
||||
"destination": route_chain, # "animal" or "vegetable"
|
||||
"query": lambda x: x["query"], # pass through input query
|
||||
} | RunnableLambda(
|
||||
# if animal, chain_1. otherwise, chain_2.
|
||||
lambda x: chain_1 if x["destination"] == "animal" else chain_2,
|
||||
)
|
||||
|
||||
chain.invoke({"query": "what color are carrots"})
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain used to perform routing"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_prompt(self) -> Self:
|
||||
prompt = self.llm_chain.prompt
|
||||
if prompt.output_parser is None:
|
||||
msg = (
|
||||
"LLMRouterChain requires base llm_chain prompt to have an output"
|
||||
" parser that converts LLM text output to a dictionary with keys"
|
||||
" 'destination' and 'next_inputs'. Received a prompt with no output"
|
||||
" parser."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects."""
|
||||
return self.llm_chain.input_keys
|
||||
|
||||
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||
super()._validate_outputs(outputs)
|
||||
if not isinstance(outputs["next_inputs"], dict):
|
||||
raise ValueError # noqa: TRY004
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
|
||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
return cast(
|
||||
"dict[str, Any]",
|
||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||
)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: AsyncCallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
return cast(
|
||||
"dict[str, Any]",
|
||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
**kwargs: Any,
|
||||
) -> LLMRouterChain:
|
||||
"""Convenience constructor."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
||||
"""Parser for output of router chain in the multi-prompt chain."""
|
||||
|
||||
default_destination: str = "DEFAULT"
|
||||
next_inputs_type: type = str
|
||||
next_inputs_inner_key: str = "input"
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
if not isinstance(parsed["destination"], str):
|
||||
msg = "Expected 'destination' to be a string."
|
||||
raise TypeError(msg)
|
||||
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
|
||||
msg = f"Expected 'next_inputs' to be {self.next_inputs_type}."
|
||||
raise TypeError(msg)
|
||||
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
|
||||
if (
|
||||
parsed["destination"].strip().lower()
|
||||
== self.default_destination.lower()
|
||||
):
|
||||
parsed["destination"] = None
|
||||
else:
|
||||
parsed["destination"] = parsed["destination"].strip()
|
||||
except Exception as e:
|
||||
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||
raise OutputParserException(msg) from e
|
||||
return parsed
|
||||
@@ -0,0 +1,190 @@
|
||||
"""Use a single chain to route an input to one of multiple llm chains."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains import ConversationChain
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_classic.chains.router.base import MultiRouteChain
|
||||
from langchain_classic.chains.router.llm_router import (
|
||||
LLMRouterChain,
|
||||
RouterOutputParser,
|
||||
)
|
||||
from langchain_classic.chains.router.multi_prompt_prompt import (
|
||||
MULTI_PROMPT_ROUTER_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.12",
|
||||
removal="1.0",
|
||||
message=(
|
||||
"Please see migration guide here for recommended implementation: "
|
||||
"https://python.langchain.com/docs/versions/migrating_chains/multi_prompt_chain/"
|
||||
),
|
||||
)
|
||||
class MultiPromptChain(MultiRouteChain):
|
||||
"""A multi-route chain that uses an LLM router chain to choose amongst prompts.
|
||||
|
||||
This class is deprecated. See below for a replacement, which offers several
|
||||
benefits, including streaming and batch support.
|
||||
|
||||
Below is an example implementation:
|
||||
|
||||
```python
|
||||
from operator import itemgetter
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
model = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
# Define the prompts we will route to
|
||||
prompt_1 = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an expert on animals."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
prompt_2 = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an expert on vegetables."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
# Construct the chains we will route to. These format the input query
|
||||
# into the respective prompt, run it through a chat model, and cast
|
||||
# the result to a string.
|
||||
chain_1 = prompt_1 | model | StrOutputParser()
|
||||
chain_2 = prompt_2 | model | StrOutputParser()
|
||||
|
||||
|
||||
# Next: define the chain that selects which branch to route to.
|
||||
# Here we will take advantage of tool-calling features to force
|
||||
# the output to select one of two desired branches.
|
||||
route_system = "Route the user's query to either the animal "
|
||||
"or vegetable expert."
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", route_system),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Define schema for output:
|
||||
class RouteQuery(TypedDict):
|
||||
\"\"\"Route query to destination expert.\"\"\"
|
||||
|
||||
destination: Literal["animal", "vegetable"]
|
||||
|
||||
|
||||
route_chain = route_prompt | model.with_structured_output(RouteQuery)
|
||||
|
||||
|
||||
# For LangGraph, we will define the state of the graph to hold the query,
|
||||
# destination, and final answer.
|
||||
class State(TypedDict):
|
||||
query: str
|
||||
destination: RouteQuery
|
||||
answer: str
|
||||
|
||||
|
||||
# We define functions for each node, including routing the query:
|
||||
async def route_query(state: State, config: RunnableConfig):
|
||||
destination = await route_chain.ainvoke(state["query"], config)
|
||||
return {"destination": destination}
|
||||
|
||||
|
||||
# And one node for each prompt
|
||||
async def prompt_1(state: State, config: RunnableConfig):
|
||||
return {"answer": await chain_1.ainvoke(state["query"], config)}
|
||||
|
||||
|
||||
async def prompt_2(state: State, config: RunnableConfig):
|
||||
return {"answer": await chain_2.ainvoke(state["query"], config)}
|
||||
|
||||
|
||||
# We then define logic that selects the prompt based on the classification
|
||||
def select_node(state: State) -> Literal["prompt_1", "prompt_2"]:
|
||||
if state["destination"] == "animal":
|
||||
return "prompt_1"
|
||||
else:
|
||||
return "prompt_2"
|
||||
|
||||
|
||||
# Finally, assemble the multi-prompt chain. This is a sequence of two steps:
|
||||
# 1) Select "animal" or "vegetable" via the route_chain, and collect the
|
||||
# answer alongside the input query.
|
||||
# 2) Route the input query to chain_1 or chain_2, based on the
|
||||
# selection.
|
||||
graph = StateGraph(State)
|
||||
graph.add_node("route_query", route_query)
|
||||
graph.add_node("prompt_1", prompt_1)
|
||||
graph.add_node("prompt_2", prompt_2)
|
||||
|
||||
graph.add_edge(START, "route_query")
|
||||
graph.add_conditional_edges("route_query", select_node)
|
||||
graph.add_edge("prompt_1", END)
|
||||
graph.add_edge("prompt_2", END)
|
||||
app = graph.compile()
|
||||
|
||||
result = await app.ainvoke({"query": "what color are carrots"})
|
||||
print(result["destination"])
|
||||
print(result["answer"])
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
@property
|
||||
@override
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["text"]
|
||||
|
||||
@classmethod
|
||||
def from_prompts(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt_infos: list[dict[str, str]],
|
||||
default_chain: Chain | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MultiPromptChain:
|
||||
"""Convenience constructor for instantiating from destination prompts."""
|
||||
destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
|
||||
destinations_str = "\n".join(destinations)
|
||||
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
|
||||
destinations=destinations_str,
|
||||
)
|
||||
router_prompt = PromptTemplate(
|
||||
template=router_template,
|
||||
input_variables=["input"],
|
||||
output_parser=RouterOutputParser(),
|
||||
)
|
||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||
destination_chains = {}
|
||||
for p_info in prompt_infos:
|
||||
name = p_info["name"]
|
||||
prompt_template = p_info["prompt_template"]
|
||||
prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
destination_chains[name] = chain
|
||||
_default_chain = default_chain or ConversationChain(llm=llm, output_key="text")
|
||||
return cls(
|
||||
router_chain=router_chain,
|
||||
destination_chains=destination_chains,
|
||||
default_chain=_default_chain,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Prompt for the router chain in the multi-prompt chain."""
|
||||
|
||||
MULTI_PROMPT_ROUTER_TEMPLATE = """\
|
||||
Given a raw text input to a language model select the model prompt best suited for \
|
||||
the input. You will be given the names of the available prompts and a description of \
|
||||
what the prompt is best suited for. You may also revise the original input if you \
|
||||
think that revising it will ultimately lead to a better response from the language \
|
||||
model.
|
||||
|
||||
<< FORMATTING >>
|
||||
Return a markdown code snippet with a JSON object formatted to look like:
|
||||
```json
|
||||
{{{{
|
||||
"destination": string \\ name of the prompt to use or "DEFAULT"
|
||||
"next_inputs": string \\ a potentially modified version of the original input
|
||||
}}}}
|
||||
```
|
||||
|
||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||
modifications are needed.
|
||||
|
||||
<< CANDIDATE PROMPTS >>
|
||||
{destinations}
|
||||
|
||||
<< INPUT >>
|
||||
{{input}}
|
||||
|
||||
<< OUTPUT (must include ```json at the start of the response) >>
|
||||
<< OUTPUT (must end with ```) >>
|
||||
"""
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Prompt for the router chain in the multi-retrieval qa chain."""
|
||||
|
||||
MULTI_RETRIEVAL_ROUTER_TEMPLATE = """\
|
||||
Given a query to a question answering system select the system best suited \
|
||||
for the input. You will be given the names of the available systems and a description \
|
||||
of what questions the system is best suited for. You may also revise the original \
|
||||
input if you think that revising it will ultimately lead to a better response.
|
||||
|
||||
<< FORMATTING >>
|
||||
Return a markdown code snippet with a JSON object formatted to look like:
|
||||
```json
|
||||
{{{{
|
||||
"destination": string \\ name of the question answering system to use or "DEFAULT"
|
||||
"next_inputs": string \\ a potentially modified version of the original input
|
||||
}}}}
|
||||
```
|
||||
|
||||
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
|
||||
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
|
||||
REMEMBER: "next_inputs" can just be the original input if you don't think any \
|
||||
modifications are needed.
|
||||
|
||||
<< CANDIDATE PROMPTS >>
|
||||
{destinations}
|
||||
|
||||
<< INPUT >>
|
||||
{{input}}
|
||||
|
||||
<< OUTPUT >>
|
||||
"""
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Use a single chain to route an input to one of multiple retrieval qa chains."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains import ConversationChain
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.conversation.prompt import DEFAULT_TEMPLATE
|
||||
from langchain_classic.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
|
||||
from langchain_classic.chains.router.base import MultiRouteChain
|
||||
from langchain_classic.chains.router.llm_router import (
|
||||
LLMRouterChain,
|
||||
RouterOutputParser,
|
||||
)
|
||||
from langchain_classic.chains.router.multi_retrieval_prompt import (
|
||||
MULTI_RETRIEVAL_ROUTER_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class MultiRetrievalQAChain(MultiRouteChain):
|
||||
"""Multi Retrieval QA Chain.
|
||||
|
||||
A multi-route chain that uses an LLM router chain to choose amongst retrieval
|
||||
qa chains.
|
||||
"""
|
||||
|
||||
router_chain: LLMRouterChain
|
||||
"""Chain for deciding a destination chain and the input to it."""
|
||||
destination_chains: Mapping[str, BaseRetrievalQA]
|
||||
"""Map of name to candidate chains that inputs can be routed to."""
|
||||
default_chain: Chain
|
||||
"""Default chain to use when router doesn't map input to one of the destinations."""
|
||||
|
||||
@property
|
||||
@override
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["result"]
|
||||
|
||||
@classmethod
|
||||
def from_retrievers(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
retriever_infos: list[dict[str, Any]],
|
||||
default_retriever: BaseRetriever | None = None,
|
||||
default_prompt: PromptTemplate | None = None,
|
||||
default_chain: Chain | None = None,
|
||||
*,
|
||||
default_chain_llm: BaseLanguageModel | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MultiRetrievalQAChain:
|
||||
"""Create a multi retrieval qa chain from an LLM and a default chain.
|
||||
|
||||
Args:
|
||||
llm: The language model to use.
|
||||
retriever_infos: Dictionaries containing retriever information.
|
||||
default_retriever: Optional default retriever to use if no default chain
|
||||
is provided.
|
||||
default_prompt: Optional prompt template to use for the default retriever.
|
||||
default_chain: Optional default chain to use when router doesn't map input
|
||||
to one of the destinations.
|
||||
default_chain_llm: Optional language model to use if no default chain and
|
||||
no default retriever are provided.
|
||||
**kwargs: Additional keyword arguments to pass to the chain.
|
||||
|
||||
Returns:
|
||||
An instance of the multi retrieval qa chain.
|
||||
"""
|
||||
if default_prompt and not default_retriever:
|
||||
msg = (
|
||||
"`default_retriever` must be specified if `default_prompt` is "
|
||||
"provided. Received only `default_prompt`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
destinations = [f"{r['name']}: {r['description']}" for r in retriever_infos]
|
||||
destinations_str = "\n".join(destinations)
|
||||
router_template = MULTI_RETRIEVAL_ROUTER_TEMPLATE.format(
|
||||
destinations=destinations_str,
|
||||
)
|
||||
router_prompt = PromptTemplate(
|
||||
template=router_template,
|
||||
input_variables=["input"],
|
||||
output_parser=RouterOutputParser(next_inputs_inner_key="query"),
|
||||
)
|
||||
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
|
||||
destination_chains = {}
|
||||
for r_info in retriever_infos:
|
||||
prompt = r_info.get("prompt")
|
||||
retriever = r_info["retriever"]
|
||||
chain = RetrievalQA.from_llm(llm, prompt=prompt, retriever=retriever)
|
||||
name = r_info["name"]
|
||||
destination_chains[name] = chain
|
||||
if default_chain:
|
||||
_default_chain = default_chain
|
||||
elif default_retriever:
|
||||
_default_chain = RetrievalQA.from_llm(
|
||||
llm,
|
||||
prompt=default_prompt,
|
||||
retriever=default_retriever,
|
||||
)
|
||||
else:
|
||||
prompt_template = DEFAULT_TEMPLATE.replace("input", "query")
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["history", "query"],
|
||||
)
|
||||
if default_chain_llm is None:
|
||||
msg = (
|
||||
"conversation_llm must be provided if default_chain is not "
|
||||
"specified. This API has been changed to avoid instantiating "
|
||||
"default LLMs on behalf of users."
|
||||
"You can provide a conversation LLM like so:\n"
|
||||
"from langchain_openai import ChatOpenAI\n"
|
||||
"model = ChatOpenAI()"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
_default_chain = ConversationChain(
|
||||
llm=default_chain_llm,
|
||||
prompt=prompt,
|
||||
input_key="query",
|
||||
output_key="result",
|
||||
)
|
||||
return cls(
|
||||
router_chain=router_chain,
|
||||
destination_chains=destination_chains,
|
||||
default_chain=_default_chain,
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user