initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
from langchain_classic.chains.router.base import MultiRouteChain, RouterChain
from langchain_classic.chains.router.llm_router import LLMRouterChain
from langchain_classic.chains.router.multi_prompt import MultiPromptChain
from langchain_classic.chains.router.multi_retrieval_qa import MultiRetrievalQAChain
__all__ = [
"LLMRouterChain",
"MultiPromptChain",
"MultiRetrievalQAChain",
"MultiRouteChain",
"RouterChain",
]

View File

@@ -0,0 +1,147 @@
"""Base classes for chain routing."""
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping
from typing import Any, NamedTuple
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from pydantic import ConfigDict
from typing_extensions import override
from langchain_classic.chains.base import Chain
class Route(NamedTuple):
"""A route to a destination chain."""
destination: str | None
next_inputs: dict[str, Any]
class RouterChain(Chain, ABC):
"""Chain that outputs the name of a destination chain and the inputs to it."""
@property
@override
def output_keys(self) -> list[str]:
return ["destination", "next_inputs"]
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
"""Route inputs to a destination chain.
Args:
inputs: inputs to the chain
callbacks: callbacks to use for the chain
Returns:
a Route object
"""
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
async def aroute(
self,
inputs: dict[str, Any],
callbacks: Callbacks = None,
) -> Route:
"""Route inputs to a destination chain.
Args:
inputs: inputs to the chain
callbacks: callbacks to use for the chain
Returns:
a Route object
"""
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
class MultiRouteChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: RouterChain
"""Chain that routes inputs to destination chains."""
destination_chains: Mapping[str, Chain]
"""Chains that return final answer to inputs."""
default_chain: Chain
"""Default chain to use when none of the destination chains are suitable."""
silent_errors: bool = False
"""If `True`, use default_chain when an invalid destination name is provided."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@property
def input_keys(self) -> list[str]:
"""Will be whatever keys the router chain prompt expects."""
return self.router_chain.input_keys
@property
def output_keys(self) -> list[str]:
"""Will always return text key."""
return []
def _call(
self,
inputs: dict[str, Any],
run_manager: CallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = self.router_chain.route(inputs, callbacks=callbacks)
_run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs),
verbose=self.verbose,
)
if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks)
if route.destination in self.destination_chains:
return self.destination_chains[route.destination](
route.next_inputs,
callbacks=callbacks,
)
if self.silent_errors:
return self.default_chain(route.next_inputs, callbacks=callbacks)
msg = f"Received invalid destination chain name '{route.destination}'"
raise ValueError(msg)
async def _acall(
self,
inputs: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
await _run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs),
verbose=self.verbose,
)
if not route.destination:
return await self.default_chain.acall(
route.next_inputs,
callbacks=callbacks,
)
if route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall(
route.next_inputs,
callbacks=callbacks,
)
if self.silent_errors:
return await self.default_chain.acall(
route.next_inputs,
callbacks=callbacks,
)
msg = f"Received invalid destination chain name '{route.destination}'"
raise ValueError(msg)

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict
from typing_extensions import override
from langchain_classic.chains.router.base import RouterChain
class EmbeddingRouterChain(RouterChain):
"""Chain that uses embeddings to route between options."""
vectorstore: VectorStore
routing_keys: list[str] = ["query"]
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@property
def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects."""
return self.routing_keys
@override
def _call(
self,
inputs: dict[str, Any],
run_manager: CallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@override
async def _acall(
self,
inputs: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = await self.vectorstore.asimilarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@classmethod
def from_names_and_descriptions(
cls,
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:
"""Convenience constructor."""
documents = []
for name, descriptions in names_and_descriptions:
documents.extend(
[
Document(page_content=description, metadata={"name": name})
for description in descriptions
]
)
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
return cls(vectorstore=vectorstore, **kwargs)
@classmethod
async def afrom_names_and_descriptions(
cls,
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:
"""Convenience constructor."""
documents = []
documents.extend(
[
Document(page_content=description, metadata={"name": name})
for name, descriptions in names_and_descriptions
for description in descriptions
]
)
vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings)
return cls(vectorstore=vectorstore, **kwargs)

View File

@@ -0,0 +1,196 @@
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
from typing import Any, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.utils.json import parse_and_check_json_markdown
from pydantic import model_validator
from typing_extensions import Self, override
from langchain_classic.chains import LLMChain
from langchain_classic.chains.router.base import RouterChain
@deprecated(
since="0.2.12",
removal="1.0",
message=(
"Use RunnableLambda to select from multiple prompt templates. See example "
"in API reference: "
"https://api.python.langchain.com/en/latest/chains/langchain.chains.router.llm_router.LLMRouterChain.html"
),
)
class LLMRouterChain(RouterChain):
"""A router chain that uses an LLM chain to perform routing.
This class is deprecated. See below for a replacement, which offers several
benefits, including streaming and batch support.
Below is an example implementation:
```python
from operator import itemgetter
from typing import Literal
from typing_extensions import TypedDict
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o-mini")
prompt_1 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on animals."),
("human", "{query}"),
]
)
prompt_2 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on vegetables."),
("human", "{query}"),
]
)
chain_1 = prompt_1 | model | StrOutputParser()
chain_2 = prompt_2 | model | StrOutputParser()
route_system = "Route the user's query to either the animal "
"or vegetable expert."
route_prompt = ChatPromptTemplate.from_messages(
[
("system", route_system),
("human", "{query}"),
]
)
class RouteQuery(TypedDict):
\"\"\"Route query to destination.\"\"\"
destination: Literal["animal", "vegetable"]
route_chain = (
route_prompt
| model.with_structured_output(RouteQuery)
| itemgetter("destination")
)
chain = {
"destination": route_chain, # "animal" or "vegetable"
"query": lambda x: x["query"], # pass through input query
} | RunnableLambda(
# if animal, chain_1. otherwise, chain_2.
lambda x: chain_1 if x["destination"] == "animal" else chain_2,
)
chain.invoke({"query": "what color are carrots"})
```
"""
llm_chain: LLMChain
"""LLM chain used to perform routing"""
@model_validator(mode="after")
def _validate_prompt(self) -> Self:
prompt = self.llm_chain.prompt
if prompt.output_parser is None:
msg = (
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
raise ValueError(msg)
return self
@property
def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects."""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError # noqa: TRY004
def _call(
self,
inputs: dict[str, Any],
run_manager: CallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
return cast(
"dict[str, Any]",
self.llm_chain.prompt.output_parser.parse(prediction),
)
async def _acall(
self,
inputs: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
return cast(
"dict[str, Any]",
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
**kwargs: Any,
) -> LLMRouterChain:
"""Convenience constructor."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
class RouterOutputParser(BaseOutputParser[dict[str, str]]):
"""Parser for output of router chain in the multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: type = str
next_inputs_inner_key: str = "input"
@override
def parse(self, text: str) -> dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
msg = "Expected 'destination' to be a string."
raise TypeError(msg)
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
msg = f"Expected 'next_inputs' to be {self.next_inputs_type}."
raise TypeError(msg)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
== self.default_destination.lower()
):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
except Exception as e:
msg = f"Parsing text\n{text}\n raised following error:\n{e}"
raise OutputParserException(msg) from e
return parsed

View File

@@ -0,0 +1,190 @@
"""Use a single chain to route an input to one of multiple llm chains."""
from __future__ import annotations
from typing import Any
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from typing_extensions import override
from langchain_classic.chains import ConversationChain
from langchain_classic.chains.base import Chain
from langchain_classic.chains.llm import LLMChain
from langchain_classic.chains.router.base import MultiRouteChain
from langchain_classic.chains.router.llm_router import (
LLMRouterChain,
RouterOutputParser,
)
from langchain_classic.chains.router.multi_prompt_prompt import (
MULTI_PROMPT_ROUTER_TEMPLATE,
)
@deprecated(
since="0.2.12",
removal="1.0",
message=(
"Please see migration guide here for recommended implementation: "
"https://python.langchain.com/docs/versions/migrating_chains/multi_prompt_chain/"
),
)
class MultiPromptChain(MultiRouteChain):
"""A multi-route chain that uses an LLM router chain to choose amongst prompts.
This class is deprecated. See below for a replacement, which offers several
benefits, including streaming and batch support.
Below is an example implementation:
```python
from operator import itemgetter
from typing import Literal
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
model = ChatOpenAI(model="gpt-4o-mini")
# Define the prompts we will route to
prompt_1 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on animals."),
("human", "{input}"),
]
)
prompt_2 = ChatPromptTemplate.from_messages(
[
("system", "You are an expert on vegetables."),
("human", "{input}"),
]
)
# Construct the chains we will route to. These format the input query
# into the respective prompt, run it through a chat model, and cast
# the result to a string.
chain_1 = prompt_1 | model | StrOutputParser()
chain_2 = prompt_2 | model | StrOutputParser()
# Next: define the chain that selects which branch to route to.
# Here we will take advantage of tool-calling features to force
# the output to select one of two desired branches.
route_system = "Route the user's query to either the animal "
"or vegetable expert."
route_prompt = ChatPromptTemplate.from_messages(
[
("system", route_system),
("human", "{input}"),
]
)
# Define schema for output:
class RouteQuery(TypedDict):
\"\"\"Route query to destination expert.\"\"\"
destination: Literal["animal", "vegetable"]
route_chain = route_prompt | model.with_structured_output(RouteQuery)
# For LangGraph, we will define the state of the graph to hold the query,
# destination, and final answer.
class State(TypedDict):
query: str
destination: RouteQuery
answer: str
# We define functions for each node, including routing the query:
async def route_query(state: State, config: RunnableConfig):
destination = await route_chain.ainvoke(state["query"], config)
return {"destination": destination}
# And one node for each prompt
async def prompt_1(state: State, config: RunnableConfig):
return {"answer": await chain_1.ainvoke(state["query"], config)}
async def prompt_2(state: State, config: RunnableConfig):
return {"answer": await chain_2.ainvoke(state["query"], config)}
# We then define logic that selects the prompt based on the classification
def select_node(state: State) -> Literal["prompt_1", "prompt_2"]:
if state["destination"] == "animal":
return "prompt_1"
else:
return "prompt_2"
# Finally, assemble the multi-prompt chain. This is a sequence of two steps:
# 1) Select "animal" or "vegetable" via the route_chain, and collect the
# answer alongside the input query.
# 2) Route the input query to chain_1 or chain_2, based on the
# selection.
graph = StateGraph(State)
graph.add_node("route_query", route_query)
graph.add_node("prompt_1", prompt_1)
graph.add_node("prompt_2", prompt_2)
graph.add_edge(START, "route_query")
graph.add_conditional_edges("route_query", select_node)
graph.add_edge("prompt_1", END)
graph.add_edge("prompt_2", END)
app = graph.compile()
result = await app.ainvoke({"query": "what color are carrots"})
print(result["destination"])
print(result["answer"])
```
"""
@property
@override
def output_keys(self) -> list[str]:
return ["text"]
@classmethod
def from_prompts(
cls,
llm: BaseLanguageModel,
prompt_infos: list[dict[str, str]],
default_chain: Chain | None = None,
**kwargs: Any,
) -> MultiPromptChain:
"""Convenience constructor for instantiating from destination prompts."""
destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str,
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
destination_chains = {}
for p_info in prompt_infos:
name = p_info["name"]
prompt_template = p_info["prompt_template"]
prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
chain = LLMChain(llm=llm, prompt=prompt)
destination_chains[name] = chain
_default_chain = default_chain or ConversationChain(llm=llm, output_key="text")
return cls(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=_default_chain,
**kwargs,
)

View File

@@ -0,0 +1,32 @@
"""Prompt for the router chain in the multi-prompt chain."""
MULTI_PROMPT_ROUTER_TEMPLATE = """\
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:
```json
{{{{
"destination": string \\ name of the prompt to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT (must include ```json at the start of the response) >>
<< OUTPUT (must end with ```) >>
"""

View File

@@ -0,0 +1,30 @@
"""Prompt for the router chain in the multi-retrieval qa chain."""
MULTI_RETRIEVAL_ROUTER_TEMPLATE = """\
Given a query to a question answering system select the system best suited \
for the input. You will be given the names of the available systems and a description \
of what questions the system is best suited for. You may also revise the original \
input if you think that revising it will ultimately lead to a better response.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:
```json
{{{{
"destination": string \\ name of the question answering system to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""

View File

@@ -0,0 +1,134 @@
"""Use a single chain to route an input to one of multiple retrieval qa chains."""
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
from langchain_classic.chains import ConversationChain
from langchain_classic.chains.base import Chain
from langchain_classic.chains.conversation.prompt import DEFAULT_TEMPLATE
from langchain_classic.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
from langchain_classic.chains.router.base import MultiRouteChain
from langchain_classic.chains.router.llm_router import (
LLMRouterChain,
RouterOutputParser,
)
from langchain_classic.chains.router.multi_retrieval_prompt import (
MULTI_RETRIEVAL_ROUTER_TEMPLATE,
)
class MultiRetrievalQAChain(MultiRouteChain):
"""Multi Retrieval QA Chain.
A multi-route chain that uses an LLM router chain to choose amongst retrieval
qa chains.
"""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
destination_chains: Mapping[str, BaseRetrievalQA]
"""Map of name to candidate chains that inputs can be routed to."""
default_chain: Chain
"""Default chain to use when router doesn't map input to one of the destinations."""
@property
@override
def output_keys(self) -> list[str]:
return ["result"]
@classmethod
def from_retrievers(
cls,
llm: BaseLanguageModel,
retriever_infos: list[dict[str, Any]],
default_retriever: BaseRetriever | None = None,
default_prompt: PromptTemplate | None = None,
default_chain: Chain | None = None,
*,
default_chain_llm: BaseLanguageModel | None = None,
**kwargs: Any,
) -> MultiRetrievalQAChain:
"""Create a multi retrieval qa chain from an LLM and a default chain.
Args:
llm: The language model to use.
retriever_infos: Dictionaries containing retriever information.
default_retriever: Optional default retriever to use if no default chain
is provided.
default_prompt: Optional prompt template to use for the default retriever.
default_chain: Optional default chain to use when router doesn't map input
to one of the destinations.
default_chain_llm: Optional language model to use if no default chain and
no default retriever are provided.
**kwargs: Additional keyword arguments to pass to the chain.
Returns:
An instance of the multi retrieval qa chain.
"""
if default_prompt and not default_retriever:
msg = (
"`default_retriever` must be specified if `default_prompt` is "
"provided. Received only `default_prompt`."
)
raise ValueError(msg)
destinations = [f"{r['name']}: {r['description']}" for r in retriever_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_RETRIEVAL_ROUTER_TEMPLATE.format(
destinations=destinations_str,
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(next_inputs_inner_key="query"),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
destination_chains = {}
for r_info in retriever_infos:
prompt = r_info.get("prompt")
retriever = r_info["retriever"]
chain = RetrievalQA.from_llm(llm, prompt=prompt, retriever=retriever)
name = r_info["name"]
destination_chains[name] = chain
if default_chain:
_default_chain = default_chain
elif default_retriever:
_default_chain = RetrievalQA.from_llm(
llm,
prompt=default_prompt,
retriever=default_retriever,
)
else:
prompt_template = DEFAULT_TEMPLATE.replace("input", "query")
prompt = PromptTemplate(
template=prompt_template,
input_variables=["history", "query"],
)
if default_chain_llm is None:
msg = (
"conversation_llm must be provided if default_chain is not "
"specified. This API has been changed to avoid instantiating "
"default LLMs on behalf of users."
"You can provide a conversation LLM like so:\n"
"from langchain_openai import ChatOpenAI\n"
"model = ChatOpenAI()"
)
raise NotImplementedError(msg)
_default_chain = ConversationChain(
llm=default_chain_llm,
prompt=prompt,
input_key="query",
output_key="result",
)
return cls(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=_default_chain,
**kwargs,
)