initial commit
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Chains module for langchain_community
|
||||
|
||||
This module contains the community chains.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chains.pebblo_retrieval.base import PebbloRetrievalQA
|
||||
|
||||
__all__ = ["PebbloRetrievalQA"]
|
||||
|
||||
_module_lookup = {
|
||||
"PebbloRetrievalQA": "langchain_community.chains.pebblo_retrieval.base"
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
from langchain_classic.chains.ernie_functions.base import (
|
||||
convert_to_ernie_function,
|
||||
create_ernie_fn_chain,
|
||||
create_ernie_fn_runnable,
|
||||
create_structured_output_chain,
|
||||
create_structured_output_runnable,
|
||||
get_ernie_output_parser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_to_ernie_function",
|
||||
"create_structured_output_chain",
|
||||
"create_ernie_fn_chain",
|
||||
"create_structured_output_runnable",
|
||||
"create_ernie_fn_runnable",
|
||||
"get_ernie_output_parser",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,553 @@
|
||||
"""Methods for creating chains that use Ernie function-calling APIs."""
|
||||
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_classic.chains import LLMChain
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.output_parsers.ernie_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain_community.utils.ernie_functions import convert_pydantic_to_ernie_function
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
"str": "string",
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
}
|
||||
|
||||
|
||||
def _get_python_function_name(function: Callable) -> str:
|
||||
"""Get the name of a Python function."""
|
||||
return function.__name__
|
||||
|
||||
|
||||
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
docstring = inspect.getdoc(function)
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
elif block.startswith("Returns:") or block.startswith("Example:"):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":")
|
||||
arg_descriptions[arg.strip()] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg.strip()] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
||||
"""Get JsonSchema describing a Python functions arguments.
|
||||
|
||||
Assumes all function arguments are of primitive types (int, float, str, bool) or
|
||||
are subclasses of pydantic.BaseModel.
|
||||
"""
|
||||
properties = {}
|
||||
annotations = inspect.getfullargspec(function).annotations
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg == "return":
|
||||
continue
|
||||
if isinstance(arg_type, type) and is_basemodel_subclass(arg_type):
|
||||
# Mypy error:
|
||||
# "type" has no attribute "schema"
|
||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
||||
elif arg_type.__name__ in PYTHON_TO_JSON_TYPES:
|
||||
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
|
||||
if arg in arg_descriptions:
|
||||
if arg not in properties:
|
||||
properties[arg] = {}
|
||||
properties[arg]["description"] = arg_descriptions[arg]
|
||||
return properties
|
||||
|
||||
|
||||
def _get_python_function_required_args(function: Callable) -> List[str]:
|
||||
"""Get the required arguments for a Python function."""
|
||||
spec = inspect.getfullargspec(function)
|
||||
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
|
||||
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
|
||||
|
||||
is_class = type(function) is type
|
||||
if is_class and required[0] == "self":
|
||||
required = required[1:]
|
||||
return required
|
||||
|
||||
|
||||
def convert_python_function_to_ernie_function(
|
||||
function: Callable,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a Python function to an Ernie function-calling API compatible dict.
|
||||
|
||||
Assumes the Python function has type hints and a docstring with a description. If
|
||||
the docstring has Google Python style argument descriptions, these will be
|
||||
included as well.
|
||||
"""
|
||||
description, arg_descriptions = _parse_python_function_docstring(function)
|
||||
return {
|
||||
"name": _get_python_function_name(function),
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": _get_python_function_arguments(function, arg_descriptions),
|
||||
"required": _get_python_function_required_args(function),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_to_ernie_function(
|
||||
function: Union[Dict[str, Any], Type[BaseModel], Callable],
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an Ernie function.
|
||||
|
||||
Args:
|
||||
function: Either a dictionary, a pydantic.BaseModel class, or a Python function.
|
||||
If a dictionary is passed in, it is assumed to already be a valid Ernie
|
||||
function.
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in function which is compatible with the
|
||||
Ernie function-calling API.
|
||||
"""
|
||||
if isinstance(function, dict):
|
||||
return function
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
return cast(Dict, convert_pydantic_to_ernie_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_ernie_function(function)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function type {type(function)}. Functions must be passed in"
|
||||
f" as Dict, pydantic.BaseModel, or Callable."
|
||||
)
|
||||
|
||||
|
||||
def get_ernie_output_parser(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
"""Get the appropriate function output parser given the user functions.
|
||||
|
||||
Args:
|
||||
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
|
||||
or a Python function. If a dictionary is passed in, it is assumed to
|
||||
already be a valid Ernie function.
|
||||
|
||||
Returns:
|
||||
A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise
|
||||
a JsonOutputFunctionsParser. If there's only one function and it is
|
||||
not a Pydantic class, then the output parser will automatically extract
|
||||
only the function arguments and not the function name.
|
||||
"""
|
||||
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
|
||||
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
name: fn for name, fn in zip(function_names, functions)
|
||||
}
|
||||
else:
|
||||
pydantic_schema = functions[0]
|
||||
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
|
||||
PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1)
|
||||
return output_parser
|
||||
|
||||
|
||||
def create_ernie_fn_runnable(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
llm: Runnable,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
"""Create a runnable sequence that uses Ernie functions.
|
||||
|
||||
Args:
|
||||
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
|
||||
Python functions. If dictionaries are passed in, they are assumed to
|
||||
already be a valid Ernie functions. If only a single
|
||||
function is passed in, then it will be enforced that the model use that
|
||||
function. pydantic.BaseModels and Python functions should have docstrings
|
||||
describing what the function does. For best results, pydantic.BaseModels
|
||||
should have descriptions of the parameters and Python functions should have
|
||||
Google Python style args descriptions in the docstring. Additionally,
|
||||
Python functions should only use primitive types (str, int, float, bool) or
|
||||
pydantic.BaseModels for arguments.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON. If multiple functions are
|
||||
passed in and they are not pydantic.BaseModels, the chain output will
|
||||
include both the name of the function that was returned and the arguments
|
||||
to pass to the function.
|
||||
|
||||
Returns:
|
||||
A runnable sequence that will pass in the given functions to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_classic.chains.ernie_functions import create_ernie_fn_chain
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecordPerson(BaseModel):
|
||||
\"\"\"Record some identifying information about a person.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The person's name")
|
||||
age: int = Field(..., description="The person's age")
|
||||
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||
|
||||
|
||||
class RecordDog(BaseModel):
|
||||
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_ernie_fn_runnable([RecordPerson, RecordDog], llm, prompt)
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if not functions:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
ernie_functions = [convert_to_ernie_function(f) for f in functions]
|
||||
llm_kwargs: Dict[str, Any] = {"functions": ernie_functions, **kwargs}
|
||||
if len(ernie_functions) == 1:
|
||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||
output_parser = output_parser or get_ernie_output_parser(functions)
|
||||
return prompt | llm.bind(**llm_kwargs) | output_parser
|
||||
|
||||
|
||||
def create_structured_output_runnable(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: Runnable,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable:
|
||||
"""Create a runnable that uses an Ernie function to get a structured output.
|
||||
|
||||
Args:
|
||||
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON.
|
||||
|
||||
Returns:
|
||||
A runnable sequence that will pass the given function to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_classic.chains.ernie_functions import create_structured_output_chain
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Dog(BaseModel):
|
||||
\"\"\"Identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Use the given format to extract information from the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_structured_output_chain(Dog, llm, prompt)
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if isinstance(output_schema, dict):
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": output_schema,
|
||||
}
|
||||
else:
|
||||
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore[valid-type]
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||
)
|
||||
return create_ernie_fn_runnable(
|
||||
[function],
|
||||
llm,
|
||||
prompt,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
""" --- Legacy --- """
|
||||
|
||||
|
||||
def create_ernie_fn_chain(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""[Legacy] Create an LLM chain that uses Ernie functions.
|
||||
|
||||
Args:
|
||||
functions: A sequence of either dictionaries, pydantic.BaseModels classes, or
|
||||
Python functions. If dictionaries are passed in, they are assumed to
|
||||
already be a valid Ernie functions. If only a single
|
||||
function is passed in, then it will be enforced that the model use that
|
||||
function. pydantic.BaseModels and Python functions should have docstrings
|
||||
describing what the function does. For best results, pydantic.BaseModels
|
||||
should have descriptions of the parameters and Python functions should have
|
||||
Google Python style args descriptions in the docstring. Additionally,
|
||||
Python functions should only use primitive types (str, int, float, bool) or
|
||||
pydantic.BaseModels for arguments.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_key: The key to use when returning the output in LLMChain.__call__.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON. If multiple functions are
|
||||
passed in and they are not pydantic.BaseModels, the chain output will
|
||||
include both the name of the function that was returned and the arguments
|
||||
to pass to the function.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass in the given functions to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_classic.chains.ernie_functions import create_ernie_fn_chain
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecordPerson(BaseModel):
|
||||
\"\"\"Record some identifying information about a person.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The person's name")
|
||||
age: int = Field(..., description="The person's age")
|
||||
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||
|
||||
|
||||
class RecordDog(BaseModel):
|
||||
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Make calls to the relevant function to record the entities in the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_ernie_fn_chain([RecordPerson, RecordDog], llm, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if not functions:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
ernie_functions = [convert_to_ernie_function(f) for f in functions]
|
||||
output_parser = output_parser or get_ernie_output_parser(functions)
|
||||
llm_kwargs: Dict[str, Any] = {
|
||||
"functions": ernie_functions,
|
||||
}
|
||||
if len(ernie_functions) == 1:
|
||||
llm_kwargs["function_call"] = {"name": ernie_functions[0]["name"]}
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
output_parser=output_parser,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_key=output_key,
|
||||
**kwargs,
|
||||
)
|
||||
return llm_chain
|
||||
|
||||
|
||||
def create_structured_output_chain(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_key: str = "function",
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""[Legacy] Create an LLMChain that uses an Ernie function to get a structured output.
|
||||
|
||||
Args:
|
||||
output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the Ernie function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
output_key: The key to use when returning the output in LLMChain.__call__.
|
||||
output_parser: BaseLLMOutputParser to use for parsing model outputs. By default
|
||||
will be inferred from the function types. If pydantic.BaseModels are passed
|
||||
in, then the OutputParser will try to parse outputs using those. Otherwise
|
||||
model outputs will simply be parsed as JSON.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass the given function to the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_classic.chains.ernie_functions import create_structured_output_chain
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Dog(BaseModel):
|
||||
\"\"\"Identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
llm = ErnieBotChat(model_name="ERNIE-Bot-4")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("user", "Use the given format to extract information from the following input: {input}"),
|
||||
("assistant", "OK!"),
|
||||
("user", "Tip: Make sure to answer in the correct format"),
|
||||
]
|
||||
)
|
||||
chain = create_structured_output_chain(Dog, llm, prompt)
|
||||
chain.run("Harry was a chubby brown beagle who loved chicken")
|
||||
# -> Dog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if isinstance(output_schema, dict):
|
||||
function: Any = {
|
||||
"name": "output_formatter",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": output_schema,
|
||||
}
|
||||
else:
|
||||
|
||||
class _OutputFormatter(BaseModel):
|
||||
"""Output formatter. Should always be used to format your response to the user.""" # noqa: E501
|
||||
|
||||
output: output_schema # type: ignore[valid-type]
|
||||
|
||||
function = _OutputFormatter
|
||||
output_parser = output_parser or PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=_OutputFormatter, attr_name="output"
|
||||
)
|
||||
return create_ernie_fn_chain(
|
||||
[function],
|
||||
llm,
|
||||
prompt,
|
||||
output_key=output_key,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Question answering over a knowledge graph."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,273 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
AQL_FIX_PROMPT,
|
||||
AQL_GENERATION_PROMPT,
|
||||
AQL_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.arangodb_graph import ArangoGraph
|
||||
|
||||
|
||||
class ArangoGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating AQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: ArangoGraph = Field(exclude=True)
|
||||
aql_generation_chain: LLMChain
|
||||
aql_fix_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
# Specifies the maximum number of AQL Query Results to return
|
||||
top_k: int = 10
|
||||
|
||||
# Specifies the set of AQL Query Examples that promote few-shot-learning
|
||||
aql_examples: str = ""
|
||||
|
||||
# Specify whether to return the AQL Query in the output dictionary
|
||||
return_aql_query: bool = False
|
||||
|
||||
# Specify whether to return the AQL JSON Result in the output dictionary
|
||||
return_aql_result: bool = False
|
||||
|
||||
# Specify the maximum amount of AQL Generation attempts that should be made
|
||||
max_aql_generation_attempts: int = 3
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_aql_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = AQL_QA_PROMPT,
|
||||
aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT,
|
||||
aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> ArangoGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt)
|
||||
aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
aql_generation_chain=aql_generation_chain,
|
||||
aql_fix_chain=aql_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an AQL statement from user input, use it retrieve a response
|
||||
from an ArangoDB Database instance, and respond to the user input
|
||||
in natural language.
|
||||
|
||||
Users can modify the following ArangoGraphQAChain Class Variables:
|
||||
|
||||
:var top_k: The maximum number of AQL Query Results to return
|
||||
:type top_k: int
|
||||
|
||||
:var aql_examples: A set of AQL Query Examples that are passed to
|
||||
the AQL Generation Prompt Template to promote few-shot-learning.
|
||||
Defaults to an empty string.
|
||||
:type aql_examples: str
|
||||
|
||||
:var return_aql_query: Whether to return the AQL Query in the
|
||||
output dictionary. Defaults to False.
|
||||
:type return_aql_query: bool
|
||||
|
||||
:var return_aql_result: Whether to return the AQL Query in the
|
||||
output dictionary. Defaults to False
|
||||
:type return_aql_result: bool
|
||||
|
||||
:var max_aql_generation_attempts: The maximum amount of AQL
|
||||
Generation attempts to be made prior to raising the last
|
||||
AQL Query Execution Error. Defaults to 3.
|
||||
:type max_aql_generation_attempts: int
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
user_input = inputs[self.input_key]
|
||||
|
||||
#########################
|
||||
# Generate AQL Query #
|
||||
aql_generation_output = self.aql_generation_chain.run(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"aql_examples": self.aql_examples,
|
||||
"user_input": user_input,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
#########################
|
||||
|
||||
aql_query = ""
|
||||
aql_error = ""
|
||||
aql_result = None
|
||||
aql_generation_attempt = 1
|
||||
|
||||
while (
|
||||
aql_result is None
|
||||
and aql_generation_attempt < self.max_aql_generation_attempts + 1
|
||||
):
|
||||
#####################
|
||||
# Extract AQL Query #
|
||||
pattern = r"```(?i:aql)?(.*?)```"
|
||||
matches = re.findall(pattern, aql_generation_output, re.DOTALL)
|
||||
if not matches:
|
||||
_run_manager.on_text(
|
||||
"Invalid Response: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_generation_output, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
raise ValueError(f"Response is Invalid: {aql_generation_output}")
|
||||
|
||||
aql_query = matches[0]
|
||||
#####################
|
||||
|
||||
_run_manager.on_text(
|
||||
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_query, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
#####################
|
||||
# Execute AQL Query #
|
||||
from arango import AQLQueryExecuteError
|
||||
|
||||
try:
|
||||
aql_result = self.graph.query(aql_query, self.top_k)
|
||||
except AQLQueryExecuteError as e:
|
||||
aql_error = e.error_message
|
||||
|
||||
_run_manager.on_text(
|
||||
"AQL Query Execution Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
aql_error, color="yellow", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
########################
|
||||
# Retry AQL Generation #
|
||||
aql_generation_output = self.aql_fix_chain.run(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"aql_query": aql_query,
|
||||
"aql_error": aql_error,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
########################
|
||||
|
||||
#####################
|
||||
|
||||
aql_generation_attempt += 1
|
||||
|
||||
if aql_result is None:
|
||||
m = f"""
|
||||
Maximum amount of AQL Query Generation attempts reached.
|
||||
Unable to execute the AQL Query due to the following error:
|
||||
{aql_error}
|
||||
"""
|
||||
raise ValueError(m)
|
||||
|
||||
_run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(aql_result), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
########################
|
||||
# Interpret AQL Result #
|
||||
result = self.qa_chain(
|
||||
{
|
||||
"adb_schema": self.graph.schema,
|
||||
"user_input": user_input,
|
||||
"aql_query": aql_query,
|
||||
"aql_result": aql_result,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
########################
|
||||
|
||||
# Return results #
|
||||
result = {self.output_key: result[self.qa_chain.output_key]}
|
||||
|
||||
if self.return_aql_query:
|
||||
result["aql_query"] = aql_query
|
||||
|
||||
if self.return_aql_result:
|
||||
result["aql_result"] = aql_result
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
GRAPH_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
||||
|
||||
|
||||
class GraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: NetworkxEntityGraph = Field(exclude=True)
|
||||
entity_extraction_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
|
||||
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
entity_extraction_chain=entity_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Extract entities, look up info and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
entity_string = self.entity_extraction_chain.run(question)
|
||||
|
||||
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
entity_string, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
entities = get_entities(entity_string)
|
||||
context = ""
|
||||
all_triplets = []
|
||||
for entity in entities:
|
||||
all_triplets.extend(self.graph.get_entity_knowledge(entity))
|
||||
context = "\n".join(all_triplets)
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,421 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.cypher_utils import (
|
||||
CypherQueryCorrector,
|
||||
Schema,
|
||||
)
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
CYPHER_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
|
||||
understandable answers based on the provided information from tools.
|
||||
Do not add any other information that wasn't present in the tools, and use
|
||||
very concise style in interpreting results!
|
||||
"""
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.chains.graph_qa.cypher.extract_cypher",
|
||||
)
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.chains.graph_qa.cypher.construct_schema",
|
||||
)
|
||||
def construct_schema(
|
||||
structured_schema: Dict[str, Any],
|
||||
include_types: List[str],
|
||||
exclude_types: List[str],
|
||||
) -> str:
|
||||
"""Filter the schema based on included or excluded types"""
|
||||
|
||||
def filter_func(x: str) -> bool:
|
||||
return x in include_types if include_types else x not in exclude_types
|
||||
|
||||
filtered_schema: Dict[str, Any] = {
|
||||
"node_props": {
|
||||
k: v
|
||||
for k, v in structured_schema.get("node_props", {}).items()
|
||||
if filter_func(k)
|
||||
},
|
||||
"rel_props": {
|
||||
k: v
|
||||
for k, v in structured_schema.get("rel_props", {}).items()
|
||||
if filter_func(k)
|
||||
},
|
||||
"relationships": [
|
||||
r
|
||||
for r in structured_schema.get("relationships", [])
|
||||
if all(filter_func(r[t]) for t in ["start", "end", "type"])
|
||||
],
|
||||
}
|
||||
|
||||
# Format node properties
|
||||
formatted_node_props = []
|
||||
for label, properties in filtered_schema["node_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
||||
)
|
||||
formatted_node_props.append(f"{label} {{{props_str}}}")
|
||||
|
||||
# Format relationship properties
|
||||
formatted_rel_props = []
|
||||
for rel_type, properties in filtered_schema["rel_props"].items():
|
||||
props_str = ", ".join(
|
||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
||||
)
|
||||
formatted_rel_props.append(f"{rel_type} {{{props_str}}}")
|
||||
|
||||
# Format relationships
|
||||
formatted_rels = [
|
||||
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
|
||||
for el in filtered_schema["relationships"]
|
||||
]
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
"Node properties are the following:",
|
||||
",".join(formatted_node_props),
|
||||
"Relationship properties are the following:",
|
||||
",".join(formatted_rel_props),
|
||||
"The relationships are the following:",
|
||||
",".join(formatted_rels),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.chains.graph_qa.cypher.get_function_response",
|
||||
)
|
||||
def get_function_response(
|
||||
question: str, context: List[Dict[str, Any]]
|
||||
) -> List[BaseMessage]:
|
||||
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": TOOL_ID,
|
||||
"function": {
|
||||
"arguments": '{"question":"' + question + '"}',
|
||||
"name": "GetInformation",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.GraphCypherQAChain",
|
||||
)
|
||||
class GraphCypherQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: GraphStore = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: Union[LLMChain, Runnable]
|
||||
graph_schema: str
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
||||
"""Optional cypher validation tool"""
|
||||
use_function_response: bool = False
|
||||
"""Whether to wrap the database context as tool/function response"""
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
*,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
|
||||
exclude_types: List[str] = [],
|
||||
include_types: List[str] = [],
|
||||
validate_cypher: bool = False,
|
||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
use_function_response: bool = False,
|
||||
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
|
||||
**kwargs: Any,
|
||||
) -> GraphCypherQAChain:
|
||||
"""Initialize from LLM."""
|
||||
|
||||
if not cypher_llm and not llm:
|
||||
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
||||
if not qa_llm and not llm:
|
||||
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
|
||||
if cypher_llm and qa_llm and llm:
|
||||
raise ValueError(
|
||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||
", and 'llm', but not all three simultaneously."
|
||||
)
|
||||
if cypher_prompt and cypher_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
||||
)
|
||||
if qa_prompt and qa_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying qa_prompt and qa_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via qa_llm_kwargs."
|
||||
)
|
||||
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
||||
use_cypher_llm_kwargs = (
|
||||
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
||||
)
|
||||
if "prompt" not in use_qa_llm_kwargs:
|
||||
use_qa_llm_kwargs["prompt"] = (
|
||||
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
|
||||
)
|
||||
if "prompt" not in use_cypher_llm_kwargs:
|
||||
use_cypher_llm_kwargs["prompt"] = (
|
||||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
||||
)
|
||||
|
||||
qa_llm = qa_llm or llm
|
||||
if use_function_response:
|
||||
try:
|
||||
qa_llm.bind_tools({}) # type: ignore[union-attr]
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessage(content=function_response_system),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
|
||||
except (NotImplementedError, AttributeError):
|
||||
raise ValueError("Provided LLM does not support native tools/functions")
|
||||
else:
|
||||
qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
||||
|
||||
cypher_generation_chain = LLMChain(
|
||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||
**use_cypher_llm_kwargs,
|
||||
)
|
||||
|
||||
if exclude_types and include_types:
|
||||
raise ValueError(
|
||||
"Either `exclude_types` or `include_types` "
|
||||
"can be provided, but not both"
|
||||
)
|
||||
graph_schema = construct_schema(
|
||||
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
||||
)
|
||||
|
||||
cypher_query_corrector = None
|
||||
if validate_cypher:
|
||||
corrector_schema = [
|
||||
Schema(el["start"], el["type"], el["end"])
|
||||
for el in kwargs["graph"].structured_schema.get("relationships")
|
||||
]
|
||||
cypher_query_corrector = CypherQueryCorrector(corrector_schema)
|
||||
|
||||
return cls(
|
||||
graph_schema=graph_schema,
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
cypher_query_corrector=cypher_query_corrector,
|
||||
use_function_response=use_function_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
args = {
|
||||
"question": question,
|
||||
"schema": self.graph_schema,
|
||||
}
|
||||
args.update(inputs)
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(args, callbacks=callbacks)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
# Correct Cypher query if enabled
|
||||
if self.cypher_query_corrector:
|
||||
generated_cypher = self.cypher_query_corrector(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
# Generated Cypher be null if query corrector identifies invalid schema
|
||||
if generated_cypher:
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
if self.use_function_response:
|
||||
function_response = get_function_response(question, context)
|
||||
final_result = self.qa_chain.invoke( # type: ignore[assignment]
|
||||
{"question": question, "function_response": function_response},
|
||||
)
|
||||
else:
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key] # type: ignore[union-attr]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,267 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
|
||||
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.8",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_neo4j.chains.graph_qa.cypher_utils.CypherQueryCorrector",
|
||||
)
|
||||
class CypherQueryCorrector:
|
||||
"""
|
||||
Used to correct relationship direction in generated Cypher statements.
|
||||
This code is copied from the winner's submission to the Cypher competition:
|
||||
https://github.com/sakusaku-rich/cypher-direction-competition
|
||||
"""
|
||||
|
||||
property_pattern = re.compile(r"\{.+?\}")
|
||||
node_pattern = re.compile(r"\(.+?\)")
|
||||
path_pattern = re.compile(
|
||||
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
|
||||
)
|
||||
node_relation_node_pattern = re.compile(
|
||||
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
|
||||
)
|
||||
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
|
||||
|
||||
def __init__(self, schemas: List[Schema]):
|
||||
"""
|
||||
Args:
|
||||
schemas: list of schemas
|
||||
"""
|
||||
self.schemas = schemas
|
||||
|
||||
def clean_node(self, node: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
node: node in string format
|
||||
|
||||
"""
|
||||
node = re.sub(self.property_pattern, "", node)
|
||||
node = node.replace("(", "")
|
||||
node = node.replace(")", "")
|
||||
node = node.strip()
|
||||
return node
|
||||
|
||||
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
nodes = re.findall(self.node_pattern, query)
|
||||
nodes = [self.clean_node(node) for node in nodes]
|
||||
res: Dict[str, Any] = {}
|
||||
for node in nodes:
|
||||
parts = node.split(":")
|
||||
if parts == "":
|
||||
continue
|
||||
variable = parts[0]
|
||||
if variable not in res:
|
||||
res[variable] = []
|
||||
res[variable] += parts[1:]
|
||||
return res
|
||||
|
||||
def extract_paths(self, query: str) -> "List[str]":
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
paths = []
|
||||
idx = 0
|
||||
while matched := self.path_pattern.findall(query[idx:]):
|
||||
matched = matched[0]
|
||||
matched = [
|
||||
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
|
||||
]
|
||||
path = "".join(matched)
|
||||
idx = query.find(path) + len(path) - len(matched[-1])
|
||||
paths.append(path)
|
||||
return paths
|
||||
|
||||
def judge_direction(self, relation: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
relation: relation in string format
|
||||
"""
|
||||
direction = "BIDIRECTIONAL"
|
||||
if relation[0] == "<":
|
||||
direction = "INCOMING"
|
||||
if relation[-1] == ">":
|
||||
direction = "OUTGOING"
|
||||
return direction
|
||||
|
||||
def extract_node_variable(self, part: str) -> Optional[str]:
|
||||
"""
|
||||
Args:
|
||||
part: node in string format
|
||||
"""
|
||||
part = part.lstrip("(").rstrip(")")
|
||||
idx = part.find(":")
|
||||
if idx != -1:
|
||||
part = part[:idx]
|
||||
return None if part == "" else part
|
||||
|
||||
def detect_labels(
|
||||
self, str_node: str, node_variable_dict: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
str_node: node in string format
|
||||
node_variable_dict: dictionary of node variables
|
||||
"""
|
||||
splitted_node = str_node.split(":")
|
||||
variable = splitted_node[0]
|
||||
labels = []
|
||||
if variable in node_variable_dict:
|
||||
labels = node_variable_dict[variable]
|
||||
elif variable == "" and len(splitted_node) > 1:
|
||||
labels = splitted_node[1:]
|
||||
return labels
|
||||
|
||||
def verify_schema(
|
||||
self,
|
||||
from_node_labels: List[str],
|
||||
relation_types: List[str],
|
||||
to_node_labels: List[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
from_node_labels: labels of the from node
|
||||
relation_type: type of the relation
|
||||
to_node_labels: labels of the to node
|
||||
"""
|
||||
valid_schemas = self.schemas
|
||||
if from_node_labels != []:
|
||||
from_node_labels = [label.strip("`") for label in from_node_labels]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[0] in from_node_labels
|
||||
]
|
||||
if to_node_labels != []:
|
||||
to_node_labels = [label.strip("`") for label in to_node_labels]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[2] in to_node_labels
|
||||
]
|
||||
if relation_types != []:
|
||||
relation_types = [type.strip("`") for type in relation_types]
|
||||
valid_schemas = [
|
||||
schema for schema in valid_schemas if schema[1] in relation_types
|
||||
]
|
||||
return valid_schemas != []
|
||||
|
||||
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Args:
|
||||
str_relation: relation in string format
|
||||
"""
|
||||
relation_direction = self.judge_direction(str_relation)
|
||||
relation_type = self.relation_type_pattern.search(str_relation)
|
||||
if relation_type is None or relation_type.group("relation_type") is None:
|
||||
return relation_direction, []
|
||||
relation_types = [
|
||||
t.strip().strip("!")
|
||||
for t in relation_type.group("relation_type").split("|")
|
||||
]
|
||||
return relation_direction, relation_types
|
||||
|
||||
def correct_query(self, query: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
node_variable_dict = self.detect_node_variables(query)
|
||||
paths = self.extract_paths(query)
|
||||
for path in paths:
|
||||
original_path = path
|
||||
start_idx = 0
|
||||
while start_idx < len(path):
|
||||
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
|
||||
if match_res is None:
|
||||
break
|
||||
start_idx += match_res.start()
|
||||
match_dict = match_res.groupdict()
|
||||
left_node_labels = self.detect_labels(
|
||||
match_dict["left_node"], node_variable_dict
|
||||
)
|
||||
right_node_labels = self.detect_labels(
|
||||
match_dict["right_node"], node_variable_dict
|
||||
)
|
||||
end_idx = (
|
||||
start_idx
|
||||
+ 4
|
||||
+ len(match_dict["left_node"])
|
||||
+ len(match_dict["relation"])
|
||||
+ len(match_dict["right_node"])
|
||||
)
|
||||
original_partial_path = original_path[start_idx : end_idx + 1]
|
||||
relation_direction, relation_types = self.detect_relation_types(
|
||||
match_dict["relation"]
|
||||
)
|
||||
|
||||
if relation_types != [] and "".join(relation_types).find("*") != -1:
|
||||
start_idx += (
|
||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||||
)
|
||||
continue
|
||||
|
||||
if relation_direction == "OUTGOING":
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
is_legal = self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if is_legal:
|
||||
corrected_relation = "<" + match_dict["relation"][:-1]
|
||||
corrected_partial_path = original_partial_path.replace(
|
||||
match_dict["relation"], corrected_relation
|
||||
)
|
||||
query = query.replace(
|
||||
original_partial_path, corrected_partial_path
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
elif relation_direction == "INCOMING":
|
||||
is_legal = self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
if is_legal:
|
||||
corrected_relation = match_dict["relation"][1:] + ">"
|
||||
corrected_partial_path = original_partial_path.replace(
|
||||
match_dict["relation"], corrected_relation
|
||||
)
|
||||
query = query.replace(
|
||||
original_partial_path, corrected_partial_path
|
||||
)
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
is_legal = self.verify_schema(
|
||||
left_node_labels, relation_types, right_node_labels
|
||||
)
|
||||
is_legal |= self.verify_schema(
|
||||
right_node_labels, relation_types, left_node_labels
|
||||
)
|
||||
if not is_legal:
|
||||
return ""
|
||||
|
||||
start_idx += (
|
||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||||
)
|
||||
return query
|
||||
|
||||
def __call__(self, query: str) -> str:
|
||||
"""Correct the query to make it valid. If
|
||||
Args:
|
||||
query: cypher query
|
||||
"""
|
||||
return self.correct_query(query)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_GENERATION_PROMPT,
|
||||
CYPHER_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs import FalkorDBGraph
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""
|
||||
Extract Cypher code from a text.
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
class FalkorDBQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: FalkorDBGraph = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> FalkorDBQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs import GremlinGraph
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_gremlin(text: str) -> str:
|
||||
"""Extract Gremlin code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Gremlin code from.
|
||||
|
||||
Returns:
|
||||
Gremlin code extracted from the text.
|
||||
"""
|
||||
text = text.replace("`", "")
|
||||
if text.startswith("gremlin"):
|
||||
text = text[len("gremlin") :]
|
||||
return text.replace("\n", "")
|
||||
|
||||
|
||||
class GremlinQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: GremlinGraph = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
gremlin_fix_chain: LLMChain
|
||||
max_fix_retries: int = 3
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 100
|
||||
return_direct: bool = False
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
||||
"in Turtle format", ""
|
||||
),
|
||||
),
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GremlinQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
gremlin_fix_chain=gremlinl_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
chain_response = self.gremlin_generation_chain.invoke(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
generated_gremlin = extract_gremlin(
|
||||
chain_response[self.gremlin_generation_chain.output_key]
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_gremlin})
|
||||
|
||||
if generated_gremlin:
|
||||
context = self.execute_with_retry(
|
||||
_run_manager, callbacks, generated_gremlin
|
||||
)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
|
||||
def execute_query(self, query: str) -> List[Any]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_message"):
|
||||
raise ValueError(e.status_message)
|
||||
else:
|
||||
raise ValueError(str(e))
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_gremlin: str,
|
||||
) -> List[Any]:
|
||||
try:
|
||||
return self.execute_query(generated_gremlin)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
fix_chain_result = self.gremlin_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
# we are borrowing template from sparql
|
||||
"generated_sparql": generated_gremlin,
|
||||
"schema": self.schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
||||
return self.execute_query(fixed_gremlin)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
||||
|
||||
raise ValueError("The generated Gremlin query is invalid.")
|
||||
|
||||
def log_invalid_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.hugegraph import HugeGraph
|
||||
|
||||
|
||||
class HugeGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: HugeGraph = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> HugeGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_gremlin = self.gremlin_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_gremlin)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
KUZU_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.kuzu_graph import KuzuGraph
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
"""Remove a prefix from a text.
|
||||
|
||||
Args:
|
||||
text: Text to remove the prefix from.
|
||||
prefix: Prefix to remove from the text.
|
||||
|
||||
Returns:
|
||||
Text with the prefix removed.
|
||||
"""
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :]
|
||||
return text
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
class KuzuQAChain(Chain):
|
||||
"""Question-answering against a graph by generating Cypher statements for Kùzu.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: KuzuGraph = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
|
||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[BaseLanguageModel] = None,
|
||||
**kwargs: Any,
|
||||
) -> KuzuQAChain:
|
||||
"""Initialize from LLM."""
|
||||
if not cypher_llm and not llm:
|
||||
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
||||
if not qa_llm and not llm:
|
||||
raise ValueError(
|
||||
"Either `llm` or `qa_llm` parameters must be provided along with"
|
||||
" `cypher_llm`"
|
||||
)
|
||||
if cypher_llm and qa_llm and llm:
|
||||
raise ValueError(
|
||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||
", and 'llm', but not all three simultaneously."
|
||||
)
|
||||
|
||||
qa_chain = LLMChain(
|
||||
llm=qa_llm or llm, # type: ignore[arg-type]
|
||||
prompt=qa_prompt,
|
||||
)
|
||||
cypher_generation_chain = LLMChain(
|
||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
||||
prompt=cypher_prompt,
|
||||
)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
# Extract Cypher code if it is wrapped in triple backticks
|
||||
# with the language marker "cypher"
|
||||
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,316 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
MEMGRAPH_GENERATION_PROMPT,
|
||||
MEMGRAPH_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.memgraph_graph import MemgraphGraph
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human
|
||||
understandable answers based on the provided information from tools.
|
||||
Do not add any other information that wasn't present in the tools, and use
|
||||
very concise style in interpreting results!
|
||||
"""
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Cypher code from.
|
||||
|
||||
Returns:
|
||||
Cypher code extracted from the text.
|
||||
"""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
def get_function_response(
|
||||
question: str, context: List[Dict[str, Any]]
|
||||
) -> List[BaseMessage]:
|
||||
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": TOOL_ID,
|
||||
"function": {
|
||||
"arguments": '{"question":"' + question + '"}',
|
||||
"name": "GetInformation",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
},
|
||||
),
|
||||
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
class MemgraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: MemgraphGraph = Field(exclude=True)
|
||||
cypher_generation_chain: Runnable
|
||||
qa_chain: Runnable
|
||||
graph_schema: str
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Optional cypher validation tool"""
|
||||
use_function_response: bool = False
|
||||
"""Whether to wrap the database context as tool/function response"""
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
*,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
||||
qa_llm: Optional[Union[BaseLanguageModel, Any]] = None,
|
||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
use_function_response: bool = False,
|
||||
function_response_system: str = FUNCTION_RESPONSE_SYSTEM,
|
||||
**kwargs: Any,
|
||||
) -> MemgraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
|
||||
if not cypher_llm and not llm:
|
||||
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
||||
if not qa_llm and not llm:
|
||||
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
|
||||
if cypher_llm and qa_llm and llm:
|
||||
raise ValueError(
|
||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
||||
", and 'llm', but not all three simultaneously."
|
||||
)
|
||||
if cypher_prompt and cypher_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
||||
)
|
||||
if qa_prompt and qa_llm_kwargs:
|
||||
raise ValueError(
|
||||
"Specifying qa_prompt and qa_llm_kwargs together is"
|
||||
" not allowed. Please pass prompt via qa_llm_kwargs."
|
||||
)
|
||||
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
||||
use_cypher_llm_kwargs = (
|
||||
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
||||
)
|
||||
if "prompt" not in use_qa_llm_kwargs:
|
||||
use_qa_llm_kwargs["prompt"] = (
|
||||
qa_prompt if qa_prompt is not None else MEMGRAPH_QA_PROMPT
|
||||
)
|
||||
if "prompt" not in use_cypher_llm_kwargs:
|
||||
use_cypher_llm_kwargs["prompt"] = (
|
||||
cypher_prompt
|
||||
if cypher_prompt is not None
|
||||
else MEMGRAPH_GENERATION_PROMPT
|
||||
)
|
||||
|
||||
qa_llm = qa_llm or llm
|
||||
if use_function_response:
|
||||
try:
|
||||
qa_llm.bind_tools({}) # type: ignore[union-attr]
|
||||
response_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessage(content=function_response_system),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
MessagesPlaceholder(variable_name="function_response"),
|
||||
]
|
||||
)
|
||||
qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore[operator]
|
||||
except (NotImplementedError, AttributeError):
|
||||
raise ValueError("Provided LLM does not support native tools/functions")
|
||||
else:
|
||||
qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser()
|
||||
|
||||
prompt = use_cypher_llm_kwargs["prompt"]
|
||||
llm_to_use = cypher_llm if cypher_llm is not None else llm
|
||||
|
||||
if prompt is not None and llm_to_use is not None:
|
||||
cypher_generation_chain = prompt | llm_to_use | StrOutputParser()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing required components for the cypher generation chain: "
|
||||
"'prompt' or 'llm'"
|
||||
)
|
||||
|
||||
graph_schema = kwargs["graph"].get_schema
|
||||
|
||||
return cls(
|
||||
graph_schema=graph_schema,
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
use_function_response=use_function_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
args = {
|
||||
"question": question,
|
||||
"schema": self.graph_schema,
|
||||
}
|
||||
args.update(inputs)
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.invoke(
|
||||
args, callbacks=callbacks
|
||||
)
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
# Retrieve and limit the number of results
|
||||
# Generated Cypher be null if query corrector identifies invalid schema
|
||||
if generated_cypher:
|
||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
if self.use_function_response:
|
||||
function_response = get_function_response(question, context)
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "function_response": function_response},
|
||||
)
|
||||
else:
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
chain_result: Dict[str, Any] = {"result": result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
NGQL_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.nebula_graph import NebulaGraph
|
||||
|
||||
|
||||
class NebulaGraphQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating nGQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: NebulaGraph = Field(exclude=True)
|
||||
ngql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> NebulaGraphQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
ngql_generation_chain=ngql_generation_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate nGQL statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
generated_ngql = self.ngql_generation_chain.run(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_ngql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
context = self.graph.query(generated_ngql)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
return {self.output_key: result[self.qa_chain.output_key]}
|
||||
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_classic.chains.prompt_selector import ConditionalPromptSelector
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs import BaseNeptuneGraph
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def trim_query(query: str) -> str:
|
||||
"""Trim the query to only include Cypher keywords."""
|
||||
keywords = (
|
||||
"CALL",
|
||||
"CREATE",
|
||||
"DELETE",
|
||||
"DETACH",
|
||||
"LIMIT",
|
||||
"MATCH",
|
||||
"MERGE",
|
||||
"OPTIONAL",
|
||||
"ORDER",
|
||||
"REMOVE",
|
||||
"RETURN",
|
||||
"SET",
|
||||
"SKIP",
|
||||
"UNWIND",
|
||||
"WITH",
|
||||
"WHERE",
|
||||
"//",
|
||||
)
|
||||
|
||||
lines = query.split("\n")
|
||||
new_query = ""
|
||||
|
||||
for line in lines:
|
||||
if line.strip().upper().startswith(keywords):
|
||||
new_query += line + "\n"
|
||||
|
||||
return new_query
|
||||
|
||||
|
||||
def extract_cypher(text: str) -> str:
|
||||
"""Extract Cypher code from text using Regex."""
|
||||
# The pattern to find Cypher code enclosed in triple backticks
|
||||
pattern = r"```(.*?)```"
|
||||
|
||||
# Find all matches in the input text
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
return matches[0] if matches else text
|
||||
|
||||
|
||||
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
|
||||
"""Decides whether to use the simple prompt"""
|
||||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
# Bedrock anthropic
|
||||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.15",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.create_neptune_opencypher_qa_chain",
|
||||
)
|
||||
class NeptuneOpenCypherQAChain(Chain):
|
||||
"""Chain for question-answering against a Neptune graph
|
||||
by generating openCypher statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain = NeptuneOpenCypherQAChain.from_llm(
|
||||
llm=llm,
|
||||
graph=graph
|
||||
)
|
||||
response = chain.run(query)
|
||||
"""
|
||||
|
||||
graph: BaseNeptuneGraph = Field(exclude=True)
|
||||
cypher_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
extra_instructions: Optional[str] = None
|
||||
"""Extra instructions by the appended to the query generation prompt."""
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||
extra_instructions: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> NeptuneOpenCypherQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
|
||||
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
extra_instructions=extra_instructions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_cypher = self.cypher_generation_chain.run(
|
||||
{
|
||||
"question": question,
|
||||
"schema": self.graph.get_schema,
|
||||
"extra_instructions": self.extra_instructions or "",
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Extract Cypher code if it is wrapped in backticks
|
||||
generated_cypher = extract_cypher(generated_cypher)
|
||||
generated_cypher = trim_query(generated_cypher)
|
||||
|
||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_cypher})
|
||||
|
||||
context = self.graph.query(generated_cypher)
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Question answering over an RDF or OWL graph using SPARQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import SPARQL_QA_PROMPT
|
||||
from langchain_community.graphs import NeptuneRdfGraph
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
SPARQL_GENERATION_TEMPLATE = """
|
||||
Task: Generate a SPARQL SELECT statement for querying a graph database.
|
||||
For instance, to find all email addresses of John Doe, the following
|
||||
query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
SELECT ?email
|
||||
WHERE {{
|
||||
?person foaf:name "John Doe" .
|
||||
?person foaf:mbox ?email .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
|
||||
Examples:
|
||||
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than
|
||||
for you to construct a SPARQL query.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
|
||||
The question is:
|
||||
{prompt}"""
|
||||
|
||||
SPARQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
def extract_sparql(query: str) -> str:
|
||||
"""Extract SPARQL code from a text.
|
||||
|
||||
Args:
|
||||
query: Text to extract SPARQL code from.
|
||||
|
||||
Returns:
|
||||
SPARQL code extracted from the text.
|
||||
"""
|
||||
query = query.strip()
|
||||
querytoks = query.split("```")
|
||||
if len(querytoks) == 3:
|
||||
query = querytoks[1]
|
||||
|
||||
if query.startswith("sparql"):
|
||||
query = query[6:]
|
||||
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
|
||||
query = query[8:-9]
|
||||
return query
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.15",
|
||||
removal="1.0",
|
||||
alternative_import="langchain_aws.create_neptune_sparql_qa_chain",
|
||||
)
|
||||
class NeptuneSparqlQAChain(Chain):
|
||||
"""Chain for question-answering against a Neptune graph
|
||||
by generating SPARQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain = NeptuneSparqlQAChain.from_llm(
|
||||
llm=llm,
|
||||
graph=graph
|
||||
)
|
||||
response = chain.invoke(query)
|
||||
"""
|
||||
|
||||
graph: NeptuneRdfGraph = Field(exclude=True)
|
||||
sparql_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 10
|
||||
return_intermediate_steps: bool = False
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the graph directly."""
|
||||
extra_instructions: Optional[str] = None
|
||||
"""Extra instructions by the appended to the query generation prompt."""
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
||||
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> NeptuneSparqlQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
template_to_use = SPARQL_GENERATION_TEMPLATE
|
||||
if examples:
|
||||
template_to_use = template_to_use.replace(
|
||||
"Examples:", "Examples: " + examples
|
||||
)
|
||||
sparql_prompt = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=template_to_use
|
||||
)
|
||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_chain=sparql_generation_chain,
|
||||
examples=examples,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
generated_sparql = self.sparql_generation_chain.run(
|
||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
# Extract SPARQL
|
||||
generated_sparql = extract_sparql(generated_sparql)
|
||||
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_sparql})
|
||||
|
||||
context = self.graph.query(generated_sparql)
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain(
|
||||
{"prompt": prompt, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Question answering over a graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import rdflib
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
GRAPHDB_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_PROMPT,
|
||||
GRAPHDB_SPARQL_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs import OntotextGraphDBGraph
|
||||
|
||||
|
||||
class OntotextGraphDBQAChain(Chain):
|
||||
"""Question-answering against Ontotext GraphDB
|
||||
https://graphdb.ontotext.com/ by generating SPARQL queries.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: OntotextGraphDBGraph = Field(exclude=True)
|
||||
sparql_generation_chain: LLMChain
|
||||
sparql_fix_chain: LLMChain
|
||||
max_fix_retries: int
|
||||
qa_chain: LLMChain
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT,
|
||||
sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT,
|
||||
max_fix_retries: int = 5,
|
||||
qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> OntotextGraphDBQAChain:
|
||||
"""Initialize from LLM."""
|
||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt)
|
||||
sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt)
|
||||
max_fix_retries = max_fix_retries
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_chain=sparql_generation_chain,
|
||||
sparql_fix_chain=sparql_fix_chain,
|
||||
max_fix_retries=max_fix_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a SPARQL query, use it to retrieve a response from GraphDB and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
ontology_schema = self.graph.get_schema
|
||||
|
||||
sparql_generation_chain_result = self.sparql_generation_chain.invoke(
|
||||
{"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks
|
||||
)
|
||||
generated_sparql = sparql_generation_chain_result[
|
||||
self.sparql_generation_chain.output_key
|
||||
]
|
||||
|
||||
generated_sparql = self._get_prepared_sparql_query(
|
||||
_run_manager, callbacks, generated_sparql, ontology_schema
|
||||
)
|
||||
query_results = self._execute_query(generated_sparql)
|
||||
|
||||
qa_chain_result = self.qa_chain.invoke(
|
||||
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
||||
)
|
||||
result = qa_chain_result[self.qa_chain.output_key]
|
||||
return {self.output_key: result}
|
||||
|
||||
def _get_prepared_sparql_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_sparql: str,
|
||||
ontology_schema: str,
|
||||
) -> str:
|
||||
try:
|
||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self._log_invalid_sparql_query(
|
||||
_run_manager, generated_sparql, error_message
|
||||
)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
sparql_fix_chain_result = self.sparql_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
"generated_sparql": generated_sparql,
|
||||
"schema": ontology_schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
generated_sparql = sparql_fix_chain_result[
|
||||
self.sparql_fix_chain.output_key
|
||||
]
|
||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self._log_invalid_sparql_query(
|
||||
_run_manager, generated_sparql, parse_exception
|
||||
)
|
||||
|
||||
raise ValueError("The generated SPARQL query is invalid.")
|
||||
|
||||
def _prepare_sparql_query(
|
||||
self, _run_manager: CallbackManagerForChainRun, generated_sparql: str
|
||||
) -> str:
|
||||
from rdflib.plugins.sparql import prepareQuery
|
||||
|
||||
prepareQuery(generated_sparql)
|
||||
self._log_prepared_sparql_query(_run_manager, generated_sparql)
|
||||
return generated_sparql
|
||||
|
||||
def _log_prepared_sparql_query(
|
||||
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
||||
) -> None:
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
def _log_invalid_sparql_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"SPARQL Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception:
|
||||
raise ValueError("Failed to execute the generated SPARQL query.")
|
||||
@@ -0,0 +1,468 @@
|
||||
# flake8: noqa
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
|
||||
|
||||
Return the output as a single comma-separated list, or NONE if there is nothing of note to return.
|
||||
|
||||
EXAMPLE
|
||||
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
|
||||
Output: Langchain
|
||||
END OF EXAMPLE
|
||||
|
||||
EXAMPLE
|
||||
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam.
|
||||
Output: Langchain, Sam
|
||||
END OF EXAMPLE
|
||||
|
||||
Begin!
|
||||
|
||||
{input}
|
||||
Output:"""
|
||||
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
GRAPH_QA_PROMPT = PromptTemplate(
|
||||
template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
|
||||
Instructions:
|
||||
Use only the provided relationship types and properties in the schema.
|
||||
Do not use any other relationship types or properties that are not provided.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||
Do not include any text except the generated Cypher statement.
|
||||
|
||||
The question is:
|
||||
{question}"""
|
||||
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
|
||||
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
|
||||
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
|
||||
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
|
||||
3. it uses double equals sign for comparison: `==` rather than `=`
|
||||
For instance:
|
||||
```diff
|
||||
< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II'
|
||||
< RETURN p.name, e.year, m.name;
|
||||
---
|
||||
> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'
|
||||
> RETURN p.`person`.`name`, e.year, m.`movie`.`name`;
|
||||
```\n"""
|
||||
|
||||
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate NebulaGraph Cypher"
|
||||
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
|
||||
|
||||
NGQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
KUZU_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
Generate the Kùzu dialect of Cypher with the following rules in mind:
|
||||
1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
|
||||
2. Do not include triple backticks ``` in your response. Return only Cypher.
|
||||
3. Do not return any notes or comments in your response.
|
||||
\n"""
|
||||
|
||||
KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate Kùzu Cypher"
|
||||
).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS)
|
||||
|
||||
KUZU_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin")
|
||||
|
||||
GREMLIN_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
|
||||
The information part contains the provided information that you must use to construct an answer.
|
||||
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
||||
Here is an example:
|
||||
|
||||
Question: Which managers own Neo4j stocks?
|
||||
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
||||
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
||||
|
||||
Follow this example when generating answers.
|
||||
If the provided information is empty, say that you don't know the answer.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
CYPHER_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type.
|
||||
You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types.
|
||||
Consider only the following query types:
|
||||
* SELECT: this query type corresponds to questions
|
||||
* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type.
|
||||
Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'.
|
||||
|
||||
The prompt is:
|
||||
{prompt}
|
||||
Helpful Answer:"""
|
||||
SPARQL_INTENT_PROMPT = PromptTemplate(
|
||||
input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database.
|
||||
For instance, to find all email addresses of John Doe, the following query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
SELECT ?email
|
||||
WHERE {{
|
||||
?person foaf:name "John Doe" .
|
||||
?person foaf:mbox ?email .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
|
||||
The question is:
|
||||
{prompt}"""
|
||||
SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database.
|
||||
For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable:
|
||||
```
|
||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
||||
INSERT {{
|
||||
?person foaf:mbox <mailto:jane.doe@foo.bar> .
|
||||
}}
|
||||
WHERE {{
|
||||
?person foaf:name "Jane Doe" .
|
||||
}}
|
||||
```
|
||||
Instructions:
|
||||
Make the query as short as possible and avoid adding unnecessary triples.
|
||||
Use only the node types and properties provided in the schema.
|
||||
Do not use any node types and properties that are not explicitly provided.
|
||||
Include all necessary prefixes.
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Be as concise as possible.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
|
||||
Return only the generated SPARQL query, nothing else.
|
||||
|
||||
The information to be inserted is:
|
||||
{prompt}"""
|
||||
SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE
|
||||
)
|
||||
|
||||
SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
|
||||
You are an assistant that creates well-written and human understandable answers.
|
||||
The information part contains the information provided, which you can use to construct an answer.
|
||||
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make your response sound like the information is coming from an AI assistant, but don't add any information.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {prompt}
|
||||
Helpful Answer:"""
|
||||
SPARQL_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
|
||||
)
|
||||
|
||||
GRAPHDB_SPARQL_GENERATION_TEMPLATE = """
|
||||
Write a SPARQL SELECT query for querying a graph database.
|
||||
The ontology schema delimited by triple backticks in Turtle format is:
|
||||
```
|
||||
{schema}
|
||||
```
|
||||
Use only the classes and properties provided in the schema to construct the SPARQL query.
|
||||
Do not use any classes or properties that are not explicitly provided in the SPARQL query.
|
||||
Include all necessary prefixes.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not wrap the query in backticks.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
The question delimited by triple backticks is:
|
||||
```
|
||||
{prompt}
|
||||
```
|
||||
"""
|
||||
GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "prompt"],
|
||||
template=GRAPHDB_SPARQL_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE = """
|
||||
This following SPARQL query delimited by triple backticks
|
||||
```
|
||||
{generated_sparql}
|
||||
```
|
||||
is not valid.
|
||||
The error delimited by triple backticks is
|
||||
```
|
||||
{error_message}
|
||||
```
|
||||
Give me a correct version of the SPARQL query.
|
||||
Do not change the logic of the query.
|
||||
Do not include any explanations or apologies in your responses.
|
||||
Do not wrap the query in backticks.
|
||||
Do not include any text except the SPARQL query generated.
|
||||
The ontology schema delimited by triple backticks in Turtle format is:
|
||||
```
|
||||
{schema}
|
||||
```
|
||||
"""
|
||||
|
||||
GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
)
|
||||
|
||||
GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
|
||||
You are an assistant that creates well-written and human understandable answers.
|
||||
The information part contains the information provided, which you can use to construct an answer.
|
||||
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make your response sound like the information is coming from an AI assistant, but don't add any information.
|
||||
Don't use internal knowledge to answer the question, just say you don't know if no information is available.
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {prompt}
|
||||
Helpful Answer:"""
|
||||
GRAPHDB_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE
|
||||
)
|
||||
|
||||
AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query.
|
||||
|
||||
You are given an `ArangoDB Schema`. It is a JSON Object containing:
|
||||
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
|
||||
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
|
||||
|
||||
You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used.
|
||||
|
||||
Things you should do:
|
||||
- Think step by step.
|
||||
- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query.
|
||||
- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required.
|
||||
- Return the `AQL Query` wrapped in 3 backticks (```).
|
||||
- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries.
|
||||
- Only answer to requests related to generating an AQL Query.
|
||||
- If a request is unrelated to generating AQL Query, say that you cannot help the user.
|
||||
|
||||
Things you should not do:
|
||||
- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`.
|
||||
- Do not include any text except the generated AQL Query.
|
||||
- Do not provide explanations or apologies in your responses.
|
||||
- Do not generate an AQL Query that removes or deletes any data.
|
||||
|
||||
Under no circumstance should you generate an AQL Query that deletes any data whatsoever.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
AQL Query Examples (Optional):
|
||||
{aql_examples}
|
||||
|
||||
User Input:
|
||||
{user_input}
|
||||
|
||||
AQL Query:
|
||||
"""
|
||||
|
||||
AQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["adb_schema", "aql_examples", "user_input"],
|
||||
template=AQL_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`.
|
||||
|
||||
The `AQL Error` explains why the `AQL Query` could not be executed in the database.
|
||||
The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`.
|
||||
For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`.
|
||||
|
||||
You are also given the `ArangoDB Schema`. It is a JSON Object containing:
|
||||
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
|
||||
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
|
||||
|
||||
You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query.
|
||||
|
||||
Remember to think step by step.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
AQL Query:
|
||||
{aql_query}
|
||||
|
||||
AQL Error:
|
||||
{aql_error}
|
||||
|
||||
Corrected AQL Query:
|
||||
"""
|
||||
|
||||
AQL_FIX_PROMPT = PromptTemplate(
|
||||
input_variables=[
|
||||
"adb_schema",
|
||||
"aql_query",
|
||||
"aql_error",
|
||||
],
|
||||
template=AQL_FIX_TEMPLATE,
|
||||
)
|
||||
|
||||
AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query.
|
||||
|
||||
You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`.
|
||||
|
||||
A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format.
|
||||
You are responsible for creating an `Summary` based on the AQL Result.
|
||||
|
||||
You are given the following information:
|
||||
- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database.
|
||||
- `User Input`: the original question/request of the user, which has been translated into an AQL Query.
|
||||
- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query.
|
||||
- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database.
|
||||
|
||||
Remember to think step by step.
|
||||
|
||||
Your `Summary` should sound like it is a response to the `User Input`.
|
||||
Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`.
|
||||
|
||||
ArangoDB Schema:
|
||||
{adb_schema}
|
||||
|
||||
User Input:
|
||||
{user_input}
|
||||
|
||||
AQL Query:
|
||||
{aql_query}
|
||||
|
||||
AQL Result:
|
||||
{aql_result}
|
||||
"""
|
||||
AQL_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["adb_schema", "user_input", "aql_query", "aql_result"],
|
||||
template=AQL_QA_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
Generate the query in openCypher format and follow these rules:
|
||||
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
|
||||
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
|
||||
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
|
||||
\n"""
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
|
||||
)
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question", "extra_instructions"],
|
||||
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
|
||||
)
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
|
||||
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
|
||||
Question: "{question}".
|
||||
Here is the property graph schema:
|
||||
{schema}
|
||||
\n"""
|
||||
|
||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question", "extra_instructions"],
|
||||
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
|
||||
)
|
||||
|
||||
MEMGRAPH_GENERATION_TEMPLATE = """Your task is to directly translate natural language inquiry into precise and executable Cypher query for Memgraph database.
|
||||
You will utilize a provided database schema to understand the structure, nodes and relationships within the Memgraph database.
|
||||
Instructions:
|
||||
- Use provided node and relationship labels and property names from the
|
||||
schema which describes the database's structure. Upon receiving a user
|
||||
question, synthesize the schema to craft a precise Cypher query that
|
||||
directly corresponds to the user's intent.
|
||||
- Generate valid executable Cypher queries on top of Memgraph database.
|
||||
Any explanation, context, or additional information that is not a part
|
||||
of the Cypher query syntax should be omitted entirely.
|
||||
- Use Memgraph MAGE procedures instead of Neo4j APOC procedures.
|
||||
- Do not include any explanations or apologies in your responses.
|
||||
- Do not include any text except the generated Cypher statement.
|
||||
- For queries that ask for information or functionalities outside the direct
|
||||
generation of Cypher queries, use the Cypher query format to communicate
|
||||
limitations or capabilities. For example: RETURN "I am designed to generate
|
||||
Cypher queries based on the provided schema only."
|
||||
Schema:
|
||||
{schema}
|
||||
|
||||
With all the above information and instructions, generate Cypher query for the
|
||||
user question.
|
||||
|
||||
The question is:
|
||||
{question}"""
|
||||
|
||||
MEMGRAPH_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=MEMGRAPH_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
MEMGRAPH_QA_TEMPLATE = """Your task is to form nice and human
|
||||
understandable answers. The information part contains the provided
|
||||
information that you must use to construct an answer.
|
||||
The provided information is authoritative, you must never doubt it or try to
|
||||
use your internal knowledge to correct it. Make the answer sound as a
|
||||
response to the question. Do not mention that you based the result on the
|
||||
given information. Here is an example:
|
||||
|
||||
Question: Which managers own Neo4j stocks?
|
||||
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
|
||||
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
|
||||
|
||||
Follow this example when generating answers. If the provided information is
|
||||
empty, say that you don't know the answer.
|
||||
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
MEMGRAPH_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=MEMGRAPH_QA_TEMPLATE
|
||||
)
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Question answering over an RDF or OWL graph using SPARQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.chains.graph_qa.prompts import (
|
||||
SPARQL_GENERATION_SELECT_PROMPT,
|
||||
SPARQL_GENERATION_UPDATE_PROMPT,
|
||||
SPARQL_INTENT_PROMPT,
|
||||
SPARQL_QA_PROMPT,
|
||||
)
|
||||
from langchain_community.graphs.rdf_graph import RdfGraph
|
||||
|
||||
|
||||
class GraphSparqlQAChain(Chain):
|
||||
"""Question-answering against an RDF or OWL graph by generating SPARQL statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: RdfGraph = Field(exclude=True)
|
||||
sparql_generation_select_chain: LLMChain
|
||||
sparql_generation_update_chain: LLMChain
|
||||
sparql_intent_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
return_sparql_query: bool = False
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
sparql_query_key: str = "sparql_query" #: :meta private:
|
||||
|
||||
allow_dangerous_requests: bool = False
|
||||
"""Forced user opt-in to acknowledge that the chain can make dangerous requests.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the chain."""
|
||||
super().__init__(**kwargs)
|
||||
if self.allow_dangerous_requests is not True:
|
||||
raise ValueError(
|
||||
"In order to use this chain, you must acknowledge that it can make "
|
||||
"dangerous requests by setting `allow_dangerous_requests` to `True`."
|
||||
"You must narrowly scope the permissions of the database connection "
|
||||
"to only include necessary permissions. Failure to do so may result "
|
||||
"in data corruption or loss or reading sensitive data if such data is "
|
||||
"present in the database."
|
||||
"Only use this chain if you understand the risks and have taken the "
|
||||
"necessary precautions. "
|
||||
"See https://python.langchain.com/docs/security for more information."
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
||||
sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT,
|
||||
sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT,
|
||||
sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GraphSparqlQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt)
|
||||
sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt)
|
||||
sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt)
|
||||
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
sparql_generation_select_chain=sparql_generation_select_chain,
|
||||
sparql_generation_update_chain=sparql_generation_update_chain,
|
||||
sparql_intent_chain=sparql_intent_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
||||
the question.
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
prompt = inputs[self.input_key]
|
||||
|
||||
_intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks)
|
||||
intent = _intent.strip()
|
||||
|
||||
if "SELECT" in intent and "UPDATE" not in intent:
|
||||
sparql_generation_chain = self.sparql_generation_select_chain
|
||||
intent = "SELECT"
|
||||
elif "UPDATE" in intent and "SELECT" not in intent:
|
||||
sparql_generation_chain = self.sparql_generation_update_chain
|
||||
intent = "UPDATE"
|
||||
else:
|
||||
raise ValueError(
|
||||
"I am sorry, but this prompt seems to fit none of the currently "
|
||||
"supported SPARQL query types, i.e., SELECT and UPDATE."
|
||||
)
|
||||
|
||||
_run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)
|
||||
|
||||
generated_sparql = sparql_generation_chain.run(
|
||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
if intent == "SELECT":
|
||||
context = self.graph.query(generated_sparql)
|
||||
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
result = self.qa_chain(
|
||||
{"prompt": prompt, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
res = result[self.qa_chain.output_key]
|
||||
elif intent == "UPDATE":
|
||||
self.graph.update(generated_sparql)
|
||||
res = "Successfully inserted triples into the graph."
|
||||
else:
|
||||
raise ValueError("Unsupported SPARQL query type.")
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: res}
|
||||
if self.return_sparql_query:
|
||||
chain_result[self.sparql_query_key] = generated_sparql
|
||||
return chain_result
|
||||
@@ -0,0 +1,98 @@
|
||||
"""Chain that hits a URL and then uses an LLM to parse results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains import LLMChain
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_community.utilities.requests import TextRequestsWrapper
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class LLMRequestsChain(Chain):
|
||||
"""Chain that requests a URL and then uses an LLM to parse results.
|
||||
|
||||
**Security Note**: This chain can make GET requests to arbitrary URLs,
|
||||
including internal URLs.
|
||||
|
||||
Control access to who can run this chain and what network access
|
||||
this chain has.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
requests_wrapper: TextRequestsWrapper = Field(
|
||||
default_factory=lambda: TextRequestsWrapper(headers=DEFAULT_HEADERS),
|
||||
exclude=True,
|
||||
)
|
||||
text_length: int = 8000
|
||||
requests_key: str = "requests_result" #: :meta private:
|
||||
input_key: str = "url" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa: F401
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import bs4 python package. "
|
||||
"Please install it with `pip install bs4`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
url = inputs[self.input_key]
|
||||
res = self.requests_wrapper.get(url)
|
||||
# extract the text from the html
|
||||
soup = BeautifulSoup(res, "html.parser") # type: ignore[arg-type]
|
||||
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
|
||||
result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
return {self.output_key: result}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_requests_chain"
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Implement a GPT-3 driven browser.
|
||||
|
||||
Heavily influenced from https://github.com/nat/natbot
|
||||
"""
|
||||
|
||||
from langchain_community.chains.natbot.base import NatBotChain
|
||||
|
||||
__all__ = ["NatBotChain"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
from langchain_classic.chains import NatBotChain
|
||||
|
||||
__all__ = ["NatBotChain"]
|
||||
@@ -0,0 +1,7 @@
|
||||
from langchain_classic.chains.natbot.crawler import (
|
||||
Crawler,
|
||||
ElementInViewPort,
|
||||
black_listed_elements,
|
||||
)
|
||||
|
||||
__all__ = ["ElementInViewPort", "Crawler", "black_listed_elements"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain_classic.chains.natbot.prompt import PROMPT
|
||||
|
||||
__all__ = ["PROMPT"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,230 @@
|
||||
"""Chain that makes API calls and summarizes the responses to answer a question."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, cast
|
||||
|
||||
from langchain_classic.chains.api.openapi.requests_chain import APIRequesterChain
|
||||
from langchain_classic.chains.api.openapi.response_chain import APIResponderChain
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from pydantic import BaseModel, Field
|
||||
from requests import Response
|
||||
|
||||
from langchain_community.tools.openapi.utils.api_models import APIOperation
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
|
||||
class _ParamMapping(NamedTuple):
|
||||
"""Mapping from parameter name to parameter value."""
|
||||
|
||||
query_params: List[str]
|
||||
body_params: List[str]
|
||||
path_params: List[str]
|
||||
|
||||
|
||||
class OpenAPIEndpointChain(Chain, BaseModel):
|
||||
"""Chain interacts with an OpenAPI endpoint using natural language."""
|
||||
|
||||
api_request_chain: LLMChain
|
||||
api_response_chain: Optional[LLMChain] = None
|
||||
api_operation: APIOperation
|
||||
requests: Requests = Field(exclude=True, default_factory=Requests)
|
||||
param_mapping: _ParamMapping = Field(alias="param_mapping")
|
||||
return_intermediate_steps: bool = False
|
||||
instructions_key: str = "instructions" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
max_text_length: Optional[int] = Field(ge=0) #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.instructions_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
if not self.return_intermediate_steps:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
|
||||
def _construct_path(self, args: Dict[str, str]) -> str:
|
||||
"""Construct the path from the deserialized input."""
|
||||
path = self.api_operation.base_url + self.api_operation.path
|
||||
for param in self.param_mapping.path_params:
|
||||
path = path.replace(f"{{{param}}}", str(args.pop(param, "")))
|
||||
return path
|
||||
|
||||
def _extract_query_params(self, args: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Extract the query params from the deserialized input."""
|
||||
query_params = {}
|
||||
for param in self.param_mapping.query_params:
|
||||
if param in args:
|
||||
query_params[param] = args.pop(param)
|
||||
return query_params
|
||||
|
||||
def _extract_body_params(self, args: Dict[str, str]) -> Optional[Dict[str, str]]:
|
||||
"""Extract the request body params from the deserialized input."""
|
||||
body_params = None
|
||||
if self.param_mapping.body_params:
|
||||
body_params = {}
|
||||
for param in self.param_mapping.body_params:
|
||||
if param in args:
|
||||
body_params[param] = args.pop(param)
|
||||
return body_params
|
||||
|
||||
def deserialize_json_input(self, serialized_args: str) -> dict:
|
||||
"""Use the serialized typescript dictionary.
|
||||
|
||||
Resolve the path, query params dict, and optional requestBody dict.
|
||||
"""
|
||||
args: dict = json.loads(serialized_args)
|
||||
path = self._construct_path(args)
|
||||
body_params = self._extract_body_params(args)
|
||||
query_params = self._extract_query_params(args)
|
||||
return {
|
||||
"url": path,
|
||||
"data": body_params,
|
||||
"params": query_params,
|
||||
}
|
||||
|
||||
def _get_output(self, output: str, intermediate_steps: dict) -> dict:
|
||||
"""Return the output from the API call."""
|
||||
if self.return_intermediate_steps:
|
||||
return {
|
||||
self.output_key: output,
|
||||
"intermediate_steps": intermediate_steps,
|
||||
}
|
||||
else:
|
||||
return {self.output_key: output}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
intermediate_steps = {}
|
||||
instructions = inputs[self.instructions_key]
|
||||
instructions = instructions[: self.max_text_length]
|
||||
_api_arguments = self.api_request_chain.predict_and_parse(
|
||||
instructions=instructions, callbacks=_run_manager.get_child()
|
||||
)
|
||||
api_arguments = cast(str, _api_arguments)
|
||||
intermediate_steps["request_args"] = api_arguments
|
||||
_run_manager.on_text(
|
||||
api_arguments, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
if api_arguments.startswith("ERROR"):
|
||||
return self._get_output(api_arguments, intermediate_steps)
|
||||
elif api_arguments.startswith("MESSAGE:"):
|
||||
return self._get_output(
|
||||
api_arguments[len("MESSAGE:") :], intermediate_steps
|
||||
)
|
||||
try:
|
||||
request_args = self.deserialize_json_input(api_arguments)
|
||||
method = getattr(self.requests, self.api_operation.method.value)
|
||||
api_response: Response = method(**request_args)
|
||||
if api_response.status_code != 200:
|
||||
method_str = str(self.api_operation.method.value)
|
||||
response_text = (
|
||||
f"{api_response.status_code}: {api_response.reason}"
|
||||
+ f"\nFor {method_str.upper()} {request_args['url']}\n"
|
||||
+ f"Called with args: {request_args['params']}"
|
||||
)
|
||||
else:
|
||||
response_text = api_response.text
|
||||
except Exception as e:
|
||||
response_text = f"Error with message {str(e)}"
|
||||
response_text = response_text[: self.max_text_length]
|
||||
intermediate_steps["response_text"] = response_text
|
||||
_run_manager.on_text(
|
||||
response_text, color="blue", end="\n", verbose=self.verbose
|
||||
)
|
||||
if self.api_response_chain is not None:
|
||||
_answer = self.api_response_chain.predict_and_parse(
|
||||
response=response_text,
|
||||
instructions=instructions,
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
answer = cast(str, _answer)
|
||||
_run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
|
||||
return self._get_output(answer, intermediate_steps)
|
||||
else:
|
||||
return self._get_output(response_text, intermediate_steps)
|
||||
|
||||
@classmethod
|
||||
def from_url_and_method(
|
||||
cls,
|
||||
spec_url: str,
|
||||
path: str,
|
||||
method: str,
|
||||
llm: BaseLanguageModel,
|
||||
requests: Optional[Requests] = None,
|
||||
return_intermediate_steps: bool = False,
|
||||
**kwargs: Any,
|
||||
# TODO: Handle async
|
||||
) -> "OpenAPIEndpointChain":
|
||||
"""Create an OpenAPIEndpoint from a spec at the specified url."""
|
||||
operation = APIOperation.from_openapi_url(spec_url, path, method)
|
||||
return cls.from_api_operation(
|
||||
operation,
|
||||
requests=requests,
|
||||
llm=llm,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_api_operation(
|
||||
cls,
|
||||
operation: APIOperation,
|
||||
llm: BaseLanguageModel,
|
||||
requests: Optional[Requests] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
raw_response: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
# TODO: Handle async
|
||||
) -> "OpenAPIEndpointChain":
|
||||
"""Create an OpenAPIEndpointChain from an operation and a spec."""
|
||||
param_mapping = _ParamMapping(
|
||||
query_params=operation.query_params,
|
||||
body_params=operation.body_params,
|
||||
path_params=operation.path_params,
|
||||
)
|
||||
requests_chain = APIRequesterChain.from_llm_and_typescript(
|
||||
llm,
|
||||
typescript_definition=operation.to_typescript(),
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if raw_response:
|
||||
response_chain = None
|
||||
else:
|
||||
response_chain = APIResponderChain.from_llm(
|
||||
llm, verbose=verbose, callbacks=callbacks
|
||||
)
|
||||
_requests = requests or Requests()
|
||||
return cls(
|
||||
api_request_chain=requests_chain,
|
||||
api_response_chain=response_chain,
|
||||
api_operation=operation,
|
||||
requests=_requests,
|
||||
param_mapping=param_mapping,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
# flake8: noqa
|
||||
REQUEST_TEMPLATE = """You are a helpful AI Assistant. Please provide JSON arguments to agentFunc() based on the user's instructions.
|
||||
|
||||
API_SCHEMA: ```typescript
|
||||
{schema}
|
||||
```
|
||||
|
||||
USER_INSTRUCTIONS: "{instructions}"
|
||||
|
||||
Your arguments must be plain json provided in a markdown block:
|
||||
|
||||
ARGS: ```json
|
||||
{{valid json conforming to API_SCHEMA}}
|
||||
```
|
||||
|
||||
Example
|
||||
-----
|
||||
|
||||
ARGS: ```json
|
||||
{{"foo": "bar", "baz": {{"qux": "quux"}}}}
|
||||
```
|
||||
|
||||
The block must be no more than 1 line long, and all arguments must be valid JSON. All string arguments must be wrapped in double quotes.
|
||||
You MUST strictly comply to the types indicated by the provided schema, including all required args.
|
||||
|
||||
If you don't have sufficient information to call the function due to things like requiring specific uuid's, you can reply with the following message:
|
||||
|
||||
Message: ```text
|
||||
Concise response requesting the additional information that would make calling the function successful.
|
||||
```
|
||||
|
||||
Begin
|
||||
-----
|
||||
ARGS:
|
||||
"""
|
||||
RESPONSE_TEMPLATE = """You are a helpful AI assistant trained to answer user queries from API responses.
|
||||
You attempted to call an API, which resulted in:
|
||||
API_RESPONSE: {response}
|
||||
|
||||
USER_COMMENT: "{instructions}"
|
||||
|
||||
|
||||
If the API_RESPONSE can answer the USER_COMMENT respond with the following markdown json block:
|
||||
Response: ```json
|
||||
{{"response": "Human-understandable synthesis of the API_RESPONSE"}}
|
||||
```
|
||||
|
||||
Otherwise respond with the following markdown json block:
|
||||
Response Error: ```json
|
||||
{{"response": "What you did and a concise statement of the resulting error. If it can be easily fixed, provide a suggestion."}}
|
||||
```
|
||||
|
||||
You MUST respond as a markdown json code block. The person you are responding to CANNOT see the API_RESPONSE, so if there is any relevant information there you must include it in your response.
|
||||
|
||||
Begin:
|
||||
---
|
||||
"""
|
||||
@@ -0,0 +1,62 @@
|
||||
"""request parser."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_classic.chains.api.openapi.prompts import REQUEST_TEMPLATE
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class APIRequesterOutputParser(BaseOutputParser):
|
||||
"""Parse the request and error tags."""
|
||||
|
||||
def _load_json_block(self, serialized_block: str) -> str:
|
||||
try:
|
||||
return json.dumps(json.loads(serialized_block, strict=False))
|
||||
except json.JSONDecodeError:
|
||||
return "ERROR serializing request."
|
||||
|
||||
def parse(self, llm_output: str) -> str:
|
||||
"""Parse the request and error tags."""
|
||||
|
||||
json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL)
|
||||
if json_match:
|
||||
return self._load_json_block(json_match.group(1).strip())
|
||||
message_match = re.search(r"```text(.*?)```", llm_output, re.DOTALL)
|
||||
if message_match:
|
||||
return f"MESSAGE: {message_match.group(1).strip()}"
|
||||
return "ERROR making request"
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "api_requester"
|
||||
|
||||
|
||||
class APIRequesterChain(LLMChain):
|
||||
"""Get the request parser."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_typescript(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
typescript_definition: str,
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
"""Get the request parser."""
|
||||
output_parser = APIRequesterOutputParser()
|
||||
prompt = PromptTemplate(
|
||||
template=REQUEST_TEMPLATE,
|
||||
output_parser=output_parser,
|
||||
partial_variables={"schema": typescript_definition},
|
||||
input_variables=["instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Response parser."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_classic.chains.api.openapi.prompts import RESPONSE_TEMPLATE
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class APIResponderOutputParser(BaseOutputParser):
|
||||
"""Parse the response and error tags."""
|
||||
|
||||
def _load_json_block(self, serialized_block: str) -> str:
|
||||
try:
|
||||
response_content = json.loads(serialized_block, strict=False)
|
||||
return response_content.get("response", "ERROR parsing response.")
|
||||
except json.JSONDecodeError:
|
||||
return "ERROR parsing response."
|
||||
except:
|
||||
raise
|
||||
|
||||
def parse(self, llm_output: str) -> str:
|
||||
"""Parse the response and error tags."""
|
||||
json_match = re.search(r"```json(.*?)```", llm_output, re.DOTALL)
|
||||
if json_match:
|
||||
return self._load_json_block(json_match.group(1).strip())
|
||||
else:
|
||||
raise ValueError(f"No response found in output: {llm_output}.")
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "api_responder"
|
||||
|
||||
|
||||
class APIResponderChain(LLMChain):
|
||||
"""Get the response parser."""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any
|
||||
) -> LLMChain:
|
||||
"""Get the response parser."""
|
||||
output_parser = APIResponderOutputParser()
|
||||
prompt = PromptTemplate(
|
||||
template=RESPONSE_TEMPLATE,
|
||||
output_parser=output_parser,
|
||||
input_variables=["response", "instructions"],
|
||||
)
|
||||
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||
against a vector database.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_classic.chains.base import Chain
|
||||
from langchain_classic.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import ConfigDict, Field, validator
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||
SUPPORTED_VECTORSTORES,
|
||||
set_enforcement_filters,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
App,
|
||||
AuthContext,
|
||||
ChainInfo,
|
||||
Framework,
|
||||
Model,
|
||||
SemanticContext,
|
||||
VectorDB,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.utilities import (
|
||||
PLUGIN_VERSION,
|
||||
PebbloRetrievalAPIWrapper,
|
||||
get_runtime,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PebbloRetrievalQA(Chain):
|
||||
"""
|
||||
Retrieval Chain with Identity & Semantic Enforcement for question-answering
|
||||
against a vector database.
|
||||
"""
|
||||
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents or not."""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStore to use for retrieval."""
|
||||
auth_context_key: str = "auth_context" #: :meta private:
|
||||
"""Authentication context for identity enforcement."""
|
||||
semantic_context_key: str = "semantic_context" #: :meta private:
|
||||
"""Semantic context for semantic enforcement."""
|
||||
app_name: str #: :meta private:
|
||||
"""App name."""
|
||||
owner: str #: :meta private:
|
||||
"""Owner of app."""
|
||||
description: str #: :meta private:
|
||||
"""Description of app."""
|
||||
api_key: Optional[str] = None #: :meta private:
|
||||
"""Pebblo cloud API key for app."""
|
||||
classifier_url: Optional[str] = None #: :meta private:
|
||||
"""Classifier endpoint."""
|
||||
classifier_location: str = "local" #: :meta private:
|
||||
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
|
||||
_discover_sent: bool = False #: :meta private:
|
||||
"""Flag to check if discover payload has been sent."""
|
||||
enable_prompt_gov: bool = True #: :meta private:
|
||||
"""Flag to check if prompt governance is enabled or not"""
|
||||
pb_client: PebbloRetrievalAPIWrapper = Field(
|
||||
default_factory=PebbloRetrievalAPIWrapper
|
||||
)
|
||||
"""Pebblo Retrieval API client"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
prompt_time = datetime.datetime.now().isoformat()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
_, prompt_entities = self.pb_client.check_prompt_validity(question)
|
||||
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
if accepts_run_manager:
|
||||
docs = self._get_docs(
|
||||
question, auth_context, semantic_context, run_manager=_run_manager
|
||||
)
|
||||
else:
|
||||
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
self.pb_client.send_prompt(
|
||||
self.app_name,
|
||||
self.retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
self.enable_prompt_gov,
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
prompt_time = datetime.datetime.now().isoformat()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
|
||||
_, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
|
||||
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(
|
||||
question, auth_context, semantic_context, run_manager=_run_manager
|
||||
)
|
||||
else:
|
||||
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
await self.pb_client.asend_prompt(
|
||||
self.app_name,
|
||||
self.retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
self.enable_prompt_gov,
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key, self.auth_context_key, self.semantic_context_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
if self.return_source_documents:
|
||||
_output_keys += ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "pebblo_retrieval_qa"
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
app_name: str,
|
||||
description: str,
|
||||
owner: str,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
classifier_url: Optional[str] = None,
|
||||
classifier_location: str = "local",
|
||||
**kwargs: Any,
|
||||
) -> "PebbloRetrievalQA":
|
||||
"""Load chain from chain type."""
|
||||
from langchain_classic.chains.question_answering import load_qa_chain
|
||||
|
||||
_chain_type_kwargs = chain_type_kwargs or {}
|
||||
combine_documents_chain = load_qa_chain(
|
||||
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||
)
|
||||
|
||||
# generate app
|
||||
app: App = PebbloRetrievalQA._get_app_details(
|
||||
app_name=app_name,
|
||||
description=description,
|
||||
owner=owner,
|
||||
llm=llm,
|
||||
**kwargs,
|
||||
)
|
||||
# initialize Pebblo API client
|
||||
pb_client = PebbloRetrievalAPIWrapper(
|
||||
api_key=api_key,
|
||||
classifier_location=classifier_location,
|
||||
classifier_url=classifier_url,
|
||||
)
|
||||
# send app discovery request
|
||||
pb_client.send_app_discover(app)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
app_name=app_name,
|
||||
owner=owner,
|
||||
description=description,
|
||||
api_key=api_key,
|
||||
classifier_url=classifier_url,
|
||||
classifier_location=classifier_location,
|
||||
pb_client=pb_client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@validator("retriever", pre=True, always=True)
|
||||
def validate_vectorstore(
|
||||
cls, retriever: VectorStoreRetriever
|
||||
) -> VectorStoreRetriever:
|
||||
"""
|
||||
Validate that the vectorstore of the retriever is supported vectorstores.
|
||||
"""
|
||||
if retriever.vectorstore.__class__.__name__ not in SUPPORTED_VECTORSTORES:
|
||||
raise ValueError(
|
||||
f"Vectorstore must be an instance of one of the supported "
|
||||
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
|
||||
f"Got '{retriever.vectorstore.__class__.__name__}' instead."
|
||||
)
|
||||
return retriever
|
||||
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||
return self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
"""Get docs."""
|
||||
set_enforcement_filters(self.retriever, auth_context, semantic_context)
|
||||
return await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_app_details(
|
||||
app_name: str,
|
||||
owner: str,
|
||||
description: str,
|
||||
llm: BaseLanguageModel,
|
||||
**kwargs: Any,
|
||||
) -> App:
|
||||
"""Fetch app details. Internal method.
|
||||
Returns:
|
||||
App: App details.
|
||||
"""
|
||||
framework, runtime = get_runtime()
|
||||
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
|
||||
app = App(
|
||||
name=app_name,
|
||||
owner=owner,
|
||||
description=description,
|
||||
runtime=runtime,
|
||||
framework=framework,
|
||||
chains=chains,
|
||||
plugin_version=PLUGIN_VERSION,
|
||||
client_version=Framework(
|
||||
name="langchain_community",
|
||||
version=version("langchain_community"),
|
||||
),
|
||||
)
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def set_discover_sent(cls) -> None:
|
||||
cls._discover_sent = True
|
||||
|
||||
@classmethod
|
||||
def get_chain_details(
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> List[ChainInfo]:
|
||||
"""
|
||||
Get chain details.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): Language model instance.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[ChainInfo]: Chain details.
|
||||
"""
|
||||
llm_dict = llm.__dict__
|
||||
chains = [
|
||||
ChainInfo(
|
||||
name=cls.__name__,
|
||||
model=Model(
|
||||
name=llm_dict.get("model_name", llm_dict.get("model")),
|
||||
vendor=llm.__class__.__name__,
|
||||
),
|
||||
vector_dbs=[
|
||||
VectorDB(
|
||||
name=kwargs["retriever"].vectorstore.__class__.__name__,
|
||||
embedding_model=str(
|
||||
kwargs["retriever"].vectorstore._embeddings.model
|
||||
)
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
|
||||
else (
|
||||
str(kwargs["retriever"].vectorstore._embedding.model)
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
|
||||
else None
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
return chains
|
||||
@@ -0,0 +1,532 @@
|
||||
"""
|
||||
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
|
||||
|
||||
This module contains methods for applying Identity and Semantic Enforcement filters
|
||||
in the PebbloRetrievalQA chain.
|
||||
These filters are used to control the retrieval of documents based on authorization and
|
||||
semantic context.
|
||||
The Identity Enforcement filter ensures that only authorized identities can access
|
||||
certain documents, while the Semantic Enforcement filter controls document retrieval
|
||||
based on semantic context.
|
||||
|
||||
The methods in this module are designed to work with different types of vector stores.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
AuthContext,
|
||||
SemanticContext,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PINECONE = "Pinecone"
|
||||
QDRANT = "Qdrant"
|
||||
PGVECTOR = "PGVector"
|
||||
PINECONE_VECTOR_STORE = "PineconeVectorStore"
|
||||
|
||||
SUPPORTED_VECTORSTORES = {PINECONE, QDRANT, PGVECTOR, PINECONE_VECTOR_STORE}
|
||||
|
||||
|
||||
def clear_enforcement_filters(retriever: VectorStoreRetriever) -> None:
|
||||
"""
|
||||
Clear the identity and semantic enforcement filters in the retriever search_kwargs.
|
||||
"""
|
||||
if retriever.vectorstore.__class__.__name__ == PGVECTOR:
|
||||
search_kwargs = retriever.search_kwargs
|
||||
if "filter" in search_kwargs:
|
||||
filters = search_kwargs["filter"]
|
||||
_pgvector_clear_pebblo_filters(
|
||||
search_kwargs, filters, "authorized_identities"
|
||||
)
|
||||
_pgvector_clear_pebblo_filters(
|
||||
search_kwargs, filters, "pebblo_semantic_topics"
|
||||
)
|
||||
_pgvector_clear_pebblo_filters(
|
||||
search_kwargs, filters, "pebblo_semantic_entities"
|
||||
)
|
||||
|
||||
|
||||
def set_enforcement_filters(
|
||||
retriever: VectorStoreRetriever,
|
||||
auth_context: Optional[AuthContext],
|
||||
semantic_context: Optional[SemanticContext],
|
||||
) -> None:
|
||||
"""
|
||||
Set identity and semantic enforcement filters in the retriever.
|
||||
"""
|
||||
# Clear existing enforcement filters
|
||||
clear_enforcement_filters(retriever)
|
||||
if auth_context is not None:
|
||||
_set_identity_enforcement_filter(retriever, auth_context)
|
||||
if semantic_context is not None:
|
||||
_set_semantic_enforcement_filter(retriever, semantic_context)
|
||||
|
||||
|
||||
def _apply_qdrant_semantic_filter(
|
||||
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||
"""
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import `qdrant-client.http` python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
) from e
|
||||
|
||||
# Create a semantic enforcement filter condition
|
||||
semantic_filters: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
|
||||
if (
|
||||
semantic_context is not None
|
||||
and semantic_context.pebblo_semantic_topics is not None
|
||||
):
|
||||
semantic_topics_filter = rest.FieldCondition(
|
||||
key="metadata.pebblo_semantic_topics",
|
||||
match=rest.MatchAny(any=semantic_context.pebblo_semantic_topics.deny),
|
||||
)
|
||||
semantic_filters.append(semantic_topics_filter)
|
||||
if (
|
||||
semantic_context is not None
|
||||
and semantic_context.pebblo_semantic_entities is not None
|
||||
):
|
||||
semantic_entities_filter = rest.FieldCondition(
|
||||
key="metadata.pebblo_semantic_entities",
|
||||
match=rest.MatchAny(any=semantic_context.pebblo_semantic_entities.deny),
|
||||
)
|
||||
semantic_filters.append(semantic_entities_filter)
|
||||
|
||||
# If 'filter' already exists in search_kwargs
|
||||
if "filter" in search_kwargs:
|
||||
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||
|
||||
# Check if existing_filter is a qdrant-client filter
|
||||
if isinstance(existing_filter, rest.Filter):
|
||||
# If 'must_not' condition exists in the existing filter
|
||||
if isinstance(existing_filter.must_not, list):
|
||||
# Warn if 'pebblo_semantic_topics' or 'pebblo_semantic_entities'
|
||||
# filter is overridden
|
||||
new_must_not_conditions: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
# Drop semantic filter conditions if already present
|
||||
for condition in existing_filter.must_not:
|
||||
if hasattr(condition, "key"):
|
||||
if condition.key == "metadata.pebblo_semantic_topics":
|
||||
continue
|
||||
if condition.key == "metadata.pebblo_semantic_entities":
|
||||
continue
|
||||
new_must_not_conditions.append(condition)
|
||||
# Add semantic enforcement filters to 'must_not' conditions
|
||||
existing_filter.must_not = new_must_not_conditions
|
||||
existing_filter.must_not.extend(semantic_filters)
|
||||
else:
|
||||
# Set 'must_not' condition with semantic enforcement filters
|
||||
existing_filter.must_not = semantic_filters
|
||||
else:
|
||||
raise TypeError(
|
||||
"Using dict as a `filter` is deprecated. "
|
||||
"Please use qdrant-client filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||
)
|
||||
else:
|
||||
# If 'filter' does not exist in search_kwargs, create it
|
||||
search_kwargs["filter"] = rest.Filter(must_not=semantic_filters)
|
||||
|
||||
|
||||
def _apply_qdrant_authorization_filter(
|
||||
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs for Qdrant vectorstore.
|
||||
"""
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import `qdrant-client.http` python package. "
|
||||
"Please install it with `pip install qdrant-client`."
|
||||
) from e
|
||||
|
||||
if auth_context is not None:
|
||||
# Create a identity enforcement filter condition
|
||||
identity_enforcement_filter = rest.FieldCondition(
|
||||
key="metadata.authorized_identities",
|
||||
match=rest.MatchAny(any=auth_context.user_auth),
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
# If 'filter' already exists in search_kwargs
|
||||
if "filter" in search_kwargs:
|
||||
existing_filter: rest.Filter = search_kwargs["filter"]
|
||||
|
||||
# Check if existing_filter is a qdrant-client filter
|
||||
if isinstance(existing_filter, rest.Filter):
|
||||
# If 'must' exists in the existing filter
|
||||
if existing_filter.must:
|
||||
new_must_conditions: List[
|
||||
Union[
|
||||
rest.FieldCondition,
|
||||
rest.IsEmptyCondition,
|
||||
rest.IsNullCondition,
|
||||
rest.HasIdCondition,
|
||||
rest.NestedCondition,
|
||||
rest.Filter,
|
||||
]
|
||||
] = []
|
||||
# Drop 'authorized_identities' filter condition if already present
|
||||
for condition in existing_filter.must:
|
||||
if (
|
||||
hasattr(condition, "key")
|
||||
and condition.key == "metadata.authorized_identities"
|
||||
):
|
||||
continue
|
||||
new_must_conditions.append(condition)
|
||||
|
||||
# Add identity enforcement filter to 'must' conditions
|
||||
existing_filter.must = new_must_conditions
|
||||
existing_filter.must.append(identity_enforcement_filter)
|
||||
else:
|
||||
# Set 'must' condition with identity enforcement filter
|
||||
existing_filter.must = [identity_enforcement_filter]
|
||||
else:
|
||||
raise TypeError(
|
||||
"Using dict as a `filter` is deprecated. "
|
||||
"Please use qdrant-client filters directly: "
|
||||
"https://qdrant.tech/documentation/concepts/filtering/"
|
||||
)
|
||||
else:
|
||||
# If 'filter' does not exist in search_kwargs, create it
|
||||
search_kwargs["filter"] = rest.Filter(must=[identity_enforcement_filter])
|
||||
|
||||
|
||||
def _apply_pinecone_semantic_filter(
|
||||
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||
"""
|
||||
# Check if semantic_context is provided
|
||||
semantic_context = semantic_context
|
||||
if semantic_context is not None:
|
||||
if semantic_context.pebblo_semantic_topics is not None:
|
||||
# Add pebblo_semantic_topics filter to search_kwargs
|
||||
search_kwargs.setdefault("filter", {})["pebblo_semantic_topics"] = {
|
||||
"$nin": semantic_context.pebblo_semantic_topics.deny
|
||||
}
|
||||
|
||||
if semantic_context.pebblo_semantic_entities is not None:
|
||||
# Add pebblo_semantic_entities filter to search_kwargs
|
||||
search_kwargs.setdefault("filter", {})["pebblo_semantic_entities"] = {
|
||||
"$nin": semantic_context.pebblo_semantic_entities.deny
|
||||
}
|
||||
|
||||
|
||||
def _apply_pinecone_authorization_filter(
|
||||
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs for Pinecone vectorstore.
|
||||
"""
|
||||
if auth_context is not None:
|
||||
search_kwargs.setdefault("filter", {})["authorized_identities"] = {
|
||||
"$in": auth_context.user_auth
|
||||
}
|
||||
|
||||
|
||||
def _apply_pgvector_filter(
|
||||
search_kwargs: dict, filters: Optional[Any], pebblo_filter: dict
|
||||
) -> None:
|
||||
"""
|
||||
Apply pebblo filters in the search_kwargs filters.
|
||||
"""
|
||||
if isinstance(filters, dict):
|
||||
if len(filters) == 1:
|
||||
# The only operators allowed at the top level are $and, $or, and $not
|
||||
# First check if an operator or a field
|
||||
key, value = list(filters.items())[0]
|
||||
if key.startswith("$"):
|
||||
# Then it's an operator
|
||||
if key.lower() not in ["$and", "$or", "$not"]:
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected $and, $or or $not "
|
||||
f"but got: {key}"
|
||||
)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"Expected a list, but got {type(value)} for value: {value}"
|
||||
)
|
||||
|
||||
# Here we handle the $and, $or, and $not operators(Semantic filters)
|
||||
if key.lower() == "$and":
|
||||
# Add pebblo_filter to the $and list as it is
|
||||
value.append(pebblo_filter)
|
||||
elif key.lower() == "$not":
|
||||
# Check if pebblo_filter is an operator or a field
|
||||
_key, _value = list(pebblo_filter.items())[0]
|
||||
if _key.startswith("$"):
|
||||
# Then it's a operator
|
||||
if _key.lower() == "$not":
|
||||
# It's Semantic filter, add it's value to filters
|
||||
value.append(_value)
|
||||
logger.warning(
|
||||
"Adding $not operator to the existing $not operator"
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Only $not operator is supported in pebblo_filter
|
||||
raise ValueError(
|
||||
f"Invalid filter key. Expected '$not' but got: {_key}"
|
||||
)
|
||||
else:
|
||||
# Then it's a field(Auth filter), move filters into $and
|
||||
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
|
||||
return
|
||||
elif key.lower() == "$or":
|
||||
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
|
||||
else:
|
||||
# Then it's a field and we can check pebblo_filter now
|
||||
# Check if pebblo_filter is an operator or a field
|
||||
_key, _ = list(pebblo_filter.items())[0]
|
||||
if _key.startswith("$"):
|
||||
# Then it's a operator
|
||||
if _key.lower() == "$not":
|
||||
# It's a $not operator(Semantic filter), move filters into $and
|
||||
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
|
||||
return
|
||||
else:
|
||||
# Only $not operator is allowed in pebblo_filter
|
||||
raise ValueError(
|
||||
f"Invalid filter key. Expected '$not' but got: {_key}"
|
||||
)
|
||||
else:
|
||||
# Then it's a field(This handles Auth filter)
|
||||
filters.update(pebblo_filter)
|
||||
return
|
||||
elif len(filters) > 1:
|
||||
# Then all keys have to be fields (they cannot be operators)
|
||||
for key in filters.keys():
|
||||
if key.startswith("$"):
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected a field but got: {key}"
|
||||
)
|
||||
# filters should all be fields and we can check pebblo_filter now
|
||||
# Check if pebblo_filter is an operator or a field
|
||||
_key, _ = list(pebblo_filter.items())[0]
|
||||
if _key.startswith("$"):
|
||||
# Then it's a operator
|
||||
if _key.lower() == "$not":
|
||||
# It's a $not operator(Semantic filter), move filters into '$and'
|
||||
search_kwargs["filter"] = {"$and": [filters, pebblo_filter]}
|
||||
return
|
||||
else:
|
||||
# Only $not operator is supported in pebblo_filter
|
||||
raise ValueError(
|
||||
f"Invalid filter key. Expected '$not' but got: {_key}"
|
||||
)
|
||||
else:
|
||||
# Then it's a field(This handles Auth filter)
|
||||
filters.update(pebblo_filter)
|
||||
return
|
||||
else:
|
||||
# Got an empty dictionary for filters, set pebblo_filter in filter
|
||||
search_kwargs.setdefault("filter", {}).update(pebblo_filter)
|
||||
elif filters is None:
|
||||
# If filters is None, set pebblo_filter as a new filter
|
||||
search_kwargs.setdefault("filter", {}).update(pebblo_filter)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}"
|
||||
)
|
||||
|
||||
|
||||
def _pgvector_clear_pebblo_filters(
|
||||
search_kwargs: dict, filters: dict, pebblo_filter_key: str
|
||||
) -> None:
|
||||
"""
|
||||
Remove pebblo filters from the search_kwargs filters.
|
||||
"""
|
||||
if isinstance(filters, dict):
|
||||
if len(filters) == 1:
|
||||
# The only operators allowed at the top level are $and, $or, and $not
|
||||
# First check if an operator or a field
|
||||
key, value = list(filters.items())[0]
|
||||
if key.startswith("$"):
|
||||
# Then it's an operator
|
||||
# Validate the operator's key and value type
|
||||
if key.lower() not in ["$and", "$or", "$not"]:
|
||||
raise ValueError(
|
||||
f"Invalid filter condition. Expected $and, $or or $not "
|
||||
f"but got: {key}"
|
||||
)
|
||||
elif not isinstance(value, list):
|
||||
raise ValueError(
|
||||
f"Expected a list, but got {type(value)} for value: {value}"
|
||||
)
|
||||
|
||||
# Here we handle the $and, $or, and $not operators
|
||||
if key.lower() == "$and":
|
||||
# Remove the pebblo filter from the $and list
|
||||
for i, _filter in enumerate(value):
|
||||
if pebblo_filter_key in _filter:
|
||||
# This handles Auth filter
|
||||
value.pop(i)
|
||||
break
|
||||
# Check for $not operator with Semantic filter
|
||||
if "$not" in _filter:
|
||||
sem_filter_found = False
|
||||
# This handles Semantic filter
|
||||
for j, nested_filter in enumerate(_filter["$not"]):
|
||||
if pebblo_filter_key in nested_filter:
|
||||
if len(_filter["$not"]) == 1:
|
||||
# If only one filter is left,
|
||||
# then remove the $not operator
|
||||
value.pop(i)
|
||||
else:
|
||||
value[i]["$not"].pop(j)
|
||||
sem_filter_found = True
|
||||
break
|
||||
if sem_filter_found:
|
||||
break
|
||||
if len(value) == 1:
|
||||
# If only one filter is left, then remove the $and operator
|
||||
search_kwargs["filter"] = value[0]
|
||||
elif key.lower() == "$not":
|
||||
# Remove the pebblo filter from the $not list
|
||||
for i, _filter in enumerate(value):
|
||||
if pebblo_filter_key in _filter:
|
||||
# This removes Semantic filter
|
||||
value.pop(i)
|
||||
break
|
||||
if len(value) == 0:
|
||||
# If no filter is left, then unset the filter
|
||||
search_kwargs["filter"] = {}
|
||||
elif key.lower() == "$or":
|
||||
# If $or, pebblo filter will not be present
|
||||
return
|
||||
else:
|
||||
# Then it's a field, check if it's a pebblo filter
|
||||
if key == pebblo_filter_key:
|
||||
filters.pop(key)
|
||||
return
|
||||
elif len(filters) > 1:
|
||||
# Then all keys have to be fields (they cannot be operators)
|
||||
if pebblo_filter_key in filters:
|
||||
# This handles Auth filter
|
||||
filters.pop(pebblo_filter_key)
|
||||
return
|
||||
else:
|
||||
# Got an empty dictionary for filters, ignore the filter
|
||||
return
|
||||
elif filters is None:
|
||||
# If filters is None, ignore the filter
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid filter. Expected a dictionary/None but got type: {type(filters)}"
|
||||
)
|
||||
|
||||
|
||||
def _apply_pgvector_semantic_filter(
|
||||
search_kwargs: dict, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs for PGVector vectorstore.
|
||||
"""
|
||||
# Check if semantic_context is provided
|
||||
if semantic_context is not None:
|
||||
_semantic_filters = []
|
||||
filters = search_kwargs.get("filter")
|
||||
if semantic_context.pebblo_semantic_topics is not None:
|
||||
# Add pebblo_semantic_topics filter to search_kwargs
|
||||
topic_filter: dict = {
|
||||
"pebblo_semantic_topics": {
|
||||
"$eq": semantic_context.pebblo_semantic_topics.deny
|
||||
}
|
||||
}
|
||||
_semantic_filters.append(topic_filter)
|
||||
|
||||
if semantic_context.pebblo_semantic_entities is not None:
|
||||
# Add pebblo_semantic_entities filter to search_kwargs
|
||||
entity_filter: dict = {
|
||||
"pebblo_semantic_entities": {
|
||||
"$eq": semantic_context.pebblo_semantic_entities.deny
|
||||
}
|
||||
}
|
||||
_semantic_filters.append(entity_filter)
|
||||
|
||||
if len(_semantic_filters) > 0:
|
||||
semantic_filter: dict = {"$not": _semantic_filters}
|
||||
_apply_pgvector_filter(search_kwargs, filters, semantic_filter)
|
||||
|
||||
|
||||
def _apply_pgvector_authorization_filter(
|
||||
search_kwargs: dict, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs for PGVector vectorstore.
|
||||
"""
|
||||
if auth_context is not None:
|
||||
auth_filter: dict = {"authorized_identities": {"$eq": auth_context.user_auth}}
|
||||
filters = search_kwargs.get("filter")
|
||||
_apply_pgvector_filter(search_kwargs, filters, auth_filter)
|
||||
|
||||
|
||||
def _set_identity_enforcement_filter(
|
||||
retriever: VectorStoreRetriever, auth_context: Optional[AuthContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set identity enforcement filter in search_kwargs.
|
||||
|
||||
This method sets the identity enforcement filter in the search_kwargs
|
||||
of the retriever based on the type of the vectorstore.
|
||||
"""
|
||||
search_kwargs = retriever.search_kwargs
|
||||
if retriever.vectorstore.__class__.__name__ in [PINECONE, PINECONE_VECTOR_STORE]:
|
||||
_apply_pinecone_authorization_filter(search_kwargs, auth_context)
|
||||
elif retriever.vectorstore.__class__.__name__ == QDRANT:
|
||||
_apply_qdrant_authorization_filter(search_kwargs, auth_context)
|
||||
elif retriever.vectorstore.__class__.__name__ == PGVECTOR:
|
||||
_apply_pgvector_authorization_filter(search_kwargs, auth_context)
|
||||
|
||||
|
||||
def _set_semantic_enforcement_filter(
|
||||
retriever: VectorStoreRetriever, semantic_context: Optional[SemanticContext]
|
||||
) -> None:
|
||||
"""
|
||||
Set semantic enforcement filter in search_kwargs.
|
||||
|
||||
This method sets the semantic enforcement filter in the search_kwargs
|
||||
of the retriever based on the type of the vectorstore.
|
||||
"""
|
||||
search_kwargs = retriever.search_kwargs
|
||||
if retriever.vectorstore.__class__.__name__ == PINECONE:
|
||||
_apply_pinecone_semantic_filter(search_kwargs, semantic_context)
|
||||
elif retriever.vectorstore.__class__.__name__ == QDRANT:
|
||||
_apply_qdrant_semantic_filter(search_kwargs, semantic_context)
|
||||
elif retriever.vectorstore.__class__.__name__ == PGVECTOR:
|
||||
_apply_pgvector_semantic_filter(search_kwargs, semantic_context)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""Models for the PebbloRetrievalQA chain."""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
"""Class for an authorization context."""
|
||||
|
||||
name: Optional[str] = None
|
||||
user_id: str
|
||||
user_auth: List[str]
|
||||
"""List of user authorizations, which may include their User ID and
|
||||
the groups they are part of"""
|
||||
|
||||
|
||||
class SemanticEntities(BaseModel):
|
||||
"""Class for a semantic entity filter."""
|
||||
|
||||
deny: List[str]
|
||||
|
||||
|
||||
class SemanticTopics(BaseModel):
|
||||
"""Class for a semantic topic filter."""
|
||||
|
||||
deny: List[str]
|
||||
|
||||
|
||||
class SemanticContext(BaseModel):
|
||||
"""Class for a semantic context."""
|
||||
|
||||
pebblo_semantic_entities: Optional[SemanticEntities] = None
|
||||
pebblo_semantic_topics: Optional[SemanticTopics] = None
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
# Validate semantic_context
|
||||
if (
|
||||
self.pebblo_semantic_entities is None
|
||||
and self.pebblo_semantic_topics is None
|
||||
):
|
||||
raise ValueError(
|
||||
"semantic_context must contain 'pebblo_semantic_entities' or "
|
||||
"'pebblo_semantic_topics'"
|
||||
)
|
||||
|
||||
|
||||
class ChainInput(BaseModel):
|
||||
"""Input for PebbloRetrievalQA chain."""
|
||||
|
||||
query: str
|
||||
auth_context: Optional[AuthContext] = None
|
||||
semantic_context: Optional[SemanticContext] = None
|
||||
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
base_dict = super().dict(**kwargs)
|
||||
# Keep auth_context and semantic_context as it is(Pydantic models)
|
||||
base_dict["auth_context"] = self.auth_context
|
||||
base_dict["semantic_context"] = self.semantic_context
|
||||
return base_dict
|
||||
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
OS, language details
|
||||
"""
|
||||
|
||||
type: Optional[str] = ""
|
||||
host: str
|
||||
path: str
|
||||
ip: Optional[str] = ""
|
||||
platform: str
|
||||
os: str
|
||||
os_version: str
|
||||
language: str
|
||||
language_version: str
|
||||
runtime: Optional[str] = ""
|
||||
|
||||
|
||||
class Framework(BaseModel):
|
||||
"""
|
||||
Langchain framework details
|
||||
"""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
vendor: Optional[str]
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
class PkgInfo(BaseModel):
|
||||
project_home_page: Optional[str]
|
||||
documentation_url: Optional[str]
|
||||
pypi_url: Optional[str]
|
||||
liscence_type: Optional[str]
|
||||
installed_via: Optional[str]
|
||||
location: Optional[str]
|
||||
|
||||
|
||||
class VectorDB(BaseModel):
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
embedding_model: Optional[str] = None
|
||||
|
||||
|
||||
class ChainInfo(BaseModel):
|
||||
name: str
|
||||
model: Optional[Model]
|
||||
vector_dbs: Optional[List[VectorDB]]
|
||||
|
||||
|
||||
class App(BaseModel):
|
||||
name: str
|
||||
owner: str
|
||||
description: Optional[str]
|
||||
runtime: Runtime
|
||||
framework: Framework
|
||||
chains: List[ChainInfo]
|
||||
plugin_version: str
|
||||
client_version: Framework
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
retrieved_from: Optional[str]
|
||||
doc: Optional[str]
|
||||
vector_db: str
|
||||
pb_checksum: Optional[str]
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
data: Optional[Union[list, str]]
|
||||
entityCount: Optional[int] = None
|
||||
entities: Optional[dict] = None
|
||||
prompt_gov_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class Qa(BaseModel):
|
||||
name: str
|
||||
context: Union[List[Optional[Context]], Optional[Context]]
|
||||
prompt: Optional[Prompt]
|
||||
response: Optional[Prompt]
|
||||
prompt_time: str
|
||||
user: str
|
||||
user_identities: Optional[List[str]]
|
||||
classifier_location: str
|
||||
@@ -0,0 +1,542 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import BaseModel
|
||||
from requests import Response, request
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
App,
|
||||
AuthContext,
|
||||
Context,
|
||||
Framework,
|
||||
Prompt,
|
||||
Qa,
|
||||
Runtime,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_VERSION = "0.1.1"
|
||||
|
||||
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
|
||||
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
|
||||
|
||||
|
||||
class Routes(str, Enum):
|
||||
"""Routes available for the Pebblo API as enumerator."""
|
||||
|
||||
retrieval_app_discover = "/v1/app/discover"
|
||||
prompt = "/v1/prompt"
|
||||
prompt_governance = "/v1/prompt/governance"
|
||||
|
||||
|
||||
def get_runtime() -> Tuple[Framework, Runtime]:
|
||||
"""Fetch the current Framework and Runtime details.
|
||||
|
||||
Returns:
|
||||
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
|
||||
"""
|
||||
runtime_env = get_runtime_environment()
|
||||
framework = Framework(
|
||||
name="langchain", version=runtime_env.get("library_version", "unknown")
|
||||
)
|
||||
uname = platform.uname()
|
||||
runtime = Runtime(
|
||||
host=uname.node,
|
||||
path=os.environ["PWD"],
|
||||
platform=runtime_env.get("platform", "unknown"),
|
||||
os=uname.system,
|
||||
os_version=uname.version,
|
||||
ip=get_ip(),
|
||||
language=runtime_env.get("runtime", "unknown"),
|
||||
language_version=runtime_env.get("runtime_version", "unknown"),
|
||||
)
|
||||
|
||||
if "Darwin" in runtime.os:
|
||||
runtime.type = "desktop"
|
||||
runtime.runtime = "Mac OSX"
|
||||
|
||||
logger.debug(f"framework {framework}")
|
||||
logger.debug(f"runtime {runtime}")
|
||||
return framework, runtime
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
"""Fetch local runtime ip address.
|
||||
|
||||
Returns:
|
||||
str: IP address
|
||||
"""
|
||||
import socket # lazy imports
|
||||
|
||||
host = socket.gethostname()
|
||||
try:
|
||||
public_ip = socket.gethostbyname(host)
|
||||
except Exception:
|
||||
public_ip = socket.gethostbyname("localhost")
|
||||
return public_ip
|
||||
|
||||
|
||||
class PebbloRetrievalAPIWrapper(BaseModel):
|
||||
"""Wrapper for Pebblo Retrieval API."""
|
||||
|
||||
api_key: Optional[str] # Use SecretStr
|
||||
"""API key for Pebblo Cloud"""
|
||||
classifier_location: str = "local"
|
||||
"""Location of the classifier, local or cloud. Defaults to 'local'"""
|
||||
classifier_url: Optional[str]
|
||||
"""URL of the Pebblo Classifier"""
|
||||
cloud_url: Optional[str]
|
||||
"""URL of the Pebblo Cloud"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Validate that api key in environment."""
|
||||
kwargs["api_key"] = get_from_dict_or_env(
|
||||
kwargs, "api_key", "PEBBLO_API_KEY", ""
|
||||
)
|
||||
kwargs["classifier_url"] = get_from_dict_or_env(
|
||||
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
|
||||
)
|
||||
kwargs["cloud_url"] = get_from_dict_or_env(
|
||||
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def send_app_discover(self, app: App) -> None:
|
||||
"""
|
||||
Send app discovery request to Pebblo server & cloud.
|
||||
|
||||
Args:
|
||||
app (App): App instance to be discovered.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = app.dict(exclude_unset=True)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send app details to local classifier
|
||||
headers = self._make_headers()
|
||||
app_discover_url = (
|
||||
f"{self.classifier_url}{Routes.retrieval_app_discover.value}"
|
||||
)
|
||||
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send app details to Pebblo cloud if api_key is present
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
if pebblo_resp:
|
||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||
"pebblo_server_version"
|
||||
)
|
||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||
|
||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||
pebblo_cloud_url = f"{self.cloud_url}{Routes.retrieval_app_discover.value}"
|
||||
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||
|
||||
def send_prompt(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Send prompt to Pebblo server for classification.
|
||||
Then send prompt to Daxa cloud(If api_key is present).
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = self.build_prompt_qa_payload(
|
||||
app_name,
|
||||
retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
prompt_gov_enabled,
|
||||
)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send prompt to local classifier
|
||||
headers = self._make_headers()
|
||||
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
|
||||
pebblo_resp = self.make_request("POST", prompt_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send prompt to Pebblo cloud if api_key is present
|
||||
if self.classifier_location == "local":
|
||||
# If classifier location is local, then response, context and prompt
|
||||
# should be fetched from pebblo_resp and replaced in payload.
|
||||
pebblo_resp = pebblo_resp.json() if pebblo_resp else None
|
||||
self.update_cloud_payload(payload, pebblo_resp)
|
||||
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
|
||||
_ = self.make_request("POST", pebblo_cloud_prompt_url, headers, payload)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
async def asend_prompt(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Send prompt to Pebblo server for classification.
|
||||
Then send prompt to Daxa cloud(If api_key is present).
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = self.build_prompt_qa_payload(
|
||||
app_name,
|
||||
retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
prompt_gov_enabled,
|
||||
)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send prompt to local classifier
|
||||
headers = self._make_headers()
|
||||
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
|
||||
pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send prompt to Pebblo cloud if api_key is present
|
||||
if self.classifier_location == "local":
|
||||
# If classifier location is local, then response, context and prompt
|
||||
# should be fetched from pebblo_resp and replaced in payload.
|
||||
self.update_cloud_payload(payload, pebblo_resp)
|
||||
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
|
||||
_ = await self.amake_request(
|
||||
"POST", pebblo_cloud_prompt_url, headers, payload
|
||||
)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
"""
|
||||
prompt_payload = {"prompt": question}
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
is_valid_prompt: bool = True
|
||||
if self.classifier_location == "local":
|
||||
headers = self._make_headers()
|
||||
prompt_gov_api_url = (
|
||||
f"{self.classifier_url}{Routes.prompt_governance.value}"
|
||||
)
|
||||
pebblo_resp = self.make_request(
|
||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||
)
|
||||
if pebblo_resp:
|
||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||
"entityCount", 0
|
||||
)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
async def acheck_prompt_validity(
|
||||
self, question: str
|
||||
) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
"""
|
||||
prompt_payload = {"prompt": question}
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
is_valid_prompt: bool = True
|
||||
if self.classifier_location == "local":
|
||||
headers = self._make_headers()
|
||||
prompt_gov_api_url = (
|
||||
f"{self.classifier_url}{Routes.prompt_governance.value}"
|
||||
)
|
||||
pebblo_resp = await self.amake_request(
|
||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||
)
|
||||
if pebblo_resp:
|
||||
prompt_entities["entities"] = pebblo_resp.get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||
"""
|
||||
Generate headers for the request.
|
||||
|
||||
args:
|
||||
cloud_request (bool): flag indicating whether the request is for Pebblo
|
||||
cloud.
|
||||
returns:
|
||||
dict: Headers for the request.
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if cloud_request:
|
||||
# Add API key for Pebblo cloud request
|
||||
if self.api_key:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
else:
|
||||
logger.warning("API key is missing for Pebblo cloud request.")
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def make_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
payload: Optional[dict] = None,
|
||||
timeout: int = 20,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Make a request to the Pebblo server/cloud API.
|
||||
|
||||
Args:
|
||||
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
url (str): URL for the request.
|
||||
headers (dict): Headers for the request.
|
||||
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||
timeout (int): Timeout for the request in seconds.
|
||||
|
||||
Returns:
|
||||
Optional[Response]: Response object if the request is successful.
|
||||
"""
|
||||
try:
|
||||
response = request(
|
||||
method=method, url=url, headers=headers, json=payload, timeout=timeout
|
||||
)
|
||||
logger.debug(
|
||||
"Request: method %s, url %s, len %s response status %s",
|
||||
method,
|
||||
response.request.url,
|
||||
str(len(response.request.body if response.request.body else [])),
|
||||
str(response.status_code),
|
||||
)
|
||||
|
||||
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
logger.warning(f"Pebblo Server: Error {response.status_code}")
|
||||
elif response.status_code >= HTTPStatus.BAD_REQUEST:
|
||||
logger.warning(f"Pebblo received an invalid payload: {response.text}")
|
||||
elif response.status_code != HTTPStatus.OK:
|
||||
logger.warning(
|
||||
f"Pebblo returned an unexpected response code: "
|
||||
f"{response.status_code}"
|
||||
)
|
||||
|
||||
return response
|
||||
except RequestException:
|
||||
logger.warning("Unable to reach server %s", url)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in make_request: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None:
|
||||
"""
|
||||
Update the payload with response, prompt and context from Pebblo response.
|
||||
|
||||
Args:
|
||||
payload (dict): Payload to be updated.
|
||||
pebblo_resp (Optional[dict]): Response from Pebblo server.
|
||||
"""
|
||||
if pebblo_resp:
|
||||
# Update response, prompt and context from pebblo response
|
||||
response = payload.get("response", {})
|
||||
response.update(pebblo_resp.get("retrieval_data", {}).get("response", {}))
|
||||
response.pop("data", None)
|
||||
prompt = payload.get("prompt", {})
|
||||
prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {}))
|
||||
prompt.pop("data", None)
|
||||
context = payload.get("context", [])
|
||||
for context_data in context:
|
||||
context_data.pop("doc", None)
|
||||
else:
|
||||
payload["response"] = {}
|
||||
payload["prompt"] = {}
|
||||
payload["context"] = []
|
||||
|
||||
@staticmethod
|
||||
async def amake_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
payload: Optional[dict] = None,
|
||||
timeout: int = 20,
|
||||
) -> Any:
|
||||
"""
|
||||
Make a async request to the Pebblo server/cloud API.
|
||||
|
||||
Args:
|
||||
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
url (str): URL for the request.
|
||||
headers (dict): Headers for the request.
|
||||
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||
timeout (int): Timeout for the request in seconds.
|
||||
|
||||
Returns:
|
||||
Any: Response json if the request is successful.
|
||||
"""
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession() as asession:
|
||||
async with asession.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=client_timeout,
|
||||
) as response:
|
||||
if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
logger.warning(f"Pebblo Server: Error {response.status}")
|
||||
elif response.status >= HTTPStatus.BAD_REQUEST:
|
||||
logger.warning(
|
||||
f"Pebblo received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != HTTPStatus.OK:
|
||||
logger.warning(
|
||||
f"Pebblo returned an unexpected response code: "
|
||||
f"{response.status}"
|
||||
)
|
||||
response_json = await response.json()
|
||||
return response_json
|
||||
except RequestException:
|
||||
logger.warning("Unable to reach server %s", url)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in amake_request: %s", e)
|
||||
return None
|
||||
|
||||
def build_prompt_qa_payload(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Build the QA payload for the prompt.
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
|
||||
Returns:
|
||||
dict: The QA payload for the prompt.
|
||||
"""
|
||||
qa = Qa(
|
||||
name=app_name,
|
||||
context=[
|
||||
Context(
|
||||
retrieved_from=doc.metadata.get(
|
||||
"full_path", doc.metadata.get("source")
|
||||
),
|
||||
doc=doc.page_content,
|
||||
vector_db=retriever.vectorstore.__class__.__name__,
|
||||
pb_checksum=doc.metadata.get("pb_checksum"),
|
||||
)
|
||||
for doc in docs
|
||||
if isinstance(doc, Document)
|
||||
],
|
||||
prompt=Prompt(
|
||||
data=question,
|
||||
entities=prompt_entities.get("entities", {}),
|
||||
entityCount=prompt_entities.get("entityCount", 0),
|
||||
prompt_gov_enabled=prompt_gov_enabled,
|
||||
),
|
||||
response=Prompt(data=answer),
|
||||
prompt_time=prompt_time,
|
||||
user=auth_context.user_id if auth_context else "unknown",
|
||||
user_identities=auth_context.user_auth
|
||||
if auth_context and hasattr(auth_context, "user_auth")
|
||||
else [],
|
||||
classifier_location=self.classifier_location,
|
||||
)
|
||||
return qa.dict(exclude_unset=True)
|
||||
Reference in New Issue
Block a user