initial commit
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
"""**OutputParser** classes parse the output of an LLM call."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.output_parsers import (
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
PydanticOutputParser,
|
||||
XMLOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
from langchain_classic.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain_classic.output_parsers.combining import CombiningOutputParser
|
||||
from langchain_classic.output_parsers.datetime import DatetimeOutputParser
|
||||
from langchain_classic.output_parsers.enum import EnumOutputParser
|
||||
from langchain_classic.output_parsers.fix import OutputFixingParser
|
||||
from langchain_classic.output_parsers.pandas_dataframe import (
|
||||
PandasDataFrameOutputParser,
|
||||
)
|
||||
from langchain_classic.output_parsers.regex import RegexParser
|
||||
from langchain_classic.output_parsers.regex_dict import RegexDictParser
|
||||
from langchain_classic.output_parsers.retry import (
|
||||
RetryOutputParser,
|
||||
RetryWithErrorOutputParser,
|
||||
)
|
||||
from langchain_classic.output_parsers.structured import (
|
||||
ResponseSchema,
|
||||
StructuredOutputParser,
|
||||
)
|
||||
from langchain_classic.output_parsers.yaml import YamlOutputParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BooleanOutputParser",
|
||||
"CombiningOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"DatetimeOutputParser",
|
||||
"EnumOutputParser",
|
||||
"GuardrailsOutputParser",
|
||||
"JsonOutputKeyToolsParser",
|
||||
"JsonOutputToolsParser",
|
||||
"ListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
"OutputFixingParser",
|
||||
"PandasDataFrameOutputParser",
|
||||
"PydanticOutputParser",
|
||||
"PydanticToolsParser",
|
||||
"RegexDictParser",
|
||||
"RegexParser",
|
||||
"ResponseSchema",
|
||||
"RetryOutputParser",
|
||||
"RetryWithErrorOutputParser",
|
||||
"StructuredOutputParser",
|
||||
"XMLOutputParser",
|
||||
"YamlOutputParser",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,54 @@
|
||||
import re
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
"""Parse the output of an LLM call to a boolean."""
|
||||
|
||||
true_val: str = "YES"
|
||||
"""The string value that should be parsed as True."""
|
||||
false_val: str = "NO"
|
||||
"""The string value that should be parsed as False."""
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
"""Parse the output of an LLM call to a boolean.
|
||||
|
||||
Args:
|
||||
text: output of a language model
|
||||
|
||||
Returns:
|
||||
boolean
|
||||
"""
|
||||
regexp = rf"\b({self.true_val}|{self.false_val})\b"
|
||||
|
||||
truthy = {
|
||||
val.upper()
|
||||
for val in re.findall(regexp, text, flags=re.IGNORECASE | re.MULTILINE)
|
||||
}
|
||||
if self.true_val.upper() in truthy:
|
||||
if self.false_val.upper() in truthy:
|
||||
msg = (
|
||||
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
||||
f"in received: {text}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return True
|
||||
if self.false_val.upper() in truthy:
|
||||
if self.true_val.upper() in truthy:
|
||||
msg = (
|
||||
f"Ambiguous response. Both {self.true_val} and {self.false_val} "
|
||||
f"in received: {text}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return False
|
||||
msg = (
|
||||
f"BooleanOutputParser expected output value to include either "
|
||||
f"{self.true_val} or {self.false_val}. Received {text}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Snake-case string identifier for an output parser type."""
|
||||
return "boolean_output_parser"
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.utils import pre_init
|
||||
from typing_extensions import override
|
||||
|
||||
_MIN_PARSERS = 2
|
||||
|
||||
|
||||
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
"""Combine multiple output parsers into one."""
|
||||
|
||||
parsers: list[BaseOutputParser]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@pre_init
|
||||
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate the parsers."""
|
||||
parsers = values["parsers"]
|
||||
if len(parsers) < _MIN_PARSERS:
|
||||
msg = "Must have at least two parsers"
|
||||
raise ValueError(msg)
|
||||
for parser in parsers:
|
||||
if parser._type == "combining": # noqa: SLF001
|
||||
msg = "Cannot nest combining parsers"
|
||||
raise ValueError(msg)
|
||||
if parser._type == "list": # noqa: SLF001
|
||||
msg = "Cannot combine list parsers"
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
return "combining"
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
initial = f"For your first output: {self.parsers[0].get_format_instructions()}"
|
||||
subsequent = "\n".join(
|
||||
f"Complete that output fully. Then produce another output, separated by two newline characters: {p.get_format_instructions()}" # noqa: E501
|
||||
for p in self.parsers[1:]
|
||||
)
|
||||
return f"{initial}\n{subsequent}"
|
||||
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
"""Parse the output of an LLM call."""
|
||||
texts = text.split("\n\n")
|
||||
output = {}
|
||||
for txt, parser in zip(texts, self.parsers, strict=False):
|
||||
output.update(parser.parse(txt.strip()))
|
||||
return output
|
||||
@@ -0,0 +1,58 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.utils import comma_list
|
||||
|
||||
|
||||
class DatetimeOutputParser(BaseOutputParser[datetime]):
|
||||
"""Parse the output of an LLM call to a datetime."""
|
||||
|
||||
format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
"""The string value that is used as the datetime format.
|
||||
|
||||
Update this to match the desired datetime format for your application.
|
||||
"""
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Returns the format instructions for the given format."""
|
||||
if self.format == "%Y-%m-%dT%H:%M:%S.%fZ":
|
||||
examples = comma_list(
|
||||
[
|
||||
"2023-07-04T14:30:00.000000Z",
|
||||
"1999-12-31T23:59:59.999999Z",
|
||||
"2025-01-01T00:00:00.000000Z",
|
||||
],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
examples = comma_list(
|
||||
[
|
||||
now.strftime(self.format),
|
||||
(now.replace(year=now.year - 1)).strftime(self.format),
|
||||
(now - timedelta(days=1)).strftime(self.format),
|
||||
],
|
||||
)
|
||||
except ValueError:
|
||||
# Fallback if the format is very unusual
|
||||
examples = f"e.g., a valid string in the format {self.format}"
|
||||
|
||||
return (
|
||||
f"Write a datetime string that matches the "
|
||||
f"following pattern: '{self.format}'.\n\n"
|
||||
f"Examples: {examples}\n\n"
|
||||
f"Return ONLY this string, no other words!"
|
||||
)
|
||||
|
||||
def parse(self, response: str) -> datetime:
|
||||
"""Parse a string into a datetime object."""
|
||||
try:
|
||||
return datetime.strptime(response.strip(), self.format) # noqa: DTZ007
|
||||
except ValueError as e:
|
||||
msg = f"Could not parse datetime string: {response}"
|
||||
raise OutputParserException(msg) from e
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "datetime"
|
||||
@@ -0,0 +1,45 @@
|
||||
from enum import Enum
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.utils import pre_init
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class EnumOutputParser(BaseOutputParser[Enum]):
|
||||
"""Parse an output that is one of a set of values."""
|
||||
|
||||
enum: type[Enum]
|
||||
"""The enum to parse. Its values must be strings."""
|
||||
|
||||
@pre_init
|
||||
def _raise_deprecation(cls, values: dict) -> dict:
|
||||
enum = values["enum"]
|
||||
if not all(isinstance(e.value, str) for e in enum):
|
||||
msg = "Enum values must be strings"
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _valid_values(self) -> list[str]:
|
||||
return [e.value for e in self.enum]
|
||||
|
||||
@override
|
||||
def parse(self, response: str) -> Enum:
|
||||
try:
|
||||
return self.enum(response.strip())
|
||||
except ValueError as e:
|
||||
msg = (
|
||||
f"Response '{response}' is not one of the "
|
||||
f"expected values: {self._valid_values}"
|
||||
)
|
||||
raise OutputParserException(msg) from e
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
return f"Select one of the following options: {', '.join(self._valid_values)}"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[Enum]:
|
||||
return self.enum
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.output_parsers.ernie_functions import (
|
||||
JsonKeyOutputFunctionsParser,
|
||||
JsonOutputFunctionsParser,
|
||||
OutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"JsonKeyOutputFunctionsParser": (
|
||||
"langchain_community.output_parsers.ernie_functions"
|
||||
),
|
||||
"JsonOutputFunctionsParser": "langchain_community.output_parsers.ernie_functions",
|
||||
"OutputFunctionsParser": "langchain_community.output_parsers.ernie_functions",
|
||||
"PydanticAttrOutputFunctionsParser": (
|
||||
"langchain_community.output_parsers.ernie_functions"
|
||||
),
|
||||
"PydanticOutputFunctionsParser": (
|
||||
"langchain_community.output_parsers.ernie_functions"
|
||||
),
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"JsonKeyOutputFunctionsParser",
|
||||
"JsonOutputFunctionsParser",
|
||||
"OutputFunctionsParser",
|
||||
"PydanticAttrOutputFunctionsParser",
|
||||
"PydanticOutputFunctionsParser",
|
||||
]
|
||||
156
venv/Lib/site-packages/langchain_classic/output_parsers/fix.py
Normal file
156
venv/Lib/site-packages/langchain_classic/output_parsers/fix.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, TypeVar
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableSerializable
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import TypedDict, override
|
||||
|
||||
from langchain_classic.output_parsers.prompts import NAIVE_FIX_PROMPT
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputFixingParserRetryChainInput(TypedDict, total=False):
|
||||
"""Input for the retry chain of the OutputFixingParser."""
|
||||
|
||||
instructions: str
|
||||
completion: str
|
||||
error: str
|
||||
|
||||
|
||||
class OutputFixingParser(BaseOutputParser[T]):
|
||||
"""Wrap a parser and try to fix parsing errors."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
parser: Annotated[Any, SkipValidation()]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from
|
||||
# langchain_classic.chains
|
||||
retry_chain: Annotated[
|
||||
RunnableSerializable[OutputFixingParserRetryChainInput, str] | Any,
|
||||
SkipValidation(),
|
||||
]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
legacy: bool = True
|
||||
"""Whether to use the run or arun method of the retry_chain."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: Runnable,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_FIX_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> OutputFixingParser[T]:
|
||||
"""Create an OutputFixingParser from a language model and a parser.
|
||||
|
||||
Args:
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
OutputFixingParser
|
||||
"""
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
@override
|
||||
def parse(self, completion: str) -> T:
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
completion = self.retry_chain.invoke(
|
||||
{
|
||||
"instructions": self.parser.get_format_instructions(),
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = self.retry_chain.invoke(
|
||||
{
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@override
|
||||
async def aparse(self, completion: str) -> T:
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
instructions=self.parser.get_format_instructions(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
{
|
||||
"instructions": self.parser.get_format_instructions(),
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
except (NotImplementedError, AttributeError):
|
||||
# Case: self.parser does not have get_format_instructions
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
{
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "output_fixing"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
return self.parser.OutputType
|
||||
@@ -0,0 +1,79 @@
|
||||
STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":
|
||||
|
||||
```json
|
||||
{{
|
||||
{format}
|
||||
}}
|
||||
```""" # noqa: E501
|
||||
|
||||
STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS = """
|
||||
```json
|
||||
{{
|
||||
{format}
|
||||
}}
|
||||
```"""
|
||||
|
||||
|
||||
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
|
||||
|
||||
Here is the output schema:
|
||||
```
|
||||
{schema}
|
||||
```""" # noqa: E501
|
||||
|
||||
YAML_FORMAT_INSTRUCTIONS = """The output should be formatted as a YAML instance that conforms to the given JSON schema below.
|
||||
|
||||
# Examples
|
||||
## Schema
|
||||
```
|
||||
{{"title": "Players", "description": "A list of players", "type": "array", "items": {{"$ref": "#/definitions/Player"}}, "definitions": {{"Player": {{"title": "Player", "type": "object", "properties": {{"name": {{"title": "Name", "description": "Player name", "type": "string"}}, "avg": {{"title": "Avg", "description": "Batting average", "type": "number"}}}}, "required": ["name", "avg"]}}}}}}
|
||||
```
|
||||
## Well formatted instance
|
||||
```
|
||||
- name: John Doe
|
||||
avg: 0.3
|
||||
- name: Jane Maxfield
|
||||
avg: 1.4
|
||||
```
|
||||
|
||||
## Schema
|
||||
```
|
||||
{{"properties": {{"habit": {{ "description": "A common daily habit", "type": "string" }}, "sustainable_alternative": {{ "description": "An environmentally friendly alternative to the habit", "type": "string"}}}}, "required": ["habit", "sustainable_alternative"]}}
|
||||
```
|
||||
## Well formatted instance
|
||||
```
|
||||
habit: Using disposable water bottles for daily hydration.
|
||||
sustainable_alternative: Switch to a reusable water bottle to reduce plastic waste and decrease your environmental footprint.
|
||||
```
|
||||
|
||||
Please follow the standard YAML formatting conventions with an indent of 2 spaces and make sure that the data types adhere strictly to the following JSON schema:
|
||||
```
|
||||
{schema}
|
||||
```
|
||||
|
||||
Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output!""" # noqa: E501
|
||||
|
||||
|
||||
PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS = """The output should be formatted as a string as the operation, followed by a colon, followed by the column or row to be queried on, followed by optional array parameters.
|
||||
1. The column names are limited to the possible columns below.
|
||||
2. Arrays must either be a comma-separated list of numbers formatted as [1,3,5], or it must be in range of numbers formatted as [0..4].
|
||||
3. Remember that arrays are optional and not necessarily required.
|
||||
4. If the column is not in the possible columns or the operation is not a valid Pandas DataFrame operation, return why it is invalid as a sentence starting with either "Invalid column" or "Invalid operation".
|
||||
|
||||
As an example, for the formats:
|
||||
1. String "column:num_legs" is a well-formatted instance which gets the column num_legs, where num_legs is a possible column.
|
||||
2. String "row:1" is a well-formatted instance which gets row 1.
|
||||
3. String "column:num_legs[1,2]" is a well-formatted instance which gets the column num_legs for rows 1 and 2, where num_legs is a possible column.
|
||||
4. String "row:1[num_legs]" is a well-formatted instance which gets row 1, but for just column num_legs, where num_legs is a possible column.
|
||||
5. String "mean:num_legs[1..3]" is a well-formatted instance which takes the mean of num_legs from rows 1 to 3, where num_legs is a possible column and mean is a valid Pandas DataFrame operation.
|
||||
6. String "do_something:num_legs" is a badly-formatted instance, where do_something is not a valid Pandas DataFrame operation.
|
||||
7. String "mean:invalid_col" is a badly-formatted instance, where invalid_col is not a possible column.
|
||||
|
||||
Here are the possible columns:
|
||||
```
|
||||
{columns}
|
||||
```
|
||||
""" # noqa: E501
|
||||
@@ -0,0 +1,15 @@
|
||||
from langchain_core.output_parsers.json import (
|
||||
SimpleJsonOutputParser,
|
||||
)
|
||||
from langchain_core.utils.json import (
|
||||
parse_and_check_json_markdown,
|
||||
parse_json_markdown,
|
||||
parse_partial_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SimpleJsonOutputParser",
|
||||
"parse_and_check_json_markdown",
|
||||
"parse_json_markdown",
|
||||
"parse_partial_json",
|
||||
]
|
||||
@@ -0,0 +1,13 @@
|
||||
from langchain_core.output_parsers.list import (
|
||||
CommaSeparatedListOutputParser,
|
||||
ListOutputParser,
|
||||
MarkdownListOutputParser,
|
||||
NumberedListOutputParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CommaSeparatedListOutputParser",
|
||||
"ListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
]
|
||||
@@ -0,0 +1,22 @@
|
||||
from langchain_classic.output_parsers.regex import RegexParser
|
||||
|
||||
|
||||
def load_output_parser(config: dict) -> dict:
|
||||
"""Load an output parser.
|
||||
|
||||
Args:
|
||||
config: config dict
|
||||
|
||||
Returns:
|
||||
config dict with output parser loaded
|
||||
"""
|
||||
if "output_parsers" in config and config["output_parsers"] is not None:
|
||||
_config = config["output_parsers"]
|
||||
output_parser_type = _config["_type"]
|
||||
if output_parser_type == "regex_parser":
|
||||
output_parser = RegexParser(**_config)
|
||||
else:
|
||||
msg = f"Unsupported output parser {output_parser_type}"
|
||||
raise ValueError(msg)
|
||||
config["output_parsers"] = output_parser
|
||||
return config
|
||||
@@ -0,0 +1,13 @@
|
||||
from langchain_core.output_parsers.openai_functions import (
|
||||
JsonKeyOutputFunctionsParser,
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JsonKeyOutputFunctionsParser",
|
||||
"JsonOutputFunctionsParser",
|
||||
"PydanticAttrOutputFunctionsParser",
|
||||
"PydanticOutputFunctionsParser",
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
__all__ = ["JsonOutputKeyToolsParser", "JsonOutputToolsParser", "PydanticToolsParser"]
|
||||
@@ -0,0 +1,171 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers.base import BaseOutputParser
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.output_parsers.format_instructions import (
|
||||
PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
|
||||
class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
"""Parse an output using Pandas DataFrame format."""
|
||||
|
||||
"""The Pandas DataFrame to parse."""
|
||||
dataframe: Any
|
||||
|
||||
@field_validator("dataframe")
|
||||
@classmethod
|
||||
def _validate_dataframe(cls, val: Any) -> Any:
|
||||
import pandas as pd
|
||||
|
||||
if issubclass(type(val), pd.DataFrame):
|
||||
return val
|
||||
if pd.DataFrame(val).empty:
|
||||
msg = "DataFrame cannot be empty."
|
||||
raise ValueError(msg)
|
||||
|
||||
msg = "Wrong type for 'dataframe', must be a subclass \
|
||||
of Pandas DataFrame (pd.DataFrame)"
|
||||
raise TypeError(msg)
|
||||
|
||||
def parse_array(
|
||||
self,
|
||||
array: str,
|
||||
original_request_params: str,
|
||||
) -> tuple[list[int | str], str]:
|
||||
"""Parse the array from the request parameters.
|
||||
|
||||
Args:
|
||||
array: The array string to parse.
|
||||
original_request_params: The original request parameters string.
|
||||
|
||||
Returns:
|
||||
A tuple containing the parsed array and the stripped request parameters.
|
||||
|
||||
Raises:
|
||||
OutputParserException: If the array format is invalid or cannot be parsed.
|
||||
"""
|
||||
parsed_array: list[int | str] = []
|
||||
|
||||
# Check if the format is [1,3,5]
|
||||
if re.match(r"\[\d+(,\s*\d+)*\]", array):
|
||||
parsed_array = [int(i) for i in re.findall(r"\d+", array)]
|
||||
# Check if the format is [1..5]
|
||||
elif re.match(r"\[(\d+)\.\.(\d+)\]", array):
|
||||
match = re.match(r"\[(\d+)\.\.(\d+)\]", array)
|
||||
if match:
|
||||
start, end = map(int, match.groups())
|
||||
parsed_array = list(range(start, end + 1))
|
||||
else:
|
||||
msg = f"Unable to parse the array provided in {array}. \
|
||||
Please check the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
# Check if the format is ["column_name"]
|
||||
elif re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array):
|
||||
match = re.match(r"\[[a-zA-Z0-9_]+(?:,[a-zA-Z0-9_]+)*\]", array)
|
||||
if match:
|
||||
parsed_array = list(map(str, match.group().strip("[]").split(",")))
|
||||
else:
|
||||
msg = f"Unable to parse the array provided in {array}. \
|
||||
Please check the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
|
||||
# Validate the array
|
||||
if not parsed_array:
|
||||
msg = f"Invalid array format in '{original_request_params}'. \
|
||||
Please check the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
if (
|
||||
isinstance(parsed_array[0], int)
|
||||
and parsed_array[-1] > self.dataframe.index.max()
|
||||
):
|
||||
msg = f"The maximum index {parsed_array[-1]} exceeds the maximum index of \
|
||||
the Pandas DataFrame {self.dataframe.index.max()}."
|
||||
raise OutputParserException(msg)
|
||||
|
||||
return parsed_array, original_request_params.split("[", maxsplit=1)[0]
|
||||
|
||||
@override
|
||||
def parse(self, request: str) -> dict[str, Any]:
|
||||
stripped_request_params = None
|
||||
splitted_request = request.strip().split(":")
|
||||
if len(splitted_request) != 2: # noqa: PLR2004
|
||||
msg = f"Request '{request}' is not correctly formatted. \
|
||||
Please refer to the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
result = {}
|
||||
try:
|
||||
request_type, request_params = splitted_request
|
||||
if request_type in {"Invalid column", "Invalid operation"}:
|
||||
msg = f"{request}. Please check the format instructions."
|
||||
raise OutputParserException(msg)
|
||||
array_exists = re.search(r"(\[.*?\])", request_params)
|
||||
if array_exists:
|
||||
parsed_array, stripped_request_params = self.parse_array(
|
||||
array_exists.group(1),
|
||||
request_params,
|
||||
)
|
||||
if request_type == "column":
|
||||
filtered_df = self.dataframe[
|
||||
self.dataframe.index.isin(parsed_array)
|
||||
]
|
||||
if len(parsed_array) == 1:
|
||||
result[stripped_request_params] = filtered_df[
|
||||
stripped_request_params
|
||||
].iloc[parsed_array[0]]
|
||||
else:
|
||||
result[stripped_request_params] = filtered_df[
|
||||
stripped_request_params
|
||||
]
|
||||
elif request_type == "row":
|
||||
filtered_df = self.dataframe[
|
||||
self.dataframe.columns.intersection(parsed_array)
|
||||
]
|
||||
if len(parsed_array) == 1:
|
||||
result[stripped_request_params] = filtered_df.iloc[
|
||||
int(stripped_request_params)
|
||||
][parsed_array[0]]
|
||||
else:
|
||||
result[stripped_request_params] = filtered_df.iloc[
|
||||
int(stripped_request_params)
|
||||
]
|
||||
else:
|
||||
filtered_df = self.dataframe[
|
||||
self.dataframe.index.isin(parsed_array)
|
||||
]
|
||||
result[request_type] = getattr(
|
||||
filtered_df[stripped_request_params],
|
||||
request_type,
|
||||
)()
|
||||
elif request_type == "column":
|
||||
result[request_params] = self.dataframe[request_params]
|
||||
elif request_type == "row":
|
||||
result[request_params] = self.dataframe.iloc[int(request_params)]
|
||||
else:
|
||||
result[request_type] = getattr(
|
||||
self.dataframe[request_params],
|
||||
request_type,
|
||||
)()
|
||||
except (AttributeError, IndexError, KeyError) as e:
|
||||
if request_type not in {"column", "row"}:
|
||||
msg = f"Unsupported request type '{request_type}'. \
|
||||
Please check the format instructions."
|
||||
raise OutputParserException(msg) from e
|
||||
msg = f"""Requested index {
|
||||
request_params
|
||||
if stripped_request_params is None
|
||||
else stripped_request_params
|
||||
} is out of bounds."""
|
||||
raise OutputParserException(msg) from e
|
||||
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
return PANDAS_DATAFRAME_FORMAT_INSTRUCTIONS.format(
|
||||
columns=", ".join(self.dataframe.columns),
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
NAIVE_FIX = """Instructions:
|
||||
--------------
|
||||
{instructions}
|
||||
--------------
|
||||
Completion:
|
||||
--------------
|
||||
{completion}
|
||||
--------------
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Instructions.
|
||||
Error:
|
||||
--------------
|
||||
{error}
|
||||
--------------
|
||||
|
||||
Please try again. Please only respond with an answer that satisfies the constraints laid out in the Instructions:""" # noqa: E501
|
||||
|
||||
|
||||
NAIVE_FIX_PROMPT = PromptTemplate.from_template(NAIVE_FIX)
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
__all__ = ["PydanticOutputParser"]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"GuardrailsOutputParser": "langchain_community.output_parsers.rail_parser",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GuardrailsOutputParser",
|
||||
]
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class RegexParser(BaseOutputParser[dict[str, str]]):
|
||||
"""Parse the output of an LLM call using a regex."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
regex: str
|
||||
"""The regex to use to parse the output."""
|
||||
output_keys: list[str]
|
||||
"""The keys to use for the output."""
|
||||
default_output_key: str | None = None
|
||||
"""The default key to use for the output."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
return "regex_parser"
|
||||
|
||||
def parse(self, text: str) -> dict[str, str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
match = re.search(self.regex, text)
|
||||
if match:
|
||||
return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)}
|
||||
if self.default_output_key is None:
|
||||
msg = f"Could not parse output: {text}"
|
||||
raise ValueError(msg)
|
||||
return {
|
||||
key: text if key == self.default_output_key else ""
|
||||
for key in self.output_keys
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
|
||||
|
||||
class RegexDictParser(BaseOutputParser[dict[str, str]]):
|
||||
"""Parse the output of an LLM call into a Dictionary using a regex."""
|
||||
|
||||
regex_pattern: str = r"{}:\s?([^.'\n']*)\.?"
|
||||
"""The regex pattern to use to parse the output."""
|
||||
output_key_to_format: dict[str, str]
|
||||
"""The keys to use for the output."""
|
||||
no_update_value: str | None = None
|
||||
"""The default key to use for the output."""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
return "regex_dict_parser"
|
||||
|
||||
def parse(self, text: str) -> dict[str, str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
result = {}
|
||||
for output_key, expected_format in self.output_key_to_format.items():
|
||||
specific_regex = self.regex_pattern.format(re.escape(expected_format))
|
||||
matches = re.findall(specific_regex, text)
|
||||
if not matches:
|
||||
msg = (
|
||||
f"No match found for output key: {output_key} with expected format \
|
||||
{expected_format} on text {text}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if len(matches) > 1:
|
||||
msg = f"Multiple matches found for output key: {output_key} with \
|
||||
expected format {expected_format} on text {text}"
|
||||
raise ValueError(msg)
|
||||
if self.no_update_value is not None and matches[0] == self.no_update_value:
|
||||
continue
|
||||
result[output_key] = matches[0]
|
||||
return result
|
||||
315
venv/Lib/site-packages/langchain_classic/output_parsers/retry.py
Normal file
315
venv/Lib/site-packages/langchain_classic/output_parsers/retry.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, TypeVar
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import SkipValidation
|
||||
from typing_extensions import TypedDict, override
|
||||
|
||||
NAIVE_COMPLETION_RETRY = """Prompt:
|
||||
{prompt}
|
||||
Completion:
|
||||
{completion}
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Prompt.
|
||||
Please try again:"""
|
||||
|
||||
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
|
||||
{prompt}
|
||||
Completion:
|
||||
{completion}
|
||||
|
||||
Above, the Completion did not satisfy the constraints given in the Prompt.
|
||||
Details: {error}
|
||||
Please try again:"""
|
||||
|
||||
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
|
||||
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
|
||||
NAIVE_COMPLETION_RETRY_WITH_ERROR,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryOutputParserRetryChainInput(TypedDict):
|
||||
"""Retry chain input for RetryOutputParser."""
|
||||
|
||||
prompt: str
|
||||
completion: str
|
||||
|
||||
|
||||
class RetryWithErrorOutputParserRetryChainInput(TypedDict):
|
||||
"""Retry chain input for RetryWithErrorOutputParser."""
|
||||
|
||||
prompt: str
|
||||
completion: str
|
||||
error: str
|
||||
|
||||
|
||||
class RetryOutputParser(BaseOutputParser[T]):
|
||||
"""Wrap a parser and try to fix parsing errors.
|
||||
|
||||
Does this by passing the original prompt and the completion to another
|
||||
LLM, and telling it the completion did not satisfy criteria in the prompt.
|
||||
"""
|
||||
|
||||
parser: Annotated[BaseOutputParser[T], SkipValidation()]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from
|
||||
# langchain_classic.chains
|
||||
retry_chain: Annotated[
|
||||
RunnableSerializable[RetryOutputParserRetryChainInput, str] | Any,
|
||||
SkipValidation(),
|
||||
]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
legacy: bool = True
|
||||
"""Whether to use the run or arun method of the retry_chain."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryOutputParser[T]:
|
||||
"""Create an RetryOutputParser from a language model and a parser.
|
||||
|
||||
Args:
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
RetryOutputParser
|
||||
"""
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
|
||||
Args:
|
||||
completion: The chain completion to parse.
|
||||
prompt_value: The prompt to use to parse the completion.
|
||||
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
)
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
|
||||
Args:
|
||||
completion: The chain completion to parse.
|
||||
prompt_value: The prompt to use to parse the completion.
|
||||
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@override
|
||||
def parse(self, completion: str) -> T:
|
||||
msg = "This OutputParser can only be called by the `parse_with_prompt` method."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "retry"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
return self.parser.OutputType
|
||||
|
||||
|
||||
class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
"""Wrap a parser and try to fix parsing errors.
|
||||
|
||||
Does this by passing the original prompt, the completion, AND the error
|
||||
that was raised to another language model and telling it that the completion
|
||||
did not work, and raised the given error. Differs from RetryOutputParser
|
||||
in that this implementation provides the error that was raised back to the
|
||||
LLM, which in theory should give it more information on how to fix it.
|
||||
"""
|
||||
|
||||
parser: Annotated[BaseOutputParser[T], SkipValidation()]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from
|
||||
# langchain_classic.chains
|
||||
retry_chain: Annotated[
|
||||
RunnableSerializable[RetryWithErrorOutputParserRetryChainInput, str] | Any,
|
||||
SkipValidation(),
|
||||
]
|
||||
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
legacy: bool = True
|
||||
"""Whether to use the run or arun method of the retry_chain."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryWithErrorOutputParser[T]:
|
||||
"""Create a RetryWithErrorOutputParser from an LLM.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use to retry the completion.
|
||||
parser: The parser to use to parse the output.
|
||||
prompt: The prompt to use to retry the completion.
|
||||
max_retries: The maximum number of times to retry the completion.
|
||||
|
||||
Returns:
|
||||
A RetryWithErrorOutputParser.
|
||||
"""
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
@override
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "run"):
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
completion = self.retry_chain.invoke(
|
||||
{
|
||||
"completion": completion,
|
||||
"prompt": prompt_value.to_string(),
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
|
||||
Args:
|
||||
completion: The chain completion to parse.
|
||||
prompt_value: The prompt to use to parse the completion.
|
||||
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
retries = 0
|
||||
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise
|
||||
retries += 1
|
||||
if self.legacy and hasattr(self.retry_chain, "arun"):
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
else:
|
||||
completion = await self.retry_chain.ainvoke(
|
||||
{
|
||||
"prompt": prompt_value.to_string(),
|
||||
"completion": completion,
|
||||
"error": repr(e),
|
||||
},
|
||||
)
|
||||
|
||||
msg = "Failed to parse"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
@override
|
||||
def parse(self, completion: str) -> T:
|
||||
msg = "This OutputParser can only be called by the `parse_with_prompt` method."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
return self.parser.get_format_instructions()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "retry_with_error"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
return self.parser.OutputType
|
||||
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.output_parsers.json import parse_and_check_json_markdown
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.output_parsers.format_instructions import (
|
||||
STRUCTURED_FORMAT_INSTRUCTIONS,
|
||||
STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
line_template = '\t"{name}": {type} // {description}'
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
"""Schema for a response from a structured output parser."""
|
||||
|
||||
name: str
|
||||
"""The name of the schema."""
|
||||
description: str
|
||||
"""The description of the schema."""
|
||||
type: str = "string"
|
||||
"""The type of the response."""
|
||||
|
||||
|
||||
def _get_sub_string(schema: ResponseSchema) -> str:
|
||||
return line_template.format(
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
type=schema.type,
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputParser(BaseOutputParser[dict[str, Any]]):
|
||||
"""Parse the output of an LLM call to a structured output."""
|
||||
|
||||
response_schemas: list[ResponseSchema]
|
||||
"""The schemas for the response."""
|
||||
|
||||
@classmethod
|
||||
def from_response_schemas(
|
||||
cls,
|
||||
response_schemas: list[ResponseSchema],
|
||||
) -> StructuredOutputParser:
|
||||
"""Create a StructuredOutputParser from a list of ResponseSchema.
|
||||
|
||||
Args:
|
||||
response_schemas: The schemas for the response.
|
||||
|
||||
Returns:
|
||||
An instance of StructuredOutputParser.
|
||||
"""
|
||||
return cls(response_schemas=response_schemas)
|
||||
|
||||
def get_format_instructions(
|
||||
self,
|
||||
only_json: bool = False, # noqa: FBT001,FBT002
|
||||
) -> str:
|
||||
"""Get format instructions for the output parser.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain_classic.output_parsers.structured import (
|
||||
StructuredOutputParser, ResponseSchema
|
||||
)
|
||||
|
||||
response_schemas = [
|
||||
ResponseSchema(
|
||||
name="foo",
|
||||
description="a list of strings",
|
||||
type="List[string]"
|
||||
),
|
||||
ResponseSchema(
|
||||
name="bar",
|
||||
description="a string",
|
||||
type="string"
|
||||
),
|
||||
]
|
||||
|
||||
parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||
|
||||
print(parser.get_format_instructions()) # noqa: T201
|
||||
|
||||
output:
|
||||
# The output should be a Markdown code snippet formatted in the following
|
||||
# schema, including the leading and trailing "```json" and "```":
|
||||
#
|
||||
# ```json
|
||||
# {
|
||||
# "foo": List[string] // a list of strings
|
||||
# "bar": string // a string
|
||||
# }
|
||||
# ```
|
||||
|
||||
Args:
|
||||
only_json: If `True`, only the json in the Markdown code snippet
|
||||
will be returned, without the introducing text.
|
||||
"""
|
||||
schema_str = "\n".join(
|
||||
[_get_sub_string(schema) for schema in self.response_schemas],
|
||||
)
|
||||
if only_json:
|
||||
return STRUCTURED_FORMAT_SIMPLE_INSTRUCTIONS.format(format=schema_str)
|
||||
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
expected_keys = [rs.name for rs in self.response_schemas]
|
||||
return parse_and_check_json_markdown(text, expected_keys)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured"
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
|
||||
__all__ = ["XMLOutputParser"]
|
||||
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
import re
|
||||
from typing import TypeVar
|
||||
|
||||
import yaml
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.output_parsers.format_instructions import (
|
||||
YAML_FORMAT_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class YamlOutputParser(BaseOutputParser[T]):
|
||||
"""Parse YAML output using a Pydantic model."""
|
||||
|
||||
pydantic_object: type[T]
|
||||
"""The Pydantic model to parse."""
|
||||
pattern: re.Pattern = re.compile(
|
||||
r"^```(?:ya?ml)?(?P<yaml>[^`]*)",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
"""Regex pattern to match yaml code blocks
|
||||
within triple backticks with optional yaml or yml prefix."""
|
||||
|
||||
@override
|
||||
def parse(self, text: str) -> T:
|
||||
try:
|
||||
# Greedy search for 1st yaml candidate.
|
||||
match = re.search(self.pattern, text.strip())
|
||||
# If no backticks were present, try to parse the entire output as yaml.
|
||||
yaml_str = match.group("yaml") if match else text
|
||||
|
||||
json_object = yaml.safe_load(yaml_str)
|
||||
return self.pydantic_object.model_validate(json_object)
|
||||
|
||||
except (yaml.YAMLError, ValidationError) as e:
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
|
||||
raise OutputParserException(msg, llm_output=text) from e
|
||||
|
||||
@override
|
||||
def get_format_instructions(self) -> str:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = dict(self.pydantic_object.model_json_schema().items())
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
if "title" in reduced_schema:
|
||||
del reduced_schema["title"]
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure yaml in context is well-formed with double quotes.
|
||||
schema_str = json.dumps(reduced_schema)
|
||||
|
||||
return YAML_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "yaml"
|
||||
|
||||
@property
|
||||
@override
|
||||
def OutputType(self) -> type[T]:
|
||||
return self.pydantic_object
|
||||
Reference in New Issue
Block a user