initial commit
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Adapted from https://github.com/jzbjyb/FLARE."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
311
venv/Lib/site-packages/langchain_classic/chains/flare/base.py
Normal file
311
venv/Lib/site-packages/langchain_classic/chains/flare/base.py
Normal file
@@ -0,0 +1,311 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import Runnable
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.flare.prompts import (
|
||||
PROMPT,
|
||||
QUESTION_GENERATOR_PROMPT,
|
||||
FinishedOutputParser,
|
||||
)
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_tokens_and_log_probs(response: AIMessage) -> tuple[list[str], list[float]]:
|
||||
"""Extract tokens and log probabilities from chat model response."""
|
||||
tokens = []
|
||||
log_probs = []
|
||||
for token in response.response_metadata["logprobs"]["content"]:
|
||||
tokens.append(token["token"])
|
||||
log_probs.append(token["logprob"])
|
||||
return tokens, log_probs
|
||||
|
||||
|
||||
class QuestionGeneratorChain(LLMChain):
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
|
||||
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
|
||||
"""Prompt template for the chain."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input", "context", "response"]
|
||||
|
||||
|
||||
def _low_confidence_spans(
|
||||
tokens: Sequence[str],
|
||||
log_probs: Sequence[float],
|
||||
min_prob: float,
|
||||
min_token_gap: int,
|
||||
num_pad_tokens: int,
|
||||
) -> list[str]:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
_low_idx = np.where(np.exp(log_probs) < min_prob)[0]
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"NumPy not found in the current Python environment. FlareChain will use a "
|
||||
"pure Python implementation for internal calculations, which may "
|
||||
"significantly impact performance, especially for large datasets. For "
|
||||
"optimal speed and efficiency, consider installing NumPy: pip install "
|
||||
"numpy",
|
||||
)
|
||||
import math
|
||||
|
||||
_low_idx = [ # type: ignore[assignment]
|
||||
idx
|
||||
for idx, log_prob in enumerate(log_probs)
|
||||
if math.exp(log_prob) < min_prob
|
||||
]
|
||||
low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])]
|
||||
if len(low_idx) == 0:
|
||||
return []
|
||||
spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]]
|
||||
for i, idx in enumerate(low_idx[1:]):
|
||||
end = idx + num_pad_tokens + 1
|
||||
if idx - low_idx[i] < min_token_gap:
|
||||
spans[-1][1] = end
|
||||
else:
|
||||
spans.append([idx, end])
|
||||
return ["".join(tokens[start:end]) for start, end in spans]
|
||||
|
||||
|
||||
class FlareChain(Chain):
|
||||
"""Flare chain.
|
||||
|
||||
Chain that combines a retriever, a question generator,
|
||||
and a response generator.
|
||||
|
||||
See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
|
||||
"""
|
||||
|
||||
question_generator_chain: Runnable
|
||||
"""Chain that generates questions from uncertain spans."""
|
||||
response_chain: Runnable
|
||||
"""Chain that generates responses from user input and context."""
|
||||
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
|
||||
"""Parser that determines whether the chain is finished."""
|
||||
retriever: BaseRetriever
|
||||
"""Retriever that retrieves relevant documents from a user input."""
|
||||
min_prob: float = 0.2
|
||||
"""Minimum probability for a token to be considered low confidence."""
|
||||
min_token_gap: int = 5
|
||||
"""Minimum number of tokens between two low confidence spans."""
|
||||
num_pad_tokens: int = 2
|
||||
"""Number of tokens to pad around a low confidence span."""
|
||||
max_iter: int = 10
|
||||
"""Maximum number of iterations."""
|
||||
start_with_retrieval: bool = True
|
||||
"""Whether to start with retrieval."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Output keys for the chain."""
|
||||
return ["response"]
|
||||
|
||||
def _do_generation(
|
||||
self,
|
||||
questions: list[str],
|
||||
user_input: str,
|
||||
response: str,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
) -> tuple[str, bool]:
|
||||
callbacks = _run_manager.get_child()
|
||||
docs = []
|
||||
for question in questions:
|
||||
docs.extend(self.retriever.invoke(question))
|
||||
context = "\n\n".join(d.page_content for d in docs)
|
||||
result = self.response_chain.invoke(
|
||||
{
|
||||
"user_input": user_input,
|
||||
"context": context,
|
||||
"response": response,
|
||||
},
|
||||
{"callbacks": callbacks},
|
||||
)
|
||||
if isinstance(result, AIMessage):
|
||||
result = result.content
|
||||
marginal, finished = self.output_parser.parse(result)
|
||||
return marginal, finished
|
||||
|
||||
def _do_retrieval(
|
||||
self,
|
||||
low_confidence_spans: list[str],
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
user_input: str,
|
||||
response: str,
|
||||
initial_response: str,
|
||||
) -> tuple[str, bool]:
|
||||
question_gen_inputs = [
|
||||
{
|
||||
"user_input": user_input,
|
||||
"current_response": initial_response,
|
||||
"uncertain_span": span,
|
||||
}
|
||||
for span in low_confidence_spans
|
||||
]
|
||||
callbacks = _run_manager.get_child()
|
||||
if isinstance(self.question_generator_chain, LLMChain):
|
||||
question_gen_outputs = self.question_generator_chain.apply(
|
||||
question_gen_inputs,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
questions = [
|
||||
output[self.question_generator_chain.output_keys[0]]
|
||||
for output in question_gen_outputs
|
||||
]
|
||||
else:
|
||||
questions = self.question_generator_chain.batch(
|
||||
question_gen_inputs,
|
||||
config={"callbacks": callbacks},
|
||||
)
|
||||
_run_manager.on_text(
|
||||
f"Generated Questions: {questions}",
|
||||
color="yellow",
|
||||
end="\n",
|
||||
)
|
||||
return self._do_generation(questions, user_input, response, _run_manager)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
run_manager: CallbackManagerForChainRun | None = None,
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
|
||||
user_input = inputs[self.input_keys[0]]
|
||||
|
||||
response = ""
|
||||
|
||||
for _i in range(self.max_iter):
|
||||
_run_manager.on_text(
|
||||
f"Current Response: {response}",
|
||||
color="blue",
|
||||
end="\n",
|
||||
)
|
||||
_input = {"user_input": user_input, "context": "", "response": response}
|
||||
tokens, log_probs = _extract_tokens_and_log_probs(
|
||||
self.response_chain.invoke(
|
||||
_input,
|
||||
{"callbacks": _run_manager.get_child()},
|
||||
),
|
||||
)
|
||||
low_confidence_spans = _low_confidence_spans(
|
||||
tokens,
|
||||
log_probs,
|
||||
self.min_prob,
|
||||
self.min_token_gap,
|
||||
self.num_pad_tokens,
|
||||
)
|
||||
initial_response = response.strip() + " " + "".join(tokens)
|
||||
if not low_confidence_spans:
|
||||
response = initial_response
|
||||
final_response, finished = self.output_parser.parse(response)
|
||||
if finished:
|
||||
return {self.output_keys[0]: final_response}
|
||||
continue
|
||||
|
||||
marginal, finished = self._do_retrieval(
|
||||
low_confidence_spans,
|
||||
_run_manager,
|
||||
user_input,
|
||||
response,
|
||||
initial_response,
|
||||
)
|
||||
response = response.strip() + " " + marginal
|
||||
if finished:
|
||||
break
|
||||
return {self.output_keys[0]: response}
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel | None,
|
||||
max_generation_len: int = 32,
|
||||
**kwargs: Any,
|
||||
) -> FlareChain:
|
||||
"""Creates a FlareChain from a language model.
|
||||
|
||||
Args:
|
||||
llm: Language model to use.
|
||||
max_generation_len: Maximum length of the generated response.
|
||||
kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
FlareChain class with the given language model.
|
||||
"""
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"OpenAI is required for FlareChain. "
|
||||
"Please install langchain-openai."
|
||||
"pip install langchain-openai"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
# Preserve supplied llm instead of always creating a new ChatOpenAI.
|
||||
# Enforce ChatOpenAI requirement (token logprobs needed for FLARE).
|
||||
if llm is None:
|
||||
llm = ChatOpenAI(
|
||||
max_completion_tokens=max_generation_len,
|
||||
logprobs=True,
|
||||
temperature=0,
|
||||
)
|
||||
else:
|
||||
if not isinstance(llm, ChatOpenAI):
|
||||
msg = (
|
||||
f"FlareChain.from_llm requires ChatOpenAI; got "
|
||||
f"{type(llm).__name__}."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if not getattr(llm, "logprobs", False): # attribute presence may vary
|
||||
msg = (
|
||||
"Provided ChatOpenAI instance must be constructed with "
|
||||
"logprobs=True for FlareChain."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
current_max = getattr(llm, "max_completion_tokens", None)
|
||||
if current_max is not None and current_max != max_generation_len:
|
||||
logger.debug(
|
||||
"FlareChain.from_llm: supplied llm max_completion_tokens=%s "
|
||||
"differs from requested max_generation_len=%s; "
|
||||
"leaving model unchanged.",
|
||||
current_max,
|
||||
max_generation_len,
|
||||
)
|
||||
response_chain = PROMPT | llm
|
||||
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
|
||||
return cls(
|
||||
question_generator_chain=question_gen_chain,
|
||||
response_chain=response_chain,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class FinishedOutputParser(BaseOutputParser[tuple[str, bool]]):
|
||||
"""Output parser that checks if the output is finished."""
|
||||
|
||||
finished_value: str = "FINISHED"
|
||||
"""Value that indicates the output is finished."""
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> tuple[str, bool]:
|
||||
cleaned = text.strip()
|
||||
finished = self.finished_value in cleaned
|
||||
return cleaned.replace(self.finished_value, ""), finished
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = """\
|
||||
Respond to the user message using any relevant context. \
|
||||
If context is provided, you should ground your answer in that context. \
|
||||
Once you're done responding return FINISHED.
|
||||
|
||||
>>> CONTEXT: {context}
|
||||
>>> USER INPUT: {user_input}
|
||||
>>> RESPONSE: {response}\
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
template=PROMPT_TEMPLATE,
|
||||
input_variables=["user_input", "context", "response"],
|
||||
)
|
||||
|
||||
|
||||
QUESTION_GENERATOR_PROMPT_TEMPLATE = """\
|
||||
Given a user input and an existing partial response as context, \
|
||||
ask a question to which the answer is the given term/entity/phrase:
|
||||
|
||||
>>> USER INPUT: {user_input}
|
||||
>>> EXISTING PARTIAL RESPONSE: {current_response}
|
||||
|
||||
The question to which the answer is the term/entity/phrase "{uncertain_span}" is:"""
|
||||
QUESTION_GENERATOR_PROMPT = PromptTemplate(
|
||||
template=QUESTION_GENERATOR_PROMPT_TEMPLATE,
|
||||
input_variables=["user_input", "current_response", "uncertain_span"],
|
||||
)
|
||||
Reference in New Issue
Block a user