initial commit
This commit is contained in:
116
venv/Lib/site-packages/langchain_classic/chains/mapreduce.py
Normal file
116
venv/Lib/site-packages/langchain_classic/chains/mapreduce.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Map-reduce chain.
|
||||
|
||||
Splits up a document, sends the smaller parts to the LLM with one prompt,
|
||||
then combines the results with another one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_text_splitters import TextSplitter
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from langchain_classic.chains import ReduceDocumentsChain
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_classic.chains.combine_documents.map_reduce import (
|
||||
MapReduceDocumentsChain,
|
||||
)
|
||||
from langchain_classic.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.13",
|
||||
removal="1.0",
|
||||
message=(
|
||||
"Refer to migration guide here for a recommended implementation using "
|
||||
"LangGraph: https://docs.langchain.com/oss/python/langgraph/graph-api#map-reduce-and-the-send-api"
|
||||
"."
|
||||
),
|
||||
)
|
||||
class MapReduceChain(Chain):
|
||||
"""Map-reduce chain."""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine documents."""
|
||||
text_splitter: TextSplitter
|
||||
"""Text splitter to use."""
|
||||
input_key: str = "input_text"
|
||||
output_key: str = "output_text"
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
text_splitter: TextSplitter,
|
||||
callbacks: Callbacks = None,
|
||||
combine_chain_kwargs: Mapping[str, Any] | None = None,
|
||||
reduce_chain_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceChain:
|
||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||
stuff_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
callbacks=callbacks,
|
||||
**(reduce_chain_kwargs or {}),
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=stuff_chain,
|
||||
)
|
||||
combine_documents_chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
callbacks=callbacks,
|
||||
**(combine_chain_kwargs or {}),
|
||||
)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
text_splitter=text_splitter,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key."""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key."""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
run_manager: CallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Split the larger text into smaller chunks.
|
||||
doc_text = inputs.pop(self.input_key)
|
||||
texts = self.text_splitter.split_text(doc_text)
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
_inputs: dict[str, Any] = {
|
||||
**inputs,
|
||||
self.combine_documents_chain.input_key: docs,
|
||||
}
|
||||
outputs = self.combine_documents_chain.run(
|
||||
_inputs,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: outputs}
|
||||
Reference in New Issue
Block a user