initial commit
This commit is contained in:
126
venv/Lib/site-packages/langchain_classic/memory/__init__.py
Normal file
126
venv/Lib/site-packages/langchain_classic/memory/__init__.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""**Memory** maintains Chain state, incorporating context from past runs."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
from langchain_classic.memory.buffer import (
|
||||
ConversationBufferMemory,
|
||||
ConversationStringBufferMemory,
|
||||
)
|
||||
from langchain_classic.memory.buffer_window import ConversationBufferWindowMemory
|
||||
from langchain_classic.memory.combined import CombinedMemory
|
||||
from langchain_classic.memory.entity import (
|
||||
ConversationEntityMemory,
|
||||
InMemoryEntityStore,
|
||||
RedisEntityStore,
|
||||
SQLiteEntityStore,
|
||||
UpstashRedisEntityStore,
|
||||
)
|
||||
from langchain_classic.memory.readonly import ReadOnlySharedMemory
|
||||
from langchain_classic.memory.simple import SimpleMemory
|
||||
from langchain_classic.memory.summary import ConversationSummaryMemory
|
||||
from langchain_classic.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain_classic.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain_classic.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
from langchain_classic.memory.vectorstore_token_buffer_memory import (
|
||||
ConversationVectorStoreTokenBufferMemory, # avoid circular import
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
AstraDBChatMessageHistory,
|
||||
CassandraChatMessageHistory,
|
||||
ChatMessageHistory,
|
||||
CosmosDBChatMessageHistory,
|
||||
DynamoDBChatMessageHistory,
|
||||
ElasticsearchChatMessageHistory,
|
||||
FileChatMessageHistory,
|
||||
MomentoChatMessageHistory,
|
||||
MongoDBChatMessageHistory,
|
||||
PostgresChatMessageHistory,
|
||||
RedisChatMessageHistory,
|
||||
SingleStoreDBChatMessageHistory,
|
||||
SQLChatMessageHistory,
|
||||
StreamlitChatMessageHistory,
|
||||
UpstashRedisChatMessageHistory,
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
from langchain_community.memory.kg import ConversationKGMemory
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"MotorheadMemory": "langchain_community.memory.motorhead_memory",
|
||||
"ConversationKGMemory": "langchain_community.memory.kg",
|
||||
"ZepMemory": "langchain_community.memory.zep_memory",
|
||||
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"FileChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"XataChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"ChatMessageHistory",
|
||||
"CombinedMemory",
|
||||
"ConversationBufferMemory",
|
||||
"ConversationBufferWindowMemory",
|
||||
"ConversationEntityMemory",
|
||||
"ConversationKGMemory",
|
||||
"ConversationStringBufferMemory",
|
||||
"ConversationSummaryBufferMemory",
|
||||
"ConversationSummaryMemory",
|
||||
"ConversationTokenBufferMemory",
|
||||
"ConversationVectorStoreTokenBufferMemory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"InMemoryEntityStore",
|
||||
"MomentoChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"MotorheadMemory",
|
||||
"PostgresChatMessageHistory",
|
||||
"ReadOnlySharedMemory",
|
||||
"RedisChatMessageHistory",
|
||||
"RedisEntityStore",
|
||||
"SQLChatMessageHistory",
|
||||
"SQLiteEntityStore",
|
||||
"SimpleMemory",
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
"UpstashRedisEntityStore",
|
||||
"VectorStoreRetrieverMemory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
"ZepMemory",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
173
venv/Lib/site-packages/langchain_classic/memory/buffer.py
Normal file
173
venv/Lib/site-packages/langchain_classic/memory/buffer.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.utils import pre_init
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationBufferMemory(BaseChatMemory):
|
||||
"""A basic memory implementation that simply stores the conversation history.
|
||||
|
||||
This stores the entire conversation history in memory without any
|
||||
additional processing.
|
||||
|
||||
Note that additional processing may be required in some situations when the
|
||||
conversation history is too large to fit in the context window of the model.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
async def abuffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return (
|
||||
await self.abuffer_as_messages()
|
||||
if self.return_messages
|
||||
else await self.abuffer_as_str()
|
||||
)
|
||||
|
||||
def _buffer_as_str(self, messages: list[BaseMessage]) -> str:
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
return self._buffer_as_str(self.chat_memory.messages)
|
||||
|
||||
async def abuffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is True."""
|
||||
messages = await self.chat_memory.aget_messages()
|
||||
return self._buffer_as_str(messages)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> list[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
async def abuffer_as_messages(self) -> list[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
||||
return await self.chat_memory.aget_messages()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
@override
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.abuffer()
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationStringBufferMemory(BaseMemory):
|
||||
"""A basic memory implementation that simply stores the conversation history.
|
||||
|
||||
This stores the entire conversation history in memory without any
|
||||
additional processing.
|
||||
|
||||
Equivalent to ConversationBufferMemory but tailored more specifically
|
||||
for string-based conversations rather than chat models.
|
||||
|
||||
Note that additional processing may be required in some situations when the
|
||||
conversation history is too large to fit in the context window of the model.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
"""Prefix to use for AI generated responses."""
|
||||
buffer: str = ""
|
||||
output_key: str | None = None
|
||||
input_key: str | None = None
|
||||
memory_key: str = "history"
|
||||
|
||||
@pre_init
|
||||
def validate_chains(cls, values: dict) -> dict:
|
||||
"""Validate that return messages is not True."""
|
||||
if values.get("return_messages", False):
|
||||
msg = "return_messages must be False for ConversationStringBufferMemory"
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return self.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
msg = f"One output key expected, got {outputs.keys()}"
|
||||
raise ValueError(msg)
|
||||
output_key = next(iter(outputs.keys()))
|
||||
else:
|
||||
output_key = self.output_key
|
||||
human = f"{self.human_prefix}: " + inputs[prompt_input_key]
|
||||
ai = f"{self.ai_prefix}: " + outputs[output_key]
|
||||
self.buffer += f"\n{human}\n{ai}"
|
||||
|
||||
async def asave_context(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
return self.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.buffer = ""
|
||||
|
||||
@override
|
||||
async def aclear(self) -> None:
|
||||
self.clear()
|
||||
@@ -0,0 +1,59 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Use to keep track of the last k turns of a conversation.
|
||||
|
||||
If the number of messages in the conversation is more than the maximum number
|
||||
of messages to keep, the oldest messages are dropped.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history"
|
||||
k: int = 5
|
||||
"""Number of messages to store in buffer."""
|
||||
|
||||
@property
|
||||
def buffer(self) -> str | list[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
messages = self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
return get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> list[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages[-self.k * 2 :] if self.k > 0 else []
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
104
venv/Lib/site-packages/langchain_classic/memory/chat_memory.py
Normal file
104
venv/Lib/site-packages/langchain_classic/memory/chat_memory.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.chat_history import (
|
||||
BaseChatMessageHistory,
|
||||
InMemoryChatMessageHistory,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class BaseChatMemory(BaseMemory, ABC):
|
||||
"""Abstract base class for chat memory.
|
||||
|
||||
**ATTENTION** This abstraction was created prior to when chat models had
|
||||
native tool calling capabilities.
|
||||
It does **NOT** support native tool calling capabilities for chat models and
|
||||
will fail SILENTLY if used with a chat model that has native tool calling.
|
||||
|
||||
DO NOT USE THIS ABSTRACTION FOR NEW CODE.
|
||||
"""
|
||||
|
||||
chat_memory: BaseChatMessageHistory = Field(
|
||||
default_factory=InMemoryChatMessageHistory,
|
||||
)
|
||||
output_key: str | None = None
|
||||
input_key: str | None = None
|
||||
return_messages: bool = False
|
||||
|
||||
def _get_input_output(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> tuple[str, str]:
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) == 1:
|
||||
output_key = next(iter(outputs.keys()))
|
||||
elif "output" in outputs:
|
||||
output_key = "output"
|
||||
warnings.warn(
|
||||
f"'{self.__class__.__name__}' got multiple output keys:"
|
||||
f" {outputs.keys()}. The default 'output' key is being used."
|
||||
f" If this is not desired, please manually set 'output_key'.",
|
||||
stacklevel=3,
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"Got multiple output keys: {outputs.keys()}, cannot "
|
||||
f"determine which to store in memory. Please set the "
|
||||
f"'output_key' explicitly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
output_key = self.output_key
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_messages(
|
||||
[
|
||||
HumanMessage(content=input_str),
|
||||
AIMessage(content=output_str),
|
||||
],
|
||||
)
|
||||
|
||||
async def asave_context(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
await self.chat_memory.aadd_messages(
|
||||
[
|
||||
HumanMessage(content=input_str),
|
||||
AIMessage(content=output_str),
|
||||
],
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
await self.chat_memory.aclear()
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
AstraDBChatMessageHistory,
|
||||
CassandraChatMessageHistory,
|
||||
ChatMessageHistory,
|
||||
CosmosDBChatMessageHistory,
|
||||
DynamoDBChatMessageHistory,
|
||||
ElasticsearchChatMessageHistory,
|
||||
FileChatMessageHistory,
|
||||
FirestoreChatMessageHistory,
|
||||
MomentoChatMessageHistory,
|
||||
MongoDBChatMessageHistory,
|
||||
Neo4jChatMessageHistory,
|
||||
PostgresChatMessageHistory,
|
||||
RedisChatMessageHistory,
|
||||
RocksetChatMessageHistory,
|
||||
SingleStoreDBChatMessageHistory,
|
||||
SQLChatMessageHistory,
|
||||
StreamlitChatMessageHistory,
|
||||
UpstashRedisChatMessageHistory,
|
||||
XataChatMessageHistory,
|
||||
ZepChatMessageHistory,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"FileChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"FirestoreChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"Neo4jChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"RocksetChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"XataChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
"CassandraChatMessageHistory",
|
||||
"ChatMessageHistory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
"FileChatMessageHistory",
|
||||
"FirestoreChatMessageHistory",
|
||||
"MomentoChatMessageHistory",
|
||||
"MongoDBChatMessageHistory",
|
||||
"Neo4jChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"RedisChatMessageHistory",
|
||||
"RocksetChatMessageHistory",
|
||||
"SQLChatMessageHistory",
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
"StreamlitChatMessageHistory",
|
||||
"UpstashRedisChatMessageHistory",
|
||||
"XataChatMessageHistory",
|
||||
"ZepChatMessageHistory",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import AstraDBChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"AstraDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AstraDBChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import CassandraChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"CassandraChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CassandraChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import CosmosDBChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"CosmosDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CosmosDBChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import DynamoDBChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"DynamoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DynamoDBChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
ElasticsearchChatMessageHistory,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"ElasticsearchChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ElasticsearchChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import FileChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"FileChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import FirestoreChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"FirestoreChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FirestoreChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,5 @@
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory as ChatMessageHistory
|
||||
|
||||
__all__ = [
|
||||
"ChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import MomentoChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"MomentoChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MomentoChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"MongoDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MongoDBChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"Neo4jChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Neo4jChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import PostgresChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"PostgresChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PostgresChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import RedisChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"RedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RedisChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import RocksetChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"RocksetChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RocksetChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
SingleStoreDBChatMessageHistory,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"SingleStoreDBChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SingleStoreDBChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import SQLChatMessageHistory
|
||||
from langchain_community.chat_message_histories.sql import (
|
||||
BaseMessageConverter,
|
||||
DefaultMessageConverter,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"BaseMessageConverter": "langchain_community.chat_message_histories.sql",
|
||||
"DefaultMessageConverter": "langchain_community.chat_message_histories.sql",
|
||||
"SQLChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseMessageConverter",
|
||||
"DefaultMessageConverter",
|
||||
"SQLChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"StreamlitChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"StreamlitChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
UpstashRedisChatMessageHistory,
|
||||
)
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"UpstashRedisChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"UpstashRedisChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import XataChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"XataChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"XataChatMessageHistory",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import ZepChatMessageHistory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {
|
||||
"ZepChatMessageHistory": "langchain_community.chat_message_histories",
|
||||
}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ZepChatMessageHistory",
|
||||
]
|
||||
85
venv/Lib/site-packages/langchain_classic/memory/combined.py
Normal file
85
venv/Lib/site-packages/langchain_classic/memory/combined.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
class CombinedMemory(BaseMemory):
|
||||
"""Combining multiple memories' data together."""
|
||||
|
||||
memories: list[BaseMemory]
|
||||
"""For tracking all the memories that should be accessed."""
|
||||
|
||||
@field_validator("memories")
|
||||
@classmethod
|
||||
def _check_repeated_memory_variable(
|
||||
cls,
|
||||
value: list[BaseMemory],
|
||||
) -> list[BaseMemory]:
|
||||
all_variables: set[str] = set()
|
||||
for val in value:
|
||||
overlap = all_variables.intersection(val.memory_variables)
|
||||
if overlap:
|
||||
msg = (
|
||||
f"The same variables {overlap} are found in multiple"
|
||||
"memory object, which is not allowed by CombinedMemory."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
all_variables |= set(val.memory_variables)
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("memories")
|
||||
@classmethod
|
||||
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
|
||||
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
||||
for val in value:
|
||||
if isinstance(val, BaseChatMemory) and val.input_key is None:
|
||||
warnings.warn(
|
||||
"When using CombinedMemory, "
|
||||
"input keys should be so the input is known. "
|
||||
f" Was not set on {val}",
|
||||
stacklevel=5,
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""All the memory variables that this instance provides."""
|
||||
"""Collected from the all the linked memories."""
|
||||
|
||||
memory_variables = []
|
||||
|
||||
for memory in self.memories:
|
||||
memory_variables.extend(memory.memory_variables)
|
||||
|
||||
return memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
"""Load all vars from sub-memories."""
|
||||
memory_data: dict[str, Any] = {}
|
||||
|
||||
# Collect vars from all sub-memories
|
||||
for memory in self.memories:
|
||||
data = memory.load_memory_variables(inputs)
|
||||
for key, value in data.items():
|
||||
if key in memory_data:
|
||||
msg = f"The variable {key} is repeated in the CombinedMemory."
|
||||
raise ValueError(msg)
|
||||
memory_data[key] = value
|
||||
|
||||
return memory_data
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this session for every memory."""
|
||||
# Save context for all sub-memories
|
||||
for memory in self.memories:
|
||||
memory.save_context(inputs, outputs)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear context from this session for every memory."""
|
||||
for memory in self.memories:
|
||||
memory.clear()
|
||||
611
venv/Lib/site-packages/langchain_classic/memory/entity.py
Normal file
611
venv/Lib/site-packages/langchain_classic/memory/entity.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""Deprecated as of LangChain v0.3.4 and will be removed in LangChain v1.0.0."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.prompt import (
|
||||
ENTITY_EXTRACTION_PROMPT,
|
||||
ENTITY_SUMMARIZATION_PROMPT,
|
||||
)
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlite3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class BaseEntityStore(BaseModel, ABC):
|
||||
"""Abstract base class for Entity store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
"""Get entity value from store."""
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: str | None) -> None:
|
||||
"""Set entity value in store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete entity value from store."""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Check if entity exists in store."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Delete all entities from store."""
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class InMemoryEntityStore(BaseEntityStore):
|
||||
"""In-memory Entity store."""
|
||||
|
||||
store: dict[str, str | None] = {}
|
||||
|
||||
@override
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
return self.store.get(key, default)
|
||||
|
||||
@override
|
||||
def set(self, key: str, value: str | None) -> None:
|
||||
self.store[key] = value
|
||||
|
||||
@override
|
||||
def delete(self, key: str) -> None:
|
||||
del self.store[key]
|
||||
|
||||
@override
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.store
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
return self.store.clear()
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class UpstashRedisEntityStore(BaseEntityStore):
|
||||
"""Upstash Redis backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "",
|
||||
token: str = "",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: int | None = 60 * 60 * 24,
|
||||
recall_ttl: int | None = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initializes the RedisEntityStore.
|
||||
|
||||
Args:
|
||||
session_id: Unique identifier for the session.
|
||||
url: URL of the Redis server.
|
||||
token: Authentication token for the Redis server.
|
||||
key_prefix: Prefix for keys in the Redis store.
|
||||
ttl: Time-to-live for keys in seconds (default 1 day).
|
||||
recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
try:
|
||||
from upstash_redis import Redis
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import upstash_redis python package. "
|
||||
"Please install it with `pip install upstash_redis`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
self.redis_client = Redis(url=url, token=token)
|
||||
except Exception as exc:
|
||||
error_msg = "Upstash Redis instance could not be initiated"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg) from exc
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
"""Returns the full key prefix with session ID."""
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
@override
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug(
|
||||
"Upstash Redis MEM get '%s:%s': '%s'", self.full_key_prefix, key, res
|
||||
)
|
||||
return res
|
||||
|
||||
@override
|
||||
def set(self, key: str, value: str | None) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
"Redis MEM set '%s:%s': '%s' EX %s",
|
||||
self.full_key_prefix,
|
||||
key,
|
||||
value,
|
||||
self.ttl,
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
@override
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
def scan_and_delete(cursor: int) -> int:
|
||||
cursor, keys_to_delete = self.redis_client.scan(
|
||||
cursor,
|
||||
f"{self.full_key_prefix}:*",
|
||||
)
|
||||
self.redis_client.delete(*keys_to_delete)
|
||||
return cursor
|
||||
|
||||
cursor = scan_and_delete(0)
|
||||
while cursor != 0:
|
||||
scan_and_delete(cursor)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class RedisEntityStore(BaseEntityStore):
|
||||
"""Redis-backed Entity store.
|
||||
|
||||
Entities get a TTL of 1 day by default, and
|
||||
that TTL is extended by 3 days every time the entity is read back.
|
||||
"""
|
||||
|
||||
redis_client: Any
|
||||
session_id: str = "default"
|
||||
key_prefix: str = "memory_store"
|
||||
ttl: int | None = 60 * 60 * 24
|
||||
recall_ttl: int | None = 60 * 60 * 24 * 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
url: str = "redis://localhost:6379/0",
|
||||
key_prefix: str = "memory_store",
|
||||
ttl: int | None = 60 * 60 * 24,
|
||||
recall_ttl: int | None = 60 * 60 * 24 * 3,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initializes the RedisEntityStore.
|
||||
|
||||
Args:
|
||||
session_id: Unique identifier for the session.
|
||||
url: URL of the Redis server.
|
||||
key_prefix: Prefix for keys in the Redis store.
|
||||
ttl: Time-to-live for keys in seconds (default 1 day).
|
||||
recall_ttl: Time-to-live extension for keys when recalled (default 3 days).
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
try:
|
||||
import redis
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
try:
|
||||
from langchain_community.utilities.redis import get_client
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import langchain_community.utilities.redis.get_client. "
|
||||
"Please install it with `pip install langchain-community`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
try:
|
||||
self.redis_client = get_client(redis_url=url, decode_responses=True)
|
||||
except redis.exceptions.ConnectionError:
|
||||
logger.exception("Redis client could not connect")
|
||||
|
||||
self.session_id = session_id
|
||||
self.key_prefix = key_prefix
|
||||
self.ttl = ttl
|
||||
self.recall_ttl = recall_ttl or ttl
|
||||
|
||||
@property
|
||||
def full_key_prefix(self) -> str:
|
||||
"""Returns the full key prefix with session ID."""
|
||||
return f"{self.key_prefix}:{self.session_id}"
|
||||
|
||||
@override
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
res = (
|
||||
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
||||
or default
|
||||
or ""
|
||||
)
|
||||
logger.debug("REDIS MEM get '%s:%s': '%s'", self.full_key_prefix, key, res)
|
||||
return res
|
||||
|
||||
@override
|
||||
def set(self, key: str, value: str | None) -> None:
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
||||
logger.debug(
|
||||
"REDIS MEM set '%s:%s': '%s' EX %s",
|
||||
self.full_key_prefix,
|
||||
key,
|
||||
value,
|
||||
self.ttl,
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def delete(self, key: str) -> None:
|
||||
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
||||
|
||||
@override
|
||||
def exists(self, key: str) -> bool:
|
||||
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
# iterate a list in batches of size batch_size
|
||||
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
||||
iterator = iter(iterable)
|
||||
while batch := list(islice(iterator, batch_size)):
|
||||
yield batch
|
||||
|
||||
for keybatch in batched(
|
||||
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"),
|
||||
500,
|
||||
):
|
||||
self.redis_client.delete(*keybatch)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class SQLiteEntityStore(BaseEntityStore):
|
||||
"""SQLite-backed Entity store with safe query construction."""
|
||||
|
||||
session_id: str = "default"
|
||||
table_name: str = "memory_store"
|
||||
conn: Any = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default",
|
||||
db_file: str = "entities.db",
|
||||
table_name: str = "memory_store",
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initializes the SQLiteEntityStore.
|
||||
|
||||
Args:
|
||||
session_id: Unique identifier for the session.
|
||||
db_file: Path to the SQLite database file.
|
||||
table_name: Name of the table to store entities.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError as e:
|
||||
msg = (
|
||||
"Could not import sqlite3 python package. "
|
||||
"Please install it with `pip install sqlite3`."
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
# Basic validation to prevent obviously malicious table/session names
|
||||
if not table_name.isidentifier() or not session_id.isidentifier():
|
||||
# Since we validate here, we can safely suppress the S608 bandit warning
|
||||
msg = "Table name and session ID must be valid Python identifiers."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.conn = sqlite3.connect(db_file)
|
||||
self.session_id = session_id
|
||||
self.table_name = table_name
|
||||
self._create_table_if_not_exists()
|
||||
|
||||
@property
|
||||
def full_table_name(self) -> str:
|
||||
"""Returns the full table name with session ID."""
|
||||
return f"{self.table_name}_{self.session_id}"
|
||||
|
||||
def _execute_query(self, query: str, params: tuple = ()) -> "sqlite3.Cursor":
|
||||
"""Executes a query with proper connection handling."""
|
||||
with self.conn:
|
||||
return self.conn.execute(query, params)
|
||||
|
||||
def _create_table_if_not_exists(self) -> None:
|
||||
"""Creates the entity table if it doesn't exist, using safe quoting."""
|
||||
# Use standard SQL double quotes for the table name identifier
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS "{self.full_table_name}" (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
"""
|
||||
self._execute_query(create_table_query)
|
||||
|
||||
def get(self, key: str, default: str | None = None) -> str | None:
|
||||
"""Retrieves a value, safely quoting the table name."""
|
||||
# `?` placeholder is used for the value to prevent SQL injection
|
||||
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
||||
query = f'SELECT value FROM "{self.full_table_name}" WHERE key = ?' # noqa: S608
|
||||
cursor = self._execute_query(query, (key,))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result is not None else default
|
||||
|
||||
def set(self, key: str, value: str | None) -> None:
|
||||
"""Inserts or replaces a value, safely quoting the table name."""
|
||||
if not value:
|
||||
return self.delete(key)
|
||||
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
||||
query = (
|
||||
"INSERT OR REPLACE INTO " # noqa: S608
|
||||
f'"{self.full_table_name}" (key, value) VALUES (?, ?)'
|
||||
)
|
||||
self._execute_query(query, (key, value))
|
||||
return None
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Deletes a key-value pair, safely quoting the table name."""
|
||||
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
||||
query = f'DELETE FROM "{self.full_table_name}" WHERE key = ?' # noqa: S608
|
||||
self._execute_query(query, (key,))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Checks for the existence of a key, safely quoting the table name."""
|
||||
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
||||
query = f'SELECT 1 FROM "{self.full_table_name}" WHERE key = ? LIMIT 1' # noqa: S608
|
||||
cursor = self._execute_query(query, (key,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
# Ignore S608 since we validate for malicious table/session names in `__init__`
|
||||
query = f"""
|
||||
DELETE FROM {self.full_table_name}
|
||||
""" # noqa: S608
|
||||
with self.conn:
|
||||
self.conn.execute(query)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationEntityMemory(BaseChatMemory):
|
||||
"""Entity extractor & summarizer memory.
|
||||
|
||||
Extracts named entities from the recent chat history and generates summaries.
|
||||
With a swappable entity store, persisting entities across conversations.
|
||||
Defaults to an in-memory entity store, and can be swapped out for a Redis,
|
||||
SQLite, or other entity store.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
||||
|
||||
# Cache of recently detected entity names, if any
|
||||
# It is updated when load_memory_variables is called:
|
||||
entity_cache: list[str] = []
|
||||
|
||||
# Number of recent message pairs to consider when updating entities:
|
||||
k: int = 3
|
||||
|
||||
chat_history_key: str = "history"
|
||||
|
||||
# Store to manage entity-related data:
|
||||
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
|
||||
|
||||
@property
|
||||
def buffer(self) -> list[BaseMessage]:
|
||||
"""Access chat memory messages."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return ["entities", self.chat_history_key]
|
||||
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Load memory variables.
|
||||
|
||||
Returns chat history and all generated entities with summaries if available,
|
||||
and updates or clears the recent entity cache.
|
||||
|
||||
New entity name can be found when calling this method, before the entity
|
||||
summaries are generated, so the entity cache values may be empty if no entity
|
||||
descriptions are generated yet.
|
||||
"""
|
||||
# Create an LLMChain for predicting entity names from the recent chat history:
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
# Generates a comma-separated list of named entities,
|
||||
# e.g. "Jane, White House, UFO"
|
||||
# or "NONE" if no named entities are extracted:
|
||||
output = chain.predict(
|
||||
history=buffer_string,
|
||||
input=inputs[prompt_input_key],
|
||||
)
|
||||
|
||||
# If no named entities are extracted, assigns an empty list.
|
||||
if output.strip() == "NONE":
|
||||
entities = []
|
||||
else:
|
||||
# Make a list of the extracted entities:
|
||||
entities = [w.strip() for w in output.split(",")]
|
||||
|
||||
# Make a dictionary of entities with summary if exists:
|
||||
entity_summaries = {}
|
||||
|
||||
for entity in entities:
|
||||
entity_summaries[entity] = self.entity_store.get(entity, "")
|
||||
|
||||
# Replaces the entity name cache with the most recently discussed entities,
|
||||
# or if no entities were extracted, clears the cache:
|
||||
self.entity_cache = entities
|
||||
|
||||
# Should we return as message objects or as a string?
|
||||
if self.return_messages:
|
||||
# Get last `k` pair of chat messages:
|
||||
buffer: Any = self.buffer[-self.k * 2 :]
|
||||
else:
|
||||
# Reuse the string we made earlier:
|
||||
buffer = buffer_string
|
||||
|
||||
return {
|
||||
self.chat_history_key: buffer,
|
||||
"entities": entity_summaries,
|
||||
}
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation history to the entity store.
|
||||
|
||||
Generates a summary for each entity in the entity cache by prompting
|
||||
the model, and saves these summaries to the entity store.
|
||||
"""
|
||||
super().save_context(inputs, outputs)
|
||||
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
|
||||
# Extract an arbitrary window of the last message pairs from
|
||||
# the chat history, where the hyperparameter k is the
|
||||
# number of message pairs:
|
||||
buffer_string = get_buffer_string(
|
||||
self.buffer[-self.k * 2 :],
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
input_data = inputs[prompt_input_key]
|
||||
|
||||
# Create an LLMChain for predicting entity summarization from the context
|
||||
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
|
||||
|
||||
# Generate new summaries for entities and save them in the entity store
|
||||
for entity in self.entity_cache:
|
||||
# Get existing summary if it exists
|
||||
existing_summary = self.entity_store.get(entity, "")
|
||||
output = chain.predict(
|
||||
summary=existing_summary,
|
||||
entity=entity,
|
||||
history=buffer_string,
|
||||
input=input_data,
|
||||
)
|
||||
# Save the updated summary to the entity store
|
||||
self.entity_store.set(entity, output.strip())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
self.chat_memory.clear()
|
||||
self.entity_cache.clear()
|
||||
self.entity_store.clear()
|
||||
23
venv/Lib/site-packages/langchain_classic/memory/kg.py
Normal file
23
venv/Lib/site-packages/langchain_classic/memory/kg.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.kg import ConversationKGMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ConversationKGMemory": "langchain_community.memory.kg"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConversationKGMemory",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.motorhead_memory import MotorheadMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"MotorheadMemory": "langchain_community.memory.motorhead_memory"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MotorheadMemory",
|
||||
]
|
||||
164
venv/Lib/site-packages/langchain_classic/memory/prompt.py
Normal file
164
venv/Lib/site-packages/langchain_classic/memory/prompt.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE = """You are an assistant to a human, powered by a large language model trained by OpenAI.
|
||||
|
||||
You are designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, you are able to generate human-like text based on the input you receive, allowing you to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
||||
|
||||
You are constantly learning and improving, and your capabilities are constantly evolving. You are able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. You have access to some personalized information provided by the human in the Context section below. Additionally, you are able to generate your own text based on the input you receive, allowing you to engage in discussions and provide explanations and descriptions on a wide range of topics.
|
||||
|
||||
Overall, you are a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether the human needs help with a specific question or just wants to have a conversation about a particular topic, you are here to assist.
|
||||
|
||||
Context:
|
||||
{entities}
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Last line:
|
||||
Human: {input}
|
||||
You:""" # noqa: E501
|
||||
|
||||
ENTITY_MEMORY_CONVERSATION_TEMPLATE = PromptTemplate(
|
||||
input_variables=["entities", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
||||
)
|
||||
|
||||
_DEFAULT_SUMMARIZER_TEMPLATE = """Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary.
|
||||
|
||||
EXAMPLE
|
||||
Current summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good.
|
||||
|
||||
New lines of conversation:
|
||||
Human: Why do you think artificial intelligence is a force for good?
|
||||
AI: Because artificial intelligence will help humans reach their full potential.
|
||||
|
||||
New summary:
|
||||
The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential.
|
||||
END OF EXAMPLE
|
||||
|
||||
Current summary:
|
||||
{summary}
|
||||
|
||||
New lines of conversation:
|
||||
{new_lines}
|
||||
|
||||
New summary:""" # noqa: E501
|
||||
SUMMARY_PROMPT = PromptTemplate(
|
||||
input_variables=["summary", "new_lines"], template=_DEFAULT_SUMMARIZER_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """You are an AI assistant reading the transcript of a conversation between an AI and a human. Extract all of the proper nouns from the last line of conversation. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
|
||||
|
||||
The conversation history is provided just in case of a coreference (e.g. "What do you know about him" where "him" is defined in a previous line) -- ignore items mentioned there that are not in the last line.
|
||||
|
||||
Return the output as a single comma-separated list, or NONE if there is nothing of note to return (e.g. the user is just issuing a greeting or having a simple conversation).
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
|
||||
Output: Langchain
|
||||
END OF EXAMPLE
|
||||
|
||||
EXAMPLE
|
||||
Conversation history:
|
||||
Person #1: how's it going today?
|
||||
AI: "It's going great! How about you?"
|
||||
Person #1: good! busy working on Langchain. lots to do.
|
||||
AI: "That sounds like a lot of work! What kind of things are you doing to make Langchain better?"
|
||||
Last line:
|
||||
Person #1: i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Person #2.
|
||||
Output: Langchain, Person #2
|
||||
END OF EXAMPLE
|
||||
|
||||
Conversation history (for reference only):
|
||||
{history}
|
||||
Last line of conversation (for extraction):
|
||||
Human: {input}
|
||||
|
||||
Output:""" # noqa: E501
|
||||
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
|
||||
)
|
||||
|
||||
_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE = """You are an AI assistant helping a human keep track of facts about relevant people, places, and concepts in their life. Update the summary of the provided entity in the "Entity" section based on the last line of your conversation with the human. If you are writing the summary for the first time, return a single sentence.
|
||||
The update should only include facts that are relayed in the last line of conversation about the provided entity, and should only contain facts about the provided entity.
|
||||
|
||||
If there is no new information about the provided entity or the information is not worth noting (not an important or relevant fact to remember long-term), return the existing summary unchanged.
|
||||
|
||||
Full conversation history (for context):
|
||||
{history}
|
||||
|
||||
Entity to summarize:
|
||||
{entity}
|
||||
|
||||
Existing summary of {entity}:
|
||||
{summary}
|
||||
|
||||
Last line of conversation:
|
||||
Human: {input}
|
||||
Updated summary:""" # noqa: E501
|
||||
|
||||
ENTITY_SUMMARIZATION_PROMPT = PromptTemplate(
|
||||
input_variables=["entity", "summary", "history", "input"],
|
||||
template=_DEFAULT_ENTITY_SUMMARIZATION_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
KG_TRIPLE_DELIMITER = "<|>"
|
||||
_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE = (
|
||||
"You are a networked intelligence helping a human track knowledge triples"
|
||||
" about all relevant people, things, concepts, etc. and integrating"
|
||||
" them with your knowledge stored within your weights"
|
||||
" as well as that stored in a knowledge graph."
|
||||
" Extract all of the knowledge triples from the last line of conversation."
|
||||
" A knowledge triple is a clause that contains a subject, a predicate,"
|
||||
" and an object. The subject is the entity being described,"
|
||||
" the predicate is the property of the subject that is being"
|
||||
" described, and the object is the value of the property.\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Did you hear aliens landed in Area 51?\n"
|
||||
"AI: No, I didn't hear that. What do you know about Area 51?\n"
|
||||
"Person #1: It's a secret military base in Nevada.\n"
|
||||
"AI: What do you know about Nevada?\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: It's a state in the US. It's also the number 1 producer of gold in the US.\n\n" # noqa: E501
|
||||
f"Output: (Nevada, is a, state){KG_TRIPLE_DELIMITER}(Nevada, is in, US)"
|
||||
f"{KG_TRIPLE_DELIMITER}(Nevada, is the number 1 producer of, gold)\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: Hello.\n"
|
||||
"AI: Hi! How are you?\n"
|
||||
"Person #1: I'm good. How are you?\n"
|
||||
"AI: I'm good too.\n"
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: I'm going to the store.\n\n"
|
||||
"Output: NONE\n"
|
||||
"END OF EXAMPLE\n\n"
|
||||
"EXAMPLE\n"
|
||||
"Conversation history:\n"
|
||||
"Person #1: What do you know about Descartes?\n"
|
||||
"AI: Descartes was a French philosopher, mathematician, and scientist who lived in the 17th century.\n" # noqa: E501
|
||||
"Person #1: The Descartes I'm referring to is a standup comedian and interior designer from Montreal.\n" # noqa: E501
|
||||
"AI: Oh yes, He is a comedian and an interior designer. He has been in the industry for 30 years. His favorite food is baked bean pie.\n" # noqa: E501
|
||||
"Last line of conversation:\n"
|
||||
"Person #1: Oh huh. I know Descartes likes to drive antique scooters and play the mandolin.\n" # noqa: E501
|
||||
f"Output: (Descartes, likes to drive, antique scooters){KG_TRIPLE_DELIMITER}(Descartes, plays, mandolin)\n" # noqa: E501
|
||||
"END OF EXAMPLE\n\n"
|
||||
"Conversation history (for reference only):\n"
|
||||
"{history}"
|
||||
"\nLast line of conversation (for extraction):\n"
|
||||
"Human: {input}\n\n"
|
||||
"Output:"
|
||||
)
|
||||
|
||||
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=_DEFAULT_KNOWLEDGE_TRIPLE_EXTRACTION_TEMPLATE,
|
||||
)
|
||||
24
venv/Lib/site-packages/langchain_classic/memory/readonly.py
Normal file
24
venv/Lib/site-packages/langchain_classic/memory/readonly.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
|
||||
|
||||
class ReadOnlySharedMemory(BaseMemory):
|
||||
"""Memory wrapper that is read-only and cannot be changed."""
|
||||
|
||||
memory: BaseMemory
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Return memory variables."""
|
||||
return self.memory.memory_variables
|
||||
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
"""Load memory variables from memory."""
|
||||
return self.memory.load_memory_variables(inputs)
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed."""
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
30
venv/Lib/site-packages/langchain_classic/memory/simple.py
Normal file
30
venv/Lib/site-packages/langchain_classic/memory/simple.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
"""Simple Memory.
|
||||
|
||||
Simple memory for storing context or other information that shouldn't
|
||||
ever change between prompts.
|
||||
"""
|
||||
|
||||
memories: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
@override
|
||||
def memory_variables(self) -> list[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Nothing should be saved or changed, my memory is set in stone."""
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear, got a memory like a vault."""
|
||||
168
venv/Lib/site-packages/langchain_classic/memory/summary.py
Normal file
168
venv/Lib/site-packages/langchain_classic/memory/summary.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.chains.llm import LLMChain
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.prompt import SUMMARY_PROMPT
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.2.12",
|
||||
removal="1.0",
|
||||
message=(
|
||||
"Refer here for how to incorporate summaries of conversation history: "
|
||||
"https://docs.langchain.com/oss/python/langgraph/add-memory#summarize-messages"
|
||||
),
|
||||
)
|
||||
class SummarizerMixin(BaseModel):
|
||||
"""Mixin for summarizer."""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
summary_message_cls: type[BaseMessage] = SystemMessage
|
||||
|
||||
def predict_new_summary(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
existing_summary: str,
|
||||
) -> str:
|
||||
"""Predict a new summary based on the messages and existing summary.
|
||||
|
||||
Args:
|
||||
messages: List of messages to summarize.
|
||||
existing_summary: Existing summary to build upon.
|
||||
|
||||
Returns:
|
||||
A new summary string.
|
||||
"""
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
async def apredict_new_summary(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
existing_summary: str,
|
||||
) -> str:
|
||||
"""Predict a new summary based on the messages and existing summary.
|
||||
|
||||
Args:
|
||||
messages: List of messages to summarize.
|
||||
existing_summary: Existing summary to build upon.
|
||||
|
||||
Returns:
|
||||
A new summary string.
|
||||
"""
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
return await chain.apredict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Continually summarizes the conversation history.
|
||||
|
||||
The summary is updated after each conversation turn.
|
||||
The implementations returns a summary of the conversation history which
|
||||
can be used to provide context to the model.
|
||||
"""
|
||||
|
||||
buffer: str = ""
|
||||
memory_key: str = "history"
|
||||
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
chat_memory: BaseChatMessageHistory,
|
||||
*,
|
||||
summarize_step: int = 2,
|
||||
**kwargs: Any,
|
||||
) -> ConversationSummaryMemory:
|
||||
"""Create a ConversationSummaryMemory from a list of messages.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for summarization.
|
||||
chat_memory: The chat history to summarize.
|
||||
summarize_step: Number of messages to summarize at a time.
|
||||
**kwargs: Additional keyword arguments to pass to the class.
|
||||
|
||||
Returns:
|
||||
An instance of ConversationSummaryMemory with the summarized history.
|
||||
"""
|
||||
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
||||
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
||||
obj.buffer = obj.predict_new_summary(
|
||||
obj.chat_memory.messages[i : i + summarize_step],
|
||||
obj.buffer,
|
||||
)
|
||||
return obj
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
else:
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
@pre_init
|
||||
def validate_prompt_input_variables(cls, values: dict) -> dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
msg = (
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.buffer = self.predict_new_summary(
|
||||
self.chat_memory.messages[-2:],
|
||||
self.buffer,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.buffer = ""
|
||||
@@ -0,0 +1,148 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.utils import pre_init
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.summary import SummarizerMixin
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""Buffer with summarizer for storing conversation memory.
|
||||
|
||||
Provides a running summary of the conversation together with the most recent
|
||||
messages in the conversation under the constraint that the total number of
|
||||
tokens in the conversation does not exceed a certain limit.
|
||||
"""
|
||||
|
||||
max_token_limit: int = 2000
|
||||
moving_summary_buffer: str = ""
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> str | list[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
return self.load_memory_variables({})[self.memory_key]
|
||||
|
||||
async def abuffer(self) -> str | list[BaseMessage]:
|
||||
"""Async memory buffer."""
|
||||
memory_variables = await self.aload_memory_variables({})
|
||||
return memory_variables[self.memory_key]
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer = self.chat_memory.messages
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: list[BaseMessage] = [
|
||||
self.summary_message_cls(content=self.moving_summary_buffer),
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@override
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Asynchronously return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.chat_memory.aget_messages()
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: list[BaseMessage] = [
|
||||
self.summary_message_cls(content=self.moving_summary_buffer),
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
final_buffer: Any = buffer
|
||||
else:
|
||||
final_buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@pre_init
|
||||
def validate_prompt_input_variables(cls, values: dict) -> dict:
|
||||
"""Validate that prompt input variables are consistent."""
|
||||
prompt_variables = values["prompt"].input_variables
|
||||
expected_keys = {"summary", "new_lines"}
|
||||
if expected_keys != set(prompt_variables):
|
||||
msg = (
|
||||
"Got unexpected prompt input variables. The prompt expects "
|
||||
f"{prompt_variables}, but it should have {expected_keys}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
super().save_context(inputs, outputs)
|
||||
self.prune()
|
||||
|
||||
async def asave_context(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> None:
|
||||
"""Asynchronously save context from this conversation to buffer."""
|
||||
await super().asave_context(inputs, outputs)
|
||||
await self.aprune()
|
||||
|
||||
def prune(self) -> None:
|
||||
"""Prune buffer if it exceeds max token limit."""
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
pruned_memory,
|
||||
self.moving_summary_buffer,
|
||||
)
|
||||
|
||||
async def aprune(self) -> None:
|
||||
"""Asynchronously prune buffer if it exceeds max token limit."""
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
self.moving_summary_buffer = await self.apredict_new_summary(
|
||||
pruned_memory,
|
||||
self.moving_summary_buffer,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
super().clear()
|
||||
self.moving_summary_buffer = ""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Asynchronously clear memory contents."""
|
||||
await super().aclear()
|
||||
self.moving_summary_buffer = ""
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
"""Conversation chat memory with token limit.
|
||||
|
||||
Keeps only the most recent messages in the conversation under the constraint
|
||||
that the total number of tokens in the conversation does not exceed a certain limit.
|
||||
"""
|
||||
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
memory_key: str = "history"
|
||||
max_token_limit: int = 2000
|
||||
|
||||
@property
|
||||
def buffer(self) -> Any:
|
||||
"""String buffer of memory."""
|
||||
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
||||
|
||||
@property
|
||||
def buffer_as_str(self) -> str:
|
||||
"""Exposes the buffer as a string in case return_messages is False."""
|
||||
return get_buffer_string(
|
||||
self.chat_memory.messages,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def buffer_as_messages(self) -> list[BaseMessage]:
|
||||
"""Exposes the buffer as a list of messages in case return_messages is True."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Will always return list of memory variables."""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
super().save_context(inputs, outputs)
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
pruned_memory.append(buffer.pop(0))
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
20
venv/Lib/site-packages/langchain_classic/memory/utils.py
Normal file
20
venv/Lib/site-packages/langchain_classic/memory/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: dict[str, Any], memory_variables: list[str]) -> str:
|
||||
"""Get the prompt input key.
|
||||
|
||||
Args:
|
||||
inputs: Dict[str, Any]
|
||||
memory_variables: List[str]
|
||||
|
||||
Returns:
|
||||
A prompt input key.
|
||||
"""
|
||||
# "stop" is a special key that can be passed as input but is not used to
|
||||
# format the prompt.
|
||||
prompt_input_keys = list(set(inputs).difference([*memory_variables, "stop"]))
|
||||
if len(prompt_input_keys) != 1:
|
||||
msg = f"One input key expected got {prompt_input_keys}"
|
||||
raise ValueError(msg)
|
||||
return prompt_input_keys[0]
|
||||
122
venv/Lib/site-packages/langchain_classic/memory/vectorstore.py
Normal file
122
venv/Lib/site-packages/langchain_classic/memory/vectorstore.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Class for a VectorStore-backed memory object."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_classic.base_memory import BaseMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.1",
|
||||
removal="1.0.0",
|
||||
message=(
|
||||
"Please see the migration guide at: "
|
||||
"https://python.langchain.com/docs/versions/migrating_memory/"
|
||||
),
|
||||
)
|
||||
class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""Vector Store Retriever Memory.
|
||||
|
||||
Store the conversation history in a vector store and retrieves the relevant
|
||||
parts of past conversation based on the input.
|
||||
"""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
memory_key: str = "history"
|
||||
"""Key name to locate the memories in the result of load_memory_variables."""
|
||||
|
||||
input_key: str | None = None
|
||||
"""Key name to index the inputs to load_memory_variables."""
|
||||
|
||||
return_docs: bool = False
|
||||
"""Whether or not to return the result of querying the database directly."""
|
||||
|
||||
exclude_input_keys: Sequence[str] = Field(default_factory=tuple)
|
||||
"""Input keys to exclude in addition to memory key when constructing the document"""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""The list of keys emitted from the load_memory_variables method."""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def _documents_to_memory_variables(
|
||||
self,
|
||||
docs: list[Document],
|
||||
) -> dict[str, list[Document] | str]:
|
||||
result: list[Document] | str
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
|
||||
def load_memory_variables(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
) -> dict[str, list[Document] | str]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.invoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
async def aload_memory_variables(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
) -> dict[str, list[Document] | str]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = await self.retriever.ainvoke(query)
|
||||
return self._documents_to_memory_variables(docs)
|
||||
|
||||
def _form_documents(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> list[Document]:
|
||||
"""Format context from this conversation to buffer."""
|
||||
# Each document should only include the current turn, not the chat history
|
||||
exclude = set(self.exclude_input_keys)
|
||||
exclude.add(self.memory_key)
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude}
|
||||
texts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in list(filtered_inputs.items()) + list(outputs.items())
|
||||
]
|
||||
page_content = "\n".join(texts)
|
||||
return [Document(page_content=page_content)]
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
async def asave_context(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
outputs: dict[str, str],
|
||||
) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
await self.retriever.aadd_documents(documents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
|
||||
async def aclear(self) -> None:
|
||||
"""Nothing to clear."""
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Class for a conversation memory buffer with older messages stored in a vectorstore .
|
||||
|
||||
This implements a conversation memory in which the messages are stored in a memory
|
||||
buffer up to a specified token limit. When the limit is exceeded, older messages are
|
||||
saved to a `VectorStore` backing database. The `VectorStore` can be made persistent
|
||||
across sessions.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from langchain_classic.memory import (
|
||||
ConversationTokenBufferMemory,
|
||||
VectorStoreRetrieverMemory,
|
||||
)
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
DEFAULT_HISTORY_TEMPLATE = """
|
||||
Current date and time: {current_time}.
|
||||
|
||||
Potentially relevant timestamped excerpts of previous conversations (you
|
||||
do not need to use these if irrelevant):
|
||||
{previous_history}
|
||||
|
||||
"""
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
|
||||
|
||||
|
||||
class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
|
||||
"""Conversation chat memory with token limit and vectordb backing.
|
||||
|
||||
load_memory_variables() will return a dict with the key "history".
|
||||
It contains background information retrieved from the vector store
|
||||
plus recent lines of the current conversation.
|
||||
|
||||
To help the LLM understand the part of the conversation stored in the
|
||||
vectorstore, each interaction is timestamped and the current date and
|
||||
time is also provided in the history. A side effect of this is that the
|
||||
LLM will have access to the current date and time.
|
||||
|
||||
Initialization arguments:
|
||||
|
||||
This class accepts all the initialization arguments of
|
||||
ConversationTokenBufferMemory, such as `llm`. In addition, it
|
||||
accepts the following additional arguments
|
||||
|
||||
retriever: (required) A VectorStoreRetriever object to use
|
||||
as the vector backing store
|
||||
|
||||
split_chunk_size: (optional, 1000) Token chunk split size
|
||||
for long messages generated by the AI
|
||||
|
||||
previous_history_template: (optional) Template used to format
|
||||
the contents of the prompt history
|
||||
|
||||
|
||||
Example using ChromaDB:
|
||||
|
||||
```python
|
||||
from langchain_classic.memory.token_buffer_vectorstore_memory import (
|
||||
ConversationVectorStoreTokenBufferMemory,
|
||||
)
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
embedder = HuggingFaceInstructEmbeddings(
|
||||
query_instruction="Represent the query for retrieval: "
|
||||
)
|
||||
chroma = Chroma(
|
||||
collection_name="demo",
|
||||
embedding_function=embedder,
|
||||
collection_metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
retriever = chroma.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"k": 5,
|
||||
"score_threshold": 0.75,
|
||||
},
|
||||
)
|
||||
|
||||
conversation_memory = ConversationVectorStoreTokenBufferMemory(
|
||||
return_messages=True,
|
||||
llm=OpenAI(),
|
||||
retriever=retriever,
|
||||
max_token_limit=1000,
|
||||
)
|
||||
|
||||
conversation_memory.save_context({"Human": "Hi there"}, {"AI": "Nice to meet you!"})
|
||||
conversation_memory.save_context(
|
||||
{"Human": "Nice day isn't it?"}, {"AI": "I love Wednesdays."}
|
||||
)
|
||||
conversation_memory.load_memory_variables({"input": "What time is it?"})
|
||||
```
|
||||
"""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
memory_key: str = "history"
|
||||
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
|
||||
split_chunk_size: int = 1000
|
||||
|
||||
_memory_retriever: VectorStoreRetrieverMemory | None = PrivateAttr(default=None)
|
||||
_timestamps: list[datetime] = PrivateAttr(default_factory=list)
|
||||
|
||||
@property
|
||||
def memory_retriever(self) -> VectorStoreRetrieverMemory:
|
||||
"""Return a memory retriever from the passed retriever object."""
|
||||
if self._memory_retriever is not None:
|
||||
return self._memory_retriever
|
||||
self._memory_retriever = VectorStoreRetrieverMemory(retriever=self.retriever)
|
||||
return self._memory_retriever
|
||||
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history and memory buffer."""
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
memory_variables = self.memory_retriever.load_memory_variables(inputs)
|
||||
previous_history = memory_variables[self.memory_retriever.memory_key]
|
||||
except AssertionError: # happens when db is empty
|
||||
previous_history = ""
|
||||
current_history = super().load_memory_variables(inputs)
|
||||
template = SystemMessagePromptTemplate.from_template(
|
||||
self.previous_history_template,
|
||||
)
|
||||
messages = [
|
||||
template.format(
|
||||
previous_history=previous_history,
|
||||
current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT),
|
||||
),
|
||||
]
|
||||
messages.extend(current_history[self.memory_key])
|
||||
return {self.memory_key: messages}
|
||||
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
BaseChatMemory.save_context(self, inputs, outputs)
|
||||
self._timestamps.append(datetime.now().astimezone())
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
self._pop_and_store_interaction(buffer)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
|
||||
def save_remainder(self) -> None:
|
||||
"""Save the remainder of the conversation buffer to the vector store.
|
||||
|
||||
Useful if you have made the VectorStore persistent, in which
|
||||
case this can be called before the end of the session to store the
|
||||
remainder of the conversation.
|
||||
"""
|
||||
buffer = self.chat_memory.messages
|
||||
while len(buffer) > 0:
|
||||
self._pop_and_store_interaction(buffer)
|
||||
|
||||
def _pop_and_store_interaction(self, buffer: list[BaseMessage]) -> None:
|
||||
input_ = buffer.pop(0)
|
||||
output = buffer.pop(0)
|
||||
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
|
||||
# Split AI output into smaller chunks to avoid creating documents
|
||||
# that will overflow the context window
|
||||
ai_chunks = self._split_long_ai_text(str(output.content))
|
||||
for index, chunk in enumerate(ai_chunks):
|
||||
self.memory_retriever.save_context(
|
||||
{"Human": f"<{timestamp}/00> {input_.content!s}"},
|
||||
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
|
||||
)
|
||||
|
||||
def _split_long_ai_text(self, text: str) -> list[str]:
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=self.split_chunk_size)
|
||||
return [chunk.page_content for chunk in splitter.create_documents([text])]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_classic._api import create_importer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.memory.zep_memory import ZepMemory
|
||||
|
||||
# Create a way to dynamically look up deprecated imports.
|
||||
# Used to consolidate logic for raising deprecation warnings and
|
||||
# handling optional imports.
|
||||
DEPRECATED_LOOKUP = {"ZepMemory": "langchain_community.memory.zep_memory"}
|
||||
|
||||
_import_attribute = create_importer(__package__, deprecated_lookups=DEPRECATED_LOOKUP)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Look up attributes dynamically."""
|
||||
return _import_attribute(name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ZepMemory",
|
||||
]
|
||||
Reference in New Issue
Block a user